Compare commits
10 Commits
96c91e386d
...
feature/de
| Author | SHA1 | Date | |
|---|---|---|---|
| 47bf1881e5 | |||
| 24a9c1b752 | |||
| 706bf88883 | |||
| 4ff0b27084 | |||
| 61d2a18234 | |||
| b3687719b6 | |||
| f80bdfa8f7 | |||
| 617a17db40 | |||
| 92716cb89a | |||
| cfc9d7a942 |
20
.env.example
20
.env.example
@@ -23,13 +23,21 @@ LLM_ROUTER_MODEL=gpt-4o-mini
|
|||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
|
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
||||||
|
S3_BUCKET=adiuva
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
S3_ENDPOINT_URL=
|
||||||
|
AWS_ACCESS_KEY_ID=
|
||||||
|
AWS_SECRET_ACCESS_KEY=
|
||||||
|
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
|
||||||
# ── Langfuse (leave empty to disable observability) ───────────────────────────
|
# ── Vector Store ──────────────────────────────────────────────────────────────
|
||||||
LANGFUSE_SECRET_KEY=
|
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
|
||||||
LANGFUSE_PUBLIC_KEY=
|
PINECONE_API_KEY=
|
||||||
# LANGFUSE_HOST=https://cloud.langfuse.com # EU (default)
|
PINECONE_INDEX=adiuva
|
||||||
# LANGFUSE_HOST=https://us.cloud.langfuse.com # US
|
QDRANT_URL=
|
||||||
# LANGFUSE_HOST=http://localhost:3000 # Self-hosted
|
QDRANT_API_KEY=
|
||||||
|
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
||||||
|
|
||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -21,7 +21,6 @@ env/
|
|||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.coverage
|
.coverage
|
||||||
tests/fixtures/private*/
|
|
||||||
|
|
||||||
# Docker
|
# Docker
|
||||||
*.log
|
*.log
|
||||||
|
|||||||
298
README.md
298
README.md
@@ -1,8 +1,8 @@
|
|||||||
# Adiuva Cloud API
|
# Adiuva Cloud API
|
||||||
|
|
||||||
**AI-powered project management backend with LLM orchestration and subscription billing.**
|
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
|
||||||
|
|
||||||
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe
|
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -20,7 +20,9 @@ Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe
|
|||||||
- [AI Agent System](#ai-agent-system)
|
- [AI Agent System](#ai-agent-system)
|
||||||
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
||||||
- [Middleware](#middleware)
|
- [Middleware](#middleware)
|
||||||
|
- [Storage Layer](#storage-layer)
|
||||||
- [Billing & Tiers](#billing--tiers)
|
- [Billing & Tiers](#billing--tiers)
|
||||||
|
- [Plugin Marketplace](#plugin-marketplace)
|
||||||
- [Testing](#testing)
|
- [Testing](#testing)
|
||||||
- [Project Structure](#project-structure)
|
- [Project Structure](#project-structure)
|
||||||
- [License](#license)
|
- [License](#license)
|
||||||
@@ -29,13 +31,15 @@ Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe
|
|||||||
|
|
||||||
## Overview
|
## Overview
|
||||||
|
|
||||||
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, text embedding generation, and Stripe-based subscription billing across four tiers.
|
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
|
||||||
|
|
||||||
### Design Principles
|
### Design Principles
|
||||||
|
|
||||||
1. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
|
||||||
2. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
||||||
3. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
|
||||||
|
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
||||||
|
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -50,26 +54,27 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
|
|||||||
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
||||||
│ │ Auth Routes │ │ Chat Routes │ │
|
│ │ Auth Routes │ │ Chat Routes │ │
|
||||||
│ │ Billing Routes │ │ ↓ │ │
|
│ │ Billing Routes │ │ ↓ │ │
|
||||||
│ │ Agent Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
||||||
│ │ Device WS │ │ ↓ classify intent │ │
|
│ │ Backup Routes │ │ ↓ classify intent │ │
|
||||||
│ └──────────────────┘ │ Agent Registry │ │
|
│ │ Plugin Routes │ │ Agent Registry │ │
|
||||||
│ │ ↓ │ │
|
│ │ Vector Routes │ │ ↓ │ │
|
||||||
│ │ TaskAgent | ProjectAgent │ │
|
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
|
||||||
│ │ NoteAgent | CheckptAgent │ │
|
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
|
||||||
│ │ (GPT-4o + LangChain) │ │
|
│ │ (GPT-4o + LangChain) │ │
|
||||||
│ └────────────────────────────┘ │
|
│ └────────────────────────────┘ │
|
||||||
└────────────────────────────────────────────────────────┘
|
└────────────────────────────────────────────────────────┘
|
||||||
│
|
│ │ │
|
||||||
┌────────▼───┐
|
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
|
||||||
│ PostgreSQL │
|
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
|
||||||
│ (Auth, │
|
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
|
||||||
│ Billing, │
|
│ Billing, │ │ backups) │ │ (Vectors) │
|
||||||
│ Agents) │
|
│ Metadata) │ └───────────────┘ └────────────────┘
|
||||||
└────────────┘
|
└────────────┘
|
||||||
│
|
│
|
||||||
┌────────▼───┐
|
┌────────▼───┐
|
||||||
│ Stripe │
|
│ Stripe │
|
||||||
│ (Billing) │
|
│ (Billing, │
|
||||||
|
│ Connect) │
|
||||||
└────────────┘
|
└────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
@@ -80,14 +85,18 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
|
|||||||
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
||||||
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
||||||
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
||||||
4. **Text embeddings** — Generates text-embedding-3-small vectors for local client-side note search.
|
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
||||||
5. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
||||||
6. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
|
||||||
7. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
|
||||||
8. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
||||||
9. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
||||||
10. **Alembic migrations** — Versioned schema management.
|
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
||||||
11. **Comprehensive test suite** — In-memory SQLite, per-tier test fixtures, and full API coverage without external dependencies.
|
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
||||||
|
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
|
||||||
|
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
||||||
|
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
|
||||||
|
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -105,6 +114,7 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
|
|||||||
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
||||||
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
||||||
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
||||||
|
| `boto3` | ≥ 1.35.0 | AWS S3 client |
|
||||||
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
||||||
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
||||||
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
||||||
@@ -114,9 +124,12 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
|
|||||||
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
||||||
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
||||||
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
||||||
|
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
|
||||||
|
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
|
||||||
| `pytest` | ≥ 8.0.0 | Test framework |
|
| `pytest` | ≥ 8.0.0 | Test framework |
|
||||||
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
||||||
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
||||||
|
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
|
||||||
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -129,6 +142,7 @@ Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron deskto
|
|||||||
- PostgreSQL 16+
|
- PostgreSQL 16+
|
||||||
- An OpenAI API key (for LLM features)
|
- An OpenAI API key (for LLM features)
|
||||||
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
||||||
|
- AWS credentials (optional — needed for S3 storage in production)
|
||||||
|
|
||||||
### Installation
|
### Installation
|
||||||
|
|
||||||
@@ -180,6 +194,11 @@ This starts two services:
|
|||||||
- **app** — FastAPI server on port `8000`
|
- **app** — FastAPI server on port `8000`
|
||||||
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
||||||
|
|
||||||
|
The compose file also includes optional services for fully local deployments:
|
||||||
|
|
||||||
|
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
|
||||||
|
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
|
||||||
|
|
||||||
### Dockerfile Details
|
### Dockerfile Details
|
||||||
|
|
||||||
The Dockerfile uses a multi-stage build:
|
The Dockerfile uses a multi-stage build:
|
||||||
@@ -197,7 +216,7 @@ gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0
|
|||||||
|
|
||||||
## Homelab / Self-Hosted Deployment
|
## Homelab / Self-Hosted Deployment
|
||||||
|
|
||||||
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**.
|
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
|
||||||
|
|
||||||
### 1. Start all services
|
### 1. Start all services
|
||||||
|
|
||||||
@@ -205,14 +224,35 @@ You can run the entire stack locally on a homelab with **no cloud dependencies e
|
|||||||
docker compose up -d
|
docker compose up -d
|
||||||
```
|
```
|
||||||
|
|
||||||
This starts PostgreSQL alongside the app.
|
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
|
||||||
|
|
||||||
### 2. Configure your `.env`
|
### 2. Create the MinIO bucket
|
||||||
|
|
||||||
|
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
|
||||||
|
docker compose exec minio mc mb local/adiuva
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Configure your `.env`
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
# Database (uses the compose PostgreSQL)
|
# Database (uses the compose PostgreSQL)
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
|
||||||
|
# S3 → MinIO
|
||||||
|
S3_BUCKET=adiuva
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
AWS_ACCESS_KEY_ID=minioadmin
|
||||||
|
AWS_SECRET_ACCESS_KEY=minioadmin
|
||||||
|
|
||||||
|
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
|
||||||
|
QDRANT_URL=http://qdrant:6333
|
||||||
|
QDRANT_API_KEY=
|
||||||
|
PINECONE_API_KEY=
|
||||||
|
|
||||||
# Billing — leave empty to stub (no Stripe needed)
|
# Billing — leave empty to stub (no Stripe needed)
|
||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=
|
STRIPE_WEBHOOK_SECRET=
|
||||||
@@ -227,7 +267,7 @@ JWT_SECRET=your-secret-here
|
|||||||
ENV=dev
|
ENV=dev
|
||||||
```
|
```
|
||||||
|
|
||||||
### 3. Run migrations
|
### 4. Run migrations
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
docker compose exec app alembic upgrade head
|
docker compose exec app alembic upgrade head
|
||||||
@@ -238,7 +278,9 @@ docker compose exec app alembic upgrade head
|
|||||||
| Service | Runs on | Port | Notes |
|
| Service | Runs on | Port | Notes |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| FastAPI app | Docker | 8000 | API server |
|
| FastAPI app | Docker | 8000 | API server |
|
||||||
| PostgreSQL | Docker | 5432 | Auth, billing, agents |
|
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
|
||||||
|
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
|
||||||
|
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
|
||||||
| Stripe | — | — | Stubbed when keys are empty |
|
| Stripe | — | — | Stubbed when keys are empty |
|
||||||
| OpenAI / LLM | Cloud | — | Only external dependency |
|
| OpenAI / LLM | Cloud | — | Only external dependency |
|
||||||
|
|
||||||
@@ -258,7 +300,17 @@ All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/
|
|||||||
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
||||||
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
||||||
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
||||||
| `STRIPE_WEBHOOK_SECRET` | `str` | `\"\"` | Stripe webhook signature secret |\n| `OPENAI_API_KEY` | `str` | `\"\"` | OpenAI key for LLM agent calls |
|
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
|
||||||
|
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
|
||||||
|
| `S3_REGION` | `str` | `us-east-1` | AWS region |
|
||||||
|
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
|
||||||
|
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
|
||||||
|
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
|
||||||
|
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
|
||||||
|
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
|
||||||
|
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
|
||||||
|
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
|
||||||
|
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
|
||||||
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
||||||
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
||||||
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
||||||
@@ -290,7 +342,6 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
|
|||||||
| Method | Path | Auth | Description |
|
| Method | Path | Auth | Description |
|
||||||
|---|---|---|---|
|
|---|---|---|---|
|
||||||
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
||||||
| `POST` | `/api/v1/chat/embed` | JWT | Generate a 1536-dim text embedding vector (`text-embedding-3-small`). Used by Electron for local note search. |
|
|
||||||
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
||||||
|
|
||||||
### Plans
|
### Plans
|
||||||
@@ -300,6 +351,42 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
|
|||||||
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
||||||
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
||||||
|
|
||||||
|
### Storage (Cloud Records)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
|
||||||
|
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
|
||||||
|
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
|
||||||
|
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
|
||||||
|
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
|
||||||
|
|
||||||
|
### Vectors (Cloud Vector Store)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
|
||||||
|
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
|
||||||
|
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
|
||||||
|
|
||||||
|
### Backup
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
|
||||||
|
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
|
||||||
|
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
|
||||||
|
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
|
||||||
|
|
||||||
|
### Plugins (Marketplace)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
|
||||||
|
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
|
||||||
|
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
|
||||||
|
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
|
||||||
|
|
||||||
### Billing
|
### Billing
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
| Method | Path | Auth | Description |
|
||||||
@@ -313,7 +400,7 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
|
|||||||
|
|
||||||
## Data Model
|
## Data Model
|
||||||
|
|
||||||
3 tables managed by Alembic migrations. Source: `app/models.py`
|
9 tables managed by Alembic migrations. Source: `app/models.py`
|
||||||
|
|
||||||
### Tables
|
### Tables
|
||||||
|
|
||||||
@@ -322,18 +409,27 @@ All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebS
|
|||||||
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
||||||
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
||||||
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
||||||
|
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
|
||||||
|
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
|
||||||
|
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
|
||||||
|
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
|
||||||
|
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
|
||||||
|
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
|
||||||
|
|
||||||
### Enum Types
|
### Enum Types
|
||||||
|
|
||||||
| Enum | Values |
|
| Enum | Values |
|
||||||
|---|---|
|
|---|---|
|
||||||
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
||||||
|
| `plugin_status` | `pending_review`, `approved`, `rejected` |
|
||||||
|
| `review_decision` | `approved`, `rejected` |
|
||||||
|
|
||||||
### Migrations
|
### Migrations
|
||||||
|
|
||||||
| Version | Description |
|
| Version | Description |
|
||||||
|---|---|
|
|---|---|
|
||||||
| `001_initial_schema` | Creates core auth and billing tables with indexes and foreign key constraints |
|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
|
||||||
|
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -343,7 +439,7 @@ The agent system uses a registry pattern with LangChain tool-calling agents powe
|
|||||||
|
|
||||||
### Architecture
|
### Architecture
|
||||||
|
|
||||||
- **`BaseAgent`** — Abstract base with `user_id` and `shared_memory`.
|
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
|
||||||
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
||||||
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
||||||
|
|
||||||
@@ -458,6 +554,39 @@ Source: `app/api/middleware/sanitizer.py`
|
|||||||
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
||||||
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
||||||
- Logs sanitization events as `WARNING`.
|
- Logs sanitization events as `WARNING`.
|
||||||
|
- Binary responses (storage, backup) are never touched.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Storage Layer
|
||||||
|
|
||||||
|
### Blob Store
|
||||||
|
|
||||||
|
Source: `app/storage/blob_store.py`
|
||||||
|
|
||||||
|
- S3-backed storage for E2E encrypted blobs.
|
||||||
|
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
|
||||||
|
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
|
||||||
|
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
|
||||||
|
- The backend **never inspects or decrypts blob content**.
|
||||||
|
|
||||||
|
### Vector Store
|
||||||
|
|
||||||
|
Source: `app/storage/vector_store.py`
|
||||||
|
|
||||||
|
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
|
||||||
|
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
|
||||||
|
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
|
||||||
|
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
|
||||||
|
- Methods: `upsert()`, `search()`, `delete()`
|
||||||
|
|
||||||
|
### Encryption Utilities
|
||||||
|
|
||||||
|
Source: `app/storage/encryption.py`
|
||||||
|
|
||||||
|
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
|
||||||
|
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
|
||||||
|
- **No decryption key ever reaches the backend.**
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -471,8 +600,11 @@ Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
|||||||
|---|---|---|---|---|
|
|---|---|---|---|---|
|
||||||
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
||||||
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
||||||
|
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
|
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
||||||
| Batch Builder | — | — | ✓ | ✓ |
|
| Batch Builder | — | — | ✓ | ✓ |
|
||||||
|
| Plugin Marketplace | — | — | ✓ | ✓ |
|
||||||
| SSO | — | — | — | ✓ |
|
| SSO | — | — | — | ✓ |
|
||||||
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
||||||
|
|
||||||
@@ -488,6 +620,47 @@ Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
|||||||
- `get_tier(user_id)` — Returns the user's current billing tier.
|
- `get_tier(user_id)` — Returns the user's current billing tier.
|
||||||
- `check_feature(tier, feature)` — Boolean feature gate check.
|
- `check_feature(tier, feature)` — Boolean feature gate check.
|
||||||
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
||||||
|
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Plugin Marketplace
|
||||||
|
|
||||||
|
Source: `app/marketplace/`
|
||||||
|
|
||||||
|
### Plugin Registry
|
||||||
|
|
||||||
|
- PostgreSQL-backed catalog of submitted and approved plugins.
|
||||||
|
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
|
||||||
|
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
|
||||||
|
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
|
||||||
|
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
|
||||||
|
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
|
||||||
|
|
||||||
|
### Review Queue
|
||||||
|
|
||||||
|
- Automated security checklist before human review:
|
||||||
|
- Plugin ID must match `^[a-z0-9-]+$`
|
||||||
|
- Permissions must be from the allowed set only
|
||||||
|
- No binary blobs in the manifest
|
||||||
|
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
|
||||||
|
- `get_pending(db)` — Lists plugins awaiting review.
|
||||||
|
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
||||||
|
|
||||||
|
### Revenue Sharing
|
||||||
|
|
||||||
|
- **70% developer / 30% platform** split on all paid plugin sales.
|
||||||
|
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
|
||||||
|
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
|
||||||
|
- Gracefully stubs transfers when Stripe is not configured.
|
||||||
|
|
||||||
|
### Seed Plugins
|
||||||
|
|
||||||
|
| Plugin | Category | Price |
|
||||||
|
|---|---|---|
|
||||||
|
| GitHub Sync | Productivity | Free |
|
||||||
|
| Slack Notifier | Communication | €4.99 |
|
||||||
|
| Time Tracker | Productivity | €9.99 |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -509,8 +682,10 @@ pytest -v
|
|||||||
### Test Infrastructure
|
### Test Infrastructure
|
||||||
|
|
||||||
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
||||||
|
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
|
||||||
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
||||||
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
||||||
|
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
|
||||||
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
||||||
- **No external dependencies** — all tests run fully offline.
|
- **No external dependencies** — all tests run fully offline.
|
||||||
|
|
||||||
@@ -519,6 +694,13 @@ pytest -v
|
|||||||
| File | Coverage |
|
| File | Coverage |
|
||||||
|---|---|
|
|---|---|
|
||||||
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
||||||
|
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
|
||||||
|
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
|
||||||
|
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
|
||||||
|
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
|
||||||
|
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
|
||||||
|
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
|
||||||
|
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
|
||||||
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -528,6 +710,7 @@ pytest -v
|
|||||||
```
|
```
|
||||||
adiuva-api/
|
adiuva-api/
|
||||||
├── alembic.ini # Alembic configuration
|
├── alembic.ini # Alembic configuration
|
||||||
|
├── BACKEND_PLAN.md # Architecture & design decisions
|
||||||
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
||||||
├── Dockerfile # Multi-stage production build
|
├── Dockerfile # Multi-stage production build
|
||||||
├── requirements.txt # Python dependencies
|
├── requirements.txt # Python dependencies
|
||||||
@@ -536,12 +719,13 @@ adiuva-api/
|
|||||||
│ ├── env.py # Alembic environment config
|
│ ├── env.py # Alembic environment config
|
||||||
│ ├── script.py.mako # Migration template
|
│ ├── script.py.mako # Migration template
|
||||||
│ └── versions/
|
│ └── versions/
|
||||||
│ └── 001_initial_schema.py # Tables, indexes, FKs
|
│ ├── 001_initial_schema.py # Tables, indexes, FKs
|
||||||
|
│ └── 002_seed_plugins.py # Seed marketplace plugins
|
||||||
│
|
│
|
||||||
├── app/ # Application source
|
├── app/ # Application source
|
||||||
│ ├── main.py # FastAPI app factory, middleware, routes
|
│ ├── main.py # FastAPI app factory, middleware, routes
|
||||||
│ ├── db.py # Async SQLAlchemy engine & session
|
│ ├── db.py # Async SQLAlchemy engine & session
|
||||||
│ ├── models.py # SQLAlchemy ORM models
|
│ ├── models.py # SQLAlchemy ORM models (9 tables)
|
||||||
│ ├── schemas.py # Pydantic request/response schemas
|
│ ├── schemas.py # Pydantic request/response schemas
|
||||||
│ │
|
│ │
|
||||||
│ ├── config/
|
│ ├── config/
|
||||||
@@ -556,29 +740,47 @@ adiuva-api/
|
|||||||
│ ├── core/ # Orchestration engine
|
│ ├── core/ # Orchestration engine
|
||||||
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
||||||
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
||||||
│ │ └── deep_agent.py # Deep agent orchestration
|
│ │ ├── orchestrator.py # Intent classification & routing
|
||||||
|
│ │ └── execution_plan.py # Plan builder, templates, cache
|
||||||
│ │
|
│ │
|
||||||
│ ├── api/ # HTTP layer
|
│ ├── api/ # HTTP layer
|
||||||
│ │ ├── deps.py # Shared FastAPI dependencies
|
│ │ ├── deps.py # Shared FastAPI dependencies
|
||||||
│ │ ├── middleware/
|
│ │ ├── middleware/
|
||||||
|
│ │ │ ├── auth.py # JWT validation, live tier lookup
|
||||||
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
||||||
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
||||||
│ │ └── routes/
|
│ │ └── routes/
|
||||||
│ │ ├── auth.py # Register, login, refresh, me
|
│ │ ├── auth.py # Register, login, refresh, me
|
||||||
│ │ ├── chat.py # Chat + embed endpoint
|
│ │ ├── chat.py # Chat + WebSocket streaming
|
||||||
│ │ ├── billing.py # Stripe checkout, webhooks, subscription
|
│ │ ├── plans.py # Execution plan playbooks
|
||||||
│ │ ├── agents.py # Agent catalog, config, runs
|
│ │ ├── storage.py # E2E encrypted record CRUD
|
||||||
│ │ └── device_ws.py # Persistent device WebSocket
|
│ │ ├── vectors.py # Vector upsert, search, delete
|
||||||
|
│ │ ├── backup.py # Encrypted backup management
|
||||||
|
│ │ ├── plugins.py # Marketplace browse & install
|
||||||
|
│ │ └── billing.py # Stripe checkout & webhooks
|
||||||
│ │
|
│ │
|
||||||
│ └── billing/
|
│ ├── storage/ # Storage backends
|
||||||
│ ├── stripe_service.py # Stripe API wrapper
|
│ │ ├── blob_store.py # S3 blob storage
|
||||||
│ └── tier_manager.py # Feature matrix, rate limits
|
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
|
||||||
|
│ │ └── encryption.py # Checksum verification utilities
|
||||||
|
│ │
|
||||||
|
│ ├── billing/ # Subscription management
|
||||||
|
│ │ ├── stripe_service.py # Stripe API integration
|
||||||
|
│ │ └── tier_manager.py # Feature matrix & quota enforcement
|
||||||
|
│ │
|
||||||
|
│ └── marketplace/ # Plugin ecosystem
|
||||||
|
│ ├── plugin_registry.py # Catalog CRUD & search
|
||||||
|
│ ├── plugin_review.py # Security checklist & review queue
|
||||||
|
│ └── revenue_share.py # 70/30 split & Stripe Connect
|
||||||
│
|
│
|
||||||
└── tests/ # Test suite
|
└── tests/ # Test suite
|
||||||
├── conftest.py # Fixtures: DB, auth, seeds
|
├── conftest.py # Fixtures: DB, S3, auth, seeds
|
||||||
├── test_auth.py
|
├── test_auth.py
|
||||||
├── test_orchestrator.py
|
├── test_orchestrator.py
|
||||||
├── test_agents.py
|
├── test_agents.py
|
||||||
|
├── test_storage.py
|
||||||
|
├── test_backup.py
|
||||||
|
├── test_plugins.py
|
||||||
├── test_agent_registry.py
|
├── test_agent_registry.py
|
||||||
├── test_execution_plan.py
|
├── test_execution_plan.py
|
||||||
└── test_middleware.py
|
└── test_middleware.py
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Initial schema: users, refresh_tokens, subscriptions.
|
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||||
|
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||||
|
|
||||||
Revision ID: 001
|
Revision ID: 001
|
||||||
Revises:
|
Revises:
|
||||||
@@ -27,6 +28,18 @@ def upgrade() -> None:
|
|||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
END $$;
|
END $$;
|
||||||
""")
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
# ── users ─────────────────────────────────────────────────────────────
|
# ── users ─────────────────────────────────────────────────────────────
|
||||||
op.create_table(
|
op.create_table(
|
||||||
@@ -75,10 +88,122 @@ def upgrade() -> None:
|
|||||||
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||||
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||||
|
|
||||||
|
# ── storage_records ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"storage_records",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("table_name", sa.String(100), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||||
|
|
||||||
|
# ── backup_metadata ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"backup_metadata",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer, nullable=False),
|
||||||
|
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugins ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugins",
|
||||||
|
sa.Column("id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||||
|
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||||
|
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||||
|
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
||||||
|
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||||
|
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||||
|
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||||
|
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── plugin_installations ──────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_installations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||||
|
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_reviews",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
||||||
|
sa.Column("notes", sa.Text, nullable=True),
|
||||||
|
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||||
|
|
||||||
|
# ── revenue_events ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"revenue_events",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||||
|
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
|
op.drop_table("revenue_events")
|
||||||
|
op.drop_table("plugin_reviews")
|
||||||
|
op.drop_table("plugin_installations")
|
||||||
|
op.drop_table("plugins")
|
||||||
|
op.drop_table("backup_metadata")
|
||||||
|
op.drop_table("storage_records")
|
||||||
op.drop_table("subscriptions")
|
op.drop_table("subscriptions")
|
||||||
op.drop_table("refresh_tokens")
|
op.drop_table("refresh_tokens")
|
||||||
op.drop_table("users")
|
op.drop_table("users")
|
||||||
|
|
||||||
|
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||||
|
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||||
op.execute("DROP TYPE IF EXISTS billing_tier")
|
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||||
|
|||||||
92
alembic/versions/002_seed_plugins.py
Normal file
92
alembic/versions/002_seed_plugins.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2026-03-03
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "002"
|
||||||
|
down_revision: Union[str, None] = "001"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
{
|
||||||
|
"id": "plugin-github-sync",
|
||||||
|
"name": "GitHub Sync",
|
||||||
|
"description": "Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 0,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-slack-notify",
|
||||||
|
"name": "Slack Notifier",
|
||||||
|
"description": "Post task and timeline updates to Slack channels.",
|
||||||
|
"version": "1.2.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "communication",
|
||||||
|
"price_cents": 499,
|
||||||
|
"permissions": json.dumps(["read:tasks", "read:timelines"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-time-tracker",
|
||||||
|
"name": "Time Tracker",
|
||||||
|
"description": "Track time spent on tasks with automatic reporting.",
|
||||||
|
"version": "0.9.1",
|
||||||
|
"author_name": "Third Party",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 999,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
plugins = sa.table(
|
||||||
|
"plugins",
|
||||||
|
sa.column("id", sa.String),
|
||||||
|
sa.column("name", sa.String),
|
||||||
|
sa.column("description", sa.Text),
|
||||||
|
sa.column("version", sa.String),
|
||||||
|
sa.column("author_name", sa.String),
|
||||||
|
sa.column("category", sa.String),
|
||||||
|
sa.column("price_cents", sa.Integer),
|
||||||
|
sa.column("permissions", sa.Text),
|
||||||
|
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
|
||||||
|
sa.column("s3_package_key", sa.String),
|
||||||
|
sa.column("install_count", sa.Integer),
|
||||||
|
sa.column("avg_rating", sa.Float),
|
||||||
|
)
|
||||||
|
op.bulk_insert(plugins, _SEED_PLUGINS)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"DELETE FROM plugins WHERE id IN ("
|
||||||
|
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
|
||||||
|
")"
|
||||||
|
)
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
"""Deprecate backend agent config tables.
|
|
||||||
|
|
||||||
The Electron client is now the source of truth for agent configuration
|
|
||||||
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
|
||||||
billing checks and trigger/run logs only.
|
|
||||||
|
|
||||||
Revision ID: 9a1f2d0b6c7e
|
|
||||||
Revises: 818478c251dc
|
|
||||||
Create Date: 2026-03-16
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
revision: str = "9a1f2d0b6c7e"
|
|
||||||
down_revision: Union[str, None] = "818478c251dc"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
bind = op.get_bind()
|
|
||||||
inspector = sa.inspect(bind)
|
|
||||||
existing = set(inspector.get_table_names())
|
|
||||||
|
|
||||||
if "cloud_agent_configs" in existing:
|
|
||||||
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
|
||||||
op.drop_table("cloud_agent_configs")
|
|
||||||
|
|
||||||
if "local_agent_configs" in existing:
|
|
||||||
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
|
||||||
op.drop_table("local_agent_configs")
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"local_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("device_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
|
||||||
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
op.create_table(
|
|
||||||
"cloud_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"provider",
|
|
||||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
|
||||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
"""add agent_config to local_agent_configs
|
|
||||||
|
|
||||||
Revision ID: a3b9c0d1e2f3
|
|
||||||
Revises: 9a1f2d0b6c7e
|
|
||||||
Create Date: 2026-04-07 00:00:00.000000
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "a3b9c0d1e2f3"
|
|
||||||
down_revision: Union[str, None] = "9a1f2d0b6c7e"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.add_column(
|
|
||||||
"local_agent_configs",
|
|
||||||
sa.Column("agent_config", sa.JSON(), nullable=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_column("local_agent_configs", "agent_config")
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
|
||||||
|
|
||||||
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
@@ -1,85 +0,0 @@
|
|||||||
"""Filesystem agent — tools for reading local directories and files on Electron.
|
|
||||||
|
|
||||||
These tools delegate to the Electron client via ``execute_on_client()`` using
|
|
||||||
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
|
||||||
handles actual disk I/O and responds with ``tool_result`` frames.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_directory(path: str) -> str:
|
|
||||||
"""List files and folders in a local directory on the user's device.
|
|
||||||
|
|
||||||
Returns a formatted listing of entries with name, type (file/directory),
|
|
||||||
and full path.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="list_directory",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
|
||||||
if not entries:
|
|
||||||
return f"Directory '{path}' is empty or does not exist."
|
|
||||||
lines: list[str] = []
|
|
||||||
for entry in entries:
|
|
||||||
entry_type = entry.get("type", "unknown")
|
|
||||||
entry_name = entry.get("name", "")
|
|
||||||
entry_path = entry.get("path", "")
|
|
||||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
|
||||||
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def read_file_content(path: str) -> str:
|
|
||||||
"""Read the text content of a local file on the user's device.
|
|
||||||
|
|
||||||
Returns the file content as a string. Large files may be truncated
|
|
||||||
by the Electron client.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="read_file_content",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
content: str = result.get("content", "")
|
|
||||||
if not content:
|
|
||||||
return f"File '{path}' is empty or could not be read."
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_file_metadata(path: str) -> str:
|
|
||||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
|
||||||
|
|
||||||
Returns a formatted summary of the file's metadata.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="get_file_metadata",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
size = result.get("size", "unknown")
|
|
||||||
created = result.get("createdAt", "unknown")
|
|
||||||
modified = result.get("modifiedAt", "unknown")
|
|
||||||
extension = result.get("extension", "unknown")
|
|
||||||
name = result.get("name", path)
|
|
||||||
return (
|
|
||||||
f"File: {name}\n"
|
|
||||||
f" Extension: {extension}\n"
|
|
||||||
f" Size: {size} bytes\n"
|
|
||||||
f" Created: {created}\n"
|
|
||||||
f" Modified: {modified}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
FILESYSTEM_TOOLS: list[Any] = [
|
|
||||||
list_directory,
|
|
||||||
read_file_content,
|
|
||||||
get_file_metadata,
|
|
||||||
]
|
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
"""Note agent — tool definitions for Markdown note CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
@@ -10,38 +9,14 @@ from langchain_core.tools import tool
|
|||||||
from app.core.llm import embed
|
from app.core.llm import embed
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
NOTE_SYSTEM_PROMPT = (
|
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": normalized_project_id or None},
|
filters={"projectId": project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -130,10 +105,4 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
NOTE_TOOLS: list[Any] = [
|
|
||||||
list_notes,
|
|
||||||
get_note,
|
|
||||||
create_note,
|
|
||||||
update_note,
|
|
||||||
delete_note,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
"""Project agent — tool definitions for project lifecycle CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -8,22 +8,6 @@ from langchain_core.tools import tool
|
|||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
PROJECT_SYSTEM_PROMPT = (
|
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
|
||||||
"update, and archive projects in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: active, archived\n"
|
|
||||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
|
||||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
|
||||||
" derive it from context data — do not fabricate content\n"
|
|
||||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
|
||||||
" user wants a complete cross-client view including archived projects\n"
|
|
||||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
|
||||||
" list_projects if you only have a project name\n"
|
|
||||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
|
||||||
" only call delete_project when the user explicitly confirms deletion."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_projects(
|
async def list_projects(
|
||||||
@@ -133,11 +117,4 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
PROJECT_TOOLS: list[Any] = [
|
|
||||||
list_projects,
|
|
||||||
list_all_projects,
|
|
||||||
get_project,
|
|
||||||
create_project,
|
|
||||||
update_project,
|
|
||||||
delete_project,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,40 +1,14 @@
|
|||||||
"""Task agent — full CRUD for tasks and task comments."""
|
"""Task agent — tool definitions for task and task comment CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
TASK_SYSTEM_PROMPT = (
|
|
||||||
"You are a task management assistant for a project workspace.\n"
|
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: todo, in_progress, done\n"
|
|
||||||
" - priority must be one of: high, medium, low\n"
|
|
||||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
|
||||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
|
||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
|
||||||
" did not explicitly request; 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -48,12 +22,11 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={
|
filters={
|
||||||
"projectId": normalized_project_id or None,
|
"projectId": project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
@@ -79,6 +52,7 @@ async def create_task(
|
|||||||
due_date: int = 0,
|
due_date: int = 0,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new task.
|
"""Create a new task.
|
||||||
title: task title (required)
|
title: task title (required)
|
||||||
@@ -89,6 +63,7 @@ async def create_task(
|
|||||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
project_id: optional UUID of the parent project
|
project_id: optional UUID of the parent project
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms; 1 when confirmed
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -102,6 +77,7 @@ async def create_task(
|
|||||||
"dueDate": due_date or None,
|
"dueDate": due_date or None,
|
||||||
"projectId": project_id or None,
|
"projectId": project_id or None,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -121,10 +97,12 @@ async def update_task(
|
|||||||
assignees: str = "",
|
assignees: str = "",
|
||||||
due_date: int = -1,
|
due_date: int = -1,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
|
is_approved: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update fields on an existing task. Only pass fields you want to change.
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
task_id: the task's UUID (required)
|
task_id: the task's UUID (required)
|
||||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the value
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
@@ -141,6 +119,8 @@ async def update_task(
|
|||||||
updates["dueDate"] = due_date or None
|
updates["dueDate"] = due_date or None
|
||||||
if project_id:
|
if project_id:
|
||||||
updates["projectId"] = project_id
|
updates["projectId"] = project_id
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
@@ -208,12 +188,8 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result.get("row", {})
|
row = result["row"]
|
||||||
row_author = row.get("author", author)
|
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
||||||
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
|
|
||||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
|
||||||
row_comment_id = row.get("id", "unknown")
|
|
||||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -223,16 +199,4 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
return f"Comment {comment_id} deleted."
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
# ── Agent ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
TASK_TOOLS: list[Any] = [
|
|
||||||
list_tasks,
|
|
||||||
create_task,
|
|
||||||
update_task,
|
|
||||||
delete_task,
|
|
||||||
list_tasks_due_today,
|
|
||||||
list_task_comments,
|
|
||||||
add_task_comment,
|
|
||||||
delete_task_comment,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,45 +1,21 @@
|
|||||||
"""Timeline agent — project milestone management (list, create, update, delete)."""
|
"""Timeline agent — tool definitions for project milestone CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
TIMELINE_SYSTEM_PROMPT = (
|
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all timelines across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
filters={"projectId": normalized_project_id or None},
|
filters={"projectId": project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -54,12 +30,14 @@ async def create_timeline(
|
|||||||
title: str,
|
title: str,
|
||||||
date: int,
|
date: int,
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a project timeline (milestone).
|
"""Create a project timeline (milestone).
|
||||||
project_id: REQUIRED UUID of the parent project
|
project_id: REQUIRED UUID of the parent project
|
||||||
title: descriptive name for the milestone
|
title: descriptive name for the milestone
|
||||||
date: Unix timestamp in milliseconds
|
date: Unix timestamp in milliseconds
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -69,6 +47,7 @@ async def create_timeline(
|
|||||||
"title": title,
|
"title": title,
|
||||||
"date": date,
|
"date": date,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -80,16 +59,20 @@ async def update_timeline(
|
|||||||
timeline_id: str,
|
timeline_id: str,
|
||||||
title: str = "",
|
title: str = "",
|
||||||
date: int = -1,
|
date: int = -1,
|
||||||
|
is_approved: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update a timeline. Only pass fields that should change.
|
"""Update a timeline. Only pass fields that should change.
|
||||||
timeline_id: UUID of the timeline (required)
|
timeline_id: UUID of the timeline (required)
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
updates["title"] = title
|
updates["title"] = title
|
||||||
if date != -1:
|
if date != -1:
|
||||||
updates["date"] = date
|
updates["date"] = date
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
@@ -106,9 +89,4 @@ async def delete_timeline(timeline_id: str) -> str:
|
|||||||
return f"Timeline {timeline_id} deleted."
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
TIMELINE_TOOLS: list[Any] = [
|
|
||||||
list_timelines,
|
|
||||||
create_timeline,
|
|
||||||
update_timeline,
|
|
||||||
delete_timeline,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -55,15 +55,12 @@ async def get_current_user(
|
|||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup — subscription row is the authoritative source.
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
|
||||||
# block local development when no Stripe subscription exists.
|
|
||||||
from app.models import Subscription, User # noqa: PLC0415
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
tier: str = result.scalar_one_or_none() or "free"
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
# Fetch name/surname from user row.
|
# Fetch name/surname from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ that could reveal server-side prompt IP:
|
|||||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||||
- Exact-match known prompt fingerprints
|
- Exact-match known prompt fingerprints
|
||||||
|
|
||||||
The middleware only activates for paths under /api/v1/chat.
|
Binary responses (storage blobs, backup data) are never touched — the
|
||||||
|
middleware only activates for paths under /api/v1/chat.
|
||||||
|
|
||||||
Any sanitisation event is logged as a WARNING with the request path and the
|
Any sanitisation event is logged as a WARNING with the request path and the
|
||||||
names of the fields that were modified.
|
names of the fields that were modified.
|
||||||
|
|||||||
@@ -1,72 +1,74 @@
|
|||||||
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
|
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
Endpoints:
|
||||||
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
POST /agents/journey/start — start a new journey session
|
||||||
frames to the functions exported here.
|
POST /agents/journey/message — continue the conversation
|
||||||
|
|
||||||
|
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
||||||
|
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. FE sends ``journey_start`` frame with basic agent info (directory,
|
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
||||||
data_types, schedule).
|
2. Server creates a session, calls the LLM with a contextual system prompt,
|
||||||
2. Server creates an in-memory session, sets up a WS executor so the
|
and returns the first question.
|
||||||
setup LLM can use file-system tools, does a first directory scrape,
|
3. Client sends follow-up messages to ``/message``.
|
||||||
and sends back a ``journey_reply`` with the first question.
|
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
||||||
3. FE sends ``journey_message`` frames for each user reply.
|
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||||
4. Server appends the user message, calls the LLM (which may read files
|
5. Server parses the block, sets ``done=True``, and returns the template.
|
||||||
via tools), and sends back a ``journey_reply``.
|
|
||||||
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
|
The ``prompt_template`` from the final response is meant to be stored in
|
||||||
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
||||||
6. Server parses and validates the JSON with Pydantic, sends
|
by the Electron client (via the agent CRUD endpoints).
|
||||||
``journey_reply`` with ``done=True`` and the serialised config.
|
|
||||||
FE stores it locally.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_llm
|
||||||
from app.schemas import AgentConfig
|
from app.db import get_session
|
||||||
|
from app.models import CloudAgentConfig, LocalAgentConfig
|
||||||
|
from app.schemas import (
|
||||||
|
JourneyMessageRequest,
|
||||||
|
JourneyResponse,
|
||||||
|
JourneyStartRequest,
|
||||||
|
UserProfile,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
||||||
|
|
||||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
|
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||||
_CONFIG_START = "AGENT_CONFIG_START"
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
_CONFIG_END = "AGENT_CONFIG_END"
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
# Minimum turns before we consider nudging the LLM to wrap up.
|
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
||||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
_MAX_TURNS: int = 5
|
||||||
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
|
|
||||||
_MAX_TURNS: int = 15
|
|
||||||
# Max tool-calling steps per LLM invocation.
|
|
||||||
_MAX_TOOL_STEPS: int = 6
|
|
||||||
|
|
||||||
# ── In-memory session store ───────────────────────────────────────────────
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class JourneySession:
|
class _JourneySession:
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
agent_type: str # "local" | "cloud"
|
agent_type: str # "local" | "cloud"
|
||||||
directory: str
|
|
||||||
data_types: list[str]
|
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
system_prompt: str = ""
|
|
||||||
langfuse_prompt: Any = None
|
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
@@ -74,182 +76,103 @@ class JourneySession:
|
|||||||
|
|
||||||
|
|
||||||
# session_id → session
|
# session_id → session
|
||||||
_sessions: dict[str, JourneySession] = {}
|
_sessions: dict[str, _JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
||||||
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
||||||
s = _sessions.get(session_id)
|
s = _sessions.get(session_id)
|
||||||
if s is None or s.is_expired():
|
if s is None or s.is_expired():
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
return None
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
if s.user_id != user_id:
|
if s.user_id != user_id:
|
||||||
return None
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt ─────────────────────────────────────────────────────────
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
_JOURNEY_SYSTEM_PROMPT = """\
|
_LOCAL_PREAMBLE = """\
|
||||||
|
What kind of files are in the directories you want to monitor? \
|
||||||
|
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
||||||
|
|
||||||
|
_CLOUD_PREAMBLE = """\
|
||||||
|
What kind of emails or messages should I look for? \
|
||||||
|
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand what files the user has in their directory and produce a
|
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
||||||
structured AgentConfig JSON that the extraction agent will use as its instruction set.
|
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
||||||
|
|
||||||
You have access to file-system tools to explore the user's directory:
|
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
- list_directory: see folder structure and file names
|
1. The type and format of the source content.
|
||||||
- read_file_content: peek at a file's content
|
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
||||||
- get_file_metadata: check file size, extension, dates
|
3. How fields should be mapped (e.g. email subject → task title).
|
||||||
|
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
5. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
The user's configured directory is: {directory}
|
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
||||||
Target data types: {data_types}
|
these exact markers on their own lines:
|
||||||
|
|
||||||
## Your process
|
{template_start}
|
||||||
|
<the complete extraction prompt here>
|
||||||
|
{template_end}
|
||||||
|
|
||||||
### Step 1 — Explore the directory
|
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
||||||
Use list_directory and read_file_content to understand what types of files are present
|
and must return a JSON array of records in this shape:
|
||||||
(HTML emails, plain-text documents, CSVs, etc.).
|
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
||||||
|
|
||||||
### Step 2 — Identify content types
|
|
||||||
For each distinct file type found, decide:
|
|
||||||
- A short id (e.g. "email_html", "plain_text", "csv")
|
|
||||||
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
|
|
||||||
- A human-readable label and optional detection_hint
|
|
||||||
|
|
||||||
### Step 3 — Ask focused questions (one at a time)
|
|
||||||
Cover these topics based on what you discovered:
|
|
||||||
1. How to map content to entity types (task / note / timeline entry)
|
|
||||||
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
|
|
||||||
3. Priority or status rules (e.g. "urgent" in subject → high priority)
|
|
||||||
4. Date extraction (e.g. "by Friday" → dueDate)
|
|
||||||
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
|
||||||
|
|
||||||
### Step 4 — Produce the AgentConfig JSON
|
|
||||||
Once you are ≥ 90% confident, output the final config between these exact markers
|
|
||||||
(each on its own line):
|
|
||||||
|
|
||||||
{config_start}
|
|
||||||
{{
|
|
||||||
"content_types": [
|
|
||||||
{{
|
|
||||||
"id": "email_html",
|
|
||||||
"label": "Email HTML",
|
|
||||||
"detection_hint": "HTML file with From/To/Subject headers",
|
|
||||||
"preprocessing": "email_html",
|
|
||||||
"extraction_prompt": "Detailed extraction instructions for this content type..."
|
|
||||||
}}
|
|
||||||
],
|
|
||||||
"global_rules": [
|
|
||||||
"If the file cannot be matched to any project, do not create any entity."
|
|
||||||
],
|
|
||||||
"data_types": {data_types_json}
|
|
||||||
}}
|
|
||||||
{config_end}
|
|
||||||
|
|
||||||
## Rules for the extraction_prompt field
|
|
||||||
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
|
|
||||||
- Include field mapping rules based on what you found in the directory
|
|
||||||
- Include priority/status/date rules if applicable
|
|
||||||
- Do NOT include projectId logic — the runner handles project assignment automatically
|
|
||||||
- Do NOT mention isAiSuggested — the runner always sets it to 1
|
|
||||||
|
|
||||||
## Constraints
|
|
||||||
- Never ask about projects, projectId, or how to link records to projects
|
|
||||||
- Never include projectId or project creation logic in the generated config
|
|
||||||
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
|
|
||||||
|
|
||||||
|
Rules for the generated template:
|
||||||
|
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
||||||
|
- Include concrete examples of mappings.
|
||||||
|
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
||||||
|
- Set isAiSuggested: true and isApproved: false on every record.
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Begin by exploring the directory, then ask your first question.\
|
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_system_prompt(
|
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
||||||
directory: str,
|
source_description = (
|
||||||
data_types: list[str],
|
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
||||||
existing_config: str | None = None,
|
)
|
||||||
) -> tuple[str, Any]:
|
|
||||||
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
|
||||||
existing_section = (
|
existing_section = (
|
||||||
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
f"```json\n{existing_config}\n```\n"
|
f"---\n{existing_template}\n---\n"
|
||||||
if existing_config
|
if existing_template
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
template, prompt_obj = get_prompt_or_fallback(
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
"journey_system", _JOURNEY_SYSTEM_PROMPT
|
source_description=source_description,
|
||||||
)
|
template_start=_TEMPLATE_START,
|
||||||
compiled = compile_prompt(
|
template_end=_TEMPLATE_END,
|
||||||
template,
|
|
||||||
prompt_obj,
|
|
||||||
directory=directory,
|
|
||||||
data_types=", ".join(data_types),
|
|
||||||
data_types_json=json.dumps(data_types),
|
|
||||||
config_start=_CONFIG_START,
|
|
||||||
config_end=_CONFIG_END,
|
|
||||||
existing_section=existing_section,
|
existing_section=existing_section,
|
||||||
|
max_turns=_MAX_TURNS,
|
||||||
)
|
)
|
||||||
return compiled, prompt_obj
|
|
||||||
|
|
||||||
|
|
||||||
# ── AgentConfig extraction ────────────────────────────────────────────────
|
def _first_question(agent_type: str) -> str:
|
||||||
|
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
||||||
|
|
||||||
|
|
||||||
def _extract_agent_config(text: str) -> str | None:
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
"""Return validated AgentConfig JSON string from between markers, or None.
|
|
||||||
|
|
||||||
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
|
||||||
returning. Returns None if markers are absent or JSON is invalid.
|
def _extract_template(text: str) -> str | None:
|
||||||
"""
|
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||||
if _CONFIG_START not in text or _CONFIG_END not in text:
|
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||||
return None
|
|
||||||
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
|
|
||||||
end_idx = text.index(_CONFIG_END)
|
|
||||||
raw = text[start_idx:end_idx].strip()
|
|
||||||
if not raw:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
parsed = AgentConfig.model_validate_json(raw)
|
|
||||||
return parsed.model_dump_json()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
|
|
||||||
return None
|
return None
|
||||||
|
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||||
|
end_idx = text.index(_TEMPLATE_END)
|
||||||
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call with tool support ───────────────────────────────────────────
|
# ── LLM call ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _as_text(content: Any) -> str:
|
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
||||||
if content is None:
|
"""Build LangChain messages from history and invoke the LLM."""
|
||||||
return ""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts: list[str] = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, str):
|
|
||||||
parts.append(item)
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
text = item.get("text")
|
|
||||||
if isinstance(text, str):
|
|
||||||
parts.append(text)
|
|
||||||
return "".join(parts)
|
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
async def _call_llm_with_tools(
|
|
||||||
system_prompt: str,
|
|
||||||
history: list[dict[str, Any]],
|
|
||||||
tools: list[Any],
|
|
||||||
*,
|
|
||||||
user_id: str = "",
|
|
||||||
session_id: str = "",
|
|
||||||
langfuse_prompt: Any = None,
|
|
||||||
) -> str:
|
|
||||||
"""Build LangChain messages from history and invoke the LLM with tools.
|
|
||||||
|
|
||||||
Handles tool-calling loops: if the LLM calls tools, execute them and
|
|
||||||
continue until a final text response is produced.
|
|
||||||
"""
|
|
||||||
lf = get_langfuse()
|
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
for turn in history:
|
for turn in history:
|
||||||
if turn["role"] == "user":
|
if turn["role"] == "user":
|
||||||
@@ -258,238 +181,137 @@ async def _call_llm_with_tools(
|
|||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_llm(model=None, temperature=0.4)
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
response = await llm.ainvoke(messages)
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
return response.content # type: ignore[return-value]
|
||||||
|
|
||||||
_span_ctx = (
|
|
||||||
lf.start_as_current_observation(
|
|
||||||
as_type="span",
|
|
||||||
name="journey-setup",
|
|
||||||
metadata={"user_id": user_id or None, "session_id": session_id or None},
|
|
||||||
input=history[-1]["content"] if history else "",
|
|
||||||
)
|
|
||||||
if lf else None
|
|
||||||
)
|
|
||||||
_span = _span_ctx.__enter__() if _span_ctx else None
|
|
||||||
|
|
||||||
try:
|
|
||||||
for _ in range(_MAX_TOOL_STEPS):
|
|
||||||
_gen_ctx = (
|
|
||||||
lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="journey-setup-llm",
|
|
||||||
model=settings.LLM_MODEL,
|
|
||||||
prompt=langfuse_prompt,
|
|
||||||
input=messages,
|
|
||||||
)
|
|
||||||
if lf else None
|
|
||||||
)
|
|
||||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
if _gen_ctx:
|
|
||||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
|
||||||
_gen_ctx.__exit__(None, None, None)
|
|
||||||
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
if _span:
|
|
||||||
_span.update(output=_as_text(response.content))
|
|
||||||
return _as_text(response.content)
|
|
||||||
|
|
||||||
for call in response.tool_calls:
|
|
||||||
call_name = str(call.get("name", ""))
|
|
||||||
call_args = call.get("args", {})
|
|
||||||
logger.info(
|
|
||||||
"agent_setup: journey tool_call name=%s args=%s",
|
|
||||||
call_name,
|
|
||||||
json.dumps(call_args, ensure_ascii=True)[:500],
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_fn = tool_map.get(call_name)
|
|
||||||
if tool_fn is None:
|
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
|
||||||
else:
|
|
||||||
tool_output = await tool_fn.ainvoke(call_args)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"agent_setup: journey tool_result name=%s output=%s",
|
|
||||||
call_name,
|
|
||||||
str(tool_output)[:800],
|
|
||||||
)
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
||||||
|
|
||||||
# Fallback: exceeded max steps.
|
|
||||||
final = await llm.ainvoke(messages)
|
|
||||||
final_text = _as_text(final.content)
|
|
||||||
if _span:
|
|
||||||
_span.update(output=final_text)
|
|
||||||
return final_text
|
|
||||||
finally:
|
|
||||||
if _span_ctx:
|
|
||||||
_span_ctx.__exit__(None, None, None)
|
|
||||||
if lf:
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
# ── Existing-config loader ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def handle_journey_start(
|
async def _load_existing_template(
|
||||||
|
agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
frame: dict[str, Any],
|
db: AsyncSession,
|
||||||
) -> dict[str, Any]:
|
) -> str | None:
|
||||||
"""Handle a ``journey_start`` WS frame.
|
"""Return the prompt_template of an existing agent config, or None."""
|
||||||
|
# Try local first, then cloud.
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local = local_result.scalar_one_or_none()
|
||||||
|
if local is not None:
|
||||||
|
return local.prompt_template
|
||||||
|
|
||||||
Creates a session, runs the setup LLM with directory exploration,
|
cloud_result = await db.execute(
|
||||||
and returns the ``journey_reply`` payload.
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud = cloud_result.scalar_one_or_none()
|
||||||
|
return cloud.prompt_template if cloud is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
|
async def start_journey(
|
||||||
|
body: JourneyStartRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Start a new Chatbot Journey session.
|
||||||
|
|
||||||
|
If ``agent_id`` is provided the session is pre-seeded with the existing
|
||||||
|
agent's ``prompt_template`` so the user can refine it.
|
||||||
"""
|
"""
|
||||||
agent_type = frame.get("agent_type", "local")
|
# Load existing template (may be None).
|
||||||
directory = frame.get("directory", "")
|
existing_template: str | None = None
|
||||||
data_types = frame.get("data_types", [])
|
if body.agent_id:
|
||||||
existing_config = frame.get("existing_config")
|
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
||||||
|
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
||||||
|
# the user may be starting a fresh journey for a not-yet-persisted config).
|
||||||
|
|
||||||
# Use the session_id provided by the FE so the reply matches the
|
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
||||||
# listener key; fall back to a generated one if absent.
|
first_question = _first_question(body.agent_type)
|
||||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
|
||||||
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
|
|
||||||
|
|
||||||
session = JourneySession(
|
session_id = str(uuid.uuid4())
|
||||||
|
session = _JourneySession(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=current_user.id,
|
||||||
agent_type=agent_type,
|
agent_type=body.agent_type,
|
||||||
directory=directory,
|
# Seed history with the AI's first question so it stays consistent.
|
||||||
data_types=data_types,
|
history=[{"role": "assistant", "content": first_question}],
|
||||||
system_prompt=system_prompt,
|
|
||||||
langfuse_prompt=langfuse_prompt,
|
|
||||||
)
|
)
|
||||||
|
# Store the system prompt inside the session for reuse in /message.
|
||||||
# Seed with an initial user message — some providers require at least one
|
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
||||||
# user/input message to be present.
|
|
||||||
seed_history: list[dict[str, Any]] = [
|
|
||||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
|
||||||
]
|
|
||||||
ai_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
history=seed_history,
|
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
langfuse_prompt=langfuse_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
session.history.extend(seed_history)
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
_sessions[session_id] = session
|
_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info(
|
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
||||||
"agent_setup: journey session %s started for user %s (directory=%s)",
|
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
||||||
session_id,
|
|
||||||
user_id,
|
|
||||||
directory,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the LLM produced the config on the first turn (unlikely but possible).
|
|
||||||
agent_config = _extract_agent_config(ai_reply)
|
|
||||||
done = agent_config is not None
|
|
||||||
|
|
||||||
display_message = ai_reply
|
|
||||||
if done:
|
|
||||||
display_message = (
|
|
||||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
|
||||||
)
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": display_message,
|
|
||||||
"done": done,
|
|
||||||
"agent_config": agent_config,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_journey_message(
|
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
user_id: str,
|
async def send_journey_message(
|
||||||
frame: dict[str, Any],
|
body: JourneyMessageRequest,
|
||||||
) -> dict[str, Any]:
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
"""Handle a ``journey_message`` WS frame.
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Send a message in an existing Chatbot Journey session.
|
||||||
|
|
||||||
Appends the user message, calls the LLM, and returns the
|
The server appends the user's message to the conversation history,
|
||||||
``journey_reply`` payload.
|
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
||||||
|
``prompt_template`` block the response includes ``done=True`` and the
|
||||||
|
extracted template.
|
||||||
"""
|
"""
|
||||||
session_id = frame.get("session_id", "")
|
session = _get_session(body.session_id, current_user.id)
|
||||||
message = frame.get("message", "")
|
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
||||||
|
|
||||||
session = get_journey_session(session_id, user_id)
|
# Append user turn to history.
|
||||||
if session is None:
|
session.history.append({"role": "user", "content": body.message})
|
||||||
return {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": "Journey session not found or expired. Please start a new setup.",
|
|
||||||
"done": True,
|
|
||||||
"agent_config": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Append user turn.
|
# Call the LLM with the full conversation so far.
|
||||||
session.history.append({"role": "user", "content": message})
|
ai_reply = await _call_llm(system_prompt, session.history)
|
||||||
|
|
||||||
# Call the LLM with tools.
|
|
||||||
ai_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=session.system_prompt,
|
|
||||||
history=session.history,
|
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
|
||||||
user_id=session.user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
langfuse_prompt=session.langfuse_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Append AI turn.
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
# Check if the LLM produced the final config.
|
# Check if the LLM produced the final template.
|
||||||
agent_config = _extract_agent_config(ai_reply)
|
prompt_template = _extract_template(ai_reply)
|
||||||
done = agent_config is not None
|
done = prompt_template is not None
|
||||||
|
|
||||||
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
|
|
||||||
if not done:
|
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
|
||||||
if turns >= _MAX_TURNS:
|
|
||||||
nudge_content = (
|
|
||||||
"[System: You have enough information. Please generate the final "
|
|
||||||
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
|
||||||
)
|
|
||||||
session.history.append({"role": "user", "content": nudge_content})
|
|
||||||
|
|
||||||
nudge_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=session.system_prompt,
|
|
||||||
history=session.history,
|
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
|
||||||
user_id=session.user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
langfuse_prompt=session.langfuse_prompt,
|
|
||||||
)
|
|
||||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
|
||||||
|
|
||||||
agent_config = _extract_agent_config(nudge_reply)
|
|
||||||
if agent_config is not None:
|
|
||||||
done = True
|
|
||||||
ai_reply = nudge_reply
|
|
||||||
|
|
||||||
|
# Strip the sentinel markers from the message shown to the user.
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
if _CONFIG_START in ai_reply
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
else "Here is your agent configuration. You can save it or continue refining."
|
|
||||||
)
|
)
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
|
||||||
|
|
||||||
return {
|
if done:
|
||||||
"type": "journey_reply",
|
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
||||||
"session_id": session_id,
|
# Clean up the session immediately on completion.
|
||||||
"message": display_message,
|
_sessions.pop(body.session_id, None)
|
||||||
"done": done,
|
else:
|
||||||
"agent_config": agent_config,
|
# Nudge the LLM to wrap up after max turns.
|
||||||
}
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
|
if turns >= _MAX_TURNS:
|
||||||
|
# Add a system-level nudge as a hidden user message.
|
||||||
|
session.history.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"[System: You have enough information. Please generate the final "
|
||||||
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
return JourneyResponse(
|
||||||
|
session_id=body.session_id,
|
||||||
|
message=display_message,
|
||||||
|
done=done,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,36 +1,45 @@
|
|||||||
"""Agent routes.
|
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
||||||
|
|
||||||
Backend responsibilities are intentionally minimal:
|
Endpoints:
|
||||||
GET /agents/catalog — static catalog for UI display
|
GET /agents/catalog — hardcoded agent type catalog
|
||||||
POST /agents/can-create — billing eligibility check
|
GET /agents/local — list user's local agent configs
|
||||||
POST /agents/trigger — trigger a local agent run
|
POST /agents/local — create local agent (tier-gated)
|
||||||
|
PUT /agents/local/{agent_id} — partial update (ownership check)
|
||||||
Agent configuration is owned by the Electron app and is not persisted
|
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
||||||
in backend agent-config tables.
|
GET /agents/cloud — list user's cloud agent configs
|
||||||
|
POST /agents/cloud — create cloud agent (tier-gated)
|
||||||
|
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
||||||
|
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
||||||
|
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
||||||
|
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import uuid
|
from datetime import datetime
|
||||||
from datetime import datetime, timedelta, timezone
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from sqlalchemy import func, select
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, or_, select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
from app.core.agent_runner import is_agent_running, run_local_agent
|
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, LocalAgentConfig
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
AgentCatalogItem,
|
AgentCatalogItem,
|
||||||
AgentCreationCheckRequest,
|
|
||||||
AgentCreationCheckResponse,
|
|
||||||
AgentRunLogResponse,
|
AgentRunLogResponse,
|
||||||
AgentTriggerRequest,
|
CloudAgentConfigCreate,
|
||||||
|
CloudAgentConfigResponse,
|
||||||
|
CloudAgentConfigUpdate,
|
||||||
|
LocalAgentConfigCreate,
|
||||||
|
LocalAgentConfigResponse,
|
||||||
|
LocalAgentConfigUpdate,
|
||||||
UserProfile,
|
UserProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -47,21 +56,39 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
def _to_data_types(values: list[str]) -> list[str]:
|
# ── Model → schema converters ─────────────────────────────────────────
|
||||||
normalize = {
|
|
||||||
"task": "tasks", "tasks": "tasks",
|
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
||||||
"note": "notes", "notes": "notes",
|
return LocalAgentConfigResponse(
|
||||||
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
id=a.id,
|
||||||
"project": "projects", "projects": "projects",
|
name=a.name,
|
||||||
}
|
device_id=a.device_id,
|
||||||
seen: set[str] = set()
|
directory_paths=a.directory_paths,
|
||||||
result: list[str] = []
|
data_types=a.data_types,
|
||||||
for v in values:
|
prompt_template=a.prompt_template,
|
||||||
mapped = normalize.get(v)
|
file_extensions=a.file_extensions,
|
||||||
if mapped and mapped not in seen:
|
schedule_cron=a.schedule_cron,
|
||||||
seen.add(mapped)
|
enabled=a.enabled,
|
||||||
result.append(mapped)
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
return result
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
||||||
|
return CloudAgentConfigResponse(
|
||||||
|
id=a.id,
|
||||||
|
provider=a.provider, # type: ignore[arg-type]
|
||||||
|
name=a.name,
|
||||||
|
data_types=a.data_types,
|
||||||
|
prompt_template=a.prompt_template,
|
||||||
|
schedule_cron=a.schedule_cron,
|
||||||
|
filter_config=a.filter_config,
|
||||||
|
enabled=a.enabled,
|
||||||
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
@@ -78,42 +105,77 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
# ── Ownership-checked lookups ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_local_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> LocalAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_cloud_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> CloudAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier limit helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return combined enabled local + cloud agent count for the user."""
|
||||||
|
local_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(LocalAgentConfig.id)).where(
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
LocalAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
cloud_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(CloudAgentConfig.id)).where(
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
CloudAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
return local_count + cloud_count
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
if limit != -1 and current_count >= limit:
|
if limit != -1 and current_count >= limit:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
)
|
)
|
||||||
return limit
|
|
||||||
|
|
||||||
|
|
||||||
async def _enforce_run_frequency(
|
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
||||||
tier: str,
|
|
||||||
user_id: str,
|
|
||||||
db: AsyncSession,
|
|
||||||
) -> None:
|
|
||||||
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
|
||||||
if limit == -1:
|
|
||||||
return # unlimited
|
|
||||||
|
|
||||||
today_start = datetime.now(timezone.utc).replace(
|
class _RunsPage(BaseModel):
|
||||||
hour=0, minute=0, second=0, microsecond=0
|
total: int
|
||||||
)
|
page: int
|
||||||
result = await db.execute(
|
limit: int
|
||||||
select(func.count(AgentRunLog.id)).where(
|
items: list[AgentRunLogResponse]
|
||||||
AgentRunLog.user_id == user_id,
|
|
||||||
AgentRunLog.started_at >= today_start,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
runs_today: int = result.scalar_one()
|
|
||||||
|
|
||||||
if runs_today >= limit:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
@@ -147,61 +209,229 @@ async def get_agent_catalog(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
# ── Local agent CRUD ──────────────────────────────────────────────────
|
||||||
async def can_create_agent(
|
|
||||||
body: AgentCreationCheckRequest,
|
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
||||||
|
async def list_local_agents(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> AgentCreationCheckResponse:
|
db: AsyncSession = Depends(get_session),
|
||||||
"""Check if the user can create one more agent based on billing tier.
|
) -> list[LocalAgentConfigResponse]:
|
||||||
|
"""List all local directory agent configs owned by the authenticated user."""
|
||||||
Since configuration is client-owned, the Electron app sends its current
|
result = await db.execute(
|
||||||
active agent count and the backend applies tier limits.
|
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
||||||
"""
|
|
||||||
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
|
||||||
allowed = limit == -1 or body.active_agents < limit
|
|
||||||
return AgentCreationCheckResponse(
|
|
||||||
allowed=allowed,
|
|
||||||
tier=current_user.tier,
|
|
||||||
active_agents=body.active_agents,
|
|
||||||
limit=limit,
|
|
||||||
)
|
)
|
||||||
|
return [_to_local_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_local_agent(
|
||||||
|
body: LocalAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Create a new local directory agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = LocalAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
name=body.name,
|
||||||
|
device_id=body.device_id,
|
||||||
|
directory_paths=body.directory_paths,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
file_extensions=body.file_extensions,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
||||||
|
async def update_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: LocalAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Partially update a local agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/local/{agent_id}", response_model=dict)
|
||||||
|
async def delete_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
||||||
|
async def list_cloud_agents(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[CloudAgentConfigResponse]:
|
||||||
|
"""List all cloud connector agent configs owned by the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
return [_to_cloud_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_cloud_agent(
|
||||||
|
body: CloudAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Create a new cloud connector agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = CloudAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
provider=body.provider,
|
||||||
|
name=body.name,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
oauth_token_encrypted=body.oauth_token_encrypted,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
filter_config=body.filter_config,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
||||||
|
async def update_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: CloudAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Partially update a cloud agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/cloud/{agent_id}", response_model=dict)
|
||||||
|
async def delete_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Run logs ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/runs", response_model=_RunsPage)
|
||||||
|
async def list_run_logs(
|
||||||
|
agent_id: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=20, ge=1, le=100),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _RunsPage:
|
||||||
|
"""Return paginated run logs for the authenticated user.
|
||||||
|
|
||||||
|
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
||||||
|
"""
|
||||||
|
base_filter = [AgentRunLog.user_id == current_user.id]
|
||||||
|
if agent_id:
|
||||||
|
base_filter.append(AgentRunLog.agent_id == agent_id)
|
||||||
|
|
||||||
|
total = (
|
||||||
|
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
||||||
|
).scalar_one()
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentRunLog)
|
||||||
|
.where(*base_filter)
|
||||||
|
.order_by(AgentRunLog.started_at.desc())
|
||||||
|
.offset((page - 1) * limit)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
||||||
|
|
||||||
|
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Manual trigger stub ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
async def trigger_agent_run(
|
async def trigger_agent_run(
|
||||||
body: AgentTriggerRequest,
|
agent_id: str,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> AgentRunLogResponse:
|
) -> AgentRunLogResponse:
|
||||||
"""Trigger a local agent run using client-provided configuration."""
|
"""Manually trigger an agent run.
|
||||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
|
||||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
|
||||||
|
|
||||||
config = LocalAgentConfig(
|
Looks up the agent config (local or cloud) by ID with ownership check,
|
||||||
id=str(uuid.uuid4()),
|
creates a run log entry with ``status="running"``, and returns it.
|
||||||
user_id=current_user.id,
|
|
||||||
device_id=body.device_id,
|
Actual dispatch to the agent runner is wired in Step 3.4 once
|
||||||
name="Local Directory Monitor",
|
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
||||||
directory_paths=[body.directory],
|
"""
|
||||||
data_types=_to_data_types(body.what_to_extract),
|
# Determine agent type by trying local first, then cloud.
|
||||||
prompt_template=body.custom_agent_prompt,
|
# Keep the full config object so we can pass it to the agent runner.
|
||||||
file_extensions=[],
|
local_config: LocalAgentConfig | None = None
|
||||||
schedule_cron=body.batch_interval,
|
cloud_config: CloudAgentConfig | None = None
|
||||||
enabled=True,
|
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == current_user.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
|
||||||
stable_agent_id = body.agent_id or config.id
|
|
||||||
|
|
||||||
if is_agent_running(stable_agent_id):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
|
||||||
)
|
)
|
||||||
|
local_config = local_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if local_config is not None:
|
||||||
|
agent_type = "local"
|
||||||
|
else:
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud_config = cloud_result.scalar_one_or_none()
|
||||||
|
if cloud_config is not None:
|
||||||
|
agent_type = "cloud"
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = AgentRunLog(
|
||||||
agent_id=stable_agent_id,
|
agent_id=agent_id,
|
||||||
agent_type="local",
|
agent_type=agent_type,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
status="running",
|
status="running",
|
||||||
)
|
)
|
||||||
@@ -209,14 +439,14 @@ async def trigger_agent_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
run_context = {
|
# Dispatch the run as a background task — returns 202 immediately.
|
||||||
"type": "agent_batch",
|
if agent_type == "local" and local_config is not None:
|
||||||
"run_id": run_log.id,
|
|
||||||
"agent_id": stable_agent_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
||||||
|
)
|
||||||
|
elif agent_type == "cloud" and cloud_config is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|||||||
171
app/api/routes/backup.py
Normal file
171
app/api/routes/backup.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
|
||||||
|
PostgreSQL ``backup_metadata`` table.
|
||||||
|
|
||||||
|
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
||||||
|
treating "history" as a ``{backup_id}`` path parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from email.utils import parsedate_to_datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import BackupMetadata as BackupMetadataModel
|
||||||
|
from app.schemas import BackupMetadata, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/backup", tags=["backup"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
|
||||||
|
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return total backup bytes stored by *user_id*."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
|
||||||
|
BackupMetadataModel.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_backup_quota(
|
||||||
|
user: UserProfile, size_bytes: int, db: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||||
|
current = await _current_backup_bytes(user.id, db)
|
||||||
|
tier_manager.enforce_backup_quota(
|
||||||
|
user.tier, current_bytes=current, additional_bytes=size_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("")
|
||||||
|
async def upload_backup(
|
||||||
|
request: Request,
|
||||||
|
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
||||||
|
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
||||||
|
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Upload an E2E-encrypted backup blob.
|
||||||
|
|
||||||
|
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
||||||
|
"""
|
||||||
|
blob = await request.body()
|
||||||
|
reject_if_tampered(blob, x_backup_checksum)
|
||||||
|
await _check_backup_quota(current_user, len(blob), db)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
row = BackupMetadataModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=current_user.id,
|
||||||
|
s3_key=s3_key,
|
||||||
|
version=x_backup_version,
|
||||||
|
timestamp=x_backup_timestamp,
|
||||||
|
checksum=x_backup_checksum,
|
||||||
|
size_bytes=len(blob),
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history", response_model=list[BackupMetadata])
|
||||||
|
async def backup_history(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[BackupMetadata]:
|
||||||
|
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
return [
|
||||||
|
BackupMetadata(
|
||||||
|
version=r.version,
|
||||||
|
timestamp=r.timestamp,
|
||||||
|
checksum=r.checksum,
|
||||||
|
chunk_count=1,
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def download_backup(
|
||||||
|
request: Request,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> Response:
|
||||||
|
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
latest = result.scalar_one_or_none()
|
||||||
|
if latest is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
||||||
|
|
||||||
|
ims_header = request.headers.get("If-Modified-Since")
|
||||||
|
if ims_header:
|
||||||
|
try:
|
||||||
|
ims_dt = parsedate_to_datetime(ims_header)
|
||||||
|
ims_ms = int(ims_dt.timestamp() * 1000)
|
||||||
|
if latest.timestamp <= ims_ms:
|
||||||
|
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
||||||
|
except Exception:
|
||||||
|
pass # malformed header — ignore and serve the blob
|
||||||
|
|
||||||
|
blob = await _blob_store.download(current_user.id, latest.s3_key)
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={
|
||||||
|
"X-Backup-Version": str(latest.version),
|
||||||
|
"X-Backup-Timestamp": str(latest.timestamp),
|
||||||
|
"X-Checksum": latest.checksum,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{backup_id}", response_model=dict)
|
||||||
|
async def delete_backup(
|
||||||
|
backup_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a specific backup by ID."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel).where(
|
||||||
|
BackupMetadataModel.id == backup_id,
|
||||||
|
BackupMetadataModel.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
target = result.scalar_one_or_none()
|
||||||
|
if target is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
||||||
|
|
||||||
|
await _blob_store.delete(current_user.id, target.s3_key)
|
||||||
|
await db.delete(target)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
|
"""Chat routes: POST /chat (REST fallback).
|
||||||
|
|
||||||
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||||
"""
|
"""
|
||||||
@@ -7,53 +7,36 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.core.deep_agent import run_home
|
from app.core.deep_agent import run_home
|
||||||
from app.core.llm import embed
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.db import async_session
|
||||||
|
from app.schemas import ChatRequest, ChatResponse, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
# ── Embed helpers ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _EmbedRequest(BaseModel):
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class _EmbedResponse(BaseModel):
|
|
||||||
vector: list[float]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def chat(
|
async def chat(
|
||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""REST fallback for home chat when websocket streaming is unavailable."""
|
"""Route a chat message through the Home deep agent (non-streaming)."""
|
||||||
response = await run_home(
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(current_user.id, body.message)
|
||||||
|
|
||||||
|
context = {
|
||||||
|
**body.context.model_dump(),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
response_text = await run_home(
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
message=body.message,
|
message=body.message,
|
||||||
context=body.context.model_dump(),
|
context=context,
|
||||||
|
db_session_factory=async_session,
|
||||||
)
|
)
|
||||||
return JSONResponse(content={"response": response})
|
result = ChatResponse(response=response_text)
|
||||||
|
return JSONResponse(content=result.model_dump())
|
||||||
|
|
||||||
@router.post("/embed", response_model=_EmbedResponse)
|
|
||||||
async def embed_text(
|
|
||||||
body: _EmbedRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> _EmbedResponse:
|
|
||||||
"""Generate a 1536-dim embedding vector for the given text.
|
|
||||||
|
|
||||||
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
|
||||||
Used by Electron (vectordb.ts) for local note search.
|
|
||||||
"""
|
|
||||||
vector = await embed(body.text)
|
|
||||||
return _EmbedResponse(vector=vector)
|
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ Protocol:
|
|||||||
|
|
||||||
Incoming frame dispatch:
|
Incoming frame dispatch:
|
||||||
- ``tool_result`` → resolves a pending tool-call Future.
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
- ``journey_start`` → starts a guided setup journey session.
|
- ``agent_data`` → enqueued in the per-run agent data queue.
|
||||||
- ``journey_message`` → continues a journey conversation.
|
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
||||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
- unknown types → logged, ignored.
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
@@ -39,13 +39,12 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
from app.core.deep_agent import run_floating_stream, run_home_stream
|
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.output_formatter import StreamFormatter
|
from app.core.deep_agent import run_home_stream, run_floating_stream
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
@@ -148,6 +147,37 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: tool_result missing id from user=%s", user_id
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_data:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
await queue.put(frame)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data for unknown run user=%s run=%s",
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_complete:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
# Sentinel: signals the agent data stream is finished.
|
||||||
|
await queue.put(None)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.home_request:
|
elif frame_type == WsFrameType.home_request:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_home_request(websocket, user_id, frame)
|
_handle_home_request(websocket, user_id, frame)
|
||||||
@@ -158,16 +188,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_floating_request(websocket, user_id, frame)
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.journey_start:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_journey_start(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.journey_message:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_journey_message(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -180,13 +200,35 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
|
|
||||||
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
|
||||||
|
|
||||||
|
|
||||||
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
async def _executor(payload: dict) -> dict:
|
async def _executor(payload: dict) -> dict:
|
||||||
payload["type"] = WsFrameType.tool_call
|
payload["type"] = WsFrameType.tool_call
|
||||||
|
call_id = payload["id"]
|
||||||
|
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
|
||||||
await websocket.send_text(json.dumps(payload))
|
await websocket.send_text(json.dumps(payload))
|
||||||
future = device_manager.create_pending_call(user_id, payload["id"])
|
future = device_manager.create_pending_call(user_id, call_id)
|
||||||
return await future
|
try:
|
||||||
|
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
|
||||||
|
call_id, payload.get("action"), user_id,
|
||||||
|
)
|
||||||
|
# Clean up the pending future so it doesn't leak
|
||||||
|
conn = device_manager._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.pending_calls.pop(call_id, None)
|
||||||
|
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
|
||||||
|
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
|
||||||
|
call_id, type(result).__name__,
|
||||||
|
list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
|
if result is None:
|
||||||
|
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
|
||||||
|
return result
|
||||||
return _executor
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
@@ -199,27 +241,14 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
logger.info(
|
|
||||||
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
message[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
user_id,
|
|
||||||
message,
|
|
||||||
trace_id=request_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -227,11 +256,12 @@ async def _handle_home_request(
|
|||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_home_stream(user_id, message, context)
|
event_stream = run_home_stream(
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
user_id, message, context, db_session_factory=async_session
|
||||||
|
)
|
||||||
|
formatter = HomeFormatter(request_id=request_id)
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
# Collect text chunks to build the full response for episode storage
|
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -246,14 +276,7 @@ async def _handle_home_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
len("".join(response_chunks)),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -267,37 +290,23 @@ async def _handle_floating_request(
|
|||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
scope: dict = frame.get("scope", {})
|
scope: dict = frame.get("scope", {})
|
||||||
logger.info(
|
|
||||||
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
json.dumps(scope, ensure_ascii=True)[:200],
|
|
||||||
message[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
user_id,
|
|
||||||
message,
|
|
||||||
trace_id=request_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {"scope": scope, **memory_context}
|
||||||
"scope": scope,
|
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
**memory_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_floating_stream(user_id, message, context)
|
event_stream = run_floating_stream(
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
user_id, message, context, scope=scope,
|
||||||
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
|
formatter = FloatingFormatter(request_id=request_id)
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
@@ -314,72 +323,8 @@ async def _handle_floating_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
len("".join(response_chunks)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_journey_start(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Handle a journey_start frame — explores directory and sends first question."""
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
|
||||||
set_client_executor(executor)
|
|
||||||
try:
|
|
||||||
reply = await handle_journey_start(user_id, frame)
|
|
||||||
await websocket.send_text(json.dumps(reply))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(
|
|
||||||
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
|
||||||
)
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": frame.get("session_id", ""),
|
|
||||||
"message": f"Failed to start journey: {exc}",
|
|
||||||
"done": True,
|
|
||||||
"prompt_template": None,
|
|
||||||
}))
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_journey_message(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Handle a journey_message frame — continues the journey conversation."""
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
|
||||||
set_client_executor(executor)
|
|
||||||
try:
|
|
||||||
reply = await handle_journey_message(user_id, frame)
|
|
||||||
await websocket.send_text(json.dumps(reply))
|
|
||||||
except Exception as exc:
|
|
||||||
session_id = frame.get("session_id", "")
|
|
||||||
logger.error(
|
|
||||||
"device_ws: journey_message failed user=%s session=%s: %s",
|
|
||||||
user_id, session_id, exc,
|
|
||||||
)
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": f"Journey error: {exc}",
|
|
||||||
"done": True,
|
|
||||||
"prompt_template": None,
|
|
||||||
}))
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
@@ -415,3 +360,6 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
|||||||
user_id,
|
user_id,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
148
app/api/routes/plugins.py
Normal file
148
app/api/routes/plugins.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""Plugins routes: browse and install plugins from the marketplace.
|
||||||
|
|
||||||
|
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
|
||||||
|
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.db import get_session
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
from app.models import PluginInstallation, PluginReview as PluginReviewModel
|
||||||
|
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier gate ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _require_plugin_tier(user: UserProfile) -> None:
|
||||||
|
"""Raise HTTP 403 for users below Power tier."""
|
||||||
|
if user.tier not in ("power", "team"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Plugin marketplace requires Power tier or above",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local detail schema ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _PluginDetail(BaseModel):
|
||||||
|
plugin: PluginManifest
|
||||||
|
install_count: int
|
||||||
|
ratings: list[Any]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("", response_model=PluginListResponse)
|
||||||
|
async def list_plugins(
|
||||||
|
category: str | None = Query(default=None),
|
||||||
|
q: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||||
|
async def get_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _PluginDetail:
|
||||||
|
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Fetch review ratings for this plugin
|
||||||
|
review_result = await db.execute(
|
||||||
|
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
|
||||||
|
)
|
||||||
|
reviews = review_result.scalars().all()
|
||||||
|
ratings = [
|
||||||
|
{
|
||||||
|
"reviewer_id": r.reviewer_id,
|
||||||
|
"decision": r.decision,
|
||||||
|
"notes": r.notes,
|
||||||
|
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
|
||||||
|
}
|
||||||
|
for r in reviews
|
||||||
|
]
|
||||||
|
|
||||||
|
return _PluginDetail(
|
||||||
|
plugin=entry["manifest"],
|
||||||
|
install_count=entry["install_count"],
|
||||||
|
ratings=ratings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def install_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
||||||
|
|
||||||
|
Requires Power tier or above.
|
||||||
|
"""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Record the installation in plugin_installations
|
||||||
|
installation = PluginInstallation(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
)
|
||||||
|
db.add(installation)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
await revenue_share.record_install(
|
||||||
|
db,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
amount_cents=entry["manifest"].price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
||||||
|
return {"ok": True, "download_url": download_url}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def uninstall_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Unregister a plugin installation."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(PluginInstallation).where(
|
||||||
|
PluginInstallation.plugin_id == plugin_id,
|
||||||
|
PluginInstallation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
installation = result.scalar_one_or_none()
|
||||||
|
if installation is not None:
|
||||||
|
await db.delete(installation)
|
||||||
|
await db.commit()
|
||||||
|
await registry.record_uninstall(db, plugin_id)
|
||||||
|
return {"ok": True}
|
||||||
195
app/api/routes/storage.py
Normal file
195
app/api/routes/storage.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
|
||||||
|
PostgreSQL ``storage_records`` table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import StorageRecord
|
||||||
|
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["storage"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local response schemas ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CreateResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordMeta(BaseModel):
|
||||||
|
id: str
|
||||||
|
table: str
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return total bytes stored by *user_id*."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
|
||||||
|
StorageRecord.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
|
||||||
|
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
|
||||||
|
current = await _current_usage_bytes(user.id, db)
|
||||||
|
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_record_for_user(
|
||||||
|
record_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> StorageRecord:
|
||||||
|
"""Look up a record and verify ownership. Returns 404 on mismatch
|
||||||
|
to prevent user enumeration attacks."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(StorageRecord).where(
|
||||||
|
StorageRecord.id == record_id, StorageRecord.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_record(
|
||||||
|
body: StorageRecordCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _CreateResponse:
|
||||||
|
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
await _check_quota(current_user, len(body.blob), db)
|
||||||
|
|
||||||
|
record_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, body.table, record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
record = StorageRecord(
|
||||||
|
id=record_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
table_name=body.table,
|
||||||
|
s3_key=s3_key,
|
||||||
|
checksum=body.checksum,
|
||||||
|
size_bytes=len(body.blob),
|
||||||
|
)
|
||||||
|
db.add(record)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(record)
|
||||||
|
|
||||||
|
created_at_ms = int(record.created_at.timestamp() * 1000)
|
||||||
|
return _CreateResponse(id=record_id, created_at=created_at_ms)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records", response_model=list[_RecordMeta])
|
||||||
|
async def list_records(
|
||||||
|
table: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[_RecordMeta]:
|
||||||
|
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
||||||
|
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
|
||||||
|
if table is not None:
|
||||||
|
query = query.where(StorageRecord.table_name == table)
|
||||||
|
query = query.offset((page - 1) * limit).limit(limit)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
_RecordMeta(
|
||||||
|
id=r.id,
|
||||||
|
table=r.table_name,
|
||||||
|
checksum=r.checksum,
|
||||||
|
created_at=int(r.created_at.timestamp() * 1000),
|
||||||
|
updated_at=int(r.updated_at.timestamp() * 1000),
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records/{record_id}")
|
||||||
|
async def download_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> Response:
|
||||||
|
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
blob = await _blob_store.download(current_user.id, record.s3_key)
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"X-Checksum": record.checksum},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/records/{record_id}", response_model=dict)
|
||||||
|
async def update_record(
|
||||||
|
record_id: str,
|
||||||
|
body: StorageRecordUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
|
||||||
|
delta = len(body.blob) - record.size_bytes
|
||||||
|
if delta > 0:
|
||||||
|
await _check_quota(current_user, delta, db)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, record.table_name, record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
record.s3_key = s3_key
|
||||||
|
record.checksum = body.checksum
|
||||||
|
record.size_bytes = len(body.blob)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/records/{record_id}", response_model=dict)
|
||||||
|
async def delete_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a record and its S3 blob."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
await _blob_store.delete(current_user.id, record.s3_key)
|
||||||
|
await db.delete(record)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
79
app/api/routes/vectors.py
Normal file
79
app/api/routes/vectors.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.llm import embed
|
||||||
|
from app.schemas import (
|
||||||
|
UserProfile,
|
||||||
|
VectorSearchRequest,
|
||||||
|
VectorSearchResponse,
|
||||||
|
VectorUpsertRequest,
|
||||||
|
)
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
from app.storage.vector_store import VectorStore
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["vectors"])
|
||||||
|
|
||||||
|
_vector_store = VectorStore()
|
||||||
|
|
||||||
|
|
||||||
|
class _VectorDeleteRequest(BaseModel):
|
||||||
|
ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedResponse(BaseModel):
|
||||||
|
vector: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/upsert", response_model=dict)
|
||||||
|
async def upsert_vectors(
|
||||||
|
body: VectorUpsertRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
||||||
|
for item in body.vectors:
|
||||||
|
reject_if_tampered(item.blob, item.checksum)
|
||||||
|
await _vector_store.upsert(current_user.id, body.vectors)
|
||||||
|
return {"upserted": len(body.vectors)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
||||||
|
async def search_vectors(
|
||||||
|
body: VectorSearchRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> VectorSearchResponse:
|
||||||
|
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
||||||
|
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
||||||
|
return VectorSearchResponse(results=results)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/vectors", response_model=dict)
|
||||||
|
async def delete_vectors(
|
||||||
|
body: _VectorDeleteRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete vectors by ID, scoped to the authenticated user."""
|
||||||
|
await _vector_store.delete(current_user.id, body.ids)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/embed", response_model=_EmbedResponse)
|
||||||
|
async def embed_text(
|
||||||
|
body: _EmbedRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _EmbedResponse:
|
||||||
|
"""Generate a 1536-dim embedding vector for the given text.
|
||||||
|
|
||||||
|
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||||
|
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
|
||||||
|
"""
|
||||||
|
vector = await embed(body.text)
|
||||||
|
return _EmbedResponse(vector=vector)
|
||||||
@@ -21,33 +21,41 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"free": {
|
"free": {
|
||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
"batch_runs_per_day": 5,
|
"cloud_storage_gb": 0,
|
||||||
|
"backup_gb": 0,
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
"batch_runs_per_day": 50,
|
"cloud_storage_gb": 5,
|
||||||
|
"backup_gb": 5,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1, # unlimited
|
"batch_active": -1, # unlimited
|
||||||
"batch_runs_per_day": -1, # unlimited
|
"cloud_storage_gb": 25,
|
||||||
|
"backup_gb": 25,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
"batch_runs_per_day": -1, # unlimited
|
"cloud_storage_gb": -1, # unlimited
|
||||||
|
"backup_gb": -1, # unlimited
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
"sso": True,
|
"sso": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -69,18 +77,16 @@ class TierManager:
|
|||||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for ``user_id`` from the DB.
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
Falls back to ``'free'`` when no subscription row exists.
|
||||||
when no subscription row exists.
|
|
||||||
"""
|
"""
|
||||||
from app.models import Subscription # noqa: PLC0415
|
from app.models import Subscription # noqa: PLC0415
|
||||||
from app.config.settings import settings # noqa: PLC0415
|
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str | None = result.scalar_one_or_none()
|
tier: str | None = result.scalar_one_or_none()
|
||||||
if tier is None or tier not in FEATURES:
|
if tier is None or tier not in FEATURES:
|
||||||
return "power" if settings.ENV == "dev" else "free"
|
return "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
# ── Feature access ───────────────────────────────────────────────────
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
@@ -113,6 +119,71 @@ class TierManager:
|
|||||||
"""Return the requests-per-minute limit for ``tier``."""
|
"""Return the requests-per-minute limit for ``tier``."""
|
||||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
|
# ── Storage quota ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def enforce_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||||
|
|
||||||
|
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||||
|
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||||
|
"""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def enforce_backup_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||||
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
return False
|
||||||
|
if limit_gb == -1:
|
||||||
|
return True
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
return current_bytes + additional_bytes <= limit_bytes
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton shared across the app.
|
# Module-level singleton shared across the app.
|
||||||
tier_manager = TierManager()
|
tier_manager = TierManager()
|
||||||
|
|||||||
@@ -12,6 +12,17 @@ class Settings(BaseSettings):
|
|||||||
STRIPE_SECRET_KEY: str = ""
|
STRIPE_SECRET_KEY: str = ""
|
||||||
STRIPE_WEBHOOK_SECRET: str = ""
|
STRIPE_WEBHOOK_SECRET: str = ""
|
||||||
|
|
||||||
|
S3_BUCKET: str = ""
|
||||||
|
S3_REGION: str = "us-east-1"
|
||||||
|
S3_ENDPOINT_URL: str = ""
|
||||||
|
AWS_ACCESS_KEY_ID: str = ""
|
||||||
|
AWS_SECRET_ACCESS_KEY: str = ""
|
||||||
|
|
||||||
|
PINECONE_API_KEY: str = ""
|
||||||
|
PINECONE_INDEX: str = "adiuva"
|
||||||
|
QDRANT_URL: str = ""
|
||||||
|
QDRANT_API_KEY: str = ""
|
||||||
|
|
||||||
OPENAI_API_KEY: str = ""
|
OPENAI_API_KEY: str = ""
|
||||||
ANTHROPIC_API_KEY: str = ""
|
ANTHROPIC_API_KEY: str = ""
|
||||||
GOOGLE_API_KEY: str = ""
|
GOOGLE_API_KEY: str = ""
|
||||||
@@ -41,10 +52,6 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
|
|
||||||
LANGFUSE_SECRET_KEY: str = ""
|
|
||||||
LANGFUSE_PUBLIC_KEY: str = ""
|
|
||||||
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
|
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
"""Minimal agent base types retained for compatibility with batch runners."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
|
||||||
"""Common base for non-chat agents still using the old base contract."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
user_id: str = "",
|
|
||||||
shared_memory: dict[str, Any] | None = None,
|
|
||||||
vector_store_context: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.user_id = user_id
|
|
||||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
|
||||||
self.vector_store_context: list[str] = vector_store_context or []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_name(self) -> str: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_description(self) -> str: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def skills(self) -> list[str]:
|
|
||||||
return []
|
|
||||||
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
@@ -3,15 +3,20 @@
|
|||||||
Maintains in-memory state for all active Electron → backend WebSocket
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
connections. One connection per user (latest replaces previous).
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
The manager handles the **tool-call round-trip** pattern:
|
The manager participates in two interaction patterns:
|
||||||
- Backend sends ``tool_call`` frame → Electron executes the action →
|
|
||||||
returns ``tool_result`` frame.
|
1. **Tool-call round-trip** (bidirectional CRUD):
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
||||||
|
``tool_result`` frame.
|
||||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
receive the result dict from Electron.
|
receive the result dict from Electron.
|
||||||
|
|
||||||
This pattern is used by all tools (CRUD, file-system, etc.) via
|
2. **Agent-data streaming** (local directory agent runs):
|
||||||
``execute_on_client()`` in ``ws_context.py``.
|
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
||||||
|
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
||||||
|
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
||||||
|
a specific ``run_id`` so the agent runner can iterate frames.
|
||||||
|
|
||||||
The ``device_manager`` module-level singleton is imported by both the
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
device WS route and the agent runner.
|
device WS route and the agent runner.
|
||||||
@@ -37,6 +42,8 @@ class DeviceConnection:
|
|||||||
device_id: str
|
device_id: str
|
||||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
# Per-run queues for agent_data / agent_complete frames.
|
||||||
|
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class DeviceConnectionManager:
|
class DeviceConnectionManager:
|
||||||
@@ -146,6 +153,31 @@ class DeviceConnectionManager:
|
|||||||
if fut is not None and not fut.done():
|
if fut is not None and not fut.done():
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
|
|
||||||
|
# ── Agent-data queue ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_agent_data_queue(
|
||||||
|
self, user_id: str, run_id: str
|
||||||
|
) -> asyncio.Queue[dict | None]:
|
||||||
|
"""Return (creating if absent) the queue for *run_id* agent frames.
|
||||||
|
|
||||||
|
The agent runner reads from this queue. The device WS handler writes
|
||||||
|
to it. ``None`` is the sentinel that signals the stream is finished.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"get_agent_data_queue: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
if run_id not in conn.agent_data_queues:
|
||||||
|
conn.agent_data_queues[run_id] = asyncio.Queue()
|
||||||
|
return conn.agent_data_queues[run_id]
|
||||||
|
|
||||||
|
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
||||||
|
"""Remove the queue for *run_id* once a run has completed."""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.agent_data_queues.pop(run_id, None)
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — import this everywhere.
|
# Module-level singleton — import this everywhere.
|
||||||
device_manager = DeviceConnectionManager()
|
device_manager = DeviceConnectionManager()
|
||||||
|
|||||||
@@ -1,147 +0,0 @@
|
|||||||
"""Langfuse observability — singleton client and prompt helpers.
|
|
||||||
|
|
||||||
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
|
|
||||||
all helpers are no-ops so the app works without Langfuse configured.
|
|
||||||
|
|
||||||
Usage
|
|
||||||
-----
|
|
||||||
Tracing::
|
|
||||||
|
|
||||||
from app.core.langfuse_client import get_langfuse
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
if lf:
|
|
||||||
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
|
|
||||||
span.update(input=user_message)
|
|
||||||
# ... do work ...
|
|
||||||
span.update(output=result)
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
Prompt management::
|
|
||||||
|
|
||||||
from app.core.langfuse_client import get_prompt_or_fallback
|
|
||||||
|
|
||||||
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
|
|
||||||
# Use text as the system prompt; pass prompt_obj to generations for linking.
|
|
||||||
|
|
||||||
Linking a prompt to a generation::
|
|
||||||
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="llm-call",
|
|
||||||
model="gpt-4o",
|
|
||||||
prompt=prompt_obj, # links generation → prompt version in the UI
|
|
||||||
input=messages,
|
|
||||||
) as gen:
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
gen.update(output=response.content, usage=_usage(response))
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_client: Any = None
|
|
||||||
_initialized: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_langfuse() -> Any | None:
|
|
||||||
"""Return the Langfuse singleton, or ``None`` when not configured."""
|
|
||||||
global _client, _initialized
|
|
||||||
if _initialized:
|
|
||||||
return _client
|
|
||||||
_initialized = True
|
|
||||||
|
|
||||||
from app.config.settings import settings # local import to avoid circular deps
|
|
||||||
|
|
||||||
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
|
||||||
logger.debug("langfuse: not configured — observability disabled")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import Langfuse
|
|
||||||
|
|
||||||
_client = Langfuse(
|
|
||||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
|
||||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
|
||||||
host=settings.LANGFUSE_HOST,
|
|
||||||
)
|
|
||||||
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_HOST)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse: failed to initialize: %s", exc)
|
|
||||||
_client = None
|
|
||||||
|
|
||||||
return _client
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
|
||||||
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
|
|
||||||
|
|
||||||
Returns ``(raw_template, prompt_obj_or_None)``.
|
|
||||||
|
|
||||||
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
|
|
||||||
on it directly; use :func:`compile_prompt` instead so the correct variable
|
|
||||||
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
|
|
||||||
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
|
|
||||||
unavailable / the fetch failed. Pass this to generation observations so
|
|
||||||
Langfuse links the generation to the exact prompt version in the UI.
|
|
||||||
"""
|
|
||||||
lf = get_langfuse()
|
|
||||||
if lf is None:
|
|
||||||
return fallback, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
prompt = lf.get_prompt(name, label="production", fallback=fallback)
|
|
||||||
# For text-type prompts .prompt holds the raw template string.
|
|
||||||
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
|
|
||||||
return raw, prompt
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
|
|
||||||
return fallback, None
|
|
||||||
|
|
||||||
|
|
||||||
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
|
|
||||||
"""Compile *template* with *variables*, choosing the right syntax.
|
|
||||||
|
|
||||||
* When *prompt_obj* is a real Langfuse prompt object, calls
|
|
||||||
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
|
|
||||||
substitution as defined in the Langfuse UI.
|
|
||||||
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
|
|
||||||
falls back to ``template.format(**variables)`` which handles the
|
|
||||||
``{variable}`` syntax used in the hardcoded fallback strings.
|
|
||||||
|
|
||||||
This keeps callers oblivious to which syntax is in use.
|
|
||||||
"""
|
|
||||||
if prompt_obj is not None:
|
|
||||||
try:
|
|
||||||
compiled = prompt_obj.compile(**variables)
|
|
||||||
# compile() returns a string for text prompts.
|
|
||||||
if isinstance(compiled, str):
|
|
||||||
return compiled
|
|
||||||
# Chat prompts return a list of dicts — join text parts.
|
|
||||||
if isinstance(compiled, list):
|
|
||||||
return "\n".join(
|
|
||||||
m.get("content", "") for m in compiled if isinstance(m, dict)
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
|
|
||||||
getattr(prompt_obj, "name", "?"),
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
return template.format(**variables)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_usage(response: Any) -> dict[str, int]:
|
|
||||||
"""Extract token usage from a LangChain AI message into Langfuse format."""
|
|
||||||
meta = getattr(response, "usage_metadata", None)
|
|
||||||
if not meta:
|
|
||||||
return {}
|
|
||||||
return {
|
|
||||||
"input": int(meta.get("input_tokens", 0)),
|
|
||||||
"output": int(meta.get("output_tokens", 0)),
|
|
||||||
"total": int(meta.get("total_tokens", 0)),
|
|
||||||
}
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
|
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
|
||||||
instead of directly constructing a provider-specific class. The model string
|
instead of directly constructing a provider-specific class. The model string
|
||||||
follows the `LiteLLM model naming convention
|
follows the `LiteLLM model naming convention
|
||||||
<https://docs.litellm.ai/docs/providers>`_:
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
@@ -18,7 +18,6 @@ Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -33,14 +32,6 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
# Some provider responses include a plain dict in the `usage` field where a
|
|
||||||
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore",
|
|
||||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
|
||||||
category=UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
|
|||||||
@@ -43,21 +43,15 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware:
|
class MemoryMiddleware:
|
||||||
"""Enrich orchestrator context with memory and persist interactions after."""
|
"""Enrich agent context with memory and persist interactions after."""
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession) -> None:
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(
|
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
||||||
self,
|
"""Build memory context dict to inject into the agent before LLM call.
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
trace_id: str | None = None,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
core_memory — {key: plaintext_value, ...}
|
core_memory — {key: plaintext_value, ...}
|
||||||
@@ -71,21 +65,9 @@ class MemoryMiddleware:
|
|||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
episodic = await self._load_episodic(user_id, fernet)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
|
||||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
user_dbg.get("tier") or "-",
|
|
||||||
len(core),
|
|
||||||
len(associative),
|
|
||||||
len(episodic),
|
|
||||||
len(proactive),
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"core_memory": core,
|
"core_memory": core,
|
||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
@@ -99,7 +81,6 @@ class MemoryMiddleware:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
response: str,
|
response: str,
|
||||||
trace_id: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
@@ -122,19 +103,11 @@ class MemoryMiddleware:
|
|||||||
self._db.add(row)
|
self._db.add(row)
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
|
||||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
user_dbg.get("tier") or "-",
|
|
||||||
session_id,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -160,176 +133,10 @@ class MemoryMiddleware:
|
|||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
|
||||||
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
user_dbg.get("tier") or "-",
|
|
||||||
key,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
|
||||||
"""Return core memory as editable blocks (label/value)."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore)
|
|
||||||
.where(MemoryCore.user_id == user_id)
|
|
||||||
.order_by(MemoryCore.key.asc())
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
out: list[dict[str, str]] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
||||||
if plaintext is not None:
|
|
||||||
out.append({"label": row.key, "value": plaintext})
|
|
||||||
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
|
||||||
"""Return a single core memory block value by label."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(
|
|
||||||
MemoryCore.user_id == user_id,
|
|
||||||
MemoryCore.key == label,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
|
||||||
return None
|
|
||||||
value = _safe_decrypt(fernet, row.value_encrypted)
|
|
||||||
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def delete_core(self, user_id: str, label: str) -> bool:
|
|
||||||
"""Delete a core memory block by label. Returns True if deleted."""
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(
|
|
||||||
MemoryCore.user_id == user_id,
|
|
||||||
MemoryCore.key == label,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
|
||||||
return False
|
|
||||||
|
|
||||||
await self._db.delete(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
|
||||||
return True
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
|
||||||
"""Append content to a core block, creating it if missing."""
|
|
||||||
current = await self.get_core_block(user_id, label)
|
|
||||||
if current is None:
|
|
||||||
await self.update_core(user_id, label, content)
|
|
||||||
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
|
||||||
return
|
|
||||||
await self.update_core(user_id, label, f"{current}\n{content}")
|
|
||||||
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
|
||||||
|
|
||||||
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
|
||||||
"""Replace one exact string inside a core block. Returns False if not found."""
|
|
||||||
current = await self.get_core_block(user_id, label)
|
|
||||||
if current is None or old not in current:
|
|
||||||
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
|
||||||
return False
|
|
||||||
await self.update_core(user_id, label, current.replace(old, new, 1))
|
|
||||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
|
||||||
"""Insert a long-term archival memory entry."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
encrypted = _encrypt(fernet, content)
|
|
||||||
row = MemoryAssociative(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
content_encrypted=encrypted,
|
|
||||||
embedding=None,
|
|
||||||
entity_type=source,
|
|
||||||
entity_id=None,
|
|
||||||
)
|
|
||||||
self._db.add(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
|
|
||||||
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
||||||
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryAssociative)
|
|
||||||
.where(MemoryAssociative.user_id == user_id)
|
|
||||||
.order_by(MemoryAssociative.updated_at.desc())
|
|
||||||
.limit(100)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
needle = query.strip().lower()
|
|
||||||
out: list[str] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
||||||
if plaintext is None:
|
|
||||||
continue
|
|
||||||
if not needle or needle in plaintext.lower():
|
|
||||||
out.append(plaintext)
|
|
||||||
if len(out) >= max(top_k, 1):
|
|
||||||
break
|
|
||||||
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
||||||
"""Search recall memory (episodic summaries) by keyword."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryEpisodic)
|
|
||||||
.where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
|
||||||
.limit(100)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
needle = query.strip().lower()
|
|
||||||
out: list[str] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
||||||
if plaintext is None:
|
|
||||||
continue
|
|
||||||
if not needle or needle in plaintext.lower():
|
|
||||||
out.append(plaintext)
|
|
||||||
if len(out) >= max(top_k, 1):
|
|
||||||
break
|
|
||||||
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
# ── Private helpers ───────────────────────────────────────────────────────
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
@@ -341,16 +148,6 @@ class MemoryMiddleware:
|
|||||||
return None
|
return None
|
||||||
return Fernet(user.encryption_key.encode())
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
|
||||||
"""Load lightweight user debug fields for trace logs."""
|
|
||||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if user is None:
|
|
||||||
return {"tier": None}
|
|
||||||
return {
|
|
||||||
"tier": user.tier,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
@@ -386,17 +183,10 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_episodic(
|
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
fernet: Fernet,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
|
||||||
if session_id:
|
|
||||||
query = query.where(MemoryEpisodic.session_id == session_id)
|
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
query
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
.limit(_EPISODIC_RECENT_N)
|
.limit(_EPISODIC_RECENT_N)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -1,47 +1,141 @@
|
|||||||
"""Output formatter for deep-agent stream events."""
|
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
||||||
|
|
||||||
|
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
||||||
|
* ``("token", str)`` — supervisor text token
|
||||||
|
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
||||||
|
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
||||||
|
|
||||||
|
HomeFormatter:
|
||||||
|
* Streams text tokens as-is → emits ``WsStreamText``
|
||||||
|
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
||||||
|
for the frontend to parse and render as interactive components)
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
|
||||||
|
FloatingFormatter:
|
||||||
|
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
||||||
|
* Streams text tokens → emits ``WsStreamText``
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Map sub-agent tool name → floating domain / entity type
|
||||||
|
_AGENT_DOMAIN: dict[str, str] = {
|
||||||
|
"task_agent": "tasks",
|
||||||
|
"timeline_agent": "timelines",
|
||||||
|
"note_agent": "notes",
|
||||||
|
"project_agent": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
class StreamFormatter:
|
class HomeFormatter:
|
||||||
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
||||||
|
|
||||||
|
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
||||||
|
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
||||||
|
is responsible for parsing those and rendering interactive components.
|
||||||
|
Mutations are attached to ``WsStreamEnd``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
async def format(
|
async def format(
|
||||||
self,
|
self,
|
||||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
started = False
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
async for event_type, data in event_stream:
|
async for event_type, data in event_stream:
|
||||||
if event_type == "floating_domain":
|
if event_type == "token":
|
||||||
if isinstance(data, dict):
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FloatingFormatter:
|
||||||
|
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
||||||
|
|
||||||
|
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
||||||
|
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
||||||
|
``WsStreamText``. No block parsing for floating context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
domain_sent = False
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "tool_end" and not domain_sent:
|
||||||
|
# Sniff domain from the first sub-agent that completes
|
||||||
|
name = data.get("name", "")
|
||||||
|
domain = _AGENT_DOMAIN.get(name, "tasks")
|
||||||
yield WsFloatingDomain(
|
yield WsFloatingDomain(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
domain=data,
|
domain=domain, # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
continue
|
|
||||||
|
|
||||||
if event_type != "token":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not started:
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
started = True
|
domain_sent = True
|
||||||
|
|
||||||
text = str(data or "")
|
elif event_type == "token":
|
||||||
if text:
|
if not domain_sent:
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=text)
|
# First token arrived before any tool_end — default domain
|
||||||
|
yield WsFloatingDomain(
|
||||||
if not started:
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
domain_sent = True
|
||||||
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
# If no events triggered domain_sent (edge case), still emit structure
|
||||||
|
if not domain_sent:
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,104 +0,0 @@
|
|||||||
"""Preprocessor registry: detect content type and dispatch to handlers.
|
|
||||||
|
|
||||||
Public API
|
|
||||||
----------
|
|
||||||
detect_content_type(filename, raw_content) -> str
|
|
||||||
Heuristic detection based on file extension and content patterns.
|
|
||||||
|
|
||||||
preprocess(content_type, raw_content) -> PreprocessResult
|
|
||||||
Dispatch to the appropriate handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from app.core.preprocessors.base import PreprocessResult
|
|
||||||
|
|
||||||
# ── Heuristics ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
# Patterns that strongly suggest an email HTML file
|
|
||||||
_EMAIL_SIGNALS = re.compile(
|
|
||||||
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Patterns that suggest a generic HTML page (not an email)
|
|
||||||
_GENERIC_HTML_SIGNALS = re.compile(
|
|
||||||
r"<(nav|main|header|footer|article|section)\b",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_content_type(filename: str, raw_content: str) -> str:
|
|
||||||
"""Return a content-type string for the given file.
|
|
||||||
|
|
||||||
Supported types: ``"email_html"``, ``"generic_html"``,
|
|
||||||
``"plain_text"``, ``"unknown"``.
|
|
||||||
"""
|
|
||||||
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
|
||||||
|
|
||||||
if ext == "txt":
|
|
||||||
return "plain_text"
|
|
||||||
|
|
||||||
if ext in ("html", "htm", "eml", "mhtml", "mht"):
|
|
||||||
# Prefer email detection over generic HTML
|
|
||||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
|
||||||
return "email_html"
|
|
||||||
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
|
|
||||||
return "generic_html"
|
|
||||||
# .html without clear signals — check for any email header
|
|
||||||
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
|
|
||||||
return "email_html"
|
|
||||||
return "generic_html"
|
|
||||||
|
|
||||||
# Plain text files with email headers
|
|
||||||
if ext in ("", "txt") or not ext:
|
|
||||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
|
||||||
return "email_html"
|
|
||||||
|
|
||||||
# Detect binary content
|
|
||||||
try:
|
|
||||||
raw_content.encode("utf-8")
|
|
||||||
except (UnicodeEncodeError, AttributeError):
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
# Non-text bytes heuristic: high ratio of non-printable chars
|
|
||||||
sample = raw_content[:512]
|
|
||||||
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
|
|
||||||
if len(sample) > 0 and non_printable / len(sample) > 0.1:
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Generic fallback handler ──────────────────────────────────────────
|
|
||||||
|
|
||||||
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
|
|
||||||
"""Strip HTML tags if present, return text as-is."""
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
|
|
||||||
except ImportError:
|
|
||||||
# No BeautifulSoup — strip tags with a simple regex
|
|
||||||
text = re.sub(r"<[^>]+>", "", raw_content)
|
|
||||||
|
|
||||||
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
|
||||||
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
|
|
||||||
|
|
||||||
|
|
||||||
# ── Dispatch ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
|
|
||||||
"""Dispatch *raw_content* to the handler registered for *content_type*.
|
|
||||||
|
|
||||||
Falls back to the generic handler for unknown types.
|
|
||||||
"""
|
|
||||||
if content_type == "email_html":
|
|
||||||
from app.core.preprocessors.email_html import preprocess_email_html
|
|
||||||
return preprocess_email_html(raw_content)
|
|
||||||
|
|
||||||
return _preprocess_generic(raw_content, content_type)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
"""Base types for the preprocessor system."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PreprocessResult:
|
|
||||||
"""Output of a preprocessor handler.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
content_type:
|
|
||||||
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
|
|
||||||
clean_text:
|
|
||||||
Human-readable text stripped of markup/binary noise.
|
|
||||||
metadata:
|
|
||||||
Dict of extracted metadata (keys vary by handler).
|
|
||||||
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content_type: str
|
|
||||||
clean_text: str
|
|
||||||
metadata: dict = field(default_factory=dict)
|
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
"""Preprocessor for email HTML files.
|
|
||||||
|
|
||||||
Handles:
|
|
||||||
- HTML stripping via BeautifulSoup
|
|
||||||
- Metadata extraction (Subject, From, To, Date)
|
|
||||||
- Thread splitting — isolates the latest reply
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from app.core.preprocessors.base import PreprocessResult
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ── Thread split markers ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
# Matches patterns like:
|
|
||||||
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
|
|
||||||
# "-----Original Message-----"
|
|
||||||
# "> " (plain-text quote prefix)
|
|
||||||
_THREAD_PATTERNS = [
|
|
||||||
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
|
|
||||||
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
|
|
||||||
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
|
|
||||||
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
|
|
||||||
]
|
|
||||||
|
|
||||||
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
|
|
||||||
|
|
||||||
_META_PATTERNS: dict[str, list[re.Pattern]] = {
|
|
||||||
"subject": [
|
|
||||||
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
|
|
||||||
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
"from": [
|
|
||||||
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
|
||||||
re.compile(r"From:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
"to": [
|
|
||||||
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
|
||||||
re.compile(r"To:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
"date": [
|
|
||||||
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
|
||||||
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
|
|
||||||
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_metadata(raw_html: str, text: str) -> dict:
|
|
||||||
"""Extract Subject/From/To/Date from raw HTML or plain text."""
|
|
||||||
metadata: dict[str, str] = {}
|
|
||||||
for field, patterns in _META_PATTERNS.items():
|
|
||||||
for pat in patterns:
|
|
||||||
m = pat.search(raw_html) or pat.search(text)
|
|
||||||
if m:
|
|
||||||
metadata[field] = m.group(1).strip()
|
|
||||||
break
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
|
|
||||||
def _split_thread(text: str) -> str:
|
|
||||||
"""Return only the latest message in a threaded email."""
|
|
||||||
earliest_pos: int | None = None
|
|
||||||
for pat in _THREAD_PATTERNS:
|
|
||||||
m = pat.search(text)
|
|
||||||
if m and (earliest_pos is None or m.start() < earliest_pos):
|
|
||||||
earliest_pos = m.start()
|
|
||||||
|
|
||||||
if earliest_pos is not None and earliest_pos > 0:
|
|
||||||
return text[:earliest_pos].strip()
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_email_html(raw_content: str) -> PreprocessResult:
|
|
||||||
"""Strip HTML, extract metadata, split thread from an email HTML file."""
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup # lazy import — optional dep
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"beautifulsoup4 is required for email_html preprocessing. "
|
|
||||||
"Install it with: pip install beautifulsoup4"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# Parse with lxml if available, fall back to html.parser
|
|
||||||
try:
|
|
||||||
soup = BeautifulSoup(raw_content, "lxml")
|
|
||||||
except Exception:
|
|
||||||
soup = BeautifulSoup(raw_content, "html.parser")
|
|
||||||
|
|
||||||
# Remove noise tags
|
|
||||||
for tag in soup(["style", "script", "head", "noscript"]):
|
|
||||||
tag.decompose()
|
|
||||||
|
|
||||||
clean_text = soup.get_text(separator="\n")
|
|
||||||
# Collapse excessive blank lines
|
|
||||||
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
|
|
||||||
|
|
||||||
metadata = _extract_metadata(raw_content, clean_text)
|
|
||||||
latest_message = _split_thread(clean_text)
|
|
||||||
|
|
||||||
return PreprocessResult(
|
|
||||||
content_type="email_html",
|
|
||||||
clean_text=latest_message,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
@@ -7,18 +7,21 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
# Holds the execute callback for the current WS session.
|
# Holds the execute callback for the current WS session.
|
||||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
# Set by the chat WS handler before the deep agent runs; cleared after.
|
||||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||||
"_client_executor"
|
"_client_executor"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional collector that captures raw execute_on_client results.
|
# Optional collector that captures raw execute_on_client results.
|
||||||
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
# Set by the deep agent tool loop to capture CRUD mutations.
|
||||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
"_tool_result_collector", default=None
|
"_tool_result_collector", default=None
|
||||||
)
|
)
|
||||||
@@ -81,12 +84,17 @@ async def execute_on_client(
|
|||||||
if limit is not None:
|
if limit is not None:
|
||||||
payload["limit"] = limit
|
payload["limit"] = limit
|
||||||
|
|
||||||
|
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
result = await callback(payload)
|
result = await callback(payload)
|
||||||
|
if result is None:
|
||||||
|
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
|
else:
|
||||||
|
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
collector = _tool_result_collector.get(None)
|
collector = _tool_result_collector.get(None)
|
||||||
if collector is not None:
|
if collector is not None and action in ("insert", "update", "delete"):
|
||||||
collector.append({
|
collector.append({
|
||||||
"action": action,
|
"action": action,
|
||||||
"table": table,
|
"table": table,
|
||||||
"data": result,
|
"data": data or {},
|
||||||
})
|
})
|
||||||
return result
|
return result
|
||||||
|
|||||||
11
app/main.py
11
app/main.py
@@ -18,9 +18,7 @@ from app.config.settings import settings
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: ensure agent tool modules are loaded.
|
# Startup: initialise DB connection pool
|
||||||
import app.agents # noqa: F401
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown: dispose SQLAlchemy connection pool
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
@@ -50,12 +48,17 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agents, auth, billing, chat, device_ws
|
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
|
app.include_router(storage.router, prefix="/api/v1")
|
||||||
|
app.include_router(vectors.router, prefix="/api/v1")
|
||||||
|
app.include_router(backup.router, prefix="/api/v1")
|
||||||
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
|
app.include_router(agent_setup.router, prefix="/api/v1")
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
|||||||
7
app/marketplace/__init__.py
Normal file
7
app/marketplace/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Plugin marketplace package.
|
||||||
|
|
||||||
|
Three service classes introduced in Step 10:
|
||||||
|
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
||||||
|
- ``ReviewQueue`` — approval workflow + security checklist
|
||||||
|
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
||||||
|
"""
|
||||||
212
app/marketplace/plugin_registry.py
Normal file
212
app/marketplace/plugin_registry.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""Plugin catalog registry backed by PostgreSQL.
|
||||||
|
|
||||||
|
Maintains the authoritative list of plugins, their review status, and
|
||||||
|
aggregate install counts. All data is persisted in the ``plugins`` table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import Plugin
|
||||||
|
from app.schemas import PluginListResponse, PluginManifest
|
||||||
|
|
||||||
|
_PAGE_SIZE = 20
|
||||||
|
|
||||||
|
|
||||||
|
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
|
||||||
|
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
|
||||||
|
try:
|
||||||
|
permissions = json.loads(p.permissions) if p.permissions else []
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
permissions = []
|
||||||
|
return PluginManifest(
|
||||||
|
id=p.id,
|
||||||
|
name=p.name,
|
||||||
|
description=p.description,
|
||||||
|
version=p.version,
|
||||||
|
author=p.author_name,
|
||||||
|
permissions=permissions,
|
||||||
|
category=p.category,
|
||||||
|
price_cents=p.price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginRegistry:
|
||||||
|
"""PostgreSQL-backed plugin catalog.
|
||||||
|
|
||||||
|
All methods accept an ``AsyncSession`` parameter so the calling route
|
||||||
|
controls the session lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Queries ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def list_plugins(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
category: str | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
sort: Literal["rating", "installs", "newest"] = "newest",
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Return a page of approved plugins, optionally filtered and sorted."""
|
||||||
|
base = select(Plugin).where(Plugin.status == "approved")
|
||||||
|
|
||||||
|
if category:
|
||||||
|
base = base.where(Plugin.category == category)
|
||||||
|
if query:
|
||||||
|
pattern = f"%{query}%"
|
||||||
|
base = base.where(
|
||||||
|
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count
|
||||||
|
count_q = select(func.count()).select_from(base.subquery())
|
||||||
|
total = (await db.execute(count_q)).scalar_one()
|
||||||
|
|
||||||
|
# Sort
|
||||||
|
if sort == "installs":
|
||||||
|
base = base.order_by(Plugin.install_count.desc())
|
||||||
|
elif sort == "rating":
|
||||||
|
base = base.order_by(Plugin.avg_rating.desc())
|
||||||
|
else: # newest
|
||||||
|
base = base.order_by(Plugin.created_at.desc())
|
||||||
|
|
||||||
|
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
|
||||||
|
rows = (await db.execute(base)).scalars().all()
|
||||||
|
|
||||||
|
return PluginListResponse(
|
||||||
|
plugins=[_plugin_to_manifest(r) for r in rows],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
p = result.scalar_one_or_none()
|
||||||
|
if p is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"manifest": _plugin_to_manifest(p),
|
||||||
|
"status": p.status,
|
||||||
|
"install_count": p.install_count,
|
||||||
|
"avg_rating": p.avg_rating,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Mutations ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def submit_plugin(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
manifest: PluginManifest,
|
||||||
|
package_s3_key: str,
|
||||||
|
) -> str:
|
||||||
|
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
||||||
|
|
||||||
|
Returns the plugin_id. If a plugin with the same id already exists
|
||||||
|
it is overwritten (re-submission after rejection).
|
||||||
|
"""
|
||||||
|
plugin_id = manifest.id
|
||||||
|
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = existing.scalar_one_or_none()
|
||||||
|
|
||||||
|
if row is not None:
|
||||||
|
row.name = manifest.name
|
||||||
|
row.description = manifest.description
|
||||||
|
row.version = manifest.version
|
||||||
|
row.author_name = manifest.author
|
||||||
|
row.category = manifest.category
|
||||||
|
row.price_cents = manifest.price_cents
|
||||||
|
row.permissions = json.dumps(manifest.permissions)
|
||||||
|
row.status = "pending_review"
|
||||||
|
row.s3_package_key = package_s3_key
|
||||||
|
row.rejection_reason = None
|
||||||
|
else:
|
||||||
|
row = Plugin(
|
||||||
|
id=plugin_id,
|
||||||
|
name=manifest.name,
|
||||||
|
description=manifest.description,
|
||||||
|
version=manifest.version,
|
||||||
|
author_name=manifest.author,
|
||||||
|
category=manifest.category,
|
||||||
|
price_cents=manifest.price_cents,
|
||||||
|
permissions=json.dumps(manifest.permissions),
|
||||||
|
status="pending_review",
|
||||||
|
s3_package_key=package_s3_key,
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
await db.commit()
|
||||||
|
return plugin_id
|
||||||
|
|
||||||
|
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'approved'``.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
row.status = "approved"
|
||||||
|
row.rejection_reason = None
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
row.status = "rejected"
|
||||||
|
row.rejection_reason = reason
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.install_count = row.install_count + 1
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Decrement the install count for *plugin_id*, floored at 0."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.install_count = max(0, row.install_count - 1)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
||||||
|
|
||||||
|
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
|
"""Return all entries with status='pending_review'."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Plugin).where(Plugin.status == "pending_review")
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"manifest": _plugin_to_manifest(r),
|
||||||
|
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
registry = PluginRegistry()
|
||||||
125
app/marketplace/plugin_review.py
Normal file
125
app/marketplace/plugin_review.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Plugin review workflow backed by PostgreSQL.
|
||||||
|
|
||||||
|
Manages the approval queue for newly submitted plugins and enforces a
|
||||||
|
security checklist before any plugin is made visible in the marketplace.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_review import review_queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import PluginReview as PluginReviewModel
|
||||||
|
from app.schemas import PluginManifest
|
||||||
|
|
||||||
|
# ── Security policy ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"read:tasks",
|
||||||
|
"write:tasks",
|
||||||
|
"read:projects",
|
||||||
|
"write:projects",
|
||||||
|
"read:notes",
|
||||||
|
"write:notes",
|
||||||
|
"read:timelines",
|
||||||
|
"write:timelines",
|
||||||
|
"read:calendar",
|
||||||
|
"write:calendar",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_manifest(manifest: PluginManifest) -> None:
|
||||||
|
"""Enforce the plugin security checklist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``ValueError`` on the first violation found. Callers should catch
|
||||||
|
this and return HTTP 422 / reject the submission.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. Plugin id matches ``^[a-z0-9-]+$``
|
||||||
|
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
||||||
|
3. No manifest field contains raw binary data
|
||||||
|
"""
|
||||||
|
if not _PLUGIN_ID_RE.match(manifest.id):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid plugin id format: '{manifest.id}'. "
|
||||||
|
"Only lowercase letters, digits, and hyphens are allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
for perm in manifest.permissions:
|
||||||
|
if perm not in ALLOWED_PERMISSIONS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown permission: '{perm}'. "
|
||||||
|
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name, value in manifest.model_dump().items():
|
||||||
|
if isinstance(value, (bytes, bytearray)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Binary content is not allowed in manifest field '{field_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewQueue:
|
||||||
|
"""Approval queue for pending plugin submissions.
|
||||||
|
|
||||||
|
Delegates status changes to the shared ``PluginRegistry`` singleton.
|
||||||
|
Review records are persisted in the ``plugin_reviews`` table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
|
"""Return all plugins currently awaiting review.
|
||||||
|
|
||||||
|
Each item is ``{plugin_id, manifest, submitted_at}``.
|
||||||
|
"""
|
||||||
|
entries = await registry.get_pending_entries(db)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"plugin_id": e["manifest"].id,
|
||||||
|
"manifest": e["manifest"],
|
||||||
|
"submitted_at": e["submitted_at"],
|
||||||
|
}
|
||||||
|
for e in entries
|
||||||
|
]
|
||||||
|
|
||||||
|
async def submit_review(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
plugin_id: str,
|
||||||
|
reviewer_id: str,
|
||||||
|
decision: Literal["approved", "rejected"],
|
||||||
|
notes: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Record a review decision and update the plugin's status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``KeyError`` if *plugin_id* is not found in the registry.
|
||||||
|
"""
|
||||||
|
if decision == "approved":
|
||||||
|
await registry.approve_plugin(db, plugin_id)
|
||||||
|
else:
|
||||||
|
await registry.reject_plugin(db, plugin_id, reason=notes)
|
||||||
|
|
||||||
|
review = PluginReviewModel(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
reviewer_id=reviewer_id,
|
||||||
|
decision=decision,
|
||||||
|
notes=notes,
|
||||||
|
)
|
||||||
|
db.add(review)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
review_queue = ReviewQueue()
|
||||||
233
app/marketplace/revenue_share.py
Normal file
233
app/marketplace/revenue_share.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
|
||||||
|
|
||||||
|
Records every plugin installation as a revenue event and facilitates
|
||||||
|
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
|
||||||
|
in the ``revenue_events`` table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from sqlalchemy import extract, func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import Plugin, RevenueEvent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Revenue split constants ───────────────────────────────────────────
|
||||||
|
|
||||||
|
DEVELOPER_SHARE: float = 0.70
|
||||||
|
PLATFORM_SHARE: float = 0.30
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueShare:
|
||||||
|
"""Records installation revenue events and coordinates developer payouts.
|
||||||
|
|
||||||
|
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
||||||
|
is not configured, consistent with the rest of the billing layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe_configured() -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe() -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Core operations ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def record_install(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
plugin_id: str,
|
||||||
|
user_id: str,
|
||||||
|
amount_cents: int,
|
||||||
|
) -> None:
|
||||||
|
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
||||||
|
|
||||||
|
For free plugins (``amount_cents == 0``) no payment is initiated but
|
||||||
|
the event is still recorded for analytics.
|
||||||
|
|
||||||
|
For paid plugins the developer receives 70 % via a Stripe Connect
|
||||||
|
destination charge. If Stripe is not configured or the charge fails
|
||||||
|
the installation still succeeds (the event is recorded and the install
|
||||||
|
count is incremented) — a warning is logged for monitoring.
|
||||||
|
"""
|
||||||
|
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
||||||
|
stripe_transfer_id: str | None = None
|
||||||
|
|
||||||
|
if amount_cents > 0 and self._stripe_configured():
|
||||||
|
# Look up the plugin's author Stripe account from the DB
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
plugin_row = result.scalar_one_or_none()
|
||||||
|
developer_stripe_account: str | None = None
|
||||||
|
if plugin_row and plugin_row.author_id:
|
||||||
|
# Future: look up user.stripe_connect_account_id
|
||||||
|
developer_stripe_account = None # no real account yet
|
||||||
|
|
||||||
|
if developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
transfer = s.Transfer.create(
|
||||||
|
amount=developer_share_cents,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Revenue share for plugin {plugin_id}",
|
||||||
|
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
stripe_transfer_id = transfer["id"]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Stripe Connect transfer failed for plugin %s: %s",
|
||||||
|
plugin_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"No Stripe account on file for plugin %s developer; "
|
||||||
|
"skipping transfer.",
|
||||||
|
plugin_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = RevenueEvent(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=user_id,
|
||||||
|
amount_cents=amount_cents,
|
||||||
|
developer_share_cents=developer_share_cents,
|
||||||
|
stripe_transfer_id=stripe_transfer_id,
|
||||||
|
)
|
||||||
|
db.add(event)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await registry.record_install(db, plugin_id)
|
||||||
|
|
||||||
|
async def get_earnings(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
developer_id: str,
|
||||||
|
period: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return aggregated earnings for *developer_id*.
|
||||||
|
|
||||||
|
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
||||||
|
|
||||||
|
Returns::
|
||||||
|
|
||||||
|
{
|
||||||
|
"developer_id": str,
|
||||||
|
"period": str | None,
|
||||||
|
"total_installs": int,
|
||||||
|
"total_revenue_cents": int,
|
||||||
|
"developer_share_cents": int,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Find plugin ids belonging to this developer (by author_name match)
|
||||||
|
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
|
||||||
|
plugin_result = await db.execute(plugin_q)
|
||||||
|
developer_plugin_ids = [row[0] for row in plugin_result.all()]
|
||||||
|
|
||||||
|
if not developer_plugin_ids:
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": 0,
|
||||||
|
"total_revenue_cents": 0,
|
||||||
|
"developer_share_cents": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
query = select(
|
||||||
|
func.count().label("total_installs"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
|
||||||
|
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
|
||||||
|
|
||||||
|
if period:
|
||||||
|
# Filter by YYYY-MM: extract year and month from created_at
|
||||||
|
try:
|
||||||
|
year, month = period.split("-")
|
||||||
|
query = query.where(
|
||||||
|
extract("year", RevenueEvent.created_at) == int(year),
|
||||||
|
extract("month", RevenueEvent.created_at) == int(month),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass # invalid period format — return all
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
row = result.one()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": row.total_installs,
|
||||||
|
"total_revenue_cents": row.total_revenue,
|
||||||
|
"developer_share_cents": row.dev_share,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
|
||||||
|
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
||||||
|
|
||||||
|
Marks processed events with ``paid_at`` timestamp.
|
||||||
|
Stubs gracefully when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
year, month = period.split("-")
|
||||||
|
year_int, month_int = int(year), int(month)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid period format: %s", period)
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(RevenueEvent).where(
|
||||||
|
RevenueEvent.plugin_id == plugin_id,
|
||||||
|
RevenueEvent.paid_at.is_(None),
|
||||||
|
extract("year", RevenueEvent.created_at) == year_int,
|
||||||
|
extract("month", RevenueEvent.created_at) == month_int,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
unpaid = list(result.scalars().all())
|
||||||
|
|
||||||
|
total_dev_share = sum(e.developer_share_cents for e in unpaid)
|
||||||
|
if total_dev_share <= 0 or not unpaid:
|
||||||
|
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._stripe_configured():
|
||||||
|
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
plugin_row = plugin_result.scalar_one_or_none()
|
||||||
|
developer_stripe_account: str | None = None # Future: fetch from DB
|
||||||
|
if plugin_row and developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
s.Transfer.create(
|
||||||
|
amount=total_dev_share,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Payout for plugin {plugin_id} period {period}",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
paid_ts = datetime.now(timezone.utc)
|
||||||
|
for event in unpaid:
|
||||||
|
event.paid_at = paid_ts
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
revenue_share = RevenueShare()
|
||||||
164
app/models.py
164
app/models.py
@@ -1,15 +1,19 @@
|
|||||||
"""SQLAlchemy ORM models for all persistent tables.
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
Only auth, billing, agent config, and memory data live here.
|
Only auth, billing, storage metadata, and marketplace data live here.
|
||||||
User content (notes, tasks, etc.) lives exclusively on the client.
|
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||||
|
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||||
|
|
||||||
Table inventory:
|
Table inventory:
|
||||||
users — account credentials + tier
|
users — account credentials + tier
|
||||||
refresh_tokens — hashed refresh token store
|
refresh_tokens — hashed refresh token store
|
||||||
subscriptions — Stripe subscription records
|
subscriptions — Stripe subscription records
|
||||||
local_agent_configs — per-device batch agent configs
|
storage_records — S3 blob metadata (no plaintext)
|
||||||
cloud_agent_configs — OAuth-backed cloud agent configs
|
backup_metadata — encrypted backup manifests
|
||||||
agent_run_logs — execution history for all agents
|
plugins — marketplace plugin catalog
|
||||||
|
plugin_installations — per-user install records
|
||||||
|
plugin_reviews — admin review decisions
|
||||||
|
revenue_events — Stripe Connect 70/30 split ledger
|
||||||
memory_core — per-user persistent key/value preferences (encrypted)
|
memory_core — per-user persistent key/value preferences (encrypted)
|
||||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
memory_episodic — per-user session summaries (encrypted)
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
@@ -22,6 +26,7 @@ import uuid
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
Boolean,
|
Boolean,
|
||||||
DateTime,
|
DateTime,
|
||||||
Enum,
|
Enum,
|
||||||
@@ -31,6 +36,7 @@ from sqlalchemy import (
|
|||||||
JSON,
|
JSON,
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
Uuid,
|
Uuid,
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
@@ -52,6 +58,8 @@ def _now() -> datetime:
|
|||||||
# ── Enum types ────────────────────────────────────────────────────────────
|
# ── Enum types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||||
|
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||||
|
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||||
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
||||||
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
||||||
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
||||||
@@ -129,6 +137,151 @@ class Subscription(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="subscription")
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecord(Base):
|
||||||
|
__tablename__ = "storage_records"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BackupMetadata(Base):
|
||||||
|
__tablename__ = "backup_metadata"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||||
|
# nullable until developer account system is built
|
||||||
|
author_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||||
|
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
submitted_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
reviews: Mapped[list[PluginReview]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallation(Base):
|
||||||
|
__tablename__ = "plugin_installations"
|
||||||
|
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
installed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginReview(Base):
|
||||||
|
__tablename__ = "plugin_reviews"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
reviewer_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||||
|
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
reviewed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueEvent(Base):
|
||||||
|
__tablename__ = "revenue_events"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfig(Base):
|
class LocalAgentConfig(Base):
|
||||||
__tablename__ = "local_agent_configs"
|
__tablename__ = "local_agent_configs"
|
||||||
|
|
||||||
@@ -143,7 +296,6 @@ class LocalAgentConfig(Base):
|
|||||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
|
||||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
|
|||||||
249
app/schemas.py
249
app/schemas.py
@@ -50,6 +50,88 @@ class ChatResponse(BaseModel):
|
|||||||
response: str
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Backup ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class BackupMetadata(BaseModel):
|
||||||
|
version: int
|
||||||
|
timestamp: int
|
||||||
|
checksum: str
|
||||||
|
chunk_count: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
||||||
|
|
||||||
|
class StorageRecord(BaseModel):
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordCreate(BaseModel):
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordUpdate(BaseModel):
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
||||||
|
|
||||||
|
class VectorItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
blob: bytes # encrypted vector + metadata — backend never decrypts
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorUpsertRequest(BaseModel):
|
||||||
|
vectors: list[VectorItem]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchRequest(BaseModel):
|
||||||
|
query_blob: bytes # encrypted query — backend never decrypts
|
||||||
|
top_k: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResult(BaseModel):
|
||||||
|
id: str
|
||||||
|
score: float
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResponse(BaseModel):
|
||||||
|
results: list[VectorSearchResult]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin Marketplace ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PluginManifest(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
version: str
|
||||||
|
author: str
|
||||||
|
permissions: list[str]
|
||||||
|
category: str
|
||||||
|
price_cents: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class PluginListResponse(BaseModel):
|
||||||
|
plugins: list[PluginManifest]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallRequest(BaseModel):
|
||||||
|
plugin_id: str
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||||
|
|
||||||
class WsFrameType(str, Enum):
|
class WsFrameType(str, Enum):
|
||||||
@@ -60,6 +142,9 @@ class WsFrameType(str, Enum):
|
|||||||
tool_result = "tool_result"
|
tool_result = "tool_result"
|
||||||
final = "final"
|
final = "final"
|
||||||
ping = "ping"
|
ping = "ping"
|
||||||
|
agent_run = "agent_run"
|
||||||
|
agent_data = "agent_data"
|
||||||
|
agent_complete = "agent_complete"
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
@@ -71,10 +156,6 @@ class WsFrameType(str, Enum):
|
|||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
# ── v4 journey frame types ────────────────────────────────────────
|
|
||||||
journey_start = "journey_start"
|
|
||||||
journey_message = "journey_message"
|
|
||||||
journey_reply = "journey_reply"
|
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -127,6 +208,31 @@ class WsDeviceHello(BaseModel):
|
|||||||
agent_ids: list[str] = Field(default_factory=list)
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentRun(BaseModel):
|
||||||
|
"""Server → Client: trigger an agent run on the connected device."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
||||||
|
run_id: str
|
||||||
|
agent_id: str
|
||||||
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentData(BaseModel):
|
||||||
|
"""Client → Server: files read by the local agent."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
||||||
|
run_id: str
|
||||||
|
files: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentComplete(BaseModel):
|
||||||
|
"""Client → Server: Electron signals it has finished reading files."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
||||||
|
run_id: str
|
||||||
|
files_read: int
|
||||||
|
errors: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
@@ -173,14 +279,7 @@ class WsStreamEnd(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
|
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
class WsDomain(BaseModel):
|
|
||||||
"""Structured floating domain payload for UI routing decisions."""
|
|
||||||
|
|
||||||
type: Literal["task", "timeline", "project", "node"]
|
|
||||||
id: str | None = None
|
|
||||||
section: Literal["task", "timeline", "note"] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class WsFloatingDomain(BaseModel):
|
||||||
@@ -188,28 +287,7 @@ class WsFloatingDomain(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
request_id: str
|
request_id: str
|
||||||
domain: WsDomain
|
domain: Literal["tasks", "timelines", "notes", "projects"]
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Config V2 ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ContentTypeConfig(BaseModel):
|
|
||||||
"""Per-type extraction config produced by the journey chatbot."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
label: str = ""
|
|
||||||
detection_hint: str = ""
|
|
||||||
preprocessing: str = "generic" # handler name: "email_html", "plain_text", ...
|
|
||||||
extraction_prompt: str
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(BaseModel):
|
|
||||||
"""Structured agent configuration (replaces freeform prompt_template)."""
|
|
||||||
|
|
||||||
content_types: list[ContentTypeConfig] = []
|
|
||||||
global_rules: list[str] = []
|
|
||||||
data_types: list[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
@@ -218,28 +296,84 @@ class AgentCatalogItem(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
config_schema: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckRequest(BaseModel):
|
# ── Local Agent Config ────────────────────────────────────────────────
|
||||||
active_agents: int = Field(ge=0, default=0)
|
|
||||||
|
class LocalAgentConfigCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
device_id: str
|
||||||
|
directory_paths: list[str]
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
file_extensions: list[str]
|
||||||
|
schedule_cron: str
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckResponse(BaseModel):
|
class LocalAgentConfigUpdate(BaseModel):
|
||||||
allowed: bool
|
name: str | None = None
|
||||||
tier: BillingTier
|
device_id: str | None = None
|
||||||
active_agents: int
|
directory_paths: list[str] | None = None
|
||||||
limit: int
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
file_extensions: list[str] | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentTriggerRequest(BaseModel):
|
class LocalAgentConfigResponse(BaseModel):
|
||||||
directory: str = Field(min_length=1)
|
id: str
|
||||||
device_id: str = Field(default="")
|
name: str
|
||||||
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
device_id: str
|
||||||
what_to_extract: list[str] = Field(min_length=1)
|
directory_paths: list[str]
|
||||||
actions_by_type: dict[str, list[str]] | None = None
|
data_types: list[str]
|
||||||
batch_interval: str = Field(min_length=1)
|
prompt_template: str
|
||||||
custom_agent_prompt: str = Field(min_length=1)
|
file_extensions: list[str]
|
||||||
active_agents: int = Field(ge=0, default=0)
|
schedule_cron: str
|
||||||
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Agent Config ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CloudAgentConfigCreate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
oauth_token_encrypted: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigUpdate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"] | None = None
|
||||||
|
name: str | None = None
|
||||||
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
oauth_token_encrypted: str | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigResponse(BaseModel):
|
||||||
|
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None
|
||||||
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
@@ -258,3 +392,18 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class JourneyStartRequest(BaseModel):
|
||||||
|
agent_type: Literal["local", "cloud"]
|
||||||
|
agent_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyMessageRequest(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyResponse(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
done: bool
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|||||||
1
app/storage/__init__.py
Normal file
1
app/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
||||||
106
app/storage/blob_store.py
Normal file
106
app/storage/blob_store.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""S3-backed store for E2E-encrypted blobs.
|
||||||
|
|
||||||
|
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
||||||
|
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class BlobStore:
|
||||||
|
"""Thin wrapper around boto3 S3.
|
||||||
|
|
||||||
|
All blobs must be E2E encrypted by the client before upload.
|
||||||
|
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
||||||
|
but cannot decrypt the inner client-side payload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"region_name": settings.S3_REGION,
|
||||||
|
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
|
||||||
|
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
|
||||||
|
}
|
||||||
|
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
|
||||||
|
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
|
||||||
|
return boto3.client("s3", **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _key(user_id: str, table: str, record_id: str) -> str:
|
||||||
|
return f"{user_id}/{table}/{record_id}"
|
||||||
|
|
||||||
|
async def upload(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
table: str,
|
||||||
|
record_id: str,
|
||||||
|
blob: bytes,
|
||||||
|
checksum: str,
|
||||||
|
) -> str:
|
||||||
|
"""Store *blob* in S3 and return the S3 key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Owner of the blob (used as key prefix).
|
||||||
|
table: Logical table name (e.g. ``"tasks"``).
|
||||||
|
record_id: Record UUID.
|
||||||
|
blob: Raw bytes (pre-encrypted by client).
|
||||||
|
checksum: SHA-256 hex digest supplied by the client; stored as
|
||||||
|
object metadata for download-time verification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The S3 key under which the blob was stored.
|
||||||
|
"""
|
||||||
|
key = self._key(user_id, table, record_id)
|
||||||
|
self._client().put_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=key,
|
||||||
|
Body=blob,
|
||||||
|
ServerSideEncryption="AES256", # SSE-S3 at rest
|
||||||
|
Metadata={"checksum": checksum},
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
async def download(self, user_id: str, s3_key: str) -> bytes:
|
||||||
|
"""Retrieve the blob stored at *s3_key*.
|
||||||
|
|
||||||
|
*user_id* is retained in the signature so higher-level code can
|
||||||
|
enforce ownership without re-parsing the key.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
||||||
|
object does not exist.
|
||||||
|
"""
|
||||||
|
response = self._client().get_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
return response["Body"].read()
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, s3_key: str) -> None:
|
||||||
|
"""Delete the object at *s3_key*.
|
||||||
|
|
||||||
|
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
||||||
|
not exist.
|
||||||
|
"""
|
||||||
|
self._client().delete_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
||||||
|
"""Return all S3 keys for a given user + table combination.
|
||||||
|
|
||||||
|
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
||||||
|
"""
|
||||||
|
prefix = f"{user_id}/{table}/"
|
||||||
|
response = self._client().list_objects_v2(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Prefix=prefix,
|
||||||
|
)
|
||||||
|
return [obj["Key"] for obj in response.get("Contents", [])]
|
||||||
32
app/storage/encryption.py
Normal file
32
app/storage/encryption.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Integrity verification only — the backend NEVER decrypts user data."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
||||||
|
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
||||||
|
|
||||||
|
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
||||||
|
timing-based side-channel attacks.
|
||||||
|
"""
|
||||||
|
computed = hashlib.sha256(blob).hexdigest()
|
||||||
|
return hmac.compare_digest(computed, checksum)
|
||||||
|
|
||||||
|
|
||||||
|
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
||||||
|
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
||||||
|
|
||||||
|
Call this before storing or forwarding any client-provided blob.
|
||||||
|
The backend never holds decryption keys — this check only verifies
|
||||||
|
that the opaque bytes arrived intact.
|
||||||
|
"""
|
||||||
|
if not verify_checksum(blob, checksum):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Checksum mismatch: blob integrity check failed",
|
||||||
|
)
|
||||||
205
app/storage/vector_store.py
Normal file
205
app/storage/vector_store.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
||||||
|
|
||||||
|
Vectors are pre-encrypted blobs from the client. The backend stores them
|
||||||
|
alongside a deterministic 32-dim float representation derived from the blob's
|
||||||
|
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
||||||
|
is a known trade-off documented in the backend plan.
|
||||||
|
|
||||||
|
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
||||||
|
``user_id`` payload field on a shared collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pinecone import Pinecone
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.schemas import VectorItem, VectorSearchResult
|
||||||
|
|
||||||
|
_QDRANT_COLLECTION = "adiuva_vectors"
|
||||||
|
|
||||||
|
|
||||||
|
def _blob_to_vector(blob: bytes) -> list[float]:
|
||||||
|
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
||||||
|
|
||||||
|
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
||||||
|
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
||||||
|
semantic meaning on encrypted data.
|
||||||
|
"""
|
||||||
|
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStore:
|
||||||
|
"""Thin wrapper around Pinecone or Qdrant.
|
||||||
|
|
||||||
|
The backend to use is selected at runtime:
|
||||||
|
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
||||||
|
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _use_pinecone(self) -> bool:
|
||||||
|
return bool(settings.PINECONE_API_KEY)
|
||||||
|
|
||||||
|
# ── Pinecone helpers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _pinecone_index(self) -> Any:
|
||||||
|
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
||||||
|
return pc.Index(settings.PINECONE_INDEX)
|
||||||
|
|
||||||
|
# ── Qdrant helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _qdrant_client(self) -> Any:
|
||||||
|
return QdrantClient(
|
||||||
|
url=settings.QDRANT_URL,
|
||||||
|
api_key=settings.QDRANT_API_KEY or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
"""Store encrypted vectors in the backend.
|
||||||
|
|
||||||
|
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
||||||
|
so it can be returned verbatim during search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Used as Pinecone namespace or Qdrant payload field.
|
||||||
|
vectors: List of encrypted vector items from the client.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_upsert(user_id, vectors)
|
||||||
|
else:
|
||||||
|
await self._qdrant_upsert(user_id, vectors)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
query_blob: bytes,
|
||||||
|
top_k: int,
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
"""Query the vector store and return encrypted result blobs.
|
||||||
|
|
||||||
|
The query vector is derived from *query_blob* using the same
|
||||||
|
deterministic mapping as upsert.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Scopes the search to this user's namespace.
|
||||||
|
query_blob: Encrypted query from the client.
|
||||||
|
top_k: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
return await self._pinecone_search(user_id, query_blob, top_k)
|
||||||
|
return await self._qdrant_search(user_id, query_blob, top_k)
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
"""Remove vectors by ID, scoped to *user_id*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Namespace / payload filter to prevent cross-user deletion.
|
||||||
|
vector_ids: List of vector IDs to remove.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_delete(user_id, vector_ids)
|
||||||
|
else:
|
||||||
|
await self._qdrant_delete(user_id, vector_ids)
|
||||||
|
|
||||||
|
# ── Pinecone implementation ───────────────────────────────────────
|
||||||
|
|
||||||
|
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"id": v.id,
|
||||||
|
"values": _blob_to_vector(v.blob),
|
||||||
|
"metadata": {
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
index.upsert(vectors=records, namespace=user_id)
|
||||||
|
|
||||||
|
async def _pinecone_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
response = index.query(
|
||||||
|
vector=query_vector,
|
||||||
|
top_k=top_k,
|
||||||
|
namespace=user_id,
|
||||||
|
include_metadata=True,
|
||||||
|
)
|
||||||
|
results: list[VectorSearchResult] = []
|
||||||
|
for match in response.get("matches", []):
|
||||||
|
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
||||||
|
results.append(
|
||||||
|
VectorSearchResult(
|
||||||
|
id=match["id"],
|
||||||
|
score=match["score"],
|
||||||
|
blob=blob_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
index.delete(ids=vector_ids, namespace=user_id)
|
||||||
|
|
||||||
|
# ── Qdrant implementation ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
points = [
|
||||||
|
PointStruct(
|
||||||
|
id=v.id,
|
||||||
|
vector=_blob_to_vector(v.blob),
|
||||||
|
payload={
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
||||||
|
|
||||||
|
async def _qdrant_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
hits = client.search(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=Filter(
|
||||||
|
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
||||||
|
),
|
||||||
|
limit=top_k,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
VectorSearchResult(
|
||||||
|
id=str(hit.id),
|
||||||
|
score=hit.score,
|
||||||
|
blob=base64.b64decode(hit.payload["blob"]),
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
client.delete(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
points_selector=PointIdsList(points=vector_ids),
|
||||||
|
)
|
||||||
@@ -36,6 +36,37 @@ services:
|
|||||||
# image: redis:7-alpine
|
# image: redis:7-alpine
|
||||||
# restart: unless-stopped
|
# restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Local S3-compatible storage (MinIO) ──
|
||||||
|
minio:
|
||||||
|
image: minio/minio:latest
|
||||||
|
command: server /data --console-address ":9001"
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
- "9001:9001"
|
||||||
|
environment:
|
||||||
|
MINIO_ROOT_USER: minioadmin
|
||||||
|
MINIO_ROOT_PASSWORD: minioadmin
|
||||||
|
volumes:
|
||||||
|
- minio_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "mc", "ready", "local"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Local vector store (Qdrant) ──
|
||||||
|
qdrant:
|
||||||
|
image: qdrant/qdrant:latest
|
||||||
|
ports:
|
||||||
|
- "6333:6333"
|
||||||
|
- "6334:6334"
|
||||||
|
volumes:
|
||||||
|
- qdrant_data:/qdrant/storage
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
minio_data:
|
||||||
|
qdrant_data:
|
||||||
copilot_tokens:
|
copilot_tokens:
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ gunicorn>=22.0.0
|
|||||||
langchain>=0.3.0
|
langchain>=0.3.0
|
||||||
langchain-openai>=0.3.0
|
langchain-openai>=0.3.0
|
||||||
langchain-litellm>=0.1.0
|
langchain-litellm>=0.1.0
|
||||||
|
langgraph>=0.3.0
|
||||||
|
deepagents>=0.4.10
|
||||||
litellm>=1.50.0
|
litellm>=1.50.0
|
||||||
pydantic>=2.10.0
|
pydantic>=2.10.0
|
||||||
pydantic-settings>=2.7.0
|
pydantic-settings>=2.7.0
|
||||||
@@ -32,8 +34,4 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
langfuse>=2.0.0
|
|
||||||
beautifulsoup4>=4.12.0
|
|
||||||
lxml>=5.0.0
|
|
||||||
PyYAML>=6.0.0
|
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
@@ -6,21 +6,26 @@ a per-test session, and a FastAPI ``TestClient`` wired to use it.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import boto3
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
|
from moto import mock_aws
|
||||||
from sqlalchemy import StaticPool, event
|
from sqlalchemy import StaticPool, event
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.db import Base, get_session
|
from app.db import Base, get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.models import Subscription, User
|
from app.models import Plugin, Subscription, User
|
||||||
|
|
||||||
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
||||||
|
|
||||||
@@ -104,6 +109,79 @@ def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # n
|
|||||||
app.dependency_overrides.pop(get_session, None)
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Seed data helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
Plugin(
|
||||||
|
id="plugin-github-sync",
|
||||||
|
name="GitHub Sync",
|
||||||
|
description="Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
version="1.0.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=0,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-slack-notify",
|
||||||
|
name="Slack Notifier",
|
||||||
|
description="Post task and timeline updates to Slack channels.",
|
||||||
|
version="1.2.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="communication",
|
||||||
|
price_cents=499,
|
||||||
|
permissions=json.dumps(["read:tasks", "read:timelines"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-time-tracker",
|
||||||
|
name="Time Tracker",
|
||||||
|
description="Track time spent on tasks with automatic reporting.",
|
||||||
|
version="0.9.1",
|
||||||
|
author_name="Third Party",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=999,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
|
||||||
|
"""Insert the 3 default approved plugins and return them."""
|
||||||
|
plugins = []
|
||||||
|
for template in _SEED_PLUGINS:
|
||||||
|
p = Plugin(
|
||||||
|
id=template.id,
|
||||||
|
name=template.name,
|
||||||
|
description=template.description,
|
||||||
|
version=template.version,
|
||||||
|
author_name=template.author_name,
|
||||||
|
category=template.category,
|
||||||
|
price_cents=template.price_cents,
|
||||||
|
permissions=template.permissions,
|
||||||
|
status=template.status,
|
||||||
|
s3_package_key=template.s3_package_key,
|
||||||
|
install_count=template.install_count,
|
||||||
|
avg_rating=template.avg_rating,
|
||||||
|
)
|
||||||
|
db_session.add(p)
|
||||||
|
plugins.append(p)
|
||||||
|
await db_session.commit()
|
||||||
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
# ── JWT helpers ──────────────────────────────────────────────────────
|
# ── JWT helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -134,21 +212,24 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
|
|||||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||||
|
|
||||||
|
|
||||||
# ── CLI options ───────────────────────────────────────────────────────
|
# ── S3 mock fixture ──────────────────────────────────────────────────
|
||||||
|
|
||||||
def pytest_addoption(parser):
|
S3_TEST_BUCKET = "test-bucket"
|
||||||
parser.addoption(
|
S3_TEST_REGION = "us-east-1"
|
||||||
"--preprocess-dir",
|
|
||||||
default=None,
|
|
||||||
help="Override fixture folder for preprocessor tests (must contain cases.yaml + data/)",
|
@pytest.fixture
|
||||||
)
|
def s3_bucket():
|
||||||
parser.addoption(
|
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
|
||||||
"--runner-dir",
|
with mock_aws():
|
||||||
default=None,
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
||||||
help="Override fixture folder for agent_runner_v2 eval tests (must contain cases.yaml + data/)",
|
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
||||||
)
|
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
|
||||||
parser.addoption(
|
client = boto3.client("s3", region_name=S3_TEST_REGION)
|
||||||
"--journey-dir",
|
client.create_bucket(Bucket=S3_TEST_BUCKET)
|
||||||
default=None,
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||||
help="Override fixture folder for journey_v2 eval tests (must contain cases.yaml + data/)",
|
mock_settings.S3_BUCKET = S3_TEST_BUCKET
|
||||||
)
|
mock_settings.S3_REGION = S3_TEST_REGION
|
||||||
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||||
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||||
|
yield S3_TEST_BUCKET
|
||||||
|
|||||||
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
@@ -1,86 +0,0 @@
|
|||||||
# Agent Runner V2 — eval test cases (Step 2, requires real LLM)
|
|
||||||
#
|
|
||||||
# Each case drives one parametrized `test_eval_runner` invocation.
|
|
||||||
#
|
|
||||||
# Keys
|
|
||||||
# ----
|
|
||||||
# id: str unique identifier shown in pytest output
|
|
||||||
# description: str human-readable label
|
|
||||||
# file: str filename inside data/
|
|
||||||
# file_path: str path reported to the executor (affects project-matching via filename)
|
|
||||||
# projects: [alpha|beta] symbolic project names resolved by the test helper
|
|
||||||
#
|
|
||||||
# Optional pre-existing records (dedup tests)
|
|
||||||
# existing_tasks: list of {id, title, status, priority}
|
|
||||||
# existing_notes: list of {id, title, content}
|
|
||||||
# existing_timelines: list of {id, title, date}
|
|
||||||
#
|
|
||||||
# Assertions (one or more)
|
|
||||||
# expect_insert: <table> at least 1 insert row in this table (tasks|notes|timelines)
|
|
||||||
# expect_no_insert: true zero inserts in any table
|
|
||||||
# expect_project_id: <id> any insert must carry this projectId
|
|
||||||
# expect_dedup: true task inserts == 0 OR task updates >= 1 (dedup check)
|
|
||||||
#
|
|
||||||
# Langfuse
|
|
||||||
# score_name: str observation score name
|
|
||||||
|
|
||||||
- id: "2.1"
|
|
||||||
description: "Action email → create_task"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/ProjectAlpha_action.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_insert: tasks
|
|
||||||
score_name: runner.email_to_task
|
|
||||||
|
|
||||||
- id: "2.2"
|
|
||||||
description: "Informational email → create_note"
|
|
||||||
file: email_info.html
|
|
||||||
file_path: /emails/ProjectAlpha_info.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_insert: notes
|
|
||||||
score_name: runner.email_to_note
|
|
||||||
|
|
||||||
- id: "2.3"
|
|
||||||
description: "Email with meeting date → create_timeline"
|
|
||||||
file: email_date.html
|
|
||||||
file_path: /emails/ProjectAlpha_kickoff.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_insert: timelines
|
|
||||||
score_name: runner.email_to_timeline
|
|
||||||
|
|
||||||
- id: "2.4"
|
|
||||||
description: "Filename contains project name → correct project assigned"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/ProjectAlpha_report.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_project_id: proj-alpha
|
|
||||||
score_name: runner.project_filename
|
|
||||||
|
|
||||||
- id: "2.5"
|
|
||||||
description: "Email body mentions project → correct project assigned"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/email_001.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_project_id: proj-alpha
|
|
||||||
score_name: runner.project_content
|
|
||||||
|
|
||||||
- id: "2.6"
|
|
||||||
description: "Newsletter + global rule no-project → no creates"
|
|
||||||
file: email_no_project.html
|
|
||||||
file_path: /emails/newsletter.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_no_insert: true
|
|
||||||
score_name: runner.no_project
|
|
||||||
|
|
||||||
- id: "2.7"
|
|
||||||
description: "Existing task with same title → dedup (update not create)"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/ProjectAlpha_followup.html
|
|
||||||
projects: [alpha]
|
|
||||||
existing_tasks:
|
|
||||||
- id: task-existing
|
|
||||||
title: Fix the login bug
|
|
||||||
status: todo
|
|
||||||
priority: medium
|
|
||||||
expect_dedup: true
|
|
||||||
score_name: runner.dedup
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> boss@company.com</p>
|
|
||||||
<p><b>To:</b> dev@company.com</p>
|
|
||||||
<p><b>Subject:</b> Fix the login bug</p>
|
|
||||||
<p><b>Date:</b> 2026-04-07</p>
|
|
||||||
<p>Hi,<br>Please fix the login bug in Project Alpha by Friday. High priority!</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> pm@company.com</p>
|
|
||||||
<p><b>Subject:</b> Project Alpha kick-off meeting</p>
|
|
||||||
<p>The kick-off meeting for Project Alpha is scheduled for 2026-04-15 at 10:00.</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> pm@company.com</p>
|
|
||||||
<p><b>To:</b> team@company.com</p>
|
|
||||||
<p><b>Subject:</b> FYI: New policy for Project Alpha</p>
|
|
||||||
<p>Just a heads-up that starting next week all code reviews must be done
|
|
||||||
within 24 hours for Project Alpha. No action needed from you now.</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> newsletter@ads.com</p>
|
|
||||||
<p><b>Subject:</b> Weekly newsletter</p>
|
|
||||||
<p>Check out our latest deals on electronics!</p>
|
|
||||||
</body></html>
|
|
||||||
19
tests/fixtures/journey_v2/cases.yaml
vendored
19
tests/fixtures/journey_v2/cases.yaml
vendored
@@ -1,19 +0,0 @@
|
|||||||
# Journey V2 eval test cases — Step 4
|
|
||||||
#
|
|
||||||
# Only case 4.1 is kept as an automated eval. Cases 4.2–4.5 (multi-turn
|
|
||||||
# conversations that expect the LLM to produce a complete AgentConfig)
|
|
||||||
# are non-deterministic and tested manually — results tracked in Langfuse.
|
|
||||||
#
|
|
||||||
# Assertion keys:
|
|
||||||
# expect_question: true → first reply must contain "?"
|
|
||||||
|
|
||||||
- id: "4.1"
|
|
||||||
description: "Journey start explores directory, first reply contains a question"
|
|
||||||
directory: "/test/emails"
|
|
||||||
data_types: ["tasks", "notes", "timelines"]
|
|
||||||
directory_files:
|
|
||||||
- path: "/test/emails/outlook_export_2024.html"
|
|
||||||
content_file: "email_action.html"
|
|
||||||
user_messages: []
|
|
||||||
score_name: "journey.start"
|
|
||||||
expect_question: true
|
|
||||||
23
tests/fixtures/journey_v2/data/email_action.html
vendored
23
tests/fixtures/journey_v2/data/email_action.html
vendored
@@ -1,23 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>Email: Fix the login bug</title>
|
|
||||||
<style>body { font-family: Arial; } .header { color: #666; }</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="header">
|
|
||||||
<p><strong>From:</strong> boss@company.com</p>
|
|
||||||
<p><strong>To:</strong> dev@company.com</p>
|
|
||||||
<p><strong>Subject:</strong> Fix the login bug</p>
|
|
||||||
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:15:00 +0000</p>
|
|
||||||
</div>
|
|
||||||
<div class="body">
|
|
||||||
<p>Hi,</p>
|
|
||||||
<p>Please fix the login bug in Project Alpha as soon as possible.
|
|
||||||
Users are reporting that they can't log in with their Google accounts.
|
|
||||||
This is blocking the whole team. Please resolve it by Friday.</p>
|
|
||||||
<p>Thanks,<br>Boss</p>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
23
tests/fixtures/journey_v2/data/email_info.html
vendored
23
tests/fixtures/journey_v2/data/email_info.html
vendored
@@ -1,23 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>Email: New policy update</title>
|
|
||||||
<style>body { font-family: Arial; }</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="header">
|
|
||||||
<p><strong>From:</strong> hr@company.com</p>
|
|
||||||
<p><strong>To:</strong> all@company.com</p>
|
|
||||||
<p><strong>Subject:</strong> FYI: New remote work policy effective May 1</p>
|
|
||||||
<p><strong>Date:</strong> Tue, 8 Apr 2026 10:00:00 +0000</p>
|
|
||||||
</div>
|
|
||||||
<div class="body">
|
|
||||||
<p>Hi everyone,</p>
|
|
||||||
<p>Just a heads-up that starting May 1, 2026 the company will be moving to
|
|
||||||
a hybrid work model. You will be expected to come into the office at least
|
|
||||||
two days per week. More details will follow in the employee handbook.</p>
|
|
||||||
<p>Best,<br>HR Team</p>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
68
tests/fixtures/preprocessors/cases.yaml
vendored
68
tests/fixtures/preprocessors/cases.yaml
vendored
@@ -1,68 +0,0 @@
|
|||||||
# Preprocessor test cases
|
|
||||||
#
|
|
||||||
# detect: <expected_type> → chiama detect_content_type(filename, content)
|
|
||||||
# process: <content_type> → chiama preprocess(content_type, content)
|
|
||||||
#
|
|
||||||
# Sorgente: file: <nome in data/> oppure generate: binary_noise
|
|
||||||
#
|
|
||||||
# Assertions piatte (solo per process):
|
|
||||||
# no_html: true clean_text senza tag HTML
|
|
||||||
# min_chars: N len(clean_text) >= N
|
|
||||||
# ratio_lt: F len(clean) / len(raw) < F
|
|
||||||
# has_meta: [k, ...] chiavi presenti in metadata
|
|
||||||
# contains: str | [str] substring(s) presenti in clean_text
|
|
||||||
# excludes: str | [str] substring(s) assenti da clean_text
|
|
||||||
# content_type: str result.content_type == questo valore
|
|
||||||
|
|
||||||
- id: "1.1"
|
|
||||||
file: email_action.html
|
|
||||||
detect: email_html
|
|
||||||
|
|
||||||
- id: "1.2"
|
|
||||||
file: generic_page.html
|
|
||||||
detect: generic_html
|
|
||||||
|
|
||||||
- id: "1.3"
|
|
||||||
file: notes.txt
|
|
||||||
detect: plain_text
|
|
||||||
|
|
||||||
- id: "1.4"
|
|
||||||
file: archive.xyz
|
|
||||||
generate: binary_noise
|
|
||||||
detect: unknown
|
|
||||||
|
|
||||||
- id: "1.5"
|
|
||||||
file: email_action.html
|
|
||||||
process: email_html
|
|
||||||
no_html: true
|
|
||||||
min_chars: 50
|
|
||||||
ratio_lt: 0.8
|
|
||||||
|
|
||||||
- id: "1.6"
|
|
||||||
file: email_action.html
|
|
||||||
process: email_html
|
|
||||||
has_meta: [subject, from]
|
|
||||||
|
|
||||||
- id: "1.7"
|
|
||||||
file: email_thread.html
|
|
||||||
process: email_html
|
|
||||||
contains: "Sure, I'll handle the deploy"
|
|
||||||
excludes: "Let's plan the deploy"
|
|
||||||
|
|
||||||
- id: "1.8"
|
|
||||||
file: email_single.html
|
|
||||||
process: email_html
|
|
||||||
contains: "deploy is done"
|
|
||||||
|
|
||||||
- id: "1.9"
|
|
||||||
file: email_heavy.html
|
|
||||||
process: email_html
|
|
||||||
no_html: true
|
|
||||||
min_chars: 30
|
|
||||||
excludes: [border-collapse, font-size]
|
|
||||||
|
|
||||||
- id: "1.10"
|
|
||||||
file: fallback.txt
|
|
||||||
process: unknown
|
|
||||||
min_chars: 1
|
|
||||||
content_type: unknown
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<title>Fix the login bug</title>
|
|
||||||
<style>
|
|
||||||
body { font-family: Arial, sans-serif; color: #333; margin: 0; padding: 20px; }
|
|
||||||
.header { background: #f5f5f5; padding: 10px; border-bottom: 1px solid #ddd; }
|
|
||||||
.body { padding: 20px; }
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="header">
|
|
||||||
<p><strong>From:</strong> boss@company.com</p>
|
|
||||||
<p><strong>To:</strong> dev@company.com</p>
|
|
||||||
<p><strong>Subject:</strong> Fix the login bug</p>
|
|
||||||
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:00:00 +0200</p>
|
|
||||||
</div>
|
|
||||||
<div class="body">
|
|
||||||
<p>Hi,</p>
|
|
||||||
<p>Please fix the login bug by Friday. It is blocking the release.</p>
|
|
||||||
<p>Priority: high. Let me know if you need anything.</p>
|
|
||||||
<p>Thanks,<br>Boss</p>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<style>
|
|
||||||
table { border-collapse: collapse; width: 100%; max-width: 600px; margin: 0 auto; }
|
|
||||||
td { padding: 8px 12px; border: 1px solid #dddddd; font-size: 12px; color: #444444; }
|
|
||||||
.header-row { background-color: #003366; color: #ffffff; font-weight: bold; }
|
|
||||||
.label-col { background-color: #f0f0f0; width: 80px; font-weight: bold; }
|
|
||||||
.footer-row { font-size: 10px; color: #999999; text-align: center; }
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body bgcolor="#eeeeee">
|
|
||||||
<center>
|
|
||||||
<table cellpadding="0" cellspacing="0">
|
|
||||||
<tr class="header-row">
|
|
||||||
<td colspan="2">Company Internal Update</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td class="label-col">From:</td>
|
|
||||||
<td>newsletter@corp.com</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td class="label-col">Subject:</td>
|
|
||||||
<td>Q1 Results Update</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td class="label-col">Date:</td>
|
|
||||||
<td>Apr 7, 2026</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td colspan="2">
|
|
||||||
<table width="100%" cellpadding="10">
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<p style="font-size:14px; font-weight:bold;">Dear Team,</p>
|
|
||||||
<p>Q1 results are in. Revenue up 15% year-over-year.</p>
|
|
||||||
<p>Please review the attached report and share any feedback by EOW.</p>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr class="footer-row">
|
|
||||||
<td colspan="2">Confidential — do not forward outside the company.</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
</center>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html><body>
|
|
||||||
<p><strong>From:</strong> alice@co.com</p>
|
|
||||||
<p><strong>To:</strong> team@co.com</p>
|
|
||||||
<p><strong>Subject:</strong> Quick update</p>
|
|
||||||
<p><strong>Date:</strong> Tue, 7 Apr 2026 10:30:00 +0200</p>
|
|
||||||
<p>The deploy is done. Everything looks good. No issues so far.</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html><body>
|
|
||||||
<div class="message-latest">
|
|
||||||
<p><strong>From:</strong> alice@co.com</p>
|
|
||||||
<p><strong>Subject:</strong> Re: Re: Deploy plan</p>
|
|
||||||
<p>Sure, I'll handle the deploy.</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<p>On Mon, Apr 6, 2026 at 3:00 PM, Bob <bob@co.com> wrote:</p>
|
|
||||||
<blockquote>
|
|
||||||
<p>From: bob@co.com</p>
|
|
||||||
<p>Can you handle the deploy?</p>
|
|
||||||
<p>On Sun, Apr 5, 2026 at 1:00 PM, Alice <alice@co.com> wrote:</p>
|
|
||||||
<blockquote>
|
|
||||||
<p>From: alice@co.com</p>
|
|
||||||
<p>Let's plan the deploy for Monday.</p>
|
|
||||||
<p>On Sat, Apr 4, 2026 at 11:00 AM, Charlie <charlie@co.com> wrote:</p>
|
|
||||||
<blockquote>
|
|
||||||
<p>From: charlie@co.com</p>
|
|
||||||
<p>We need to schedule the deploy. What day works?</p>
|
|
||||||
</blockquote>
|
|
||||||
</blockquote>
|
|
||||||
</blockquote>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
random text content without any structure
|
|
||||||
line two with some words
|
|
||||||
line three and more content here
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>My Web App</title>
|
|
||||||
<link rel="stylesheet" href="styles.css">
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<nav>
|
|
||||||
<a href="/">Home</a>
|
|
||||||
<a href="/about">About</a>
|
|
||||||
<a href="/contact">Contact</a>
|
|
||||||
</nav>
|
|
||||||
<main>
|
|
||||||
<header>
|
|
||||||
<h1>Welcome to My App</h1>
|
|
||||||
</header>
|
|
||||||
<article>
|
|
||||||
<p>This is a generic web page with no email headers.</p>
|
|
||||||
<p>It has navigation, main content, and a footer.</p>
|
|
||||||
</article>
|
|
||||||
<section>
|
|
||||||
<h2>Features</h2>
|
|
||||||
<ul>
|
|
||||||
<li>Fast</li>
|
|
||||||
<li>Reliable</li>
|
|
||||||
<li>Secure</li>
|
|
||||||
</ul>
|
|
||||||
</section>
|
|
||||||
</main>
|
|
||||||
<footer>
|
|
||||||
<p>© 2026 My App</p>
|
|
||||||
</footer>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
15
tests/fixtures/preprocessors/data/notes.txt
vendored
15
tests/fixtures/preprocessors/data/notes.txt
vendored
@@ -1,15 +0,0 @@
|
|||||||
Meeting notes - April 7, 2026
|
|
||||||
|
|
||||||
Attendees: Alice, Bob, Charlie
|
|
||||||
|
|
||||||
Discussion points:
|
|
||||||
- Deploy scheduled for Friday
|
|
||||||
- Bug fix for login must be completed by Thursday
|
|
||||||
- Review Q1 numbers before EOW
|
|
||||||
|
|
||||||
Action items:
|
|
||||||
- Alice: fix login bug
|
|
||||||
- Bob: prepare deploy checklist
|
|
||||||
- Charlie: send Q1 report
|
|
||||||
|
|
||||||
Next meeting: April 14, 2026
|
|
||||||
@@ -10,13 +10,13 @@ Coverage:
|
|||||||
- run_local_agent — file-read timeout path
|
- run_local_agent — file-read timeout path
|
||||||
- run_local_agent — LLM extraction error path
|
- run_local_agent — LLM extraction error path
|
||||||
- run_cloud_agent — stub returns error immediately
|
- run_cloud_agent — stub returns error immediately
|
||||||
- trigger_pending_runs — skipped when config is client-owned
|
- trigger_pending_runs — overdue local + cloud dispatched
|
||||||
- trigger_pending_runs — non-overdue skipped
|
- trigger_pending_runs — non-overdue skipped
|
||||||
- trigger_pending_runs — device_id filter for local agents
|
- trigger_pending_runs — device_id filter for local agents
|
||||||
|
|
||||||
Integration:
|
Integration:
|
||||||
- POST /agents/can-create — billing eligibility check
|
- POST /agents/{id}/run — 404 on unknown agent
|
||||||
- POST /agents/trigger — creates run log + dispatches background task
|
- POST /agents/{id}/run — creates run log + dispatches background task
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path():
|
|||||||
assert kwargs["items_processed"] == 1
|
assert kwargs["items_processed"] == 1
|
||||||
assert kwargs["items_created"] == 1
|
assert kwargs["items_created"] == 1
|
||||||
assert kwargs["errors"] == []
|
assert kwargs["errors"] == []
|
||||||
assert kwargs["update_config_last_run"] is False
|
assert kwargs["update_config_last_run"] is True
|
||||||
|
|
||||||
# Verify agent_run frame was sent.
|
# Verify agent_run frame was sent.
|
||||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||||
@@ -690,11 +690,31 @@ async def test_finalize_run_updates_cloud_config_last_run_at():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_no_overdue():
|
async def test_trigger_pending_runs_no_overdue():
|
||||||
"""Pending-run scan is skipped because agent config is client-owned."""
|
"""If no agents are overdue trigger_pending_runs does nothing."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
config = _make_local_config()
|
||||||
|
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
||||||
|
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -702,11 +722,31 @@ async def test_trigger_pending_runs_no_overdue():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_device_id_filter():
|
async def test_trigger_pending_runs_device_id_filter():
|
||||||
"""Device filtering is no longer backend-managed in pending runs."""
|
"""Local agents are only triggered for the matching device_id."""
|
||||||
|
# The DB query already filters by device_id, so we verify the SELECT
|
||||||
|
# includes the device_id filter by checking that a config bound to a
|
||||||
|
# different device is never dispatched.
|
||||||
|
#
|
||||||
|
# Since trigger_pending_runs queries with device_id == "dev-001",
|
||||||
|
# simulate the DB returning an empty list (as it would for a mismatch).
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
mgr = _make_manager(device_id="dev-001")
|
mgr = _make_manager(device_id="dev-001")
|
||||||
|
|
||||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
mock_run.assert_not_called()
|
||||||
@@ -714,18 +754,56 @@ async def test_trigger_pending_runs_device_id_filter():
|
|||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_pending_runs_dispatches_overdue():
|
async def test_trigger_pending_runs_dispatches_overdue():
|
||||||
"""No pending runs are dispatched by backend after config deprecation."""
|
"""Overdue local agent triggers run_local_agent sequentially."""
|
||||||
|
config = _make_local_config() # last_run_at=None → always overdue
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
mgr = _make_manager()
|
mgr = _make_manager()
|
||||||
|
|
||||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
call_order: list[str] = []
|
||||||
|
|
||||||
|
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
||||||
|
call_order.append("run_local")
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
||||||
|
# First call: query configs. Subsequent calls: create run_log.
|
||||||
|
mock_query_ctx = AsyncMock()
|
||||||
|
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
||||||
|
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_query_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
|
||||||
|
run_log_obj = AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=config.id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=_FREE_UID,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
mock_insert_ctx = AsyncMock()
|
||||||
|
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
||||||
|
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_insert_ctx.add = MagicMock()
|
||||||
|
mock_insert_ctx.commit = AsyncMock()
|
||||||
|
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
||||||
|
|
||||||
|
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
||||||
|
|
||||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
mock_run.assert_not_called()
|
assert call_order == ["run_local"]
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Integration: POST /agents/can-create and /agents/trigger
|
# Integration: POST /agents/{id}/run
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -742,67 +820,50 @@ def _override_db(db_session):
|
|||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_can_create_agent_allows_when_under_limit(client):
|
async def test_trigger_run_unknown_agent(client):
|
||||||
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/v1/agents/can-create",
|
f"/api/v1/agents/{uuid.uuid4()}/run",
|
||||||
json={"active_agents": 0},
|
headers=auth_header("power"),
|
||||||
headers=auth_header("free"),
|
|
||||||
)
|
)
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 404
|
||||||
body = resp.json()
|
|
||||||
assert body["allowed"] is True
|
|
||||||
assert body["tier"] == "free"
|
|
||||||
assert body["active_agents"] == 0
|
|
||||||
assert body["limit"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_can_create_agent_denies_when_at_limit(client):
|
|
||||||
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
|
||||||
resp = client.post(
|
|
||||||
"/api/v1/agents/can-create",
|
|
||||||
json={"active_agents": 2},
|
|
||||||
headers=auth_header("free"),
|
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
|
||||||
body = resp.json()
|
|
||||||
assert body["allowed"] is False
|
|
||||||
assert body["limit"] == 2
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||||
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
||||||
dispatched: list[tuple[str, str]] = []
|
# Create the local agent config in the DB.
|
||||||
|
config = LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=TEST_USER_IDS["power"],
|
||||||
|
device_id="dev-001",
|
||||||
|
name="My Agent",
|
||||||
|
directory_paths=["/home/user/docs"],
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks.",
|
||||||
|
file_extensions=[".txt"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
db_session.add(config)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
dispatched: list = []
|
||||||
|
|
||||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||||
dispatched.append((user_id, cfg.id))
|
dispatched.append((user_id, cfg.id))
|
||||||
|
|
||||||
def _fake_create_task(coro):
|
|
||||||
coro.close()
|
|
||||||
return MagicMock()
|
|
||||||
|
|
||||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||||
|
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
||||||
patch("asyncio.create_task") as mock_create_task:
|
patch("asyncio.create_task") as mock_create_task:
|
||||||
mock_create_task.side_effect = _fake_create_task
|
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
"/api/v1/agents/trigger",
|
f"/api/v1/agents/{config.id}/run",
|
||||||
json={
|
|
||||||
"directory": "/home/user/docs",
|
|
||||||
"what_to_extract": ["task", "note"],
|
|
||||||
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
|
|
||||||
"batch_interval": "0 */6 * * *",
|
|
||||||
"custom_agent_prompt": "Extract tasks and notes.",
|
|
||||||
"active_agents": 0,
|
|
||||||
},
|
|
||||||
headers=auth_header("power"),
|
headers=auth_header("power"),
|
||||||
)
|
)
|
||||||
|
|
||||||
assert resp.status_code == 202
|
assert resp.status_code == 202
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert isinstance(data["agent_id"], str)
|
assert data["agent_id"] == config.id
|
||||||
assert data["agent_id"]
|
|
||||||
assert data["status"] == "running"
|
assert data["status"] == "running"
|
||||||
assert data["agent_type"] == "local"
|
assert data["agent_type"] == "local"
|
||||||
|
|
||||||
|
|||||||
@@ -1,432 +0,0 @@
|
|||||||
"""Tests for Local Agent V2 runner (Step 2).
|
|
||||||
|
|
||||||
Covers the unified per-file flow:
|
|
||||||
Phase A — detect + preprocess (Python, zero LLM)
|
|
||||||
Phase B — single LLM call with tools (classify + extract + create)
|
|
||||||
|
|
||||||
Fixture-based eval tests (2.1–2.7)
|
|
||||||
-----------------------------------
|
|
||||||
Cases are defined in tests/fixtures/agent_runner_v2/cases.yaml.
|
|
||||||
Email HTML files live in tests/fixtures/agent_runner_v2/data/.
|
|
||||||
Use --runner-dir to point at a custom folder (same structure required).
|
|
||||||
|
|
||||||
Unit tests (no LLM)
|
|
||||||
--------------------
|
|
||||||
2.8 items_created count → items_created == N create_* calls
|
|
||||||
2.9 Device offline → status=error
|
|
||||||
2.10 Empty file → items_processed=0, status=success
|
|
||||||
|
|
||||||
Run:
|
|
||||||
pytest tests/test_agent_runner_v2.py -v
|
|
||||||
pytest tests/test_agent_runner_v2.py -v -k "2_9 or 2_10 or 2_8" # unit only
|
|
||||||
pytest tests/test_agent_runner_v2.py -v -k "eval" # LLM evals only
|
|
||||||
pytest tests/test_agent_runner_v2.py -v --runner-dir /path/to/dir # custom fixtures
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from app.core.agent_runner import (
|
|
||||||
_format_metadata,
|
|
||||||
_format_projects,
|
|
||||||
_get_extraction_rules,
|
|
||||||
_get_no_match_behavior,
|
|
||||||
_is_overdue,
|
|
||||||
run_local_agent,
|
|
||||||
)
|
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
|
||||||
from app.core.langfuse_client import get_langfuse
|
|
||||||
from app.models import AgentRunLog, LocalAgentConfig
|
|
||||||
from tests.conftest import TEST_USER_IDS
|
|
||||||
|
|
||||||
# ── Constants ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_USER_ID = TEST_USER_IDS["power"]
|
|
||||||
|
|
||||||
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "agent_runner_v2"
|
|
||||||
|
|
||||||
_AGENT_CONFIG = {
|
|
||||||
"content_types": [
|
|
||||||
{
|
|
||||||
"id": "email_html",
|
|
||||||
"label": "Email HTML",
|
|
||||||
"detection_hint": "HTML file with From/To/Subject headers",
|
|
||||||
"preprocessing": "email_html",
|
|
||||||
"extraction_prompt": (
|
|
||||||
"If the email contains a direct action request or task assignment → create a task. "
|
|
||||||
"If the email contains informational content, updates, or FYI → create a note. "
|
|
||||||
"If the email mentions a specific date for a meeting or deadline → create a timeline entry."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"global_rules": [
|
|
||||||
"Se il file non è riconducibile a nessun progetto, non creare alcuna entità."
|
|
||||||
],
|
|
||||||
"data_types": ["tasks", "notes", "timelines"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Canonical project definitions, referenced symbolically in cases.yaml.
|
|
||||||
_PROJECTS: dict[str, dict] = {
|
|
||||||
"alpha": {"id": "proj-alpha", "name": "Project Alpha", "status": "active"},
|
|
||||||
"beta": {"id": "proj-beta", "name": "Project Beta", "status": "active"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixture loading ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _fixtures_dir(config) -> Path:
|
|
||||||
override = config.getoption("--runner-dir")
|
|
||||||
return Path(override) if override else _DEFAULT_FIXTURE_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cases(config) -> list[dict]:
|
|
||||||
return yaml.safe_load(
|
|
||||||
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_case_file(case: dict, data_dir: Path) -> str:
|
|
||||||
return (data_dir / case["file"]).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_projects(entries: list[str | dict]) -> list[dict]:
|
|
||||||
"""Resolve project list from YAML: symbolic names and/or inline dicts."""
|
|
||||||
result = []
|
|
||||||
for entry in entries:
|
|
||||||
if isinstance(entry, str):
|
|
||||||
if entry in _PROJECTS:
|
|
||||||
result.append(_PROJECTS[entry])
|
|
||||||
elif isinstance(entry, dict):
|
|
||||||
result.append(entry)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ── pytest_generate_tests — parametrize eval tests from YAML ─────────────
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
if "runner_case" not in metafunc.fixturenames:
|
|
||||||
return
|
|
||||||
cases = _load_cases(metafunc.config)
|
|
||||||
metafunc.parametrize("runner_case", cases, ids=[c["id"] for c in cases])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Test helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _make_config(
|
|
||||||
agent_config: dict | None = None,
|
|
||||||
directory: str = "/emails",
|
|
||||||
device_id: str = "dev-001",
|
|
||||||
) -> LocalAgentConfig:
|
|
||||||
return LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=_USER_ID,
|
|
||||||
device_id=device_id,
|
|
||||||
name="Test V2 Agent",
|
|
||||||
directory_paths=[directory],
|
|
||||||
data_types=["tasks", "notes", "timelines"],
|
|
||||||
prompt_template="",
|
|
||||||
agent_config=agent_config or _AGENT_CONFIG,
|
|
||||||
file_extensions=[".html", ".eml"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
last_run_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_run_log(agent_id: str) -> AgentRunLog:
|
|
||||||
return AgentRunLog(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=_USER_ID,
|
|
||||||
status="running",
|
|
||||||
started_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_manager(online: bool = True) -> DeviceConnectionManager:
|
|
||||||
mgr = DeviceConnectionManager()
|
|
||||||
if online:
|
|
||||||
ws = MagicMock()
|
|
||||||
ws.send_text = AsyncMock()
|
|
||||||
mgr.register(_USER_ID, "dev-001", ws)
|
|
||||||
return mgr
|
|
||||||
|
|
||||||
|
|
||||||
def _make_executor(
|
|
||||||
file_path: str,
|
|
||||||
file_content: str,
|
|
||||||
projects: list[dict] | None = None,
|
|
||||||
existing_tasks: list[dict] | None = None,
|
|
||||||
existing_notes: list[dict] | None = None,
|
|
||||||
existing_timelines: list[dict] | None = None,
|
|
||||||
) -> tuple[Any, list[dict]]:
|
|
||||||
"""Return (async_executor, captured_calls).
|
|
||||||
|
|
||||||
The executor handles all ``execute_on_client`` payloads:
|
|
||||||
directory listing, file reading, project/entity fetching, and CRUD.
|
|
||||||
"""
|
|
||||||
calls: list[dict] = []
|
|
||||||
_projects = projects if projects is not None else list(_PROJECTS.values())
|
|
||||||
|
|
||||||
async def _executor(payload: dict) -> dict:
|
|
||||||
action = payload.get("action", "")
|
|
||||||
table = payload.get("table", "")
|
|
||||||
data = payload.get("data") or {}
|
|
||||||
calls.append({"action": action, "table": table, "data": data})
|
|
||||||
|
|
||||||
if action == "list_directory":
|
|
||||||
return {"entries": [{"type": "file", "path": file_path}]}
|
|
||||||
|
|
||||||
if action == "get_file_metadata":
|
|
||||||
return {"modifiedAt": None}
|
|
||||||
|
|
||||||
if action == "read_file_content":
|
|
||||||
return {"content": file_content}
|
|
||||||
|
|
||||||
if action == "select":
|
|
||||||
if table == "projects":
|
|
||||||
return {"rows": _projects}
|
|
||||||
if table == "tasks":
|
|
||||||
return {"rows": existing_tasks or []}
|
|
||||||
if table == "notes":
|
|
||||||
return {"rows": existing_notes or []}
|
|
||||||
if table == "timelines":
|
|
||||||
return {"rows": existing_timelines or []}
|
|
||||||
return {"rows": []}
|
|
||||||
|
|
||||||
if action == "insert":
|
|
||||||
return {"row": {"id": str(uuid.uuid4()), **data}}
|
|
||||||
|
|
||||||
if action == "update":
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return _executor, calls
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: helper functions ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_projects_empty():
|
|
||||||
assert "(no projects" in _format_projects([])
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_projects_with_data():
|
|
||||||
result = _format_projects([_PROJECTS["alpha"]])
|
|
||||||
assert "proj-alpha" in result
|
|
||||||
assert "Project Alpha" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_metadata_empty():
|
|
||||||
assert _format_metadata({}) == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_metadata_email():
|
|
||||||
meta = {"subject": "Fix bug", "from": "boss@co.com", "date": "2026-04-07"}
|
|
||||||
result = _format_metadata(meta)
|
|
||||||
assert "Fix bug" in result
|
|
||||||
assert "boss@co.com" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_extraction_rules_match():
|
|
||||||
rules = _get_extraction_rules(_AGENT_CONFIG, "email_html")
|
|
||||||
assert "task" in rules.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_extraction_rules_fallback():
|
|
||||||
rules = _get_extraction_rules(_AGENT_CONFIG, "plain_text")
|
|
||||||
assert "extract" in rules.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_no_match_behavior_from_global_rules():
|
|
||||||
behavior = _get_no_match_behavior(_AGENT_CONFIG)
|
|
||||||
assert behavior # non-empty
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_no_match_behavior_default():
|
|
||||||
behavior = _get_no_match_behavior({})
|
|
||||||
assert "project" in behavior.lower()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: 2.9 — device offline ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_2_9_device_offline():
|
|
||||||
"""2.9 No device online → status=error, no executor created."""
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager(online=False)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("not connected" in e for e in kwargs.get("errors", []))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: 2.10 — empty file ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_2_10_empty_file():
|
|
||||||
"""2.10 File with empty content → skipped, items_processed=0, success."""
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
executor, calls = _make_executor(
|
|
||||||
file_path="/emails/empty.html",
|
|
||||||
file_content="",
|
|
||||||
projects=[_PROJECTS["alpha"]],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
assert kwargs["items_processed"] == 0
|
|
||||||
assert kwargs["status"] == "success"
|
|
||||||
assert kwargs["items_created"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: 2.8 — items_created count ─────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_2_8_items_created_count():
|
|
||||||
"""2.8 items_created == number of create_* tool calls per run."""
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
executor, _calls = _make_executor(
|
|
||||||
file_path="/emails/action.html",
|
|
||||||
file_content="<html><body><p>Fix the login bug in Project Alpha.</p></body></html>",
|
|
||||||
projects=[_PROJECTS["alpha"]],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_run_agent(*, _tool_calls_out=None, **kw) -> str:
|
|
||||||
if _tool_calls_out is not None:
|
|
||||||
_tool_calls_out.extend(["create_task", "create_note", "update_task"])
|
|
||||||
return "Done."
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
|
||||||
patch("app.core.agent_runner._run_agent_with_tools", side_effect=mock_run_agent), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
# Only create_task + create_note count (not update_task).
|
|
||||||
assert kwargs["items_created"] == 2
|
|
||||||
assert kwargs["items_processed"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
# ── Eval: 2.1–2.7 — fixture-driven, real LLM + Langfuse scoring ──────────
|
|
||||||
#
|
|
||||||
# Cases loaded from tests/fixtures/agent_runner_v2/cases.yaml.
|
|
||||||
# Supported assertions (from YAML):
|
|
||||||
# expect_insert: <table> → at least 1 insert in that table
|
|
||||||
# expect_no_insert: true → zero inserts in any table
|
|
||||||
# expect_project_id: <id> → any insert carries this projectId
|
|
||||||
# expect_dedup: true → task inserts == 0 OR task updates >= 1
|
|
||||||
# ─────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.eval
|
|
||||||
async def test_eval_runner(runner_case, pytestconfig):
|
|
||||||
"""Parametrized eval test — one invocation per YAML case."""
|
|
||||||
case: dict = runner_case
|
|
||||||
data_dir = _fixtures_dir(pytestconfig) / "data"
|
|
||||||
file_content = _read_case_file(case, data_dir)
|
|
||||||
projects = _resolve_projects(case.get("projects", []))
|
|
||||||
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
executor, calls = _make_executor(
|
|
||||||
file_path=case["file_path"],
|
|
||||||
file_content=file_content,
|
|
||||||
projects=projects,
|
|
||||||
existing_tasks=case.get("existing_tasks"),
|
|
||||||
existing_notes=case.get("existing_notes"),
|
|
||||||
existing_timelines=case.get("existing_timelines"),
|
|
||||||
)
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
obs_ctx = lf.start_as_current_observation(
|
|
||||||
name=f"eval-runner-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
|
|
||||||
metadata={"step": "2", "case_id": case["id"]},
|
|
||||||
) if lf else nullcontext()
|
|
||||||
|
|
||||||
with obs_ctx as obs:
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
inserts = [c for c in calls if c["action"] == "insert"]
|
|
||||||
score, comment = _evaluate_case(case, calls, kwargs)
|
|
||||||
|
|
||||||
if obs is not None:
|
|
||||||
obs.score(
|
|
||||||
name=case.get("score_name", f"runner.case_{case['id']}"),
|
|
||||||
value=score,
|
|
||||||
comment=comment,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lf:
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
assert score == 1.0, f"[{case['id']}] {case.get('description', '')} — {comment}"
|
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_case(case: dict, calls: list[dict], finalize_kwargs: dict) -> tuple[float, str]:
|
|
||||||
"""Return (score, comment) for a YAML case given the captured executor calls."""
|
|
||||||
inserts = [c for c in calls if c["action"] == "insert"]
|
|
||||||
|
|
||||||
if case.get("expect_no_insert"):
|
|
||||||
score = 1.0 if len(inserts) == 0 else 0.0
|
|
||||||
return score, f"inserts={len(inserts)} (expected 0)"
|
|
||||||
|
|
||||||
if "expect_insert" in case:
|
|
||||||
tables = case["expect_insert"]
|
|
||||||
if isinstance(tables, str):
|
|
||||||
tables = [tables]
|
|
||||||
missing = [t for t in tables if not any(c["table"] == t for c in inserts)]
|
|
||||||
score = 1.0 if not missing else 0.0
|
|
||||||
counts = {t: sum(1 for c in inserts if c["table"] == t) for t in tables}
|
|
||||||
return score, f"inserts={counts}" + (f" missing={missing}" if missing else "")
|
|
||||||
|
|
||||||
if "expect_project_id" in case:
|
|
||||||
expected_pid = case["expect_project_id"]
|
|
||||||
correct = any(c.get("data", {}).get("projectId") == expected_pid for c in inserts)
|
|
||||||
score = 1.0 if correct else 0.0
|
|
||||||
all_pids = [c.get("data", {}).get("projectId") for c in inserts]
|
|
||||||
return score, f"projectIds={all_pids} (expected {expected_pid!r})"
|
|
||||||
|
|
||||||
if case.get("expect_dedup"):
|
|
||||||
task_creates = [c for c in inserts if c["table"] == "tasks"]
|
|
||||||
task_updates = [c for c in calls if c["action"] == "update" and c["table"] == "tasks"]
|
|
||||||
score = 1.0 if len(task_creates) == 0 or len(task_updates) >= 1 else 0.0
|
|
||||||
return score, f"task_creates={len(task_creates)} task_updates={len(task_updates)}"
|
|
||||||
|
|
||||||
return 0.0, "no assertion defined in case"
|
|
||||||
243
tests/test_backup.py
Normal file
243
tests/test_backup.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""Tests for backup routes: upload, download, history, delete.
|
||||||
|
|
||||||
|
Exercises the backup lifecycle through the FastAPI TestClient against the
|
||||||
|
in-memory SQLite test database and moto-mocked S3 bucket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
from tests.conftest import auth_header, TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_BLOB = b"encrypted-backup-blob-opaque-bytes"
|
||||||
|
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||||
|
_VERSION = 1
|
||||||
|
_TIMESTAMP = 1700000000000 # arbitrary ms timestamp
|
||||||
|
|
||||||
|
|
||||||
|
def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]:
|
||||||
|
"""Return auth + backup metadata headers."""
|
||||||
|
headers = auth_header(tier)
|
||||||
|
headers["X-Backup-Version"] = str(overrides.get("version", _VERSION))
|
||||||
|
headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP))
|
||||||
|
headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM)
|
||||||
|
headers["Content-Type"] = "application/octet-stream"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821
|
||||||
|
"""Upload a backup blob and return the response."""
|
||||||
|
return client.put(
|
||||||
|
"/api/v1/backup",
|
||||||
|
content=overrides.pop("blob", _BLOB),
|
||||||
|
headers=_backup_headers(tier, **overrides),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestUploadBackup ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadBackup:
|
||||||
|
"""PUT /api/v1/backup"""
|
||||||
|
|
||||||
|
def test_upload_success(self, client, s3_bucket) -> None:
|
||||||
|
resp = _upload(client, tier="power")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
def test_upload_creates_history_entry(self, client, s3_bucket) -> None:
|
||||||
|
_upload(client, tier="power")
|
||||||
|
history = client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header("power")
|
||||||
|
).json()
|
||||||
|
assert len(history) == 1
|
||||||
|
assert history[0]["version"] == _VERSION
|
||||||
|
assert history[0]["timestamp"] == _TIMESTAMP
|
||||||
|
assert history[0]["checksum"] == _CHECKSUM
|
||||||
|
|
||||||
|
def test_upload_bad_checksum(self, client, s3_bucket) -> None:
|
||||||
|
resp = _upload(client, tier="power", checksum="0" * 64)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_upload_free_tier_blocked(self, client, s3_bucket) -> None:
|
||||||
|
"""Free tier has backup_gb=0 → should return 402."""
|
||||||
|
resp = _upload(client, tier="free")
|
||||||
|
assert resp.status_code == 402
|
||||||
|
|
||||||
|
def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None:
|
||||||
|
"""Pro tier has backup_gb=5 → small blob succeeds."""
|
||||||
|
resp = _upload(client, tier="pro")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestDownloadBackup ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDownloadBackup:
|
||||||
|
"""GET /api/v1/backup"""
|
||||||
|
|
||||||
|
def test_download_latest(self, client, s3_bucket) -> None:
|
||||||
|
_upload(client, tier="power")
|
||||||
|
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == _BLOB
|
||||||
|
assert resp.headers["X-Checksum"] == _CHECKSUM
|
||||||
|
assert resp.headers["X-Backup-Version"] == str(_VERSION)
|
||||||
|
|
||||||
|
def test_download_no_backup_returns_404(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None:
|
||||||
|
"""When If-Modified-Since is after the backup timestamp → 304."""
|
||||||
|
_upload(client, tier="power", timestamp=1700000000000)
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/backup",
|
||||||
|
headers={
|
||||||
|
**auth_header("power"),
|
||||||
|
"If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 304
|
||||||
|
|
||||||
|
def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None:
|
||||||
|
"""When If-Modified-Since is before the backup timestamp → serve blob."""
|
||||||
|
_upload(client, tier="power", timestamp=1700000000000)
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/backup",
|
||||||
|
headers={
|
||||||
|
**auth_header("power"),
|
||||||
|
"If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == _BLOB
|
||||||
|
|
||||||
|
def test_download_multiple_returns_latest(self, client, s3_bucket) -> None:
|
||||||
|
"""When multiple backups exist, GET returns the one with the highest timestamp."""
|
||||||
|
_upload(client, tier="power", timestamp=1000)
|
||||||
|
blob2 = b"second-encrypted-backup"
|
||||||
|
checksum2 = hashlib.sha256(blob2).hexdigest()
|
||||||
|
_upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2)
|
||||||
|
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == blob2
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBackupHistory ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackupHistory:
|
||||||
|
"""GET /api/v1/backup/history"""
|
||||||
|
|
||||||
|
def test_history_empty(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.get("/api/v1/backup/history", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
def test_history_returns_entries(self, client, s3_bucket) -> None:
|
||||||
|
_upload(client, tier="power", timestamp=1000)
|
||||||
|
_upload(client, tier="power", timestamp=2000)
|
||||||
|
history = client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header("power")
|
||||||
|
).json()
|
||||||
|
assert len(history) == 2
|
||||||
|
# Ordered by timestamp descending
|
||||||
|
assert history[0]["timestamp"] == 2000
|
||||||
|
assert history[1]["timestamp"] == 1000
|
||||||
|
|
||||||
|
def test_history_isolated_per_user(self, client, s3_bucket) -> None:
|
||||||
|
"""One user's backups should not appear in another user's history."""
|
||||||
|
_upload(client, tier="power")
|
||||||
|
resp = client.get("/api/v1/backup/history", headers=auth_header("team"))
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestDeleteBackup ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteBackup:
|
||||||
|
"""DELETE /api/v1/backup/{backup_id}"""
|
||||||
|
|
||||||
|
def _get_backup_id(self, client, tier="power") -> str:
|
||||||
|
"""Upload a backup and return its DB id from history."""
|
||||||
|
_upload(client, tier=tier)
|
||||||
|
client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header(tier)
|
||||||
|
).json()
|
||||||
|
# History returns BackupMetadata schema which doesn't have `id`.
|
||||||
|
# We need to look it up via a different means.
|
||||||
|
# Since there's only 1 backup, find via history length.
|
||||||
|
# Actually the schema doesn't return id — let's verify via re-download.
|
||||||
|
# We'll use a workaround: upload, then list history to confirm it exists,
|
||||||
|
# then try to delete — but we need the id...
|
||||||
|
# Let's check if history includes an id field.
|
||||||
|
# The schema is: version, timestamp, checksum, chunk_count — no id.
|
||||||
|
# We'll need to query the DB directly or use a known ID.
|
||||||
|
# For testing, we'll search history then use the DB.
|
||||||
|
return None # pragma: no cover — overridden below
|
||||||
|
|
||||||
|
def test_delete_success(self, client, s3_bucket, db_session) -> None:
|
||||||
|
_upload(client, tier="power")
|
||||||
|
|
||||||
|
# Discover the backup_id via direct DB query
|
||||||
|
import asyncio
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models import BackupMetadata
|
||||||
|
|
||||||
|
async def _get_id():
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(BackupMetadata.id).where(
|
||||||
|
BackupMetadata.user_id == TEST_USER_IDS["power"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
|
||||||
|
|
||||||
|
resp = client.delete(
|
||||||
|
f"/api/v1/backup/{backup_id}", headers=auth_header("power")
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
# History should now be empty
|
||||||
|
history = client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header("power")
|
||||||
|
).json()
|
||||||
|
assert history == []
|
||||||
|
|
||||||
|
def test_delete_nonexistent(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.delete(
|
||||||
|
"/api/v1/backup/no-such-id", headers=auth_header("power")
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None:
|
||||||
|
"""Cannot delete another user's backup (ownership check returns 404)."""
|
||||||
|
_upload(client, tier="power")
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models import BackupMetadata
|
||||||
|
|
||||||
|
async def _get_id():
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(BackupMetadata.id).where(
|
||||||
|
BackupMetadata.user_id == TEST_USER_IDS["power"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
|
||||||
|
|
||||||
|
# team user tries to delete power user's backup → 404
|
||||||
|
resp = client.delete(
|
||||||
|
f"/api/v1/backup/{backup_id}", headers=auth_header("team")
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
"""Unit tests for Step 1 file classification (_classify_file).
|
|
||||||
|
|
||||||
These tests call the real LLM so they require OPENAI_API_KEY / LLM env vars.
|
|
||||||
Run with: pytest tests/test_classify_file.py -v
|
|
||||||
|
|
||||||
To run a quick manual check against a real file without the full UI:
|
|
||||||
python -m tests.test_classify_file <path/to/file.txt> [project_name...]
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import sys
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_runner import _classify_file
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
PROJECTS_SAMPLE = [
|
|
||||||
{
|
|
||||||
"id": "aaaa-0001-0000-0000-000000000001",
|
|
||||||
"name": "ARPA Sicilia POC",
|
|
||||||
"status": "active",
|
|
||||||
"aiSummary": "Proof of concept for AI features targeting ARPA Sicilia agency.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "bbbb-0002-0000-0000-000000000002",
|
|
||||||
"name": "SNAM AI Meeting Prep",
|
|
||||||
"status": "active",
|
|
||||||
"aiSummary": "AI-assisted preparation of meeting materials for SNAM.",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "cccc-0003-0000-0000-000000000003",
|
|
||||||
"name": "SFERA+ Wave 2",
|
|
||||||
"status": "active",
|
|
||||||
"aiSummary": "Second wave of the SFERA+ whitelist project.",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
ARPA_EMAIL = """\
|
|
||||||
to: roberto.musso@hpe.com; luca.tondin@hpecds.com
|
|
||||||
isImportance: normal
|
|
||||||
hasAttachment: True
|
|
||||||
---
|
|
||||||
## Body
|
|
||||||
Buongiorno,
|
|
||||||
|
|
||||||
In riferimento alla riunione di ieri sul POC ARPA Sicilia, vi invio il riassunto
|
|
||||||
dei deliverable concordati:
|
|
||||||
- Preparare demo entro il 30 marzo
|
|
||||||
- Condividere documentazione tecnica con il team ARPA
|
|
||||||
- Fissare call di follow-up la prossima settimana
|
|
||||||
|
|
||||||
Cordiali saluti
|
|
||||||
Roberto Marchetti
|
|
||||||
"""
|
|
||||||
|
|
||||||
SNAM_EMAIL = """\
|
|
||||||
to: roberto.musso@hpe.com
|
|
||||||
isImportance: high
|
|
||||||
hasAttachment: False
|
|
||||||
---
|
|
||||||
## Body
|
|
||||||
Ciao,
|
|
||||||
ti invio l'agenda per la riunione SNAM di domani.
|
|
||||||
Per favore conferma la tua presenza.
|
|
||||||
"""
|
|
||||||
|
|
||||||
UNRELATED_EMAIL = """\
|
|
||||||
to: roberto.musso@hpe.com
|
|
||||||
isImportance: normal
|
|
||||||
---
|
|
||||||
## Body
|
|
||||||
Benvenuto nel programma HPE Employee Learning Series.
|
|
||||||
Completa la formazione richiesta entro la fine del trimestre.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tests ─────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_arpa_matches_existing():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="arpa_email.txt",
|
|
||||||
file_content=ARPA_EMAIL,
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks", "notes", "timelines"],
|
|
||||||
)
|
|
||||||
assert project_id == "aaaa-0001-0000-0000-000000000001", (
|
|
||||||
f"Expected ARPA project, got project_id={project_id!r} new_name={new_name!r}"
|
|
||||||
)
|
|
||||||
assert new_name is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_snam_matches_existing():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="snam_email.txt",
|
|
||||||
file_content=SNAM_EMAIL,
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks", "notes"],
|
|
||||||
)
|
|
||||||
assert project_id == "bbbb-0002-0000-0000-000000000002", (
|
|
||||||
f"Expected SNAM project, got project_id={project_id!r} new_name={new_name!r}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_unrelated_returns_new():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="learning_email.txt",
|
|
||||||
file_content=UNRELATED_EMAIL,
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks", "notes"],
|
|
||||||
)
|
|
||||||
assert project_id == "new"
|
|
||||||
assert new_name is not None # LLM should suggest a name
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_empty_file_returns_new():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="empty.txt",
|
|
||||||
file_content=" ",
|
|
||||||
projects=PROJECTS_SAMPLE,
|
|
||||||
config_data_types=["tasks"],
|
|
||||||
)
|
|
||||||
assert project_id == "new"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_classify_no_projects_returns_new():
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path="arpa_email.txt",
|
|
||||||
file_content=ARPA_EMAIL,
|
|
||||||
projects=[],
|
|
||||||
config_data_types=["tasks", "notes"],
|
|
||||||
)
|
|
||||||
assert project_id == "new"
|
|
||||||
assert new_name is not None
|
|
||||||
|
|
||||||
|
|
||||||
# ── CLI quick-test runner ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _cli_test(file_path: str, project_names: list[str]) -> None:
|
|
||||||
"""Run Step 1 classification against a real file from the CLI."""
|
|
||||||
import json
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
content = Path(file_path).read_text(encoding="utf-8", errors="replace")
|
|
||||||
projects = [
|
|
||||||
{"id": f"test-id-{i:04d}", "name": name, "status": "active", "aiSummary": ""}
|
|
||||||
for i, name in enumerate(project_names)
|
|
||||||
]
|
|
||||||
|
|
||||||
print(f"\nClassifying: {file_path}")
|
|
||||||
print(f"Projects in context: {[p['name'] for p in projects]}\n")
|
|
||||||
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path=file_path,
|
|
||||||
file_content=content,
|
|
||||||
projects=projects,
|
|
||||||
config_data_types=["tasks", "notes", "timelines"],
|
|
||||||
)
|
|
||||||
|
|
||||||
result = {
|
|
||||||
"project_id": project_id,
|
|
||||||
"matched_name": next((p["name"] for p in projects if p["id"] == project_id), None),
|
|
||||||
"new_project_name": new_name,
|
|
||||||
"domains": domains,
|
|
||||||
}
|
|
||||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
if len(sys.argv) < 2:
|
|
||||||
print("Usage: python -m tests.test_classify_file <file_path> [project_name ...]")
|
|
||||||
sys.exit(1)
|
|
||||||
asyncio.run(_cli_test(sys.argv[1], sys.argv[2:]))
|
|
||||||
@@ -1,288 +0,0 @@
|
|||||||
"""Unit tests for single-agent deep_agent flows with mocked tool results."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import date, timedelta
|
|
||||||
from types import SimpleNamespace
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.core.deep_agent import (
|
|
||||||
_infer_floating_domain,
|
|
||||||
_normalize_tagged_list_lines,
|
|
||||||
run_floating,
|
|
||||||
run_floating_stream,
|
|
||||||
run_home,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeTool:
|
|
||||||
name = "list_tasks"
|
|
||||||
|
|
||||||
async def ainvoke(self, args):
|
|
||||||
return {"rows": [{"id": "task-1", "title": "Mock Task"}], "echo": args}
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeLLM:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.agent_calls = 0
|
|
||||||
|
|
||||||
def bind_tools(self, _tools):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def ainvoke(self, messages):
|
|
||||||
system_prompt = str(getattr(messages[0], "content", "")) if messages else ""
|
|
||||||
if "strict domain classifier" in system_prompt:
|
|
||||||
return AIMessage(content='{"type":"timeline","id":"tl-1","section":null}')
|
|
||||||
|
|
||||||
self.agent_calls += 1
|
|
||||||
if self.agent_calls == 1:
|
|
||||||
return AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": "call-1",
|
|
||||||
"name": "list_tasks",
|
|
||||||
"args": {"project_id": "proj-1"},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_messages = [m for m in messages if isinstance(m, ToolMessage)]
|
|
||||||
assert tool_messages, "Expected at least one tool message"
|
|
||||||
return AIMessage(content=f"Final answer from mocked tool: {tool_messages[-1].content}")
|
|
||||||
|
|
||||||
async def astream(self, _messages):
|
|
||||||
yield SimpleNamespace(content="stream-")
|
|
||||||
yield SimpleNamespace(content="ok")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_home_uses_mocked_tool_result():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
||||||
):
|
|
||||||
out = await run_home("user-1", "list my tasks", {})
|
|
||||||
|
|
||||||
assert "Final answer from mocked tool" in out
|
|
||||||
assert "Mock Task" in out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"show me timeline updates",
|
|
||||||
{"scope": {"type": "timeline", "id": "tl-1"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
assert events[0] == (
|
|
||||||
"floating_domain",
|
|
||||||
{"type": "timeline", "id": "tl-1", "section": None},
|
|
||||||
)
|
|
||||||
assert ("token", "stream-") in events
|
|
||||||
assert ("token", "ok") in events
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
|
||||||
class _ClassifierOnlyLLM:
|
|
||||||
async def ainvoke(self, _messages):
|
|
||||||
return AIMessage(
|
|
||||||
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
|
||||||
domain = await _infer_floating_domain(
|
|
||||||
"Quali sono i miei task per il progetto X",
|
|
||||||
{
|
|
||||||
"scope": {"type": "timeline"},
|
|
||||||
"resolved_project_id": "213213-312321-312312-421321",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert domain == {
|
|
||||||
"type": "project",
|
|
||||||
"id": "213213-312321-312312-421321",
|
|
||||||
"section": "task",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
|
|
||||||
raw = (
|
|
||||||
"Certo!\n\n"
|
|
||||||
"1. **Task A** — priorita high <task>[task-1]</task>\n"
|
|
||||||
"2. **Task B** — priorita medium <task>[task-2]</task>\n"
|
|
||||||
)
|
|
||||||
|
|
||||||
out = _normalize_tagged_list_lines(raw, "quali sono le prossime attivita?")
|
|
||||||
|
|
||||||
assert "<task>[task-1]</task>" in out
|
|
||||||
assert "<task>[task-2]</task>" in out
|
|
||||||
assert "Task A" not in out
|
|
||||||
assert "Task B" not in out
|
|
||||||
|
|
||||||
|
|
||||||
def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_month_future_only():
|
|
||||||
today = date.today()
|
|
||||||
tomorrow = today + timedelta(days=1)
|
|
||||||
yesterday = today - timedelta(days=1)
|
|
||||||
next_month = (today.replace(day=28) + timedelta(days=5)).replace(day=1)
|
|
||||||
|
|
||||||
raw = "\n".join(
|
|
||||||
[
|
|
||||||
f"- Milestone old — {yesterday.strftime('%d/%m/%Y')} <timeline>[tl-old]</timeline>",
|
|
||||||
f"- Milestone next — {tomorrow.strftime('%d/%m/%Y')} <timeline>[tl-next]</timeline>",
|
|
||||||
f"- Milestone future — {next_month.strftime('%d/%m/%Y')} <timeline>[tl-future]</timeline>",
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
out = _normalize_tagged_list_lines(raw, "invece i miei eventi prossimi?")
|
|
||||||
|
|
||||||
assert "<timeline>[tl-next]</timeline>" in out
|
|
||||||
assert "<timeline>[tl-old]</timeline>" not in out
|
|
||||||
assert "<timeline>[tl-future]</timeline>" not in out
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_strips_xml_like_tags_from_final_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_run_single_agent(**_kwargs):
|
|
||||||
return (
|
|
||||||
"Hai 1 task:\\n"
|
|
||||||
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
||||||
):
|
|
||||||
text, _domain = await run_floating(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert "<task>" not in text
|
|
||||||
assert "</task>" not in text
|
|
||||||
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_stream(**_kwargs):
|
|
||||||
yield "token", "Hai 1 task:\\n"
|
|
||||||
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
token_events = [str(data) for event_type, data in events if event_type == "token"]
|
|
||||||
combined = "".join(token_events)
|
|
||||||
assert "<task>" not in combined
|
|
||||||
assert "</task>" not in combined
|
|
||||||
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
|
|
||||||
class _NoChunkLLM:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self.calls = 0
|
|
||||||
|
|
||||||
def bind_tools(self, _tools):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def ainvoke(self, _messages):
|
|
||||||
self.calls += 1
|
|
||||||
if self.calls == 1:
|
|
||||||
return AIMessage(
|
|
||||||
content="",
|
|
||||||
tool_calls=[
|
|
||||||
{
|
|
||||||
"id": "call-1",
|
|
||||||
"name": "list_tasks",
|
|
||||||
"args": {},
|
|
||||||
}
|
|
||||||
],
|
|
||||||
)
|
|
||||||
return AIMessage(content="No notes found.")
|
|
||||||
|
|
||||||
async def astream(self, _messages):
|
|
||||||
if False:
|
|
||||||
yield None
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
|
||||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"quali sono le note?",
|
|
||||||
{"scope": {"type": "note"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
assert events[0][0] == "floating_domain"
|
|
||||||
assert ("token", "No notes found.") in events
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_run_single_agent(**_kwargs):
|
|
||||||
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
|
||||||
):
|
|
||||||
text, _domain = await run_floating(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
)
|
|
||||||
|
|
||||||
assert text == "No results found."
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
|
|
||||||
fake_llm = _FakeLLM()
|
|
||||||
|
|
||||||
async def _fake_stream(**_kwargs):
|
|
||||||
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
|
||||||
|
|
||||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
|
||||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
|
||||||
):
|
|
||||||
events = []
|
|
||||||
async for event in run_floating_stream(
|
|
||||||
"user-1",
|
|
||||||
"quali task ho?",
|
|
||||||
{"scope": {"type": "task"}},
|
|
||||||
):
|
|
||||||
events.append(event)
|
|
||||||
|
|
||||||
assert ("token", "No results found.") in events
|
|
||||||
@@ -1,299 +0,0 @@
|
|||||||
"""Tests for Local Agent V2 journey setup (Step 4).
|
|
||||||
|
|
||||||
Covers the chatbot journey that produces a structured AgentConfig JSON
|
|
||||||
instead of a freeform prompt_template string.
|
|
||||||
|
|
||||||
Unit tests (no LLM)
|
|
||||||
--------------------
|
|
||||||
4.6a _extract_agent_config: valid JSON → returns serialised config
|
|
||||||
4.6b _extract_agent_config: invalid JSON → returns None
|
|
||||||
4.6c _extract_agent_config: markers absent → returns None
|
|
||||||
4.6d _extract_agent_config: only START marker → returns None
|
|
||||||
4.6e Session not found → done=True, agent_config=None
|
|
||||||
4.6f Nudge uses AGENT_CONFIG_START/END markers (not old PROMPT_TEMPLATE)
|
|
||||||
|
|
||||||
Eval test (real LLM + Langfuse scoring)
|
|
||||||
----------------------------------------
|
|
||||||
4.1 Journey start explores directory → first reply contains a question
|
|
||||||
|
|
||||||
Cases 4.2–4.5 (multi-turn conversations producing a full AgentConfig) are
|
|
||||||
non-deterministic and tested manually — results tracked in Langfuse.
|
|
||||||
|
|
||||||
Run:
|
|
||||||
pytest tests/test_journey_v2.py -v
|
|
||||||
pytest tests/test_journey_v2.py -v -k "4_6" # unit only
|
|
||||||
pytest tests/test_journey_v2.py -v -k "eval" # single LLM eval
|
|
||||||
pytest tests/test_journey_v2.py -v --journey-dir /p # custom fixtures
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from app.api.routes.agent_setup import (
|
|
||||||
_CONFIG_END,
|
|
||||||
_CONFIG_START,
|
|
||||||
_MAX_TURNS,
|
|
||||||
_extract_agent_config,
|
|
||||||
_sessions,
|
|
||||||
handle_journey_message,
|
|
||||||
handle_journey_start,
|
|
||||||
)
|
|
||||||
from app.core.langfuse_client import get_langfuse
|
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
|
||||||
from app.schemas import AgentConfig
|
|
||||||
from tests.conftest import TEST_USER_IDS
|
|
||||||
|
|
||||||
# ── Constants ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_USER_ID = TEST_USER_IDS["power"]
|
|
||||||
|
|
||||||
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "journey_v2"
|
|
||||||
|
|
||||||
# ── Fixture loading ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _fixtures_dir(config) -> Path:
|
|
||||||
override = config.getoption("--journey-dir")
|
|
||||||
return Path(override) if override else _DEFAULT_FIXTURE_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cases(config) -> list[dict]:
|
|
||||||
return yaml.safe_load(
|
|
||||||
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_data_file(filename: str, fixtures_dir: Path) -> str:
|
|
||||||
return (fixtures_dir / "data" / filename).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
# ── pytest_generate_tests ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
if "journey_case" not in metafunc.fixturenames:
|
|
||||||
return
|
|
||||||
cases = _load_cases(metafunc.config)
|
|
||||||
metafunc.parametrize("journey_case", cases, ids=[c["id"] for c in cases])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Executor builder ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _make_fs_executor(directory_files: list[dict], fixtures_dir: Path):
|
|
||||||
"""Return an async callback that simulates filesystem tool responses.
|
|
||||||
|
|
||||||
Matches the signature expected by ``set_client_executor`` / ``execute_on_client``:
|
|
||||||
receives the full ``payload`` dict and returns a result dict.
|
|
||||||
|
|
||||||
``directory_files`` is a list of ``{path, content_file}`` dicts;
|
|
||||||
``content_file`` is relative to ``fixtures_dir/data/``.
|
|
||||||
"""
|
|
||||||
file_map: dict[str, str] = {
|
|
||||||
entry["path"]: _read_data_file(entry["content_file"], fixtures_dir)
|
|
||||||
for entry in directory_files
|
|
||||||
}
|
|
||||||
|
|
||||||
async def _executor(payload: dict) -> dict:
|
|
||||||
action = payload.get("action", "")
|
|
||||||
data = payload.get("data") or {}
|
|
||||||
|
|
||||||
if action == "list_directory":
|
|
||||||
return {"entries": [
|
|
||||||
{"type": "file", "name": p.split("/")[-1], "path": p}
|
|
||||||
for p in file_map
|
|
||||||
]}
|
|
||||||
|
|
||||||
if action == "read_file_content":
|
|
||||||
path = data.get("path", "")
|
|
||||||
return {"content": file_map.get(path, "")}
|
|
||||||
|
|
||||||
if action == "get_file_metadata":
|
|
||||||
path = data.get("path", "")
|
|
||||||
name = path.split("/")[-1]
|
|
||||||
ext = "." + name.rsplit(".", 1)[-1] if "." in name else ""
|
|
||||||
return {"name": name, "extension": ext, "size": 1024,
|
|
||||||
"createdAt": None, "modifiedAt": None}
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return _executor
|
|
||||||
|
|
||||||
|
|
||||||
# ── Journey runner helper ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_journey(user_id: str, case: dict, executor) -> dict[str, Any]:
|
|
||||||
"""Drive start + all user_messages for a case. Returns the final reply dict.
|
|
||||||
|
|
||||||
Mirrors ``device_ws._handle_journey_start/message``: sets the client
|
|
||||||
executor (so filesystem tools work) before each handler call.
|
|
||||||
"""
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
try:
|
|
||||||
set_client_executor(executor)
|
|
||||||
reply = await handle_journey_start(user_id, {
|
|
||||||
"agent_type": "local",
|
|
||||||
"directory": case["directory"],
|
|
||||||
"data_types": case["data_types"],
|
|
||||||
"session_id": session_id,
|
|
||||||
})
|
|
||||||
|
|
||||||
for msg in case.get("user_messages", []):
|
|
||||||
if reply.get("done"):
|
|
||||||
break
|
|
||||||
set_client_executor(executor)
|
|
||||||
reply = await handle_journey_message(user_id, {
|
|
||||||
"session_id": reply["session_id"],
|
|
||||||
"message": msg,
|
|
||||||
})
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
|
|
||||||
return reply
|
|
||||||
|
|
||||||
|
|
||||||
# ── Assertion helper ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_case(case: dict, reply: dict) -> tuple[float, str]:
|
|
||||||
"""Return (score, comment) for a journey case given the final reply dict."""
|
|
||||||
if case.get("expect_question"):
|
|
||||||
has_q = "?" in reply.get("message", "")
|
|
||||||
return (1.0 if has_q else 0.0), f"first_reply_has_question={has_q}"
|
|
||||||
|
|
||||||
return 1.0, "no specific assertion"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit tests ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_4_6a_extract_valid_json():
|
|
||||||
"""_extract_agent_config: valid JSON between markers → returns serialised config."""
|
|
||||||
config = AgentConfig(
|
|
||||||
content_types=[],
|
|
||||||
global_rules=["No project = no entity"],
|
|
||||||
data_types=["tasks"],
|
|
||||||
)
|
|
||||||
text = f"Some preamble\n{_CONFIG_START}\n{config.model_dump_json()}\n{_CONFIG_END}\nTrailing"
|
|
||||||
result = _extract_agent_config(text)
|
|
||||||
assert result is not None
|
|
||||||
parsed = AgentConfig.model_validate_json(result)
|
|
||||||
assert parsed.global_rules == ["No project = no entity"]
|
|
||||||
|
|
||||||
|
|
||||||
def test_4_6b_extract_invalid_json():
|
|
||||||
"""_extract_agent_config: malformed JSON between markers → returns None."""
|
|
||||||
text = f"{_CONFIG_START}\n{{not: valid json\n{_CONFIG_END}"
|
|
||||||
assert _extract_agent_config(text) is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_4_6c_extract_markers_absent():
|
|
||||||
"""_extract_agent_config: no markers at all → returns None."""
|
|
||||||
assert _extract_agent_config("No markers here at all") is None
|
|
||||||
|
|
||||||
|
|
||||||
def test_4_6d_extract_only_start_marker():
|
|
||||||
"""_extract_agent_config: START without END → returns None."""
|
|
||||||
assert _extract_agent_config(f"text {_CONFIG_START} no end marker") is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_4_6e_session_not_found():
|
|
||||||
"""4.6e Session not found → done=True, agent_config=None, informative message."""
|
|
||||||
reply = await handle_journey_message(_USER_ID, {
|
|
||||||
"session_id": "nonexistent-session-id",
|
|
||||||
"message": "Hello",
|
|
||||||
})
|
|
||||||
assert reply["done"] is True
|
|
||||||
assert reply["agent_config"] is None
|
|
||||||
assert "not found" in reply["message"].lower() or "expired" in reply["message"].lower()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_4_6f_nudge_uses_new_markers():
|
|
||||||
"""4.6f Nudge injected after max turns uses AGENT_CONFIG markers, not PROMPT_TEMPLATE."""
|
|
||||||
session_id = str(uuid.uuid4())
|
|
||||||
captured_histories: list[list[dict]] = []
|
|
||||||
|
|
||||||
async def _mock_llm(system_prompt, history, tools, **kwargs) -> str:
|
|
||||||
captured_histories.append(list(history))
|
|
||||||
# Return plain text — no markers — to trigger the nudge path.
|
|
||||||
return "I still need more information from you."
|
|
||||||
|
|
||||||
from app.api.routes.agent_setup import JourneySession
|
|
||||||
|
|
||||||
fake_session = JourneySession(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=_USER_ID,
|
|
||||||
agent_type="local",
|
|
||||||
directory="/test",
|
|
||||||
data_types=["tasks"],
|
|
||||||
system_prompt="system",
|
|
||||||
langfuse_prompt=None,
|
|
||||||
)
|
|
||||||
# Fill history to the turn limit so the next message triggers the nudge.
|
|
||||||
for i in range(_MAX_TURNS):
|
|
||||||
fake_session.history.append({"role": "user", "content": f"msg {i}"})
|
|
||||||
fake_session.history.append({"role": "assistant", "content": "ok"})
|
|
||||||
_sessions[session_id] = fake_session
|
|
||||||
|
|
||||||
try:
|
|
||||||
with patch("app.api.routes.agent_setup._call_llm_with_tools", side_effect=_mock_llm):
|
|
||||||
await handle_journey_message(_USER_ID, {
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": "one more message to trigger nudge",
|
|
||||||
})
|
|
||||||
finally:
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
|
|
||||||
# Second LLM call receives the nudge appended to history.
|
|
||||||
assert len(captured_histories) >= 2, "Expected ≥ 2 LLM calls (main reply + nudge)"
|
|
||||||
nudge_history = captured_histories[1]
|
|
||||||
user_msgs = " ".join(t["content"] for t in nudge_history if t["role"] == "user")
|
|
||||||
assert _CONFIG_START in user_msgs, f"Nudge must reference {_CONFIG_START}"
|
|
||||||
assert _CONFIG_END in user_msgs, f"Nudge must reference {_CONFIG_END}"
|
|
||||||
assert "PROMPT_TEMPLATE" not in user_msgs, "Old PROMPT_TEMPLATE markers must not appear in nudge"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Eval tests (real LLM + Langfuse) ─────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.eval
|
|
||||||
async def test_eval_journey(journey_case, pytestconfig):
|
|
||||||
"""Parametrized eval test — one invocation per YAML case."""
|
|
||||||
case: dict = journey_case
|
|
||||||
fixtures_dir = _fixtures_dir(pytestconfig)
|
|
||||||
executor = _make_fs_executor(case.get("directory_files", []), fixtures_dir)
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
obs_ctx = lf.start_as_current_observation(
|
|
||||||
name=f"eval-journey-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
|
|
||||||
metadata={"step": "4", "case_id": case["id"]},
|
|
||||||
) if lf else nullcontext()
|
|
||||||
|
|
||||||
with obs_ctx as obs:
|
|
||||||
reply = await _run_journey(_USER_ID, case, executor)
|
|
||||||
score, comment = _evaluate_case(case, reply)
|
|
||||||
|
|
||||||
if obs is not None:
|
|
||||||
obs.score(
|
|
||||||
name=case.get("score_name", f"journey.case_{case['id']}"),
|
|
||||||
value=score,
|
|
||||||
comment=comment,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lf:
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
assert score == 1.0, f"[{case['id']}] {case.get('description', '')} — {comment}"
|
|
||||||
@@ -110,32 +110,6 @@ async def test_enrich_context_returns_episodic_memory(db_session, user_with_key)
|
|||||||
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_enrich_context_filters_episodic_by_session_id(db_session, user_with_key):
|
|
||||||
target_session = str(uuid.uuid4())
|
|
||||||
other_session = str(uuid.uuid4())
|
|
||||||
db_session.add(MemoryEpisodic(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=USER_ID,
|
|
||||||
summary_encrypted=_enc("Target session memory"),
|
|
||||||
session_id=target_session,
|
|
||||||
))
|
|
||||||
db_session.add(MemoryEpisodic(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=USER_ID,
|
|
||||||
summary_encrypted=_enc("Other session memory"),
|
|
||||||
session_id=other_session,
|
|
||||||
))
|
|
||||||
await db_session.commit()
|
|
||||||
|
|
||||||
middleware = MemoryMiddleware(db_session)
|
|
||||||
ctx = await middleware.enrich_context(USER_ID, "any message", session_id=target_session)
|
|
||||||
|
|
||||||
episodic = ctx.get("episodic_memory", [])
|
|
||||||
assert any("Target session" in s for s in episodic)
|
|
||||||
assert not any("Other session" in s for s in episodic)
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
# Add one pattern above threshold and one below
|
# Add one pattern above threshold and one below
|
||||||
@@ -255,40 +229,6 @@ async def test_update_core_upsert(db_session, user_with_key):
|
|||||||
assert _dec(rows[0].value_encrypted) == "fr"
|
assert _dec(rows[0].value_encrypted) == "fr"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_core_block_edit_ops(db_session, user_with_key):
|
|
||||||
middleware = MemoryMiddleware(db_session)
|
|
||||||
|
|
||||||
await middleware.update_core(USER_ID, "human", "Name: Roberto")
|
|
||||||
await middleware.append_core(USER_ID, "human", "Timezone: Europe/Rome")
|
|
||||||
replaced = await middleware.replace_core(USER_ID, "human", "Roberto", "Robert")
|
|
||||||
|
|
||||||
blocks = await middleware.list_core_blocks(USER_ID)
|
|
||||||
human = next(b for b in blocks if b["label"] == "human")
|
|
||||||
|
|
||||||
assert replaced is True
|
|
||||||
assert "Name: Robert" in human["value"]
|
|
||||||
assert "Timezone: Europe/Rome" in human["value"]
|
|
||||||
|
|
||||||
deleted = await middleware.delete_core(USER_ID, "human")
|
|
||||||
assert deleted is True
|
|
||||||
assert await middleware.get_core_block(USER_ID, "human") is None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_archival_and_recall_search_helpers(db_session, user_with_key):
|
|
||||||
middleware = MemoryMiddleware(db_session)
|
|
||||||
|
|
||||||
await middleware.insert_archival(USER_ID, "Project whitelist has release risk", source="assistant")
|
|
||||||
await middleware.store_episode(USER_ID, str(uuid.uuid4()), "How is whitelist?", "Whitelist is delayed")
|
|
||||||
|
|
||||||
arch = await middleware.search_archival(USER_ID, "whitelist", top_k=3)
|
|
||||||
rec = await middleware.search_recall(USER_ID, "delayed", top_k=3)
|
|
||||||
|
|
||||||
assert any("whitelist" in item.lower() for item in arch)
|
|
||||||
assert any("delayed" in item.lower() for item in rec)
|
|
||||||
|
|
||||||
|
|
||||||
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
||||||
|
|
||||||
def test_home_request_calls_memory_middleware(client):
|
def test_home_request_calls_memory_middleware(client):
|
||||||
@@ -300,20 +240,21 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
def __init__(self, db):
|
def __init__(self, db):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def enrich_context(self, user_id, message, **kwargs):
|
async def enrich_context(self, user_id, message):
|
||||||
enrich_calls.append((user_id, message))
|
enrich_calls.append((user_id, message))
|
||||||
return {"core_memory": {"tz": "UTC"}}
|
return {"core_memory": {"tz": "UTC"}}
|
||||||
|
|
||||||
async def store_episode(self, user_id, session_id, message, response, **kwargs):
|
async def store_episode(self, user_id, session_id, message, response):
|
||||||
store_calls.append((user_id, session_id, message, response))
|
store_calls.append((user_id, session_id, message, response))
|
||||||
|
|
||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
session_id = str(uuid.uuid4())
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
async def _mock_stream(user_id, message, context):
|
async def _mock_stream(user_id, message, context, db_session_factory=None):
|
||||||
# Verify memory context was injected
|
# Verify memory context was injected
|
||||||
assert context.get("core_memory") == {"tz": "UTC"}
|
assert context.get("core_memory") == {"tz": "UTC"}
|
||||||
yield "token", "Done"
|
yield ("token", "Done")
|
||||||
|
yield ("mutations", [])
|
||||||
|
|
||||||
with (
|
with (
|
||||||
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
||||||
|
|||||||
@@ -1,82 +1,214 @@
|
|||||||
"""Tests for app.core.output_formatter.StreamFormatter."""
|
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from app.core.output_formatter import StreamFormatter
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _stream(*events: tuple[str, object]):
|
async def _stream(*events: tuple[str, object]):
|
||||||
|
"""Async generator that yields (event_type, data) tuples."""
|
||||||
for event in events:
|
for event in events:
|
||||||
yield event
|
yield event
|
||||||
|
|
||||||
|
|
||||||
async def _collect(formatter: StreamFormatter, event_stream):
|
async def collect(formatter, event_stream):
|
||||||
frames = []
|
frames = []
|
||||||
async for frame in formatter.format(event_stream):
|
async for frame in formatter.format(event_stream):
|
||||||
frames.append(frame)
|
frames.append(frame)
|
||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_formatter_text_stream() -> None:
|
async def test_home_formatter_plain_text():
|
||||||
formatter = StreamFormatter(request_id="req-1")
|
req_id = "req-1"
|
||||||
frames = await _collect(
|
events = [
|
||||||
formatter,
|
("token", "Hello world"),
|
||||||
_stream(("token", "Hello"), ("token", " world")),
|
("mutations", []),
|
||||||
)
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
assert isinstance(frames[1], WsStreamText)
|
assert frames[0].request_id == req_id
|
||||||
assert frames[1].chunk == "Hello"
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
assert isinstance(frames[2], WsStreamText)
|
assert any("Hello world" in f.chunk for f in text_frames)
|
||||||
assert frames[2].chunk == " world"
|
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_formatter_floating_domain_first() -> None:
|
async def test_home_formatter_entity_tags_passed_through():
|
||||||
formatter = StreamFormatter(request_id="req-2")
|
"""Entity tags are streamed as-is — the frontend parses them."""
|
||||||
frames = await _collect(
|
req_id = "req-2"
|
||||||
formatter,
|
events = [
|
||||||
_stream(
|
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
||||||
(
|
("mutations", []),
|
||||||
"floating_domain",
|
]
|
||||||
{"type": "node", "id": "n-1", "section": None},
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
),
|
frames = await collect(formatter, _stream(*events))
|
||||||
("token", "Summary"),
|
|
||||||
),
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
)
|
assert "<project>[abc-123]</project>" in text
|
||||||
|
assert "Here is your project:" in text
|
||||||
|
assert "All good." in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_multiple_tags_passed_through():
|
||||||
|
req_id = "req-3"
|
||||||
|
events = [
|
||||||
|
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert "<project>[p1]</project>" in text
|
||||||
|
assert "<task>[t1,t2]</task>" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_tool_end_ignored():
|
||||||
|
"""tool_end events are silently ignored by HomeFormatter."""
|
||||||
|
req_id = "req-4"
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
||||||
|
("token", "No tags here."),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert text == "No tags here."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_mutations_in_stream_end():
|
||||||
|
req_id = "req-5"
|
||||||
|
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
||||||
|
events = [
|
||||||
|
("token", "Done"),
|
||||||
|
("mutations", muts),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
end_frame = frames[-1]
|
||||||
|
assert isinstance(end_frame, WsStreamEnd)
|
||||||
|
assert len(end_frame.mutations) == 1
|
||||||
|
assert end_frame.mutations[0]["action"] == "insert"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_frame_order():
|
||||||
|
"""stream_start is first, stream_end is last."""
|
||||||
|
req_id = "req-6"
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_domain_from_tool_end():
|
||||||
|
req_id = "pop-1"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "task_agent", "result": "ok"}),
|
||||||
|
("token", "Hello"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
assert isinstance(frames[0], WsFloatingDomain)
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
assert frames[0].domain.type == "node"
|
assert frames[0].domain == "tasks"
|
||||||
assert frames[0].domain.id == "n-1"
|
assert frames[0].request_id == req_id
|
||||||
assert isinstance(frames[1], WsStreamStart)
|
|
||||||
assert isinstance(frames[2], WsStreamText)
|
|
||||||
assert frames[2].chunk == "Summary"
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_text_only():
|
||||||
|
req_id = "pop-2"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "timeline_agent", "result": "done"}),
|
||||||
|
("token", "Summary"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "timelines"
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert len(text_frames) == 1
|
||||||
|
assert text_frames[0].chunk == "Summary"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_no_entity_tags():
|
||||||
|
"""FloatingFormatter never emits entity tag blocks."""
|
||||||
|
req_id = "pop-3"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "note_agent", "result": "data"}),
|
||||||
|
("token", "some text"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
# Only expected frame types
|
||||||
|
for f in frames:
|
||||||
|
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_end_frame():
|
||||||
|
req_id = "pop-4"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "project_agent", "result": "ok"}),
|
||||||
|
("token", "Done"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
assert isinstance(frames[-1], WsStreamEnd)
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_formatter_ignores_unknown_events() -> None:
|
async def test_floating_formatter_default_domain_on_early_token():
|
||||||
formatter = StreamFormatter(request_id="req-3")
|
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
|
||||||
frames = await _collect(
|
req_id = "pop-5"
|
||||||
formatter,
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
_stream(("tool_end", {"name": "x"}), ("token", "ok")),
|
events = [("token", "hi"), ("mutations", [])]
|
||||||
)
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
assert frames[0].domain == "tasks"
|
||||||
assert len(text_frames) == 1
|
|
||||||
assert text_frames[0].chunk == "ok"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_stream_formatter_empty_stream_still_brackets() -> None:
|
async def test_floating_formatter_mutations_in_stream_end():
|
||||||
formatter = StreamFormatter(request_id="req-4")
|
req_id = "pop-6"
|
||||||
frames = await _collect(formatter, _stream())
|
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
|
||||||
|
events = [
|
||||||
|
("token", "Updated"),
|
||||||
|
("mutations", muts),
|
||||||
|
]
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
assert len(frames) == 2
|
end_frame = frames[-1]
|
||||||
assert isinstance(frames[0], WsStreamStart)
|
assert isinstance(end_frame, WsStreamEnd)
|
||||||
assert isinstance(frames[1], WsStreamEnd)
|
assert len(end_frame.mutations) == 1
|
||||||
|
|||||||
400
tests/test_plugins.py
Normal file
400
tests/test_plugins.py
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
|
||||||
|
- ReviewQueue: pending queue, review decisions, manifest security checklist
|
||||||
|
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
|
||||||
|
- Route integration: tier gate, list/get/install/uninstall via TestClient
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import PluginRegistry
|
||||||
|
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
|
||||||
|
from app.marketplace.revenue_share import RevenueShare
|
||||||
|
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
|
||||||
|
from app.schemas import PluginManifest
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _fresh_manifest(
|
||||||
|
plugin_id: str | None = None,
|
||||||
|
category: str = "productivity",
|
||||||
|
price_cents: int = 0,
|
||||||
|
permissions: list[str] | None = None,
|
||||||
|
) -> PluginManifest:
|
||||||
|
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
|
||||||
|
return PluginManifest(
|
||||||
|
id=pid,
|
||||||
|
name=f"Plugin {pid}",
|
||||||
|
description=f"Description for {pid}",
|
||||||
|
version="1.0.0",
|
||||||
|
author="test-author",
|
||||||
|
permissions=permissions or ["read:tasks"],
|
||||||
|
category=category,
|
||||||
|
price_cents=price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PluginRegistry (DB-backed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginRegistry:
|
||||||
|
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reg(self) -> PluginRegistry:
|
||||||
|
return PluginRegistry()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_seed_plugins_are_listed(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
result = await reg.list_plugins(db_session)
|
||||||
|
assert result.total == 3
|
||||||
|
assert all(p.id.startswith("plugin-") for p in result.plugins)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_approved_only(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
|
||||||
|
result = await reg.list_plugins(db_session)
|
||||||
|
ids = [p.id for p in result.plugins]
|
||||||
|
assert manifest.id not in ids # still pending
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_filter_by_category(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
result = await reg.list_plugins(db_session, category="communication")
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.plugins[0].id == "plugin-slack-notify"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_filter_by_query(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
result = await reg.list_plugins(db_session, query="time tracker")
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.plugins[0].id == "plugin-time-tracker"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_sort_by_installs(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await reg.record_install(db_session, "plugin-slack-notify")
|
||||||
|
await reg.record_install(db_session, "plugin-slack-notify")
|
||||||
|
result = await reg.list_plugins(db_session, sort="installs")
|
||||||
|
assert result.plugins[0].id == "plugin-slack-notify"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_plugin_found(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["manifest"].id == "plugin-github-sync"
|
||||||
|
assert "install_count" in entry
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_plugin_not_found(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
entry = await reg.get_plugin(db_session, "no-such-plugin")
|
||||||
|
assert entry is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_sets_pending(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
assert plugin_id == manifest.id
|
||||||
|
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "pending_review"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_approve_makes_visible(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
await reg.approve_plugin(db_session, manifest.id)
|
||||||
|
result = await reg.list_plugins(db_session)
|
||||||
|
assert manifest.id in [p.id for p in result.plugins]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_stores_reason(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
|
||||||
|
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "rejected"
|
||||||
|
assert row.rejection_reason == "Unsafe permissions"
|
||||||
|
listed = await reg.list_plugins(db_session)
|
||||||
|
assert manifest.id not in [p.id for p in listed.plugins]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_approve_unknown_raises_key_error(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await reg.approve_plugin(db_session, "ghost-plugin")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_increments_count(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await reg.record_install(db_session, "plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_uninstall_decrements_count(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await reg.record_install(db_session, "plugin-github-sync")
|
||||||
|
await reg.record_install(db_session, "plugin-github-sync")
|
||||||
|
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_uninstall_floors_at_zero(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ReviewQueue (DB-backed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReviewQueue:
|
||||||
|
@pytest.fixture
|
||||||
|
def reg(self) -> PluginRegistry:
|
||||||
|
return PluginRegistry()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def queue(self) -> ReviewQueue:
|
||||||
|
return ReviewQueue()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_returns_submitted_plugins(
|
||||||
|
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
pending = await queue.get_pending(db_session)
|
||||||
|
assert any(p["plugin_id"] == manifest.id for p in pending)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_review_approved(
|
||||||
|
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
|
||||||
|
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "approved"
|
||||||
|
# Check review row was persisted
|
||||||
|
review_result = await db_session.execute(
|
||||||
|
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
|
||||||
|
)
|
||||||
|
review = review_result.scalar_one()
|
||||||
|
assert review.decision == "approved"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_review_rejected(
|
||||||
|
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
await queue.submit_review(
|
||||||
|
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
|
||||||
|
)
|
||||||
|
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "rejected"
|
||||||
|
|
||||||
|
def test_validate_manifest_ok(self) -> None:
|
||||||
|
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
|
||||||
|
validate_manifest(manifest) # should not raise
|
||||||
|
|
||||||
|
def test_validate_manifest_unknown_permission(self) -> None:
|
||||||
|
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
|
||||||
|
with pytest.raises(ValueError, match="Unknown permission"):
|
||||||
|
validate_manifest(manifest)
|
||||||
|
|
||||||
|
def test_validate_manifest_invalid_id_format(self) -> None:
|
||||||
|
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
|
||||||
|
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||||
|
validate_manifest(manifest)
|
||||||
|
|
||||||
|
def test_validate_manifest_id_with_uppercase(self) -> None:
|
||||||
|
manifest = _fresh_manifest(plugin_id="UpperCase")
|
||||||
|
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||||
|
validate_manifest(manifest)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RevenueShare (DB-backed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevenueShare:
|
||||||
|
@pytest.fixture
|
||||||
|
def rs(self) -> RevenueShare:
|
||||||
|
return RevenueShare()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_free_plugin(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
|
||||||
|
)
|
||||||
|
event = result.scalar_one()
|
||||||
|
assert event.developer_share_cents == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_paid_plugin_no_stripe(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await rs.record_install(
|
||||||
|
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
|
||||||
|
)
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
|
||||||
|
)
|
||||||
|
event = result.scalar_one()
|
||||||
|
assert event.amount_cents == 499
|
||||||
|
assert event.developer_share_cents == int(499 * 0.70)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_increments_registry_count(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
reg = PluginRegistry()
|
||||||
|
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_earnings_empty(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
result = await rs.get_earnings(db_session, "unknown-dev")
|
||||||
|
assert result["total_installs"] == 0
|
||||||
|
assert result["total_revenue_cents"] == 0
|
||||||
|
assert result["developer_share_cents"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_earnings_aggregates(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
|
||||||
|
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
|
||||||
|
result = await rs.get_earnings(db_session, "Adiuva")
|
||||||
|
assert result["total_installs"] == 2
|
||||||
|
assert result["total_revenue_cents"] == 998
|
||||||
|
assert result["developer_share_cents"] == int(499 * 0.70) * 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Route integration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginRoutes:
|
||||||
|
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "plugins" in data
|
||||||
|
assert data["total"] == 3
|
||||||
|
|
||||||
|
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_get_plugin_found(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["plugin"]["id"] == "plugin-github-sync"
|
||||||
|
assert "install_count" in data
|
||||||
|
|
||||||
|
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_install_plugin_free(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
|
json={"plugin_id": "plugin-github-sync"},
|
||||||
|
headers=auth_header(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["ok"] is True
|
||||||
|
assert "download_url" in data
|
||||||
|
|
||||||
|
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/plugins/ghost/install",
|
||||||
|
json={"plugin_id": "ghost"},
|
||||||
|
headers=auth_header(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.delete(
|
||||||
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
|
headers=auth_header(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["ok"] is True
|
||||||
|
|
||||||
|
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
|
json={"plugin_id": "plugin-github-sync"},
|
||||||
|
headers=auth_header("free"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
"""Tests for the preprocessor system (Step 1 — Local Agent V2).
|
|
||||||
|
|
||||||
Run:
|
|
||||||
pytest tests/test_preprocessors.py -v
|
|
||||||
pytest tests/test_preprocessors.py -v --preprocess-dir /path/to/folder
|
|
||||||
|
|
||||||
The folder must contain cases.yaml + data/.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from app.core.preprocessors import detect_content_type, preprocess
|
|
||||||
|
|
||||||
_DEFAULT_DIR = Path(__file__).parent / "fixtures" / "preprocessors"
|
|
||||||
|
|
||||||
_GENERATORS = {
|
|
||||||
"binary_noise": "some\x00\x01\x02\x03\x04\x05content" * 20,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _fixtures_dir(config) -> Path:
|
|
||||||
override = config.getoption("--preprocess-dir")
|
|
||||||
return Path(override) if override else _DEFAULT_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cases(config) -> list[dict]:
|
|
||||||
return yaml.safe_load((_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
|
|
||||||
def _content(case: dict, data_dir: Path) -> str:
|
|
||||||
if "generate" in case:
|
|
||||||
return _GENERATORS[case["generate"]]
|
|
||||||
return (data_dir / case["file"]).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
# ── parametrize at collection time via pytest hook ────────────────────
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
if "preprocess_case" not in metafunc.fixturenames:
|
|
||||||
return
|
|
||||||
cases = _load_cases(metafunc.config)
|
|
||||||
test_name = metafunc.function.__name__
|
|
||||||
if test_name == "test_detect":
|
|
||||||
subset = [c for c in cases if "detect" in c]
|
|
||||||
else:
|
|
||||||
subset = [c for c in cases if "process" in c]
|
|
||||||
metafunc.parametrize("preprocess_case", subset, ids=[c["id"] for c in subset])
|
|
||||||
|
|
||||||
|
|
||||||
# ── detect ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_detect(preprocess_case, pytestconfig) -> None:
|
|
||||||
case = preprocess_case
|
|
||||||
data_dir = _fixtures_dir(pytestconfig) / "data"
|
|
||||||
raw = _content(case, data_dir)
|
|
||||||
filename = case.get("file", "")
|
|
||||||
ct = detect_content_type(filename, raw)
|
|
||||||
expected = case["detect"]
|
|
||||||
assert ct == expected, f"[{case['id']}] expected {expected!r}, got {ct!r}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── preprocess ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def test_preprocess(preprocess_case, pytestconfig) -> None:
|
|
||||||
case = preprocess_case
|
|
||||||
data_dir = _fixtures_dir(pytestconfig) / "data"
|
|
||||||
raw = _content(case, data_dir)
|
|
||||||
result = preprocess(case["process"], raw)
|
|
||||||
|
|
||||||
if case.get("no_html"):
|
|
||||||
assert not re.search(r"<[^>]+>", result.clean_text), "clean_text contains HTML tags"
|
|
||||||
|
|
||||||
if "min_chars" in case:
|
|
||||||
assert len(result.clean_text) >= case["min_chars"], \
|
|
||||||
f"clean_text too short: {len(result.clean_text)} < {case['min_chars']}"
|
|
||||||
|
|
||||||
if "ratio_lt" in case:
|
|
||||||
ratio = len(result.clean_text) / len(raw)
|
|
||||||
assert ratio < case["ratio_lt"], f"compression ratio {ratio:.2f} >= {case['ratio_lt']}"
|
|
||||||
|
|
||||||
for key in case.get("has_meta", []):
|
|
||||||
assert result.metadata.get(key), f"metadata missing {key!r} (got {result.metadata})"
|
|
||||||
|
|
||||||
for item in ([case["contains"]] if isinstance(case.get("contains"), str) else case.get("contains", [])):
|
|
||||||
assert item in result.clean_text, f"clean_text missing {item!r}"
|
|
||||||
|
|
||||||
for item in ([case["excludes"]] if isinstance(case.get("excludes"), str) else case.get("excludes", [])):
|
|
||||||
assert item not in result.clean_text, f"clean_text contains forbidden {item!r}"
|
|
||||||
|
|
||||||
if "content_type" in case:
|
|
||||||
assert result.content_type == case["content_type"], \
|
|
||||||
f"expected content_type {case['content_type']!r}, got {result.content_type!r}"
|
|
||||||
@@ -4,7 +4,6 @@ import pytest
|
|||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
|
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
WsDomain,
|
|
||||||
WsFrameType,
|
WsFrameType,
|
||||||
WsHomeRequest,
|
WsHomeRequest,
|
||||||
WsFloatingDomain,
|
WsFloatingDomain,
|
||||||
@@ -179,15 +178,23 @@ def test_stream_text_deserializes():
|
|||||||
def test_stream_end_defaults():
|
def test_stream_end_defaults():
|
||||||
frame = WsStreamEnd(request_id="r1")
|
frame = WsStreamEnd(request_id="r1")
|
||||||
assert frame.type == WsFrameType.stream_end
|
assert frame.type == WsFrameType.stream_end
|
||||||
|
assert frame.mutations == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_with_mutations():
|
||||||
|
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
||||||
|
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
||||||
|
assert len(frame.mutations) == 1
|
||||||
|
assert frame.mutations[0]["action"] == "create"
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_serializes():
|
def test_stream_end_serializes():
|
||||||
data = WsStreamEnd(request_id="r2").model_dump()
|
data = WsStreamEnd(request_id="r2").model_dump()
|
||||||
assert data == {"type": "stream_end", "request_id": "r2"}
|
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
||||||
|
|
||||||
|
|
||||||
def test_stream_end_deserializes():
|
def test_stream_end_deserializes():
|
||||||
raw = {"type": "stream_end", "request_id": "r3"}
|
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
||||||
frame = WsStreamEnd.model_validate(raw)
|
frame = WsStreamEnd.model_validate(raw)
|
||||||
assert frame.request_id == "r3"
|
assert frame.request_id == "r3"
|
||||||
|
|
||||||
@@ -196,47 +203,28 @@ def test_stream_end_deserializes():
|
|||||||
|
|
||||||
|
|
||||||
def test_floating_domain_tasks():
|
def test_floating_domain_tasks():
|
||||||
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
|
frame = WsFloatingDomain(request_id="r1", domain="tasks")
|
||||||
assert frame.type == WsFrameType.floating_domain
|
assert frame.type == WsFrameType.floating_domain
|
||||||
assert frame.domain.type == "task"
|
assert frame.domain == "tasks"
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_valid_domains():
|
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
|
||||||
frame = WsFloatingDomain(
|
def test_floating_domain_valid_domains(domain: str):
|
||||||
request_id="r1",
|
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
||||||
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
|
assert frame.domain == domain
|
||||||
)
|
|
||||||
assert frame.domain.type == "project"
|
|
||||||
assert frame.domain.id == "213213-312321-312312-421321"
|
|
||||||
assert frame.domain.section == "task"
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_object_valid():
|
def test_floating_domain_invalid():
|
||||||
frame = WsFloatingDomain(
|
with pytest.raises(ValidationError):
|
||||||
request_id="r1",
|
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
||||||
domain=WsDomain(type="project", id="p1", section="task"),
|
|
||||||
)
|
|
||||||
assert frame.domain.type == "project"
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_serializes():
|
def test_floating_domain_serializes():
|
||||||
d = WsFloatingDomain(
|
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
|
||||||
request_id="r1",
|
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
|
||||||
domain=WsDomain(type="timeline"),
|
|
||||||
).model_dump()
|
|
||||||
assert d == {
|
|
||||||
"type": "floating_domain",
|
|
||||||
"request_id": "r1",
|
|
||||||
"domain": {"type": "timeline", "id": None, "section": None},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def test_floating_domain_deserializes():
|
def test_floating_domain_deserializes():
|
||||||
raw = {
|
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
|
||||||
"type": "floating_domain",
|
|
||||||
"request_id": "r1",
|
|
||||||
"domain": {"type": "node", "id": "n-1", "section": None},
|
|
||||||
}
|
|
||||||
frame = WsFloatingDomain.model_validate(raw)
|
frame = WsFloatingDomain.model_validate(raw)
|
||||||
assert frame.domain.type == "node"
|
assert frame.domain == "projects"
|
||||||
assert frame.domain.id == "n-1"
|
|
||||||
|
|||||||
562
tests/test_storage.py
Normal file
562
tests/test_storage.py
Normal file
@@ -0,0 +1,562 @@
|
|||||||
|
"""Tests for the storage layer: encryption, BlobStore, VectorStore, and storage routes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import pytest
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
from app.storage.encryption import reject_if_tampered, verify_checksum
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.vector_store import VectorStore, _blob_to_vector
|
||||||
|
from app.schemas import VectorItem, VectorSearchResult
|
||||||
|
from tests.conftest import auth_header, S3_TEST_BUCKET
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_BLOB = b"encrypted-payload-opaque-to-server"
|
||||||
|
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||||
|
_BUCKET = S3_TEST_BUCKET
|
||||||
|
_REGION = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
def _pinecone_mock():
|
||||||
|
"""Return a mock Pinecone index with realistic return shapes."""
|
||||||
|
mock_index = MagicMock()
|
||||||
|
mock_index.query.return_value = {
|
||||||
|
"matches": [
|
||||||
|
{
|
||||||
|
"id": "v1",
|
||||||
|
"score": 0.95,
|
||||||
|
"metadata": {
|
||||||
|
"blob": base64.b64encode(b"result-blob").decode(),
|
||||||
|
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
|
||||||
|
"user_id": "u1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_pc = MagicMock()
|
||||||
|
mock_pc.return_value.Index.return_value = mock_index
|
||||||
|
return mock_pc, mock_index
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestEncryption ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncryption:
|
||||||
|
def test_verify_checksum_correct(self) -> None:
|
||||||
|
assert verify_checksum(_BLOB, _CHECKSUM) is True
|
||||||
|
|
||||||
|
def test_verify_checksum_wrong(self) -> None:
|
||||||
|
assert verify_checksum(_BLOB, "0" * 64) is False
|
||||||
|
|
||||||
|
def test_verify_checksum_empty_checksum(self) -> None:
|
||||||
|
assert verify_checksum(_BLOB, "") is False
|
||||||
|
|
||||||
|
def test_verify_checksum_empty_blob(self) -> None:
|
||||||
|
expected = hashlib.sha256(b"").hexdigest()
|
||||||
|
assert verify_checksum(b"", expected) is True
|
||||||
|
|
||||||
|
def test_verify_checksum_tampered_blob(self) -> None:
|
||||||
|
tampered = _BLOB + b"\x00"
|
||||||
|
assert verify_checksum(tampered, _CHECKSUM) is False
|
||||||
|
|
||||||
|
def test_reject_if_tampered_passes_when_valid(self) -> None:
|
||||||
|
# Should not raise
|
||||||
|
reject_if_tampered(_BLOB, _CHECKSUM)
|
||||||
|
|
||||||
|
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
reject_if_tampered(_BLOB, "bad" * 20)
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
reject_if_tampered(_BLOB, "bad" * 20)
|
||||||
|
assert "checksum" in exc_info.value.detail.lower()
|
||||||
|
|
||||||
|
def test_checksum_is_sha256_hex(self) -> None:
|
||||||
|
cs = hashlib.sha256(_BLOB).hexdigest()
|
||||||
|
assert len(cs) == 64
|
||||||
|
assert all(c in "0123456789abcdef" for c in cs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBlobStore ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlobStore:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
assert key == "u1/tasks/r1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
# Verify by downloading — no exception means object exists
|
||||||
|
retrieved = await store.download("u1", "u1/tasks/r1")
|
||||||
|
assert retrieved == _BLOB
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
|
||||||
|
result = await store.download("u1", "u1/notes/n1")
|
||||||
|
assert result == b"note-data"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_removes_object(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.delete("u1", "u1/tasks/r1")
|
||||||
|
with pytest.raises(ClientError) as exc_info:
|
||||||
|
await store.download("u1", "u1/tasks/r1")
|
||||||
|
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
# Delete a key that never existed — should not raise
|
||||||
|
await store.delete("u1", "u1/tasks/nonexistent")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
|
||||||
|
keys = await store.list_keys("u1", "tasks")
|
||||||
|
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
|
||||||
|
keys = await store.list_keys("u1", "tasks")
|
||||||
|
assert "u1/notes/n1" not in keys
|
||||||
|
assert "u1/tasks/r1" in keys
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
keys_u1 = await store.list_keys("u1", "tasks")
|
||||||
|
assert "u2/tasks/r1" not in keys_u1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
keys = await store.list_keys("u1", "tasks")
|
||||||
|
assert keys == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
# Verify S3 metadata was set — check via head_object
|
||||||
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||||
|
mock_settings.S3_BUCKET = _BUCKET
|
||||||
|
mock_settings.S3_REGION = _REGION
|
||||||
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||||
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||||
|
client = boto3.client("s3", region_name=_REGION)
|
||||||
|
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||||
|
assert response.get("ServerSideEncryption") == "AES256"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
client = boto3.client("s3", region_name=_REGION)
|
||||||
|
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||||
|
assert response["Metadata"]["checksum"] == _CHECKSUM
|
||||||
|
|
||||||
|
|
||||||
|
# ── _blob_to_vector helper ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlobToVector:
|
||||||
|
def test_returns_32_floats(self) -> None:
|
||||||
|
v = _blob_to_vector(b"test")
|
||||||
|
assert len(v) == 32
|
||||||
|
|
||||||
|
def test_all_values_in_range(self) -> None:
|
||||||
|
v = _blob_to_vector(b"test")
|
||||||
|
assert all(-1.0 <= x <= 1.0 for x in v)
|
||||||
|
|
||||||
|
def test_deterministic(self) -> None:
|
||||||
|
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
|
||||||
|
|
||||||
|
def test_different_blobs_different_vectors(self) -> None:
|
||||||
|
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestVectorStorePinecone ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorStorePinecone:
|
||||||
|
def _store(self) -> VectorStore:
|
||||||
|
store = VectorStore()
|
||||||
|
store._use_pinecone = lambda: True # type: ignore[method-assign]
|
||||||
|
return store
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_calls_index_upsert(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
mock_index.upsert.assert_called_once()
|
||||||
|
call_kwargs = mock_index.upsert.call_args[1]
|
||||||
|
assert call_kwargs.get("namespace") == "u1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
|
||||||
|
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_calls_index_query(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query-blob", top_k=5)
|
||||||
|
mock_index.query.assert_called_once()
|
||||||
|
query_kwargs = mock_index.query.call_args[1]
|
||||||
|
assert query_kwargs.get("namespace") == "u1"
|
||||||
|
assert query_kwargs.get("top_k") == 5
|
||||||
|
assert query_kwargs.get("include_metadata") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_returns_vector_search_results(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
results = await store.search("u1", b"query", top_k=10)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], VectorSearchResult)
|
||||||
|
assert results[0].id == "v1"
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
assert results[0].blob == b"result-blob"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_uses_derived_query_vector(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query-blob", top_k=3)
|
||||||
|
expected_vector = _blob_to_vector(b"query-blob")
|
||||||
|
actual_vector = mock_index.query.call_args[1].get("vector")
|
||||||
|
assert actual_vector == expected_vector
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_calls_index_delete(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
await store.delete("u1", ["v1", "v2"])
|
||||||
|
mock_index.delete.assert_called_once()
|
||||||
|
delete_kwargs = mock_index.delete.call_args[1]
|
||||||
|
assert delete_kwargs.get("namespace") == "u1"
|
||||||
|
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorStoreQdrant:
|
||||||
|
def _store(self) -> VectorStore:
|
||||||
|
store = VectorStore()
|
||||||
|
store._use_pinecone = lambda: False # type: ignore[method-assign]
|
||||||
|
return store
|
||||||
|
|
||||||
|
def _qdrant_mock(self) -> MagicMock:
|
||||||
|
mock_hit = MagicMock()
|
||||||
|
mock_hit.id = "v1"
|
||||||
|
mock_hit.score = 0.88
|
||||||
|
mock_hit.payload = {
|
||||||
|
"blob": base64.b64encode(b"qdrant-result").decode(),
|
||||||
|
"user_id": "u1",
|
||||||
|
}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.search.return_value = [mock_hit]
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_calls_client_upsert(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
mock_client.upsert.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_uses_correct_collection(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
call_kwargs = mock_client.upsert.call_args[1]
|
||||||
|
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_calls_client_search(self) -> None:
|
||||||
|
mock_client = self._qdrant_mock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query", top_k=5)
|
||||||
|
mock_client.search.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_passes_limit(self) -> None:
|
||||||
|
mock_client = self._qdrant_mock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query", top_k=7)
|
||||||
|
call_kwargs = mock_client.search.call_args[1]
|
||||||
|
assert call_kwargs.get("limit") == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_returns_vector_search_results(self) -> None:
|
||||||
|
mock_client = self._qdrant_mock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
results = await store.search("u1", b"query", top_k=5)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], VectorSearchResult)
|
||||||
|
assert results[0].id == "v1"
|
||||||
|
assert results[0].score == 0.88
|
||||||
|
assert results[0].blob == b"qdrant-result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_calls_client_delete(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.delete("u1", ["v1", "v2"])
|
||||||
|
mock_client.delete.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_uses_correct_collection(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.delete("u1", ["v1"])
|
||||||
|
call_kwargs = mock_client.delete.call_args[1]
|
||||||
|
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestStorageRoutes (integration) ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageRoutes:
|
||||||
|
"""Integration tests for POST/GET/PUT/DELETE /api/v1/storage/records.
|
||||||
|
|
||||||
|
Pydantic v2 converts JSON string → bytes via ``str.encode('utf-8')``.
|
||||||
|
So "hello" in JSON becomes ``b"hello"`` on the server. We use plain
|
||||||
|
ASCII strings as blob values and compute checksums accordingly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_BLOB_STR = "encrypted-payload-opaque-to-server"
|
||||||
|
_BLOB_BYTES = _BLOB_STR.encode()
|
||||||
|
_BLOB_CHECKSUM = hashlib.sha256(_BLOB_BYTES).hexdigest()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_payload(cls, blob_str: str | None = None) -> dict:
|
||||||
|
blob_str = blob_str or cls._BLOB_STR
|
||||||
|
checksum = hashlib.sha256(blob_str.encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"table": "tasks",
|
||||||
|
"blob": blob_str,
|
||||||
|
"checksum": checksum,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_record(self, client, tier="power", blob_str=None):
|
||||||
|
payload = self._create_payload(blob_str)
|
||||||
|
return client.post(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
json=payload,
|
||||||
|
headers=auth_header(tier),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Create ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_create_record(self, client, s3_bucket) -> None:
|
||||||
|
resp = self._create_record(client)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert "created_at" in data
|
||||||
|
|
||||||
|
def test_create_record_bad_checksum(self, client, s3_bucket) -> None:
|
||||||
|
payload = {
|
||||||
|
"table": "tasks",
|
||||||
|
"blob": self._BLOB_STR,
|
||||||
|
"checksum": "0" * 64,
|
||||||
|
}
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
json=payload,
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_create_record_free_tier_blocked(self, client, s3_bucket) -> None:
|
||||||
|
"""Free tier has cloud_storage_gb=0 → 402."""
|
||||||
|
resp = self._create_record(client, tier="free")
|
||||||
|
assert resp.status_code == 402
|
||||||
|
|
||||||
|
def test_create_record_pro_tier_allowed(self, client, s3_bucket) -> None:
|
||||||
|
"""Pro tier has cloud_storage_gb=5 → succeeds for small blob."""
|
||||||
|
resp = self._create_record(client, tier="pro")
|
||||||
|
assert resp.status_code == 201
|
||||||
|
|
||||||
|
# ── List ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_list_records(self, client, s3_bucket) -> None:
|
||||||
|
self._create_record(client)
|
||||||
|
self._create_record(client, blob_str="second-blob")
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 2
|
||||||
|
# Each entry has metadata, no blob bytes
|
||||||
|
for item in data:
|
||||||
|
assert "id" in item
|
||||||
|
assert "table" in item
|
||||||
|
assert "checksum" in item
|
||||||
|
assert "blob" not in item
|
||||||
|
|
||||||
|
def test_list_records_filter_by_table(self, client, s3_bucket) -> None:
|
||||||
|
self._create_record(client)
|
||||||
|
# Create in a different table
|
||||||
|
note_blob = "note-blob"
|
||||||
|
payload = {
|
||||||
|
"table": "notes",
|
||||||
|
"blob": note_blob,
|
||||||
|
"checksum": hashlib.sha256(note_blob.encode()).hexdigest(),
|
||||||
|
}
|
||||||
|
client.post(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
json=payload,
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records?table=notes",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["table"] == "notes"
|
||||||
|
|
||||||
|
def test_list_records_isolated_per_user(self, client, s3_bucket) -> None:
|
||||||
|
"""One user's records should not appear in another user's list."""
|
||||||
|
self._create_record(client, tier="power")
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
headers=auth_header("team"),
|
||||||
|
)
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
# ── Download ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_download_record(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
resp = client.get(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == self._BLOB_BYTES
|
||||||
|
assert resp.headers["X-Checksum"] == self._BLOB_CHECKSUM
|
||||||
|
|
||||||
|
def test_download_record_not_found(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records/nonexistent-id",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
# ── Update ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_update_record(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
new_blob_str = "updated-encrypted-payload"
|
||||||
|
new_checksum = hashlib.sha256(new_blob_str.encode()).hexdigest()
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
json={"blob": new_blob_str, "checksum": new_checksum},
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
# Verify download returns the updated blob
|
||||||
|
dl = client.get(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert dl.content == new_blob_str.encode()
|
||||||
|
|
||||||
|
def test_update_record_bad_checksum(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
json={"blob": "some-data", "checksum": "0" * 64},
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
# ── Delete ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_delete_record(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
resp = client.delete(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
# Subsequent GET should return 404
|
||||||
|
dl = client.get(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert dl.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_record_not_found(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.delete(
|
||||||
|
"/api/v1/storage/records/nonexistent",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
@@ -45,13 +45,15 @@ def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
|||||||
return frames
|
return frames
|
||||||
|
|
||||||
|
|
||||||
async def _mock_home_stream(user_id, message, context):
|
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
|
||||||
yield "token", "Hello"
|
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
|
||||||
async def _mock_floating_stream(user_id, message, context):
|
async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
|
||||||
yield "floating_domain", {"type": "task", "id": None, "section": None}
|
yield "tool_end", {"name": "task_agent", "result": "ok"}
|
||||||
yield "token", "Here is a summary"
|
yield "token", "Here is a summary"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
|
||||||
# ── tests ─────────────────────────────────────────────────────────────────────
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||||
@@ -102,7 +104,7 @@ def test_floating_request_produces_domain_frame(client):
|
|||||||
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
||||||
assert domain_frame["domain"]["type"] == "task"
|
assert domain_frame["domain"] == "tasks"
|
||||||
assert domain_frame["request_id"] == "p1"
|
assert domain_frame["request_id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
@@ -111,8 +113,9 @@ def test_home_request_request_id_propagated(client):
|
|||||||
token = make_jwt("power", user_id=USER_ID)
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
req_id = "my-unique-req-id"
|
req_id = "my-unique-req-id"
|
||||||
|
|
||||||
async def _stream(user_id, message, context):
|
async def _stream(user_id, message, context, db_session_factory=None):
|
||||||
yield "token", "ok"
|
yield "token", "ok"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
||||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
|||||||
Reference in New Issue
Block a user