Compare commits
23 Commits
71fd1a0a7c
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 06de7c7ab0 | |||
| e3c7547c75 | |||
| 314780d59a | |||
| 091787a6da | |||
| 7f278c6f63 | |||
| 8bfce9da00 | |||
| 480e7ac5bd | |||
| d0b303e745 | |||
| 5d485b3665 | |||
| 9787befd4a | |||
| 8f7bc25611 | |||
| 3e07fff958 | |||
| 9119474e71 | |||
| 4c4df7335a | |||
| c8ef7b119b | |||
| 35dd9ac86f | |||
| e72d72f4f6 | |||
| 14d1a7351d | |||
| 68955d2fc2 | |||
| 864dfdc4e6 | |||
| 0d16729036 | |||
| 82669d3704 | |||
| 4d0917f5df |
44
.env.example
Normal file
44
.env.example
Normal file
@@ -0,0 +1,44 @@
|
|||||||
|
# ── Application ──────────────────────────────────────────────────────────────
|
||||||
|
ENV=dev
|
||||||
|
|
||||||
|
# ── Database ──────────────────────────────────────────────────────────────────
|
||||||
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
||||||
|
|
||||||
|
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||||
|
JWT_SECRET=replace-with-a-long-random-secret
|
||||||
|
JWT_ALGORITHM=HS256
|
||||||
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
|
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||||
|
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||||
|
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
GOOGLE_API_KEY=
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
|
STRIPE_SECRET_KEY=
|
||||||
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
|
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
||||||
|
S3_BUCKET=adiuva
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
S3_ENDPOINT_URL=
|
||||||
|
AWS_ACCESS_KEY_ID=
|
||||||
|
AWS_SECRET_ACCESS_KEY=
|
||||||
|
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
|
||||||
|
# ── Vector Store ──────────────────────────────────────────────────────────────
|
||||||
|
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
|
||||||
|
PINECONE_API_KEY=
|
||||||
|
PINECONE_INDEX=adiuva
|
||||||
|
QDRANT_URL=
|
||||||
|
QDRANT_API_KEY=
|
||||||
|
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
||||||
|
|
||||||
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
|
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||||
93
.gitea/workflows/deploy.yaml
Normal file
93
.gitea/workflows/deploy.yaml
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
name: Test & Deploy API
|
||||||
|
run-name: ${{ gitea.ref_name }} → Docker LXC
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
tags:
|
||||||
|
- 'v*'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
# ── 1. Run tests in an isolated Python container ──────────────────
|
||||||
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: python:3.12-slim
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Install git
|
||||||
|
run: apt-get update && apt-get install -y --no-install-recommends git
|
||||||
|
|
||||||
|
- name: Checkout Code
|
||||||
|
run: |
|
||||||
|
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
||||||
|
"http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . || \
|
||||||
|
git clone --depth 1 "http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . && \
|
||||||
|
git checkout "${GITHUB_SHA}"
|
||||||
|
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
- name: Run Linter
|
||||||
|
run: ruff check app/ tests/
|
||||||
|
|
||||||
|
- name: Run Tests
|
||||||
|
run: pytest tests/ -v --tb=short
|
||||||
|
|
||||||
|
# ── 2. Deploy to Docker LXC via SSH ─────────────────────────────────
|
||||||
|
deploy:
|
||||||
|
needs: test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: gitea.event_name == 'push'
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Deploy via SSH
|
||||||
|
uses: appleboy/ssh-action@v1.0.0
|
||||||
|
with:
|
||||||
|
host: ${{ secrets.SSH_HOST }}
|
||||||
|
username: ${{ secrets.SSH_USER }}
|
||||||
|
key: ${{ secrets.SSH_KEY }}
|
||||||
|
script: |
|
||||||
|
set -e
|
||||||
|
DEPLOY_DIR="/opt/adiuva-api"
|
||||||
|
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||||
|
TAG="${{ gitea.ref_name }}"
|
||||||
|
|
||||||
|
# ── Pull latest code ──
|
||||||
|
cd /tmp && rm -rf adiuva-api-deploy
|
||||||
|
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
|
||||||
|
|
||||||
|
# ── Sync source (preserve .env) ──
|
||||||
|
cp -rf /tmp/adiuva-api-deploy/app/ \
|
||||||
|
/tmp/adiuva-api-deploy/alembic/ \
|
||||||
|
/tmp/adiuva-api-deploy/alembic.ini \
|
||||||
|
/tmp/adiuva-api-deploy/Dockerfile \
|
||||||
|
/tmp/adiuva-api-deploy/docker-compose.yml \
|
||||||
|
/tmp/adiuva-api-deploy/requirements.txt \
|
||||||
|
"$DEPLOY_DIR/"
|
||||||
|
rm -rf /tmp/adiuva-api-deploy
|
||||||
|
|
||||||
|
# ── Verify .env ──
|
||||||
|
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||||
|
echo "❌ $DEPLOY_DIR/.env not found. Create it before deploying."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Build & restart ──
|
||||||
|
cd "$DEPLOY_DIR"
|
||||||
|
docker compose down --remove-orphans || true
|
||||||
|
docker compose up -d --build
|
||||||
|
|
||||||
|
# ── Migrations ──
|
||||||
|
docker compose exec -T app alembic upgrade head
|
||||||
|
|
||||||
|
# ── Health check ──
|
||||||
|
echo "Waiting for app..."
|
||||||
|
sleep 5
|
||||||
|
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/api/v1/health)
|
||||||
|
if [ "$HTTP_CODE" -eq 200 ]; then
|
||||||
|
echo "✅ API is healthy (HTTP ${HTTP_CODE})"
|
||||||
|
else
|
||||||
|
echo "❌ Health check failed (HTTP ${HTTP_CODE})"
|
||||||
|
docker compose logs app --tail=50
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
64
.github/workflows/ci.yml
vendored
Normal file
64
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install ruff
|
||||||
|
run: pip install ruff>=0.8.0
|
||||||
|
|
||||||
|
- name: Ruff check
|
||||||
|
run: ruff check .
|
||||||
|
|
||||||
|
- name: Ruff format check
|
||||||
|
run: ruff format --check .
|
||||||
|
|
||||||
|
test:
|
||||||
|
name: Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Cache pip
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip
|
||||||
|
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
|
||||||
|
restore-keys: ${{ runner.os }}-pip-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -r requirements.txt
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: pytest -v --tb=short
|
||||||
|
|
||||||
|
docker:
|
||||||
|
name: Docker Build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: test
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Build image
|
||||||
|
run: docker build -t adiuva-api:ci .
|
||||||
|
|
||||||
|
- name: Verify gunicorn installed
|
||||||
|
run: docker run --rm adiuva-api:ci gunicorn --version
|
||||||
33
.gitignore
vendored
Normal file
33
.gitignore
vendored
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
# Python
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*.egg-info/
|
||||||
|
dist/
|
||||||
|
build/
|
||||||
|
|
||||||
|
# Virtual environment
|
||||||
|
.venv/
|
||||||
|
venv/
|
||||||
|
env/
|
||||||
|
|
||||||
|
# Environment variables
|
||||||
|
.env
|
||||||
|
|
||||||
|
# IDE
|
||||||
|
.vscode/
|
||||||
|
.idea/
|
||||||
|
|
||||||
|
# Testing / coverage
|
||||||
|
.pytest_cache/
|
||||||
|
htmlcov/
|
||||||
|
.coverage
|
||||||
|
|
||||||
|
# Docker
|
||||||
|
*.log
|
||||||
|
|
||||||
|
# OS
|
||||||
|
.DS_Store
|
||||||
|
Thumbs.db
|
||||||
|
|
||||||
|
# Claude Code
|
||||||
|
.claude/
|
||||||
365
BACKEND_PLAN.md
365
BACKEND_PLAN.md
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
||||||
>
|
>
|
||||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, and backup blob storage.
|
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
|
||||||
> The backend NEVER persists user data. It receives context in requests, uses it for orchestration, and discards it.
|
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ adiuva-api/
|
|||||||
│ │ ├── orchestrator.py # LLM-based intent router
|
│ │ ├── orchestrator.py # LLM-based intent router
|
||||||
│ │ ├── execution_plan.py # Plan builder + cache
|
│ │ ├── execution_plan.py # Plan builder + cache
|
||||||
│ │ └── plugin_loader.py # Dynamic agent loading
|
│ │ └── plugin_loader.py # Dynamic agent loading
|
||||||
│ ├── agents/
|
│ ├── agents/ # Chat agents (proprietary logic + prompts)
|
||||||
│ │ ├── __init__.py # Auto-registers all agents
|
│ │ ├── __init__.py # Auto-registers all agents
|
||||||
│ │ ├── task_agent.py
|
│ │ ├── task_agent.py
|
||||||
│ │ ├── calendar_agent.py
|
│ │ ├── calendar_agent.py
|
||||||
@@ -32,7 +32,10 @@ adiuva-api/
|
|||||||
│ │ │ ├── __init__.py
|
│ │ │ ├── __init__.py
|
||||||
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
||||||
│ │ │ ├── plans.py # GET /plans/playbook
|
│ │ │ ├── plans.py # GET /plans/playbook
|
||||||
|
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
|
||||||
|
│ │ │ ├── vectors.py # Upsert/search cloud vector store
|
||||||
│ │ │ ├── backup.py # PUT/GET /backup
|
│ │ │ ├── backup.py # PUT/GET /backup
|
||||||
|
│ │ │ ├── plugins.py # Plugin marketplace
|
||||||
│ │ │ ├── auth.py # Register/login/refresh
|
│ │ │ ├── auth.py # Register/login/refresh
|
||||||
│ │ │ └── billing.py # Checkout/webhook/subscription
|
│ │ │ └── billing.py # Checkout/webhook/subscription
|
||||||
│ │ └── middleware/
|
│ │ └── middleware/
|
||||||
@@ -40,6 +43,16 @@ adiuva-api/
|
|||||||
│ │ ├── auth.py # JWT validation
|
│ │ ├── auth.py # JWT validation
|
||||||
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
||||||
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
||||||
|
│ ├── storage/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
|
||||||
|
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
|
||||||
|
│ │ └── encryption.py # Integrity verification only — NO decryption
|
||||||
|
│ ├── marketplace/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
|
||||||
|
│ │ ├── plugin_review.py # Review queue + approval workflow
|
||||||
|
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
|
||||||
│ ├── billing/
|
│ ├── billing/
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
||||||
@@ -53,8 +66,10 @@ adiuva-api/
|
|||||||
│ ├── test_orchestrator.py
|
│ ├── test_orchestrator.py
|
||||||
│ ├── test_agents.py
|
│ ├── test_agents.py
|
||||||
│ ├── test_auth.py
|
│ ├── test_auth.py
|
||||||
│ └── test_backup.py
|
│ ├── test_backup.py
|
||||||
├── alembic/ # DB migrations (auth/billing tables only)
|
│ ├── test_storage.py
|
||||||
|
│ └── test_plugins.py
|
||||||
|
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
|
||||||
│ ├── alembic.ini
|
│ ├── alembic.ini
|
||||||
│ └── versions/
|
│ └── versions/
|
||||||
├── requirements.txt
|
├── requirements.txt
|
||||||
@@ -68,9 +83,9 @@ adiuva-api/
|
|||||||
|
|
||||||
## Step-by-Step Implementation
|
## Step-by-Step Implementation
|
||||||
|
|
||||||
### Step 1 — Project scaffolding
|
### Step 1 — Project scaffolding ✅
|
||||||
- [ ] Initialize repo with the directory structure above
|
- [x] Initialize repo with the directory structure above
|
||||||
- [ ] Write `requirements.txt`:
|
- [x] Write `requirements.txt`:
|
||||||
```
|
```
|
||||||
fastapi>=0.115.0
|
fastapi>=0.115.0
|
||||||
uvicorn[standard]>=0.34.0
|
uvicorn[standard]>=0.34.0
|
||||||
@@ -91,29 +106,40 @@ adiuva-api/
|
|||||||
pytest>=8.0.0
|
pytest>=8.0.0
|
||||||
pytest-asyncio>=0.24.0
|
pytest-asyncio>=0.24.0
|
||||||
```
|
```
|
||||||
- [ ] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
||||||
- [ ] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod)
|
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
||||||
- [ ] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
||||||
- [ ] Write `docker-compose.yml`: app, postgres:16, optional redis
|
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
||||||
- [ ] Write `.env.example`
|
- [x] Write `.env.example`
|
||||||
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
|
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
|
||||||
|
|
||||||
### Step 2 — Pydantic schemas (API contracts)
|
### Step 2 — Pydantic schemas (API contracts) ✅
|
||||||
- [ ] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
||||||
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
||||||
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
||||||
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
||||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification']`, `table: str | None`, `data: dict | None`
|
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
|
||||||
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
||||||
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
||||||
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
||||||
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
||||||
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
||||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
||||||
|
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
|
||||||
|
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
|
||||||
|
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
|
||||||
|
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
|
||||||
|
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
|
||||||
|
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
|
||||||
|
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
|
||||||
|
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
|
||||||
|
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
|
||||||
|
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
|
||||||
|
- `PluginInstallRequest`: `plugin_id: str`
|
||||||
- **Outcome:** All request/response models defined and validated.
|
- **Outcome:** All request/response models defined and validated.
|
||||||
|
|
||||||
### Step 3 — Agent Registry + base classes
|
### Step 3 — Agent Registry + base classes ✅
|
||||||
- [ ] `app/core/agent_registry.py`:
|
- [x] `app/core/agent_registry.py`:
|
||||||
- `BaseAgent(ABC)`:
|
- `BaseAgent(ABC)`:
|
||||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
||||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
- Abstract `get_name() -> str`, `get_description() -> str`
|
||||||
@@ -127,11 +153,11 @@ adiuva-api/
|
|||||||
- `get(name) -> ChatAgent`
|
- `get(name) -> ChatAgent`
|
||||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
||||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
||||||
- [ ] Unit tests: register, get, list, call_agent with mock
|
- [x] Unit tests: register, get, list, call_agent with mock
|
||||||
- **Outcome:** Pluggable agent framework.
|
- **Outcome:** Pluggable agent framework.
|
||||||
|
|
||||||
### Step 4 — Orchestrator
|
### Step 4 — Orchestrator ✅
|
||||||
- [ ] `app/core/orchestrator.py`:
|
- [x] `app/core/orchestrator.py`:
|
||||||
- `async classify_intent(message, context, registry) -> str`:
|
- `async classify_intent(message, context, registry) -> str`:
|
||||||
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
||||||
- Uses gpt-4o-mini via LangChain for low latency
|
- Uses gpt-4o-mini via LangChain for low latency
|
||||||
@@ -146,16 +172,17 @@ adiuva-api/
|
|||||||
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
||||||
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
||||||
- Main entry point
|
- Main entry point
|
||||||
|
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
|
||||||
- Classifies intent
|
- Classifies intent
|
||||||
- If `execution_mode == 'direct'`: route + return response
|
- If `execution_mode == 'direct'`: route + return response
|
||||||
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
||||||
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
||||||
- Same as orchestrate but yields tokens for WebSocket streaming
|
- Same as orchestrate but yields tokens for WebSocket streaming
|
||||||
- [ ] Integration tests with mocked LLM and mocked agents
|
- [x] Integration tests with mocked LLM and mocked agents
|
||||||
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
||||||
|
|
||||||
### Step 5 — Execution Plan generator
|
### Step 5 — Execution Plan generator ✅
|
||||||
- [ ] `app/core/execution_plan.py`:
|
- [x] `app/core/execution_plan.py`:
|
||||||
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
||||||
- `ExecutionPlanBuilder`:
|
- `ExecutionPlanBuilder`:
|
||||||
- `add_step(action, params) -> self`
|
- `add_step(action, params) -> self`
|
||||||
@@ -168,32 +195,52 @@ adiuva-api/
|
|||||||
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
||||||
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
||||||
|
|
||||||
### Step 6 — Chat Agents
|
### Step 6 — Chat Agents ✅
|
||||||
- [ ] `app/agents/task_agent.py` — `@registry.register`:
|
- [x] `app/agents/task_agent.py` — `@registry.register`:
|
||||||
- Description: "Manages tasks: create, update, list, suggest"
|
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||||
- Tools: `create_task(title, description, priority, due_date)`, `update_task(id, updates)`, `list_tasks(filters)`, `suggest_tasks(notes_context)`
|
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
|
||||||
- System prompt: PM-oriented, validates task structure, infers priority from context
|
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
|
||||||
- `handle()`: LLM + tool loop via `_tool_loop()`, returns response text + list of actions performed
|
- Accepts flexible context; sentinel `-1` for optional integer update fields
|
||||||
- [ ] `app/agents/calendar_agent.py` — `@registry.register`:
|
- [x] `app/agents/checkpoint_agent.py` — `@registry.register`:
|
||||||
- Description: "Calendar management: events, conflicts, scheduling"
|
- Description: "Manages project checkpoints (milestones): list, create, update, delete"
|
||||||
- Tools: `list_events(date_range)`, `detect_conflicts(events)`, `suggest_reschedule(conflict)`
|
- Tools (4): `list_checkpoints(project_id)`, `create_checkpoint(project_id, title, date, is_ai_suggested, is_approved)`, `update_checkpoint(checkpoint_id, ...)`, `delete_checkpoint(checkpoint_id)`
|
||||||
- Works with event metadata passed in context (never raw calendar data stored)
|
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
|
||||||
- [ ] `app/agents/email_agent.py` — `@registry.register`:
|
- [x] `app/agents/project_agent.py` — `@registry.register`:
|
||||||
- Description: "Email analysis: classify, extract actions, draft responses"
|
- Description: "Manages projects: list, get, create, update, archive, delete"
|
||||||
- Tools: `classify_email(metadata)`, `extract_action_items(metadata)`, `draft_response(thread_context)`
|
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
|
||||||
- Only processes metadata sent by client — never raw email bodies
|
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
|
||||||
- [ ] `app/agents/analytics_agent.py` — `@registry.register`:
|
- [x] `app/agents/note_agent.py` — `@registry.register`:
|
||||||
- Description: "Workspace analytics: metrics, reports, trends"
|
- Description: "Manages notes: list, get, create, update, delete"
|
||||||
- Tools: `calculate_metrics(task_data)`, `generate_report(period, data)`, `trend_analysis(data_points)`
|
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
|
||||||
- Crunches numbers from context, returns structured insights
|
- content is Markdown; `get_note` should be called before update to preserve existing content
|
||||||
- [ ] `app/agents/__init__.py`: imports all agent modules to trigger `@registry.register` decorators
|
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
|
||||||
- [ ] Unit tests per agent with mocked LLM
|
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
|
||||||
- **Outcome:** Four specialized agents, all registered and tested.
|
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Checkpoints, Projects, Notes), all registered and tested.
|
||||||
|
|
||||||
### Step 7 — API Routes
|
### Step 7 — Storage Layer ✅
|
||||||
|
- [x] `app/storage/blob_store.py`:
|
||||||
|
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
|
||||||
|
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
|
||||||
|
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
|
||||||
|
- [x] `app/storage/vector_store.py`:
|
||||||
|
- `VectorStore`: `async upsert`, `async search`, `async delete`
|
||||||
|
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
|
||||||
|
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
|
||||||
|
- ANN on encrypted data: known accuracy trade-off, documented
|
||||||
|
- [x] `app/storage/encryption.py`:
|
||||||
|
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
|
||||||
|
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
|
||||||
|
- Backend NEVER holds decryption keys
|
||||||
|
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
|
||||||
|
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
||||||
|
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
|
||||||
|
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
|
||||||
|
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
|
||||||
|
|
||||||
#### 7a — Chat endpoint
|
### Step 8 — API Routes ✅
|
||||||
- [ ] `app/api/routes/chat.py`:
|
|
||||||
|
#### 8a — Chat endpoint
|
||||||
|
- [x] `app/api/routes/chat.py`:
|
||||||
- `POST /api/v1/chat`:
|
- `POST /api/v1/chat`:
|
||||||
- Request: `ChatRequest`
|
- Request: `ChatRequest`
|
||||||
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
||||||
@@ -204,49 +251,94 @@ adiuva-api/
|
|||||||
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
||||||
- Heartbeat ping every 30s to keep connection alive
|
- Heartbeat ping every 30s to keep connection alive
|
||||||
|
|
||||||
#### 7b — Plans endpoint
|
#### 8b — Plans endpoint
|
||||||
- [ ] `app/api/routes/plans.py`:
|
- [x] `app/api/routes/plans.py`:
|
||||||
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
||||||
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
||||||
|
|
||||||
#### 7c — Backup endpoint
|
#### 8c — Storage endpoint (cloud records)
|
||||||
- [ ] `app/api/routes/backup.py`:
|
- [x] `app/api/routes/storage.py`:
|
||||||
|
- `POST /api/v1/storage/records`: Create encrypted record
|
||||||
|
- Request: `StorageRecordCreate`
|
||||||
|
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
|
||||||
|
- Response: `{id: str, created_at: int}`
|
||||||
|
- `GET /api/v1/storage/records`: List record metadata (no blobs)
|
||||||
|
- Query params: `table: str`, `page: int`, `limit: int`
|
||||||
|
- Response: `list[{id, table, checksum, created_at, updated_at}]`
|
||||||
|
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
|
||||||
|
- Response: blob bytes + `X-Checksum` header
|
||||||
|
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
|
||||||
|
- Request: `StorageRecordUpdate`
|
||||||
|
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
|
||||||
|
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
|
||||||
|
|
||||||
|
#### 8d — Vectors endpoint (cloud vector store)
|
||||||
|
- [x] `app/api/routes/vectors.py`:
|
||||||
|
- `POST /api/v1/storage/vectors/upsert`:
|
||||||
|
- Request: `VectorUpsertRequest`
|
||||||
|
- Verifies checksums, delegates to `VectorStore.upsert()`
|
||||||
|
- Response: `{upserted: int}`
|
||||||
|
- `POST /api/v1/storage/vectors/search`:
|
||||||
|
- Request: `VectorSearchRequest`
|
||||||
|
- Delegates to `VectorStore.search()`
|
||||||
|
- Response: `VectorSearchResponse`
|
||||||
|
- `DELETE /api/v1/storage/vectors`:
|
||||||
|
- Request: `{ids: list[str]}`
|
||||||
|
|
||||||
|
#### 8e — Backup endpoint
|
||||||
|
- [x] `app/api/routes/backup.py`:
|
||||||
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
||||||
- Free: 0 (no backup)
|
- Free: 0 (no backup)
|
||||||
- Pro: 5 GB
|
- Pro: 5 GB
|
||||||
- Power: 50 GB
|
- Power: 25 GB
|
||||||
- Team: unlimited
|
- Team: unlimited
|
||||||
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
||||||
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
||||||
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
||||||
|
|
||||||
#### 7d — Auth endpoint
|
#### 8f — Plugins endpoint
|
||||||
- [ ] `app/api/routes/auth.py`:
|
- [x] `app/api/routes/plugins.py`:
|
||||||
|
- `GET /api/v1/plugins`:
|
||||||
|
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
|
||||||
|
- Response: `PluginListResponse`
|
||||||
|
- Available from Power tier and above
|
||||||
|
- `GET /api/v1/plugins/{id}`:
|
||||||
|
- Response: `PluginManifest` + ratings + install count
|
||||||
|
- `POST /api/v1/plugins/{id}/install`:
|
||||||
|
- Request: `PluginInstallRequest`
|
||||||
|
- Records installation for the user (billing tracking, analytics)
|
||||||
|
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
|
||||||
|
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
|
||||||
|
- `DELETE /api/v1/plugins/{id}/install`:
|
||||||
|
- Unregisters installation
|
||||||
|
|
||||||
|
#### 8g — Auth endpoint
|
||||||
|
- [x] `app/api/routes/auth.py`:
|
||||||
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
||||||
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
||||||
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
||||||
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
||||||
|
|
||||||
#### 7e — Billing endpoint
|
#### 8h — Billing endpoint
|
||||||
- [ ] `app/api/routes/billing.py`:
|
- [x] `app/api/routes/billing.py`:
|
||||||
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
||||||
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
||||||
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
||||||
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
||||||
|
|
||||||
- **Outcome:** Complete REST + WebSocket API.
|
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
|
||||||
|
|
||||||
### Step 8 — Middleware
|
### Step 9 — Middleware
|
||||||
|
|
||||||
#### 8a — Auth middleware
|
#### 9a — Auth middleware
|
||||||
- [ ] `app/api/middleware/auth.py`:
|
- [x] `app/api/middleware/auth.py`:
|
||||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
||||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
||||||
- Raises `401` on invalid/expired token
|
- Raises `401` on invalid/expired token
|
||||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||||
|
|
||||||
#### 8b — Rate limiter
|
#### 9b — Rate limiter
|
||||||
- [ ] `app/api/middleware/rate_limit.py`:
|
- [x] `app/api/middleware/rate_limit.py`:
|
||||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
||||||
- Tier-based limits:
|
- Tier-based limits:
|
||||||
- Free: 20 req/min
|
- Free: 20 req/min
|
||||||
@@ -255,8 +347,8 @@ adiuva-api/
|
|||||||
- Team: 200 req/seat/min
|
- Team: 200 req/seat/min
|
||||||
- Custom 429 response with `Retry-After` header
|
- Custom 429 response with `Retry-After` header
|
||||||
|
|
||||||
#### 8c — Sanitizer
|
#### 9c — Sanitizer
|
||||||
- [ ] `app/api/middleware/sanitizer.py`:
|
- [x] `app/api/middleware/sanitizer.py`:
|
||||||
- Response middleware that scans response bodies
|
- Response middleware that scans response bodies
|
||||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
||||||
- Pattern-based detection + exact match against known prompt fingerprints
|
- Pattern-based detection + exact match against known prompt fingerprints
|
||||||
@@ -264,46 +356,113 @@ adiuva-api/
|
|||||||
|
|
||||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
||||||
|
|
||||||
### Step 9 — Billing & Tier management
|
### Step 10 — Plugin Marketplace ✅
|
||||||
- [ ] `app/billing/stripe_service.py`:
|
- [x] `app/marketplace/plugin_registry.py`:
|
||||||
|
- `PluginRegistry`:
|
||||||
|
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
||||||
|
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
||||||
|
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
||||||
|
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
||||||
|
- `async reject_plugin(plugin_id, reason: str) -> None`
|
||||||
|
- [x] `app/marketplace/plugin_review.py`:
|
||||||
|
- `ReviewQueue`:
|
||||||
|
- `async get_pending() -> list[dict]`
|
||||||
|
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
||||||
|
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
||||||
|
- [x] `app/marketplace/revenue_share.py`:
|
||||||
|
- `RevenueShare`:
|
||||||
|
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
||||||
|
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
||||||
|
- `async get_earnings(developer_id, period) -> dict`
|
||||||
|
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
||||||
|
|
||||||
|
### Step 11 — Billing & Tier management ✅
|
||||||
|
- [x] `app/billing/stripe_service.py`:
|
||||||
- `create_checkout_session(user_id, tier) -> str`
|
- `create_checkout_session(user_id, tier) -> str`
|
||||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
||||||
- `get_subscription(user_id) -> dict | None`
|
- `get_subscription(user_id) -> dict | None`
|
||||||
- `cancel_subscription(user_id) -> None`
|
- `cancel_subscription(user_id) -> None`
|
||||||
- [ ] `app/billing/tier_manager.py`:
|
- [x] `app/billing/tier_manager.py`:
|
||||||
- `TierManager`:
|
- `TierManager`:
|
||||||
- Feature matrix:
|
- Feature matrix:
|
||||||
```python
|
```python
|
||||||
FEATURES = {
|
FEATURES = {
|
||||||
'free': {'agents': 3, 'batch': False, 'providers': 1, 'backup_gb': 0},
|
'free': {
|
||||||
'pro': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': 5},
|
'agents': 3,
|
||||||
'power': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': 50, 'byok': True},
|
'batch_active': 2,
|
||||||
'team': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': -1, 'sso': True},
|
'cloud_storage_gb': 0,
|
||||||
|
'backup_gb': 0,
|
||||||
|
'providers': 1,
|
||||||
|
'batch_builder': False,
|
||||||
|
'plugin_marketplace': False,
|
||||||
|
'sso': False,
|
||||||
|
},
|
||||||
|
'pro': {
|
||||||
|
'agents': -1, # unlimited
|
||||||
|
'batch_active': 10,
|
||||||
|
'cloud_storage_gb': 5,
|
||||||
|
'backup_gb': 5,
|
||||||
|
'providers': -1,
|
||||||
|
'batch_builder': False,
|
||||||
|
'plugin_marketplace': False,
|
||||||
|
'sso': False,
|
||||||
|
},
|
||||||
|
'power': {
|
||||||
|
'agents': -1,
|
||||||
|
'batch_active': -1, # unlimited
|
||||||
|
'cloud_storage_gb': 25,
|
||||||
|
'backup_gb': 25,
|
||||||
|
'providers': -1,
|
||||||
|
'batch_builder': True,
|
||||||
|
'plugin_marketplace': True,
|
||||||
|
'sso': False,
|
||||||
|
},
|
||||||
|
'team': {
|
||||||
|
'agents': -1,
|
||||||
|
'batch_active': -1,
|
||||||
|
'cloud_storage_gb': -1,
|
||||||
|
'backup_gb': -1,
|
||||||
|
'providers': -1,
|
||||||
|
'batch_builder': True,
|
||||||
|
'plugin_marketplace': True,
|
||||||
|
'sso': True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
- `get_tier(user_id) -> BillingTier`
|
- `get_tier(user_id) -> BillingTier`
|
||||||
- `check_feature(user_id, feature) -> bool`
|
- `check_feature(user_id, feature) -> bool`
|
||||||
- `get_rate_limit(tier) -> int`
|
- `get_rate_limit(tier) -> int`
|
||||||
- **Outcome:** Stripe integration with tier-based feature gating.
|
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
||||||
|
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
|
||||||
|
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
|
||||||
|
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
|
||||||
|
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
||||||
|
|
||||||
### Step 10 — Database (auth/billing only)
|
### Step 12 — Database (auth/billing/marketplace only)
|
||||||
- [ ] PostgreSQL schema via Alembic:
|
- [x] PostgreSQL schema via Alembic:
|
||||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
||||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
||||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
||||||
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
||||||
- [ ] Initial Alembic migration
|
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
|
||||||
- [ ] SQLAlchemy models in `app/models.py`
|
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
|
||||||
- **Outcome:** Auth and billing persistence. Zero user data stored.
|
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
||||||
|
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
||||||
|
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
||||||
|
- [x] Initial Alembic migration
|
||||||
|
- [x] SQLAlchemy models in `app/models.py`
|
||||||
|
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
||||||
|
|
||||||
### Step 11 — Testing & deployment
|
### Step 13 — Testing & deployment ✅
|
||||||
- [ ] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed)
|
- [x] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
|
||||||
- [ ] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
- [x] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
||||||
- [ ] `tests/test_agents.py`: each agent with mocked tools
|
- [x] `tests/test_agents.py`: each agent with mocked tools
|
||||||
- [ ] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
- [x] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
||||||
- [ ] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
- [x] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
||||||
- [ ] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
- [x] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
|
||||||
- [ ] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
- [x] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
|
||||||
|
- [x] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
||||||
|
- [x] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
||||||
- **Outcome:** Fully tested, deployable backend.
|
- **Outcome:** Fully tested, deployable backend.
|
||||||
|
|
||||||
---
|
---
|
||||||
@@ -320,10 +479,22 @@ adiuva-api/
|
|||||||
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
||||||
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
||||||
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
||||||
|
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
|
||||||
|
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
|
||||||
|
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
|
||||||
|
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
|
||||||
|
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
|
||||||
|
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
|
||||||
|
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
|
||||||
|
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
|
||||||
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
||||||
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
||||||
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
||||||
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
||||||
|
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
|
||||||
|
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
|
||||||
|
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
|
||||||
|
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
|
||||||
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
||||||
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
||||||
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
||||||
@@ -339,20 +510,24 @@ adiuva-api/
|
|||||||
| Framework | FastAPI + Uvicorn |
|
| Framework | FastAPI + Uvicorn |
|
||||||
| LLM | LangChain + langchain-openai |
|
| LLM | LangChain + langchain-openai |
|
||||||
| Auth | PyJWT + bcrypt + OAuth2 |
|
| Auth | PyJWT + bcrypt + OAuth2 |
|
||||||
| Billing | stripe-python |
|
| Billing | stripe-python + Stripe Connect |
|
||||||
| Storage | boto3 (S3) |
|
| Blob storage | boto3 (S3) |
|
||||||
|
| Vector store | Pinecone or Qdrant (configurable) |
|
||||||
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
||||||
| Rate limiting | slowapi |
|
| Rate limiting | slowapi |
|
||||||
| Testing | pytest + pytest-asyncio + httpx |
|
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
|
||||||
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Development Rules
|
## Development Rules
|
||||||
|
|
||||||
1. **NEVER persist user data.** The DB stores only auth, billing, and backup metadata. User context arrives in requests and is discarded after processing.
|
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
|
||||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending.
|
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
|
||||||
3. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
|
||||||
4. **Type hints everywhere.** All functions have full type annotations.
|
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
||||||
5. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
5. **Type hints everywhere.** All functions have full type annotations.
|
||||||
6. **Structured logging.** JSON logs with request ID correlation.
|
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
||||||
|
7. **Structured logging.** JSON logs with request ID correlation.
|
||||||
|
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
|
||||||
|
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
||||||
|
|||||||
39
Dockerfile
Normal file
39
Dockerfile
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY requirements.txt .
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
# Non-root user
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy installed packages from builder
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Copy application source
|
||||||
|
COPY app/ app/
|
||||||
|
|
||||||
|
# Copy Alembic migration files
|
||||||
|
COPY alembic/ alembic/
|
||||||
|
COPY alembic.ini .
|
||||||
|
|
||||||
|
# Ensure appuser owns the working directory
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "4", \
|
||||||
|
"--timeout", "120"]
|
||||||
793
README.md
Normal file
793
README.md
Normal file
@@ -0,0 +1,793 @@
|
|||||||
|
# Adiuva Cloud API
|
||||||
|
|
||||||
|
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
|
||||||
|
|
||||||
|
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [Architecture](#architecture)
|
||||||
|
- [Key Features](#key-features)
|
||||||
|
- [Tech Stack](#tech-stack)
|
||||||
|
- [Getting Started](#getting-started)
|
||||||
|
- [Docker Deployment](#docker-deployment)
|
||||||
|
- [Environment Variables](#environment-variables)
|
||||||
|
- [API Reference](#api-reference)
|
||||||
|
- [Data Model](#data-model)
|
||||||
|
- [AI Agent System](#ai-agent-system)
|
||||||
|
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
||||||
|
- [Middleware](#middleware)
|
||||||
|
- [Storage Layer](#storage-layer)
|
||||||
|
- [Billing & Tiers](#billing--tiers)
|
||||||
|
- [Plugin Marketplace](#plugin-marketplace)
|
||||||
|
- [Testing](#testing)
|
||||||
|
- [Project Structure](#project-structure)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
|
||||||
|
|
||||||
|
### Design Principles
|
||||||
|
|
||||||
|
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
|
||||||
|
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
||||||
|
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
|
||||||
|
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
||||||
|
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
|
||||||
|
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
|
||||||
|
│ Desktop App │────▶│ │
|
||||||
|
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
|
||||||
|
└──────────────┘ │ │
|
||||||
|
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
||||||
|
│ │ Auth Routes │ │ Chat Routes │ │
|
||||||
|
│ │ Billing Routes │ │ ↓ │ │
|
||||||
|
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
||||||
|
│ │ Backup Routes │ │ ↓ classify intent │ │
|
||||||
|
│ │ Plugin Routes │ │ Agent Registry │ │
|
||||||
|
│ │ Vector Routes │ │ ↓ │ │
|
||||||
|
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
|
||||||
|
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
|
||||||
|
│ │ (GPT-4o + LangChain) │ │
|
||||||
|
│ └────────────────────────────┘ │
|
||||||
|
└────────────────────────────────────────────────────────┘
|
||||||
|
│ │ │
|
||||||
|
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
|
||||||
|
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
|
||||||
|
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
|
||||||
|
│ Billing, │ │ backups) │ │ (Vectors) │
|
||||||
|
│ Metadata) │ └───────────────┘ └────────────────┘
|
||||||
|
└────────────┘
|
||||||
|
│
|
||||||
|
┌────────▼───┐
|
||||||
|
│ Stripe │
|
||||||
|
│ (Billing, │
|
||||||
|
│ Connect) │
|
||||||
|
└────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
||||||
|
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Checkpoints (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
||||||
|
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
||||||
|
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
||||||
|
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
||||||
|
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
|
||||||
|
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
|
||||||
|
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
||||||
|
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
||||||
|
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
||||||
|
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
||||||
|
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
|
||||||
|
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
||||||
|
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
|
||||||
|
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
|
||||||
|
| Package | Version | Purpose |
|
||||||
|
|---|---|---|
|
||||||
|
| `fastapi` | ≥ 0.115.0 | Web framework |
|
||||||
|
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
|
||||||
|
| `gunicorn` | ≥ 22.0.0 | Production process manager |
|
||||||
|
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
|
||||||
|
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
|
||||||
|
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
|
||||||
|
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
|
||||||
|
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
||||||
|
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
||||||
|
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
||||||
|
| `boto3` | ≥ 1.35.0 | AWS S3 client |
|
||||||
|
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
||||||
|
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
||||||
|
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
||||||
|
| `alembic` | ≥ 1.14.0 | Database migration management |
|
||||||
|
| `bcrypt` | ≥ 4.2.0 | Password hashing |
|
||||||
|
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
|
||||||
|
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
||||||
|
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
||||||
|
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
||||||
|
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
|
||||||
|
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
|
||||||
|
| `pytest` | ≥ 8.0.0 | Test framework |
|
||||||
|
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
||||||
|
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
||||||
|
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
|
||||||
|
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.12+
|
||||||
|
- PostgreSQL 16+
|
||||||
|
- An OpenAI API key (for LLM features)
|
||||||
|
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
||||||
|
- AWS credentials (optional — needed for S3 storage in production)
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone the repository
|
||||||
|
git clone <repo-url> && cd adiuva-api
|
||||||
|
|
||||||
|
# Create a virtual environment
|
||||||
|
python -m venv .venv && source .venv/bin/activate
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Configure environment
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start PostgreSQL (or use the Docker Compose database)
|
||||||
|
docker compose up db -d
|
||||||
|
|
||||||
|
# Run migrations
|
||||||
|
alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run the Development Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Docker Deployment
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts two services:
|
||||||
|
|
||||||
|
- **app** — FastAPI server on port `8000`
|
||||||
|
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
||||||
|
|
||||||
|
The compose file also includes optional services for fully local deployments:
|
||||||
|
|
||||||
|
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
|
||||||
|
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
|
||||||
|
|
||||||
|
### Dockerfile Details
|
||||||
|
|
||||||
|
The Dockerfile uses a multi-stage build:
|
||||||
|
|
||||||
|
1. **Builder stage** — Installs Python dependencies into a virtual environment.
|
||||||
|
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
|
||||||
|
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Production command (run by the container)
|
||||||
|
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Homelab / Self-Hosted Deployment
|
||||||
|
|
||||||
|
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
|
||||||
|
|
||||||
|
### 1. Start all services
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
|
||||||
|
|
||||||
|
### 2. Create the MinIO bucket
|
||||||
|
|
||||||
|
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
|
||||||
|
docker compose exec minio mc mb local/adiuva
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Configure your `.env`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Database (uses the compose PostgreSQL)
|
||||||
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
|
||||||
|
# S3 → MinIO
|
||||||
|
S3_BUCKET=adiuva
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
AWS_ACCESS_KEY_ID=minioadmin
|
||||||
|
AWS_SECRET_ACCESS_KEY=minioadmin
|
||||||
|
|
||||||
|
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
|
||||||
|
QDRANT_URL=http://qdrant:6333
|
||||||
|
QDRANT_API_KEY=
|
||||||
|
PINECONE_API_KEY=
|
||||||
|
|
||||||
|
# Billing — leave empty to stub (no Stripe needed)
|
||||||
|
STRIPE_SECRET_KEY=
|
||||||
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
|
# LLM — the only external service
|
||||||
|
OPENAI_API_KEY=sk-...
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# Auth
|
||||||
|
JWT_SECRET=your-secret-here
|
||||||
|
ENV=dev
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Run migrations
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose exec app alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### What runs where
|
||||||
|
|
||||||
|
| Service | Runs on | Port | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| FastAPI app | Docker | 8000 | API server |
|
||||||
|
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
|
||||||
|
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
|
||||||
|
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
|
||||||
|
| Stripe | — | — | Stubbed when keys are empty |
|
||||||
|
| OpenAI / LLM | Cloud | — | Only external dependency |
|
||||||
|
|
||||||
|
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
|
||||||
|
|
||||||
|
| Variable | Type | Default | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
|
||||||
|
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
|
||||||
|
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
|
||||||
|
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
||||||
|
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
||||||
|
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
||||||
|
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
|
||||||
|
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
|
||||||
|
| `S3_REGION` | `str` | `us-east-1` | AWS region |
|
||||||
|
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
|
||||||
|
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
|
||||||
|
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
|
||||||
|
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
|
||||||
|
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
|
||||||
|
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
|
||||||
|
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
|
||||||
|
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
|
||||||
|
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
||||||
|
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
||||||
|
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
||||||
|
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
|
||||||
|
|
||||||
|
### Health
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
|
||||||
|
|
||||||
|
### Auth
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
|
||||||
|
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
|
||||||
|
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
|
||||||
|
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
|
||||||
|
|
||||||
|
### Chat
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
||||||
|
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
||||||
|
|
||||||
|
### Plans
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
||||||
|
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
||||||
|
|
||||||
|
### Storage (Cloud Records)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
|
||||||
|
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
|
||||||
|
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
|
||||||
|
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
|
||||||
|
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
|
||||||
|
|
||||||
|
### Vectors (Cloud Vector Store)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
|
||||||
|
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
|
||||||
|
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
|
||||||
|
|
||||||
|
### Backup
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
|
||||||
|
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
|
||||||
|
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
|
||||||
|
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
|
||||||
|
|
||||||
|
### Plugins (Marketplace)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
|
||||||
|
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
|
||||||
|
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
|
||||||
|
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
|
||||||
|
|
||||||
|
### Billing
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
|
||||||
|
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
|
||||||
|
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
|
||||||
|
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Model
|
||||||
|
|
||||||
|
9 tables managed by Alembic migrations. Source: `app/models.py`
|
||||||
|
|
||||||
|
### Tables
|
||||||
|
|
||||||
|
| Table | Primary Key | Key Columns | Purpose |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
||||||
|
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
||||||
|
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
||||||
|
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
|
||||||
|
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
|
||||||
|
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
|
||||||
|
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
|
||||||
|
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
|
||||||
|
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
|
||||||
|
|
||||||
|
### Enum Types
|
||||||
|
|
||||||
|
| Enum | Values |
|
||||||
|
|---|---|
|
||||||
|
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
||||||
|
| `plugin_status` | `pending_review`, `approved`, `rejected` |
|
||||||
|
| `review_decision` | `approved`, `rejected` |
|
||||||
|
|
||||||
|
### Migrations
|
||||||
|
|
||||||
|
| Version | Description |
|
||||||
|
|---|---|
|
||||||
|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
|
||||||
|
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## AI Agent System
|
||||||
|
|
||||||
|
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
|
||||||
|
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
||||||
|
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
||||||
|
|
||||||
|
### Registered Agents
|
||||||
|
|
||||||
|
| Agent | Registry Name | Tools | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
||||||
|
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
||||||
|
| **CheckpointAgent** | `checkpoint_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_checkpoints`, `create_checkpoint`, `update_checkpoint`, `delete_checkpoint` |
|
||||||
|
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
||||||
|
|
||||||
|
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
||||||
|
|
||||||
|
### Switching LLM Providers
|
||||||
|
|
||||||
|
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# OpenAI (default)
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# Anthropic
|
||||||
|
LLM_MODEL=anthropic/claude-3.5-sonnet
|
||||||
|
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
|
||||||
|
|
||||||
|
# Google Gemini
|
||||||
|
LLM_MODEL=gemini/gemini-pro
|
||||||
|
LLM_ROUTER_MODEL=gemini/gemini-flash
|
||||||
|
|
||||||
|
# Local Ollama
|
||||||
|
LLM_MODEL=ollama/llama3
|
||||||
|
LLM_ROUTER_MODEL=ollama/llama3
|
||||||
|
|
||||||
|
# AWS Bedrock
|
||||||
|
LLM_MODEL=bedrock/anthropic.claude-v2
|
||||||
|
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Orchestration & Execution Plans
|
||||||
|
|
||||||
|
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
|
||||||
|
|
||||||
|
### Orchestrator
|
||||||
|
|
||||||
|
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
|
||||||
|
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
|
||||||
|
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
|
||||||
|
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
|
||||||
|
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
|
||||||
|
|
||||||
|
### Execution Plans
|
||||||
|
|
||||||
|
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
|
||||||
|
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
|
||||||
|
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
|
||||||
|
|
||||||
|
### Built-in Templates (6)
|
||||||
|
|
||||||
|
`tpl_task_agent_default`, `tpl_checkpoint_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
||||||
|
|
||||||
|
### Built-in Playbooks (2)
|
||||||
|
|
||||||
|
| Playbook | Description |
|
||||||
|
|---|---|
|
||||||
|
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
|
||||||
|
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Middleware
|
||||||
|
|
||||||
|
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
|
||||||
|
|
||||||
|
### JWT Authentication
|
||||||
|
|
||||||
|
Source: `app/api/middleware/auth.py`
|
||||||
|
|
||||||
|
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
|
||||||
|
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
|
||||||
|
- Falls back to `free` when no subscription row exists.
|
||||||
|
- Raises `401 Unauthorized` on invalid or expired tokens.
|
||||||
|
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||||
|
|
||||||
|
### Tier-Based Rate Limiter
|
||||||
|
|
||||||
|
Source: `app/api/middleware/rate_limit.py`
|
||||||
|
|
||||||
|
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
|
||||||
|
- Per-user 60-second window sized by subscription tier:
|
||||||
|
|
||||||
|
| Tier | Requests / Minute |
|
||||||
|
|---|---|
|
||||||
|
| Free | 20 |
|
||||||
|
| Pro | 60 |
|
||||||
|
| Power | 120 |
|
||||||
|
| Team | 200 |
|
||||||
|
|
||||||
|
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
|
||||||
|
- **Exempt paths:** register, login, webhook, health
|
||||||
|
|
||||||
|
### Response Sanitizer
|
||||||
|
|
||||||
|
Source: `app/api/middleware/sanitizer.py`
|
||||||
|
|
||||||
|
- Runs only on `/api/v1/chat` endpoints.
|
||||||
|
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
||||||
|
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
||||||
|
- Logs sanitization events as `WARNING`.
|
||||||
|
- Binary responses (storage, backup) are never touched.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Storage Layer
|
||||||
|
|
||||||
|
### Blob Store
|
||||||
|
|
||||||
|
Source: `app/storage/blob_store.py`
|
||||||
|
|
||||||
|
- S3-backed storage for E2E encrypted blobs.
|
||||||
|
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
|
||||||
|
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
|
||||||
|
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
|
||||||
|
- The backend **never inspects or decrypts blob content**.
|
||||||
|
|
||||||
|
### Vector Store
|
||||||
|
|
||||||
|
Source: `app/storage/vector_store.py`
|
||||||
|
|
||||||
|
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
|
||||||
|
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
|
||||||
|
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
|
||||||
|
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
|
||||||
|
- Methods: `upsert()`, `search()`, `delete()`
|
||||||
|
|
||||||
|
### Encryption Utilities
|
||||||
|
|
||||||
|
Source: `app/storage/encryption.py`
|
||||||
|
|
||||||
|
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
|
||||||
|
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
|
||||||
|
- **No decryption key ever reaches the backend.**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Billing & Tiers
|
||||||
|
|
||||||
|
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
||||||
|
|
||||||
|
### Feature Matrix
|
||||||
|
|
||||||
|
| Feature | Free | Pro | Power | Team |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
||||||
|
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
||||||
|
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
|
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
|
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
||||||
|
| Batch Builder | — | — | ✓ | ✓ |
|
||||||
|
| Plugin Marketplace | — | — | ✓ | ✓ |
|
||||||
|
| SSO | — | — | — | ✓ |
|
||||||
|
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
||||||
|
|
||||||
|
### Stripe Integration
|
||||||
|
|
||||||
|
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
|
||||||
|
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
|
||||||
|
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
|
||||||
|
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
|
||||||
|
|
||||||
|
### Tier Manager
|
||||||
|
|
||||||
|
- `get_tier(user_id)` — Returns the user's current billing tier.
|
||||||
|
- `check_feature(tier, feature)` — Boolean feature gate check.
|
||||||
|
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
||||||
|
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Plugin Marketplace
|
||||||
|
|
||||||
|
Source: `app/marketplace/`
|
||||||
|
|
||||||
|
### Plugin Registry
|
||||||
|
|
||||||
|
- PostgreSQL-backed catalog of submitted and approved plugins.
|
||||||
|
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
|
||||||
|
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
|
||||||
|
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
|
||||||
|
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
|
||||||
|
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
|
||||||
|
|
||||||
|
### Review Queue
|
||||||
|
|
||||||
|
- Automated security checklist before human review:
|
||||||
|
- Plugin ID must match `^[a-z0-9-]+$`
|
||||||
|
- Permissions must be from the allowed set only
|
||||||
|
- No binary blobs in the manifest
|
||||||
|
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:checkpoints`, `write:checkpoints`, `read:calendar`, `write:calendar`
|
||||||
|
- `get_pending(db)` — Lists plugins awaiting review.
|
||||||
|
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
||||||
|
|
||||||
|
### Revenue Sharing
|
||||||
|
|
||||||
|
- **70% developer / 30% platform** split on all paid plugin sales.
|
||||||
|
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
|
||||||
|
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
|
||||||
|
- Gracefully stubs transfers when Stripe is not configured.
|
||||||
|
|
||||||
|
### Seed Plugins
|
||||||
|
|
||||||
|
| Plugin | Category | Price |
|
||||||
|
|---|---|---|
|
||||||
|
| GitHub Sync | Productivity | Free |
|
||||||
|
| Slack Notifier | Communication | €4.99 |
|
||||||
|
| Time Tracker | Productivity | €9.99 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
pytest
|
||||||
|
|
||||||
|
# Run a specific test file
|
||||||
|
pytest tests/test_auth.py
|
||||||
|
|
||||||
|
# Run with verbose output
|
||||||
|
pytest -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Infrastructure
|
||||||
|
|
||||||
|
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
||||||
|
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
|
||||||
|
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
||||||
|
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
||||||
|
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
|
||||||
|
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
||||||
|
- **No external dependencies** — all tests run fully offline.
|
||||||
|
|
||||||
|
### Test Coverage
|
||||||
|
|
||||||
|
| File | Coverage |
|
||||||
|
|---|---|
|
||||||
|
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
||||||
|
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
|
||||||
|
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
|
||||||
|
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
|
||||||
|
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
|
||||||
|
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
|
||||||
|
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
|
||||||
|
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
|
||||||
|
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
adiuva-api/
|
||||||
|
├── alembic.ini # Alembic configuration
|
||||||
|
├── BACKEND_PLAN.md # Architecture & design decisions
|
||||||
|
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
||||||
|
├── Dockerfile # Multi-stage production build
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
│
|
||||||
|
├── alembic/ # Database migrations
|
||||||
|
│ ├── env.py # Alembic environment config
|
||||||
|
│ ├── script.py.mako # Migration template
|
||||||
|
│ └── versions/
|
||||||
|
│ ├── 001_initial_schema.py # Tables, indexes, FKs
|
||||||
|
│ └── 002_seed_plugins.py # Seed marketplace plugins
|
||||||
|
│
|
||||||
|
├── app/ # Application source
|
||||||
|
│ ├── main.py # FastAPI app factory, middleware, routes
|
||||||
|
│ ├── db.py # Async SQLAlchemy engine & session
|
||||||
|
│ ├── models.py # SQLAlchemy ORM models (9 tables)
|
||||||
|
│ ├── schemas.py # Pydantic request/response schemas
|
||||||
|
│ │
|
||||||
|
│ ├── config/
|
||||||
|
│ │ └── settings.py # Pydantic Settings (env vars)
|
||||||
|
│ │
|
||||||
|
│ ├── agents/ # LLM-powered domain agents
|
||||||
|
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
||||||
|
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
||||||
|
│ │ ├── checkpoint_agent.py # Milestones (4 tools)
|
||||||
|
│ │ └── note_agent.py # Markdown notes (5 tools)
|
||||||
|
│ │
|
||||||
|
│ ├── core/ # Orchestration engine
|
||||||
|
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
||||||
|
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
||||||
|
│ │ ├── orchestrator.py # Intent classification & routing
|
||||||
|
│ │ └── execution_plan.py # Plan builder, templates, cache
|
||||||
|
│ │
|
||||||
|
│ ├── api/ # HTTP layer
|
||||||
|
│ │ ├── deps.py # Shared FastAPI dependencies
|
||||||
|
│ │ ├── middleware/
|
||||||
|
│ │ │ ├── auth.py # JWT validation, live tier lookup
|
||||||
|
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
||||||
|
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
||||||
|
│ │ └── routes/
|
||||||
|
│ │ ├── auth.py # Register, login, refresh, me
|
||||||
|
│ │ ├── chat.py # Chat + WebSocket streaming
|
||||||
|
│ │ ├── plans.py # Execution plan playbooks
|
||||||
|
│ │ ├── storage.py # E2E encrypted record CRUD
|
||||||
|
│ │ ├── vectors.py # Vector upsert, search, delete
|
||||||
|
│ │ ├── backup.py # Encrypted backup management
|
||||||
|
│ │ ├── plugins.py # Marketplace browse & install
|
||||||
|
│ │ └── billing.py # Stripe checkout & webhooks
|
||||||
|
│ │
|
||||||
|
│ ├── storage/ # Storage backends
|
||||||
|
│ │ ├── blob_store.py # S3 blob storage
|
||||||
|
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
|
||||||
|
│ │ └── encryption.py # Checksum verification utilities
|
||||||
|
│ │
|
||||||
|
│ ├── billing/ # Subscription management
|
||||||
|
│ │ ├── stripe_service.py # Stripe API integration
|
||||||
|
│ │ └── tier_manager.py # Feature matrix & quota enforcement
|
||||||
|
│ │
|
||||||
|
│ └── marketplace/ # Plugin ecosystem
|
||||||
|
│ ├── plugin_registry.py # Catalog CRUD & search
|
||||||
|
│ ├── plugin_review.py # Security checklist & review queue
|
||||||
|
│ └── revenue_share.py # 70/30 split & Stripe Connect
|
||||||
|
│
|
||||||
|
└── tests/ # Test suite
|
||||||
|
├── conftest.py # Fixtures: DB, S3, auth, seeds
|
||||||
|
├── test_auth.py
|
||||||
|
├── test_orchestrator.py
|
||||||
|
├── test_agents.py
|
||||||
|
├── test_storage.py
|
||||||
|
├── test_backup.py
|
||||||
|
├── test_plugins.py
|
||||||
|
├── test_agent_registry.py
|
||||||
|
├── test_execution_plan.py
|
||||||
|
└── test_middleware.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
*To be determined.*
|
||||||
47
alembic.ini
Normal file
47
alembic.ini
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Alembic configuration file.
|
||||||
|
# The async app uses postgresql+asyncpg:// at runtime.
|
||||||
|
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
prepend_sys_path = .
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
93
alembic/env.py
Normal file
93
alembic/env.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Alembic migration environment — async-compatible.
|
||||||
|
|
||||||
|
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
|
||||||
|
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
|
||||||
|
env var by replacing the driver prefix.
|
||||||
|
|
||||||
|
Run migrations with:
|
||||||
|
alembic upgrade head
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
|
# Alembic Config object (gives access to alembic.ini values).
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Set up Python logging from alembic.ini.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# Import the Base so that Alembic can detect model changes for --autogenerate.
|
||||||
|
from app.models import Base # noqa: E402
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_url(async_url: str) -> str:
|
||||||
|
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
|
||||||
|
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_url() -> str:
|
||||||
|
db_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not db_url:
|
||||||
|
# Fall back to settings if env var not set directly.
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
db_url = settings.DATABASE_URL
|
||||||
|
return _sync_url(db_url)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Emit SQL without a live DB connection."""
|
||||||
|
url = _get_url()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection): # type: ignore[no-untyped-def]
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_migrations_online_async() -> None:
|
||||||
|
"""Run migrations against a live DB using the async engine."""
|
||||||
|
async_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not async_url:
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
async_url = settings.DATABASE_URL
|
||||||
|
|
||||||
|
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
asyncio.run(run_migrations_online_async())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
202
alembic/versions/001_initial_schema.py
Normal file
202
alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||||
|
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||||
|
|
||||||
|
Revision ID: 001
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-02
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "001"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enum types ────────────────────────────────────────────────────────
|
||||||
|
billing_tier = postgresql.ENUM(
|
||||||
|
"free", "pro", "power", "team", name="billing_tier", create_type=False
|
||||||
|
)
|
||||||
|
plugin_status = postgresql.ENUM(
|
||||||
|
"pending_review", "approved", "rejected", name="plugin_status", create_type=False
|
||||||
|
)
|
||||||
|
review_decision = postgresql.ENUM(
|
||||||
|
"approved", "rejected", name="review_decision", create_type=False
|
||||||
|
)
|
||||||
|
for enum in (billing_tier, plugin_status, review_decision):
|
||||||
|
enum.create(op.get_bind(), checkfirst=True)
|
||||||
|
|
||||||
|
# ── users ─────────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("email", sa.String(255), nullable=False),
|
||||||
|
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||||
|
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||||
|
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("email"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_users_email", "users", ["email"])
|
||||||
|
|
||||||
|
# ── refresh_tokens ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"refresh_tokens",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("token_hash", sa.String(64), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("token_hash"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
|
||||||
|
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
|
||||||
|
|
||||||
|
# ── subscriptions ─────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"subscriptions",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||||
|
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
|
||||||
|
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("user_id"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||||
|
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||||
|
|
||||||
|
# ── storage_records ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"storage_records",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("table_name", sa.String(100), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||||
|
|
||||||
|
# ── backup_metadata ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"backup_metadata",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer, nullable=False),
|
||||||
|
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugins ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugins",
|
||||||
|
sa.Column("id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||||
|
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||||
|
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||||
|
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
||||||
|
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||||
|
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||||
|
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||||
|
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── plugin_installations ──────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_installations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||||
|
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_reviews",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
||||||
|
sa.Column("notes", sa.Text, nullable=True),
|
||||||
|
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||||
|
|
||||||
|
# ── revenue_events ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"revenue_events",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||||
|
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("revenue_events")
|
||||||
|
op.drop_table("plugin_reviews")
|
||||||
|
op.drop_table("plugin_installations")
|
||||||
|
op.drop_table("plugins")
|
||||||
|
op.drop_table("backup_metadata")
|
||||||
|
op.drop_table("storage_records")
|
||||||
|
op.drop_table("subscriptions")
|
||||||
|
op.drop_table("refresh_tokens")
|
||||||
|
op.drop_table("users")
|
||||||
|
|
||||||
|
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||||
|
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||||
92
alembic/versions/002_seed_plugins.py
Normal file
92
alembic/versions/002_seed_plugins.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2026-03-03
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "002"
|
||||||
|
down_revision: Union[str, None] = "001"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
{
|
||||||
|
"id": "plugin-github-sync",
|
||||||
|
"name": "GitHub Sync",
|
||||||
|
"description": "Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 0,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-slack-notify",
|
||||||
|
"name": "Slack Notifier",
|
||||||
|
"description": "Post task and checkpoint updates to Slack channels.",
|
||||||
|
"version": "1.2.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "communication",
|
||||||
|
"price_cents": 499,
|
||||||
|
"permissions": json.dumps(["read:tasks", "read:checkpoints"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-time-tracker",
|
||||||
|
"name": "Time Tracker",
|
||||||
|
"description": "Track time spent on tasks with automatic reporting.",
|
||||||
|
"version": "0.9.1",
|
||||||
|
"author_name": "Third Party",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 999,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
plugins = sa.table(
|
||||||
|
"plugins",
|
||||||
|
sa.column("id", sa.String),
|
||||||
|
sa.column("name", sa.String),
|
||||||
|
sa.column("description", sa.Text),
|
||||||
|
sa.column("version", sa.String),
|
||||||
|
sa.column("author_name", sa.String),
|
||||||
|
sa.column("category", sa.String),
|
||||||
|
sa.column("price_cents", sa.Integer),
|
||||||
|
sa.column("permissions", sa.Text),
|
||||||
|
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
|
||||||
|
sa.column("s3_package_key", sa.String),
|
||||||
|
sa.column("install_count", sa.Integer),
|
||||||
|
sa.column("avg_rating", sa.Float),
|
||||||
|
)
|
||||||
|
op.bulk_insert(plugins, _SEED_PLUGINS)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"DELETE FROM plugins WHERE id IN ("
|
||||||
|
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
|
||||||
|
")"
|
||||||
|
)
|
||||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
5
app/agents/__init__.py
Normal file
5
app/agents/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Import all agent modules to trigger @registry.register decorators."""
|
||||||
|
|
||||||
|
from app.agents import checkpoint_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
|
__all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
121
app/agents/checkpoint_agent.py
Normal file
121
app/agents/checkpoint_agent.py
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
"""Checkpoint agent — project milestone management (list, create, update, delete)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
|
||||||
|
"track progress on a project — they are not calendar events.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||||
|
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||||
|
" - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n"
|
||||||
|
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
||||||
|
" - For update_checkpoint, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Listing without a project_id returns all checkpoints across projects\n"
|
||||||
|
" - Always echo the title and formatted date in your confirmation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_checkpoints(project_id: str = "") -> str:
|
||||||
|
"""List checkpoints. Provide project_id to scope to a specific project."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "checkpoints",
|
||||||
|
"filters": {"projectId": project_id or None},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_checkpoint(
|
||||||
|
project_id: str,
|
||||||
|
title: str,
|
||||||
|
date: int,
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a project checkpoint (milestone).
|
||||||
|
project_id: REQUIRED UUID of the parent project
|
||||||
|
title: descriptive name for the milestone
|
||||||
|
date: Unix timestamp in milliseconds
|
||||||
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "checkpoints",
|
||||||
|
"data": {
|
||||||
|
"projectId": project_id,
|
||||||
|
"title": title,
|
||||||
|
"date": date,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_checkpoint(
|
||||||
|
checkpoint_id: str,
|
||||||
|
title: str = "",
|
||||||
|
date: int = -1,
|
||||||
|
is_approved: int = -1,
|
||||||
|
) -> str:
|
||||||
|
"""Update a checkpoint. Only pass fields that should change.
|
||||||
|
checkpoint_id: UUID of the checkpoint (required)
|
||||||
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if date != -1:
|
||||||
|
updates["date"] = date
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
|
return json.dumps({
|
||||||
|
"action": "update_record",
|
||||||
|
"table": "checkpoints",
|
||||||
|
"data": {"id": checkpoint_id, "updates": updates},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_checkpoint(checkpoint_id: str) -> str:
|
||||||
|
"""Delete a checkpoint permanently by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "checkpoints",
|
||||||
|
"data": {"id": checkpoint_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class CheckpointAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "checkpoint_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages project checkpoints (milestones): list, create, update, delete"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
llm = get_llm()
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(
|
||||||
|
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return await self._tool_loop(llm, messages, self.get_tools())
|
||||||
122
app/agents/note_agent.py
Normal file
122
app/agents/note_agent.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||||
|
"and delete Markdown notes in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - content is always Markdown; preserve formatting when updating\n"
|
||||||
|
" - project_id is optional; link a note to a project when mentioned\n"
|
||||||
|
" - When updating, call get_note first if you need to read existing content\n"
|
||||||
|
" before appending or replacing sections\n"
|
||||||
|
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||||
|
" when the user is working within a specific project\n"
|
||||||
|
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||||
|
" is already in the note (retrieved via get_note)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_notes(project_id: str = "") -> str:
|
||||||
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "notes",
|
||||||
|
"filters": {"projectId": project_id or None},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_note(note_id: str) -> str:
|
||||||
|
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "get",
|
||||||
|
"table": "notes",
|
||||||
|
"data": {"id": note_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_note(
|
||||||
|
title: str,
|
||||||
|
content: str,
|
||||||
|
project_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Create a new note.
|
||||||
|
title: note heading (required)
|
||||||
|
content: Markdown body text (required)
|
||||||
|
project_id: optional UUID linking this note to a project
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "notes",
|
||||||
|
"data": {
|
||||||
|
"title": title,
|
||||||
|
"content": content,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_note(
|
||||||
|
note_id: str,
|
||||||
|
title: str = "",
|
||||||
|
content: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update an existing note. Only pass fields that should change.
|
||||||
|
note_id: UUID of the note (required)
|
||||||
|
If you need to preserve existing content, call get_note first.
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if content:
|
||||||
|
updates["content"] = content
|
||||||
|
return json.dumps({
|
||||||
|
"action": "update_record",
|
||||||
|
"table": "notes",
|
||||||
|
"data": {"id": note_id, "updates": updates},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_note(note_id: str) -> str:
|
||||||
|
"""Delete a note permanently by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "notes",
|
||||||
|
"data": {"id": note_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class NoteAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "note_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages notes: list, get, create, update, delete"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return [list_notes, get_note, create_note, update_note, delete_note]
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
llm = get_llm()
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(
|
||||||
|
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return await self._tool_loop(llm, messages, self.get_tools())
|
||||||
157
app/agents/project_agent.py
Normal file
157
app/agents/project_agent.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a project management assistant. You help users create, find,\n"
|
||||||
|
"update, and archive projects in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: active, archived\n"
|
||||||
|
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||||
|
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||||
|
" derive it from context data — do not fabricate content\n"
|
||||||
|
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||||
|
" user wants a complete cross-client view including archived projects\n"
|
||||||
|
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||||
|
" list_projects if you only have a project name\n"
|
||||||
|
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||||
|
" only call delete_project when the user explicitly confirms deletion."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_projects(
|
||||||
|
client_id: str = "",
|
||||||
|
include_archived: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""List projects, optionally filtered by client_id.
|
||||||
|
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "projects",
|
||||||
|
"filters": {
|
||||||
|
"clientId": client_id or None,
|
||||||
|
"includeArchived": bool(include_archived),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_all_projects() -> str:
|
||||||
|
"""List every project regardless of client or status.
|
||||||
|
Use only when the user wants a complete cross-client overview.
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list_all",
|
||||||
|
"table": "projects",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_project(project_id: str) -> str:
|
||||||
|
"""Fetch a single project by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "get",
|
||||||
|
"table": "projects",
|
||||||
|
"data": {"id": project_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_project(
|
||||||
|
name: str,
|
||||||
|
client_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Create a new project.
|
||||||
|
name: human-readable project name (required)
|
||||||
|
client_id: optional UUID of the owning client
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "projects",
|
||||||
|
"data": {
|
||||||
|
"name": name,
|
||||||
|
"clientId": client_id or None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_project(
|
||||||
|
project_id: str,
|
||||||
|
name: str = "",
|
||||||
|
client_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
ai_summary: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update a project. Only pass fields that should change.
|
||||||
|
project_id: UUID of the project (required)
|
||||||
|
status: active | archived
|
||||||
|
ai_summary: AI-generated summary text (populate only when explicitly requested)
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if name:
|
||||||
|
updates["name"] = name
|
||||||
|
if client_id:
|
||||||
|
updates["clientId"] = client_id
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if ai_summary:
|
||||||
|
updates["aiSummary"] = ai_summary
|
||||||
|
return json.dumps({
|
||||||
|
"action": "update_record",
|
||||||
|
"table": "projects",
|
||||||
|
"data": {"id": project_id, "updates": updates},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_project(project_id: str) -> str:
|
||||||
|
"""Permanently delete a project and orphan its tasks.
|
||||||
|
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||||
|
has explicitly confirmed they want permanent deletion.
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "projects",
|
||||||
|
"data": {"id": project_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class ProjectAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "project_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages projects: list, get, create, update, archive, delete"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
llm = get_llm()
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(
|
||||||
|
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return await self._tool_loop(llm, messages, self.get_tools())
|
||||||
228
app/agents/task_agent.py
Normal file
228
app/agents/task_agent.py
Normal file
@@ -0,0 +1,228 @@
|
|||||||
|
"""Task agent — full CRUD for tasks and task comments."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a task management assistant for a project workspace.\n"
|
||||||
|
"You create, update, list, and track tasks and their comments.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: todo, in_progress, done\n"
|
||||||
|
" - priority must be one of: high, medium, low\n"
|
||||||
|
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||||
|
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||||
|
" - project_id is optional; link to a project when the user mentions one\n"
|
||||||
|
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||||
|
" did not explicitly request; 0 otherwise\n"
|
||||||
|
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
||||||
|
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||||
|
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Always confirm the action in plain, user-friendly language."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks(
|
||||||
|
project_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
search: str = "",
|
||||||
|
order_by: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "tasks",
|
||||||
|
"filters": {
|
||||||
|
"projectId": project_id or None,
|
||||||
|
"status": status or None,
|
||||||
|
"search": search or None,
|
||||||
|
"orderBy": order_by or None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_task(
|
||||||
|
title: str,
|
||||||
|
description: str = "",
|
||||||
|
status: str = "todo",
|
||||||
|
priority: str = "medium",
|
||||||
|
assignees: str = "[]",
|
||||||
|
due_date: int = 0,
|
||||||
|
project_id: str = "",
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new task.
|
||||||
|
title: task title (required)
|
||||||
|
description: optional details
|
||||||
|
status: todo | in_progress | done (default: todo)
|
||||||
|
priority: high | medium | low (default: medium)
|
||||||
|
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
||||||
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
|
project_id: optional UUID of the parent project
|
||||||
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms; 1 when confirmed
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "tasks",
|
||||||
|
"data": {
|
||||||
|
"title": title,
|
||||||
|
"description": description or None,
|
||||||
|
"status": status,
|
||||||
|
"priority": priority,
|
||||||
|
"assignee": assignees,
|
||||||
|
"dueDate": due_date or None,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_task(
|
||||||
|
task_id: str,
|
||||||
|
title: str = "",
|
||||||
|
description: str = "",
|
||||||
|
status: str = "",
|
||||||
|
priority: str = "",
|
||||||
|
assignees: str = "",
|
||||||
|
due_date: int = -1,
|
||||||
|
project_id: str = "",
|
||||||
|
is_approved: int = -1,
|
||||||
|
) -> str:
|
||||||
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
|
task_id: the task's UUID (required)
|
||||||
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the value
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if description:
|
||||||
|
updates["description"] = description
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if priority:
|
||||||
|
updates["priority"] = priority
|
||||||
|
if assignees:
|
||||||
|
updates["assignee"] = assignees
|
||||||
|
if due_date != -1:
|
||||||
|
updates["dueDate"] = due_date or None
|
||||||
|
if project_id:
|
||||||
|
updates["projectId"] = project_id
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
|
return json.dumps({
|
||||||
|
"action": "update_record",
|
||||||
|
"table": "tasks",
|
||||||
|
"data": {"id": task_id, "updates": updates},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task(task_id: str) -> str:
|
||||||
|
"""Delete a task permanently by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "tasks",
|
||||||
|
"data": {"id": task_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks_due_today() -> str:
|
||||||
|
"""List all tasks whose due date falls on today's date."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list_due_today",
|
||||||
|
"table": "tasks",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task comment tools ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_task_comments(task_id: str) -> str:
|
||||||
|
"""List all comments on a task by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "taskComments",
|
||||||
|
"filters": {"taskId": task_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||||
|
"""Add a comment to a task.
|
||||||
|
task_id: UUID of the task to comment on
|
||||||
|
author: name or ID of the comment author
|
||||||
|
content: comment text
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "taskComments",
|
||||||
|
"data": {
|
||||||
|
"taskId": task_id,
|
||||||
|
"author": author,
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task_comment(comment_id: str) -> str:
|
||||||
|
"""Delete a task comment by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "taskComments",
|
||||||
|
"data": {"id": comment_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class TaskAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "task_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
llm = get_llm()
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(
|
||||||
|
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return await self._tool_loop(llm, messages, self.get_tools())
|
||||||
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
14
app/api/deps.py
Normal file
14
app/api/deps.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Shared FastAPI dependencies.
|
||||||
|
|
||||||
|
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
|
||||||
|
(the canonical location per Step 9). This module re-exports them so that all
|
||||||
|
existing route imports (``from app.api.deps import get_current_user``) continue
|
||||||
|
to work without modification.
|
||||||
|
|
||||||
|
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
|
||||||
|
instead of reading it from the JWT payload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
|
||||||
|
|
||||||
|
__all__ = ["get_current_user", "oauth2_scheme"]
|
||||||
19
app/api/middleware/__init__.py
Normal file
19
app/api/middleware/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""API middleware package.
|
||||||
|
|
||||||
|
Exports the three middleware components introduced in Step 9:
|
||||||
|
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
|
||||||
|
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
|
||||||
|
- Sanitizer: ``SanitizerMiddleware``
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.api.middleware.auth import get_current_user, oauth2_scheme
|
||||||
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
|
||||||
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_current_user",
|
||||||
|
"oauth2_scheme",
|
||||||
|
"TierRateLimitMiddleware",
|
||||||
|
"limiter",
|
||||||
|
"SanitizerMiddleware",
|
||||||
|
]
|
||||||
65
app/api/middleware/auth.py
Normal file
65
app/api/middleware/auth.py
Normal file
@@ -0,0 +1,65 @@
|
|||||||
|
"""Auth middleware — JWT validation dependency.
|
||||||
|
|
||||||
|
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||||
|
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||||
|
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||||
|
without requiring token re-issue.
|
||||||
|
|
||||||
|
Exempt routes (no JWT required):
|
||||||
|
- POST /api/v1/auth/register
|
||||||
|
- POST /api/v1/auth/login
|
||||||
|
- POST /api/v1/billing/webhook
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
|
from app.schemas import UserProfile
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: str = Depends(oauth2_scheme),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Validate a Bearer JWT and return the authenticated user.
|
||||||
|
|
||||||
|
The JWT is used for identity and expiry only. The tier is fetched live
|
||||||
|
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||||
|
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||||
|
|
||||||
|
Raises HTTP 401 on any invalid or expired token.
|
||||||
|
"""
|
||||||
|
credentials_exc = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
if not user_id or not email:
|
||||||
|
raise credentials_exc
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exc
|
||||||
|
|
||||||
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
tier: str = result.scalar_one_or_none() or "free"
|
||||||
|
|
||||||
|
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||||
129
app/api/middleware/rate_limit.py
Normal file
129
app/api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""Tier-aware rate limiting middleware.
|
||||||
|
|
||||||
|
Uses a per-user sliding-window counter (in-process, no Redis required).
|
||||||
|
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
||||||
|
|
||||||
|
Limits (requests per minute):
|
||||||
|
- free: 20
|
||||||
|
- pro: 60
|
||||||
|
- power: 120
|
||||||
|
- team: 200
|
||||||
|
|
||||||
|
Exempt paths bypass the limiter entirely:
|
||||||
|
- POST /api/v1/auth/register
|
||||||
|
- POST /api/v1/auth/login
|
||||||
|
- POST /api/v1/billing/webhook
|
||||||
|
- GET /api/v1/health
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
_TIER_LIMITS: dict[str, int] = {
|
||||||
|
"free": 20,
|
||||||
|
"pro": 60,
|
||||||
|
"power": 120,
|
||||||
|
"team": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
"/api/v1/health",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_id_from_jwt(request: Request) -> str:
|
||||||
|
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
token = auth.removeprefix("Bearer ").strip()
|
||||||
|
if not token:
|
||||||
|
return get_remote_address(request)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
return payload.get("sub") or get_remote_address(request)
|
||||||
|
except JWTError:
|
||||||
|
return get_remote_address(request)
|
||||||
|
|
||||||
|
|
||||||
|
# Exported Limiter instance — available for optional route-level decoration.
|
||||||
|
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
||||||
|
|
||||||
|
|
||||||
|
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
||||||
|
|
||||||
|
Each authenticated user gets their own 60-second window sized by tier.
|
||||||
|
Unauthenticated requests pass through (the auth dependency will reject them
|
||||||
|
with 401 before the route handler runs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
# user_id → list of request timestamps (float, seconds since epoch)
|
||||||
|
self._window: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||||
|
if request.url.path in _EXEMPT_PATHS:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
token = auth.removeprefix("Bearer ").strip()
|
||||||
|
if not token:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
user_id: str = payload.get("sub") or get_remote_address(request)
|
||||||
|
tier: str = payload.get("tier", "free")
|
||||||
|
except JWTError:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
||||||
|
now = time.monotonic()
|
||||||
|
window_start = now - 60.0
|
||||||
|
|
||||||
|
# Slide the window: discard timestamps older than 60 seconds.
|
||||||
|
timestamps = [t for t in self._window[user_id] if t > window_start]
|
||||||
|
|
||||||
|
if len(timestamps) >= limit:
|
||||||
|
retry_after = max(1, int(60 - (now - min(timestamps))))
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"detail": (
|
||||||
|
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
||||||
|
f"Retry in {retry_after}s."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status_code=429,
|
||||||
|
headers={
|
||||||
|
"Retry-After": str(retry_after),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamps.append(now)
|
||||||
|
self._window[user_id] = timestamps
|
||||||
|
return await call_next(request)
|
||||||
139
app/api/middleware/sanitizer.py
Normal file
139
app/api/middleware/sanitizer.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Response sanitizer middleware.
|
||||||
|
|
||||||
|
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
||||||
|
that could reveal server-side prompt IP:
|
||||||
|
- System prompt openers ("You are a/an/the …")
|
||||||
|
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
||||||
|
- LangChain tool schema fragments (``"type": "function"``)
|
||||||
|
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||||
|
- Exact-match known prompt fingerprints
|
||||||
|
|
||||||
|
Binary responses (storage blobs, backup data) are never touched — the
|
||||||
|
middleware only activates for paths under /api/v1/chat.
|
||||||
|
|
||||||
|
Any sanitisation event is logged as a WARNING with the request path and the
|
||||||
|
names of the fields that were modified.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Detection patterns — order matters: fingerprints checked first (exact),
|
||||||
|
# then compiled regexes.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_FINGERPRINTS: tuple[str, ...] = (
|
||||||
|
"You are an intent classifier",
|
||||||
|
"Respond with just the agent name",
|
||||||
|
"Summarize these agent results",
|
||||||
|
"Available agents:",
|
||||||
|
"route to:",
|
||||||
|
)
|
||||||
|
|
||||||
|
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||||||
|
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
||||||
|
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
||||||
|
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
||||||
|
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
||||||
|
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
||||||
|
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
||||||
|
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
||||||
|
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_text(text: str) -> tuple[str, bool]:
|
||||||
|
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
||||||
|
|
||||||
|
Returns ``(cleaned_text, was_changed)``.
|
||||||
|
"""
|
||||||
|
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
||||||
|
for fp in _FINGERPRINTS:
|
||||||
|
if fp in text:
|
||||||
|
return "[REDACTED]", True
|
||||||
|
|
||||||
|
changed = False
|
||||||
|
for pattern in _PATTERNS:
|
||||||
|
new_text, n = pattern.subn("[REDACTED]", text)
|
||||||
|
if n:
|
||||||
|
text = new_text
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
return text, changed
|
||||||
|
|
||||||
|
|
||||||
|
class SanitizerMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||||
|
response: Response = await call_next(request)
|
||||||
|
|
||||||
|
# Only process chat endpoint responses.
|
||||||
|
if not request.url.path.startswith("/api/v1/chat"):
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Read body — collect streaming chunks.
|
||||||
|
body_bytes = b""
|
||||||
|
async for chunk in response.body_iterator:
|
||||||
|
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||||
|
|
||||||
|
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
||||||
|
try:
|
||||||
|
body = json.loads(body_bytes.decode("utf-8"))
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
|
return Response(
|
||||||
|
content=body_bytes,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
media_type=response.media_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(body, dict):
|
||||||
|
return Response(
|
||||||
|
content=body_bytes,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
media_type=response.media_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Walk top-level string fields and sanitise.
|
||||||
|
sanitised_fields: list[str] = []
|
||||||
|
for key, value in body.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
cleaned, changed = _sanitize_text(value)
|
||||||
|
if changed:
|
||||||
|
body[key] = cleaned
|
||||||
|
sanitised_fields.append(key)
|
||||||
|
|
||||||
|
if sanitised_fields:
|
||||||
|
logger.warning(
|
||||||
|
"Sanitizer redacted prompt fragments",
|
||||||
|
extra={
|
||||||
|
"path": request.url.path,
|
||||||
|
"fields": sanitised_fields,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_body = json.dumps(body).encode("utf-8")
|
||||||
|
headers = dict(response.headers)
|
||||||
|
headers["content-length"] = str(len(new_body))
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=new_body,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=headers,
|
||||||
|
media_type="application/json",
|
||||||
|
)
|
||||||
0
app/api/routes/__init__.py
Normal file
0
app/api/routes/__init__.py
Normal file
197
app/api/routes/auth.py
Normal file
197
app/api/routes/auth.py
Normal file
@@ -0,0 +1,197 @@
|
|||||||
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||||
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||||
|
SHA-256 hashes so plaintext never reaches the DB.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from jose import jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import RefreshToken, User
|
||||||
|
from app.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_password(password: str) -> str:
|
||||||
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_password(password: str, hashed: str) -> bool:
|
||||||
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_token(plain_token: str) -> str:
|
||||||
|
"""SHA-256 of the plain refresh token string."""
|
||||||
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||||
|
"""Return (signed JWT, expires_at_ms)."""
|
||||||
|
now = int(time.time())
|
||||||
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"email": email,
|
||||||
|
"tier": tier,
|
||||||
|
"exp": exp,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
|
return token, exp * 1000 # ms for client
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _RegisterRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class _LoginRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class _RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(
|
||||||
|
body: _RegisterRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Create a new account and return JWT tokens."""
|
||||||
|
existing = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
email=body.email,
|
||||||
|
password_hash=_hash_password(body.password),
|
||||||
|
tier="free",
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush() # get user.id without committing
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=AuthTokens)
|
||||||
|
async def login(
|
||||||
|
body: _LoginRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Validate credentials and return JWT tokens."""
|
||||||
|
result = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not _verify_password(body.password, user.password_hash):
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=AuthTokens)
|
||||||
|
async def refresh(
|
||||||
|
body: _RefreshRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Rotate a refresh token and return a new token pair."""
|
||||||
|
token_hash = _hash_token(body.refresh_token)
|
||||||
|
result = await db.execute(
|
||||||
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
rt = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||||
|
|
||||||
|
# Rotate: delete old token, issue new one.
|
||||||
|
await db.delete(rt)
|
||||||
|
|
||||||
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||||
|
user = user_result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
new_rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=new_expires,
|
||||||
|
)
|
||||||
|
db.add(new_rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserProfile)
|
||||||
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||||
|
"""Return the profile for the authenticated user."""
|
||||||
|
return current_user
|
||||||
171
app/api/routes/backup.py
Normal file
171
app/api/routes/backup.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
|
||||||
|
PostgreSQL ``backup_metadata`` table.
|
||||||
|
|
||||||
|
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
||||||
|
treating "history" as a ``{backup_id}`` path parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from email.utils import parsedate_to_datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import BackupMetadata as BackupMetadataModel
|
||||||
|
from app.schemas import BackupMetadata, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/backup", tags=["backup"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
|
||||||
|
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return total backup bytes stored by *user_id*."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
|
||||||
|
BackupMetadataModel.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_backup_quota(
|
||||||
|
user: UserProfile, size_bytes: int, db: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||||
|
current = await _current_backup_bytes(user.id, db)
|
||||||
|
tier_manager.enforce_backup_quota(
|
||||||
|
user.tier, current_bytes=current, additional_bytes=size_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("")
|
||||||
|
async def upload_backup(
|
||||||
|
request: Request,
|
||||||
|
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
||||||
|
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
||||||
|
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Upload an E2E-encrypted backup blob.
|
||||||
|
|
||||||
|
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
||||||
|
"""
|
||||||
|
blob = await request.body()
|
||||||
|
reject_if_tampered(blob, x_backup_checksum)
|
||||||
|
await _check_backup_quota(current_user, len(blob), db)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
row = BackupMetadataModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=current_user.id,
|
||||||
|
s3_key=s3_key,
|
||||||
|
version=x_backup_version,
|
||||||
|
timestamp=x_backup_timestamp,
|
||||||
|
checksum=x_backup_checksum,
|
||||||
|
size_bytes=len(blob),
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history", response_model=list[BackupMetadata])
|
||||||
|
async def backup_history(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[BackupMetadata]:
|
||||||
|
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
return [
|
||||||
|
BackupMetadata(
|
||||||
|
version=r.version,
|
||||||
|
timestamp=r.timestamp,
|
||||||
|
checksum=r.checksum,
|
||||||
|
chunk_count=1,
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def download_backup(
|
||||||
|
request: Request,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> Response:
|
||||||
|
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
latest = result.scalar_one_or_none()
|
||||||
|
if latest is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
||||||
|
|
||||||
|
ims_header = request.headers.get("If-Modified-Since")
|
||||||
|
if ims_header:
|
||||||
|
try:
|
||||||
|
ims_dt = parsedate_to_datetime(ims_header)
|
||||||
|
ims_ms = int(ims_dt.timestamp() * 1000)
|
||||||
|
if latest.timestamp <= ims_ms:
|
||||||
|
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
||||||
|
except Exception:
|
||||||
|
pass # malformed header — ignore and serve the blob
|
||||||
|
|
||||||
|
blob = await _blob_store.download(current_user.id, latest.s3_key)
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={
|
||||||
|
"X-Backup-Version": str(latest.version),
|
||||||
|
"X-Backup-Timestamp": str(latest.timestamp),
|
||||||
|
"X-Checksum": latest.checksum,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{backup_id}", response_model=dict)
|
||||||
|
async def delete_backup(
|
||||||
|
backup_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a specific backup by ID."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel).where(
|
||||||
|
BackupMetadataModel.id == backup_id,
|
||||||
|
BackupMetadataModel.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
target = result.scalar_one_or_none()
|
||||||
|
if target is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
||||||
|
|
||||||
|
await _blob_store.delete(current_user.id, target.s3_key)
|
||||||
|
await db.delete(target)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
85
app/api/routes/billing.py
Normal file
85
app/api/routes/billing.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||||
|
|
||||||
|
Business logic lives in ``app.billing.stripe_service.StripeService``.
|
||||||
|
The route layer handles HTTP concerns (request parsing, response shaping)
|
||||||
|
and delegates everything else to the service singleton.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.stripe_service import stripe_service
|
||||||
|
from app.db import get_session
|
||||||
|
from app.schemas import BillingTier, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CheckoutRequest(BaseModel):
|
||||||
|
tier: BillingTier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/checkout", response_model=dict)
|
||||||
|
async def create_checkout(
|
||||||
|
body: _CheckoutRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Create a Stripe checkout session for a tier upgrade.
|
||||||
|
|
||||||
|
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||||
|
"""
|
||||||
|
url = stripe_service.create_checkout_session(current_user.id, body.tier)
|
||||||
|
return {"checkout_url": url}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/webhook", response_model=dict)
|
||||||
|
async def stripe_webhook(
|
||||||
|
request: Request,
|
||||||
|
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Handle Stripe webhook events.
|
||||||
|
|
||||||
|
No JWT auth — authenticated via Stripe signature verification instead.
|
||||||
|
Returns 200 immediately when Stripe is not configured (local dev).
|
||||||
|
"""
|
||||||
|
payload = await request.body()
|
||||||
|
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/subscription", response_model=dict)
|
||||||
|
async def get_subscription(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return the current subscription info for the authenticated user."""
|
||||||
|
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||||
|
if sub is None:
|
||||||
|
return {
|
||||||
|
"tier": current_user.tier,
|
||||||
|
"status": "free",
|
||||||
|
"stripe_subscription_id": None,
|
||||||
|
"current_period_end": None,
|
||||||
|
}
|
||||||
|
return sub
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||||
|
async def cancel_subscription(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Cancel the active subscription."""
|
||||||
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
|
return {"ok": True}
|
||||||
78
app/api/routes/chat.py
Normal file
78
app/api/routes/chat.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""Chat routes: POST /chat and WebSocket /chat/stream."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.orchestrator import orchestrate, orchestrate_stream
|
||||||
|
from app.schemas import ChatRequest, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def chat(
|
||||||
|
body: ChatRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Route a chat message through the orchestrator.
|
||||||
|
|
||||||
|
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
||||||
|
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
||||||
|
"""
|
||||||
|
result = await orchestrate(body)
|
||||||
|
return JSONResponse(content=result.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/stream")
|
||||||
|
async def chat_stream(websocket: WebSocket) -> None:
|
||||||
|
"""Streaming chat via WebSocket.
|
||||||
|
|
||||||
|
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
1. Client sends ``ChatRequest`` as the first JSON text frame.
|
||||||
|
2. Server streams response text chunks.
|
||||||
|
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
|
||||||
|
4. Server pings every 30 s to keep the connection alive.
|
||||||
|
"""
|
||||||
|
# Authenticate before accepting the connection
|
||||||
|
token = websocket.query_params.get("token", "")
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
if not user_id:
|
||||||
|
raise JWTError("missing sub")
|
||||||
|
except JWTError:
|
||||||
|
await websocket.close(code=1008) # 1008 = Policy Violation
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = await websocket.receive_text()
|
||||||
|
body = ChatRequest.model_validate_json(raw)
|
||||||
|
|
||||||
|
async def _heartbeat() -> None:
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||||
|
await websocket.send_text(json.dumps({"ping": True}))
|
||||||
|
|
||||||
|
heartbeat_task = asyncio.create_task(_heartbeat())
|
||||||
|
try:
|
||||||
|
async for chunk in orchestrate_stream(body):
|
||||||
|
await websocket.send_text(chunk)
|
||||||
|
finally:
|
||||||
|
heartbeat_task.cancel()
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
37
app/api/routes/plans.py
Normal file
37
app/api/routes/plans.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.execution_plan import plan_cache
|
||||||
|
from app.schemas import ExecutionPlan, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/plans", tags=["plans"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/playbook", response_model=list[ExecutionPlan])
|
||||||
|
async def list_playbooks(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> list[ExecutionPlan]:
|
||||||
|
"""Return all cached execution plan playbooks for the authenticated user.
|
||||||
|
|
||||||
|
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
||||||
|
"""
|
||||||
|
return plan_cache.get_all_playbooks()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
||||||
|
async def get_playbook(
|
||||||
|
plan_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> ExecutionPlan:
|
||||||
|
"""Return a specific execution plan playbook by ID."""
|
||||||
|
plan = plan_cache.get_plan(plan_id)
|
||||||
|
if plan is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Plan not found: {plan_id}",
|
||||||
|
)
|
||||||
|
return plan
|
||||||
148
app/api/routes/plugins.py
Normal file
148
app/api/routes/plugins.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""Plugins routes: browse and install plugins from the marketplace.
|
||||||
|
|
||||||
|
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
|
||||||
|
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.db import get_session
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
from app.models import PluginInstallation, PluginReview as PluginReviewModel
|
||||||
|
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier gate ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _require_plugin_tier(user: UserProfile) -> None:
|
||||||
|
"""Raise HTTP 403 for users below Power tier."""
|
||||||
|
if user.tier not in ("power", "team"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Plugin marketplace requires Power tier or above",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local detail schema ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _PluginDetail(BaseModel):
|
||||||
|
plugin: PluginManifest
|
||||||
|
install_count: int
|
||||||
|
ratings: list[Any]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("", response_model=PluginListResponse)
|
||||||
|
async def list_plugins(
|
||||||
|
category: str | None = Query(default=None),
|
||||||
|
q: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||||
|
async def get_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _PluginDetail:
|
||||||
|
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Fetch review ratings for this plugin
|
||||||
|
review_result = await db.execute(
|
||||||
|
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
|
||||||
|
)
|
||||||
|
reviews = review_result.scalars().all()
|
||||||
|
ratings = [
|
||||||
|
{
|
||||||
|
"reviewer_id": r.reviewer_id,
|
||||||
|
"decision": r.decision,
|
||||||
|
"notes": r.notes,
|
||||||
|
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
|
||||||
|
}
|
||||||
|
for r in reviews
|
||||||
|
]
|
||||||
|
|
||||||
|
return _PluginDetail(
|
||||||
|
plugin=entry["manifest"],
|
||||||
|
install_count=entry["install_count"],
|
||||||
|
ratings=ratings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def install_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
||||||
|
|
||||||
|
Requires Power tier or above.
|
||||||
|
"""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Record the installation in plugin_installations
|
||||||
|
installation = PluginInstallation(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
)
|
||||||
|
db.add(installation)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
await revenue_share.record_install(
|
||||||
|
db,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
amount_cents=entry["manifest"].price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
||||||
|
return {"ok": True, "download_url": download_url}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def uninstall_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Unregister a plugin installation."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(PluginInstallation).where(
|
||||||
|
PluginInstallation.plugin_id == plugin_id,
|
||||||
|
PluginInstallation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
installation = result.scalar_one_or_none()
|
||||||
|
if installation is not None:
|
||||||
|
await db.delete(installation)
|
||||||
|
await db.commit()
|
||||||
|
await registry.record_uninstall(db, plugin_id)
|
||||||
|
return {"ok": True}
|
||||||
195
app/api/routes/storage.py
Normal file
195
app/api/routes/storage.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
|
||||||
|
PostgreSQL ``storage_records`` table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import StorageRecord
|
||||||
|
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["storage"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local response schemas ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CreateResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordMeta(BaseModel):
|
||||||
|
id: str
|
||||||
|
table: str
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return total bytes stored by *user_id*."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
|
||||||
|
StorageRecord.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
|
||||||
|
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
|
||||||
|
current = await _current_usage_bytes(user.id, db)
|
||||||
|
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_record_for_user(
|
||||||
|
record_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> StorageRecord:
|
||||||
|
"""Look up a record and verify ownership. Returns 404 on mismatch
|
||||||
|
to prevent user enumeration attacks."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(StorageRecord).where(
|
||||||
|
StorageRecord.id == record_id, StorageRecord.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_record(
|
||||||
|
body: StorageRecordCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _CreateResponse:
|
||||||
|
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
await _check_quota(current_user, len(body.blob), db)
|
||||||
|
|
||||||
|
record_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, body.table, record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
record = StorageRecord(
|
||||||
|
id=record_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
table_name=body.table,
|
||||||
|
s3_key=s3_key,
|
||||||
|
checksum=body.checksum,
|
||||||
|
size_bytes=len(body.blob),
|
||||||
|
)
|
||||||
|
db.add(record)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(record)
|
||||||
|
|
||||||
|
created_at_ms = int(record.created_at.timestamp() * 1000)
|
||||||
|
return _CreateResponse(id=record_id, created_at=created_at_ms)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records", response_model=list[_RecordMeta])
|
||||||
|
async def list_records(
|
||||||
|
table: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[_RecordMeta]:
|
||||||
|
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
||||||
|
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
|
||||||
|
if table is not None:
|
||||||
|
query = query.where(StorageRecord.table_name == table)
|
||||||
|
query = query.offset((page - 1) * limit).limit(limit)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
_RecordMeta(
|
||||||
|
id=r.id,
|
||||||
|
table=r.table_name,
|
||||||
|
checksum=r.checksum,
|
||||||
|
created_at=int(r.created_at.timestamp() * 1000),
|
||||||
|
updated_at=int(r.updated_at.timestamp() * 1000),
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records/{record_id}")
|
||||||
|
async def download_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> Response:
|
||||||
|
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
blob = await _blob_store.download(current_user.id, record.s3_key)
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"X-Checksum": record.checksum},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/records/{record_id}", response_model=dict)
|
||||||
|
async def update_record(
|
||||||
|
record_id: str,
|
||||||
|
body: StorageRecordUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
|
||||||
|
delta = len(body.blob) - record.size_bytes
|
||||||
|
if delta > 0:
|
||||||
|
await _check_quota(current_user, delta, db)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, record.table_name, record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
record.s3_key = s3_key
|
||||||
|
record.checksum = body.checksum
|
||||||
|
record.size_bytes = len(body.blob)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/records/{record_id}", response_model=dict)
|
||||||
|
async def delete_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a record and its S3 blob."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
await _blob_store.delete(current_user.id, record.s3_key)
|
||||||
|
await db.delete(record)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
56
app/api/routes/vectors.py
Normal file
56
app/api/routes/vectors.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Vectors routes: upsert, search, and delete cloud vector store entries."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.schemas import (
|
||||||
|
UserProfile,
|
||||||
|
VectorSearchRequest,
|
||||||
|
VectorSearchResponse,
|
||||||
|
VectorUpsertRequest,
|
||||||
|
)
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
from app.storage.vector_store import VectorStore
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["vectors"])
|
||||||
|
|
||||||
|
_vector_store = VectorStore()
|
||||||
|
|
||||||
|
|
||||||
|
class _VectorDeleteRequest(BaseModel):
|
||||||
|
ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/upsert", response_model=dict)
|
||||||
|
async def upsert_vectors(
|
||||||
|
body: VectorUpsertRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
||||||
|
for item in body.vectors:
|
||||||
|
reject_if_tampered(item.blob, item.checksum)
|
||||||
|
await _vector_store.upsert(current_user.id, body.vectors)
|
||||||
|
return {"upserted": len(body.vectors)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
||||||
|
async def search_vectors(
|
||||||
|
body: VectorSearchRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> VectorSearchResponse:
|
||||||
|
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
||||||
|
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
||||||
|
return VectorSearchResponse(results=results)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/vectors", response_model=dict)
|
||||||
|
async def delete_vectors(
|
||||||
|
body: _VectorDeleteRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete vectors by ID, scoped to the authenticated user."""
|
||||||
|
await _vector_store.delete(current_user.id, body.ids)
|
||||||
|
return {"ok": True}
|
||||||
4
app/billing/__init__.py
Normal file
4
app/billing/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.billing.stripe_service import stripe_service
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
|
||||||
|
__all__ = ["stripe_service", "tier_manager"]
|
||||||
256
app/billing/stripe_service.py
Normal file
256
app/billing/stripe_service.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||||
|
|
||||||
|
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||||
|
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||||
|
configured, enabling local development without live credentials.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||||
|
TIER_PRICE_IDS: dict[str, str] = {
|
||||||
|
"pro": "price_pro_monthly",
|
||||||
|
"power": "price_power_monthly",
|
||||||
|
"team": "price_team_monthly",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StripeService:
|
||||||
|
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||||
|
|
||||||
|
# ── Internal helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _configured(self) -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Public API ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_checkout_session(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
tier: str,
|
||||||
|
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
|
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||||
|
) -> str:
|
||||||
|
"""Create a Stripe checkout session and return the URL.
|
||||||
|
|
||||||
|
Returns a stub URL when Stripe is not configured.
|
||||||
|
Raises ``HTTP 400`` for the free tier or an unknown tier.
|
||||||
|
"""
|
||||||
|
if tier == "free":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Cannot create a checkout session for the free tier",
|
||||||
|
)
|
||||||
|
|
||||||
|
price_id = TIER_PRICE_IDS.get(tier)
|
||||||
|
if not price_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unknown tier: {tier}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._configured():
|
||||||
|
return "https://stripe.com/stub-checkout"
|
||||||
|
|
||||||
|
s = self._client()
|
||||||
|
session = s.checkout.Session.create(
|
||||||
|
payment_method_types=["card"],
|
||||||
|
mode="subscription",
|
||||||
|
line_items=[{"price": price_id, "quantity": 1}],
|
||||||
|
success_url=success_url,
|
||||||
|
cancel_url=cancel_url,
|
||||||
|
metadata={"user_id": user_id, "tier": tier},
|
||||||
|
)
|
||||||
|
return session.url
|
||||||
|
|
||||||
|
async def handle_webhook(
|
||||||
|
self,
|
||||||
|
payload: bytes,
|
||||||
|
sig_header: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Process a Stripe webhook event.
|
||||||
|
|
||||||
|
Verifies the signature, then dispatches on event type.
|
||||||
|
Raises ``HTTP 400`` on signature mismatch.
|
||||||
|
No-ops when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
if not self._configured():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = self._client()
|
||||||
|
event = s.Webhook.construct_event(
|
||||||
|
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||||
|
)
|
||||||
|
except stripe_lib.error.SignatureVerificationError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid Stripe signature",
|
||||||
|
)
|
||||||
|
|
||||||
|
event_type: str = event["type"]
|
||||||
|
data: dict[str, Any] = event["data"]["object"]
|
||||||
|
|
||||||
|
if event_type == "checkout.session.completed":
|
||||||
|
user_id = data.get("metadata", {}).get("user_id")
|
||||||
|
tier = data.get("metadata", {}).get("tier", "free")
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
await self._upsert_subscription(
|
||||||
|
db, user_id, sub_id, tier, "active", period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.updated":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
new_status = data.get("status", "active")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status=new_status, current_period_end=period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.deleted":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, tier="free", status="canceled"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "invoice.payment_failed":
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status="past_due"
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def get_subscription(
|
||||||
|
self, user_id: str, db: AsyncSession
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"tier": sub.tier,
|
||||||
|
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||||
|
"status": sub.status,
|
||||||
|
"current_period_end": (
|
||||||
|
int(sub.current_period_end.timestamp() * 1000)
|
||||||
|
if sub.current_period_end
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||||
|
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||||
|
|
||||||
|
Raises ``HTTP 404`` when no active subscription exists.
|
||||||
|
"""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None or not sub.stripe_subscription_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="No active subscription found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._configured():
|
||||||
|
s = self._client()
|
||||||
|
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||||
|
|
||||||
|
sub.tier = "free"
|
||||||
|
sub.status = "canceled"
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _upsert_subscription(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
stripe_subscription_id: str | None,
|
||||||
|
tier: str,
|
||||||
|
sub_status: str,
|
||||||
|
current_period_end: datetime | None,
|
||||||
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
sub = Subscription(user_id=user_id)
|
||||||
|
db.add(sub)
|
||||||
|
sub.stripe_subscription_id = stripe_subscription_id
|
||||||
|
sub.tier = tier
|
||||||
|
sub.status = sub_status
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
async def _update_subscription_by_stripe_id(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
stripe_subscription_id: str,
|
||||||
|
*,
|
||||||
|
tier: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
current_period_end: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(
|
||||||
|
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
if tier is not None:
|
||||||
|
sub.tier = tier
|
||||||
|
if status is not None:
|
||||||
|
sub.status = status
|
||||||
|
if current_period_end is not None:
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton shared across the app.
|
||||||
|
stripe_service = StripeService()
|
||||||
189
app/billing/tier_manager.py
Normal file
189
app/billing/tier_manager.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""Tier manager: feature matrix and quota enforcement.
|
||||||
|
|
||||||
|
``TierManager`` is the single source of truth for what each billing tier
|
||||||
|
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||||
|
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||||
|
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.schemas import BillingTier
|
||||||
|
|
||||||
|
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||||
|
FEATURES: dict[str, dict[str, Any]] = {
|
||||||
|
"free": {
|
||||||
|
"agents": 3,
|
||||||
|
"batch_active": 2,
|
||||||
|
"cloud_storage_gb": 0,
|
||||||
|
"backup_gb": 0,
|
||||||
|
"providers": 1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"pro": {
|
||||||
|
"agents": -1, # unlimited
|
||||||
|
"batch_active": 10,
|
||||||
|
"cloud_storage_gb": 5,
|
||||||
|
"backup_gb": 5,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"power": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1, # unlimited
|
||||||
|
"cloud_storage_gb": 25,
|
||||||
|
"backup_gb": 25,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"team": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1,
|
||||||
|
"cloud_storage_gb": -1, # unlimited
|
||||||
|
"backup_gb": -1, # unlimited
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Requests-per-minute limit per tier.
|
||||||
|
RATE_LIMITS: dict[str, int] = {
|
||||||
|
"free": 20,
|
||||||
|
"pro": 60,
|
||||||
|
"power": 120,
|
||||||
|
"team": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TierManager:
|
||||||
|
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||||
|
|
||||||
|
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
|
Falls back to ``'free'`` when no subscription row exists.
|
||||||
|
"""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
tier: str | None = result.scalar_one_or_none()
|
||||||
|
if tier is None or tier not in FEATURES:
|
||||||
|
return "free"
|
||||||
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||||
|
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||||
|
|
||||||
|
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||||
|
"""
|
||||||
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
return value != 0
|
||||||
|
|
||||||
|
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||||
|
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||||
|
if not self.check_feature(tier, feature):
|
||||||
|
detail = (
|
||||||
|
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||||
|
if tier_name
|
||||||
|
else f"Feature '{feature}' is not available on your current tier."
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
# ── Rate limiting ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||||
|
"""Return the requests-per-minute limit for ``tier``."""
|
||||||
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
|
# ── Storage quota ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def enforce_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||||
|
|
||||||
|
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||||
|
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||||
|
"""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def enforce_backup_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||||
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
return False
|
||||||
|
if limit_gb == -1:
|
||||||
|
return True
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
return current_bytes + additional_bytes <= limit_bytes
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton shared across the app.
|
||||||
|
tier_manager = TierManager()
|
||||||
0
app/config/__init__.py
Normal file
0
app/config/__init__.py
Normal file
42
app/config/settings.py
Normal file
42
app/config/settings.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
from typing import Literal
|
||||||
|
from pydantic_settings import BaseSettings
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
||||||
|
JWT_SECRET: str = "change-me-in-production"
|
||||||
|
JWT_ALGORITHM: str = "HS256"
|
||||||
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||||
|
|
||||||
|
STRIPE_SECRET_KEY: str = ""
|
||||||
|
STRIPE_WEBHOOK_SECRET: str = ""
|
||||||
|
|
||||||
|
S3_BUCKET: str = ""
|
||||||
|
S3_REGION: str = "us-east-1"
|
||||||
|
S3_ENDPOINT_URL: str = ""
|
||||||
|
AWS_ACCESS_KEY_ID: str = ""
|
||||||
|
AWS_SECRET_ACCESS_KEY: str = ""
|
||||||
|
|
||||||
|
PINECONE_API_KEY: str = ""
|
||||||
|
PINECONE_INDEX: str = "adiuva"
|
||||||
|
QDRANT_URL: str = ""
|
||||||
|
QDRANT_API_KEY: str = ""
|
||||||
|
|
||||||
|
OPENAI_API_KEY: str = ""
|
||||||
|
ANTHROPIC_API_KEY: str = ""
|
||||||
|
GOOGLE_API_KEY: str = ""
|
||||||
|
|
||||||
|
LLM_MODEL: str = "gpt-4o"
|
||||||
|
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
||||||
|
|
||||||
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
|
|
||||||
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
env_file = ".env"
|
||||||
|
env_file_encoding = "utf-8"
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
137
app/core/agent_registry.py
Normal file
137
app/core/agent_registry.py
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
"""Agent Registry — base classes and singleton registry for chat agents."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
|
class BaseAgent(ABC):
|
||||||
|
"""Common base for all agents."""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
user_id: str = "",
|
||||||
|
shared_memory: dict[str, Any] | None = None,
|
||||||
|
vector_store_context: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
|
self.user_id = user_id
|
||||||
|
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||||
|
self.vector_store_context: list[str] = vector_store_context or []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_name(self) -> str: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_description(self) -> str: ...
|
||||||
|
|
||||||
|
@property
|
||||||
|
def skills(self) -> list[str]:
|
||||||
|
"""Override in subclasses to advertise capabilities."""
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class ChatAgent(BaseAgent):
|
||||||
|
"""Base class for LLM-powered chat agents."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
"""Process a user query and return a text response."""
|
||||||
|
...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
"""Return LangChain tool definitions available to this agent."""
|
||||||
|
...
|
||||||
|
|
||||||
|
async def _tool_loop(
|
||||||
|
self,
|
||||||
|
llm: Any,
|
||||||
|
messages: list[Any],
|
||||||
|
tools: list[Any],
|
||||||
|
max_iter: int = 5,
|
||||||
|
) -> str:
|
||||||
|
"""Shared tool-calling loop.
|
||||||
|
|
||||||
|
Binds *tools* to *llm*, invokes iteratively until the model stops
|
||||||
|
requesting tool calls or *max_iter* is reached, and returns the
|
||||||
|
final text response.
|
||||||
|
"""
|
||||||
|
from langchain_core.messages import AIMessage, ToolMessage
|
||||||
|
|
||||||
|
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||||
|
|
||||||
|
for _ in range(max_iter):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return str(response.content)
|
||||||
|
|
||||||
|
# Execute each requested tool call
|
||||||
|
tool_map = {t.name: t for t in tools}
|
||||||
|
for call in response.tool_calls:
|
||||||
|
tool_fn = tool_map.get(call["name"])
|
||||||
|
if tool_fn is None:
|
||||||
|
result = f"Unknown tool: {call['name']}"
|
||||||
|
else:
|
||||||
|
result = await tool_fn.ainvoke(call["args"])
|
||||||
|
messages.append(
|
||||||
|
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Exhausted iterations — ask model for a final answer without tools
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
return str(response.content)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRegistry:
|
||||||
|
"""Singleton registry for ChatAgent subclasses."""
|
||||||
|
|
||||||
|
_instance: AgentRegistry | None = None
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._agents: dict[str, type[ChatAgent]] = {}
|
||||||
|
|
||||||
|
def __new__(cls) -> AgentRegistry:
|
||||||
|
if cls._instance is None:
|
||||||
|
cls._instance = super().__new__(cls)
|
||||||
|
cls._instance._agents = {}
|
||||||
|
return cls._instance
|
||||||
|
|
||||||
|
# ── public API ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
||||||
|
"""Class decorator — registers an agent by its name."""
|
||||||
|
instance = agent_class()
|
||||||
|
name = instance.get_name()
|
||||||
|
self._agents[name] = agent_class
|
||||||
|
return agent_class
|
||||||
|
|
||||||
|
def get(self, name: str) -> ChatAgent:
|
||||||
|
"""Return a fresh instance of the named agent."""
|
||||||
|
cls = self._agents.get(name)
|
||||||
|
if cls is None:
|
||||||
|
raise KeyError(f"Agent not found: {name}")
|
||||||
|
return cls()
|
||||||
|
|
||||||
|
def list_agents(self) -> list[dict[str, str]]:
|
||||||
|
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
||||||
|
result: list[dict[str, str]] = []
|
||||||
|
for cls in self._agents.values():
|
||||||
|
inst = cls()
|
||||||
|
result.append(
|
||||||
|
{"name": inst.get_name(), "description": inst.get_description()}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
async def call_agent(
|
||||||
|
self, name: str, query: str, context: dict[str, Any]
|
||||||
|
) -> str:
|
||||||
|
"""Instantiate the named agent and call its ``handle`` method."""
|
||||||
|
agent = self.get(name)
|
||||||
|
return await agent.handle(query, context)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
registry = AgentRegistry()
|
||||||
222
app/core/execution_plan.py
Normal file
222
app/core/execution_plan.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.schemas import ExecutionPlan, PlanStep
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt Template Registry ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateRegistry:
|
||||||
|
"""Server-side store mapping template IDs to prompt text.
|
||||||
|
|
||||||
|
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
||||||
|
The actual prompt text is resolved here on the server, keeping prompt IP
|
||||||
|
out of API responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._templates: dict[str, str] = {}
|
||||||
|
|
||||||
|
def register(self, template_id: str, prompt_text: str) -> None:
|
||||||
|
self._templates[template_id] = prompt_text
|
||||||
|
|
||||||
|
def get(self, template_id: str) -> str:
|
||||||
|
"""Resolve a template ID to its prompt text.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the template is not registered.
|
||||||
|
"""
|
||||||
|
text = self._templates.get(template_id)
|
||||||
|
if text is None:
|
||||||
|
raise KeyError(f"Template not found: {template_id!r}")
|
||||||
|
return text
|
||||||
|
|
||||||
|
def has(self, template_id: str) -> bool:
|
||||||
|
return template_id in self._templates
|
||||||
|
|
||||||
|
def list_ids(self) -> list[str]:
|
||||||
|
"""Return all registered template IDs (never the text)."""
|
||||||
|
return list(self._templates.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ── Execution Plan Builder ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionPlanBuilder:
|
||||||
|
"""Fluent builder for ``ExecutionPlan`` objects.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
||||||
|
.add_data_step("create_record", data_from_step=0)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, agent: str) -> None:
|
||||||
|
self._agent = agent
|
||||||
|
self._steps: list[PlanStep] = []
|
||||||
|
|
||||||
|
# ── step adders ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def add_step(
|
||||||
|
self, action: str, params: dict[str, Any] | None = None
|
||||||
|
) -> ExecutionPlanBuilder:
|
||||||
|
"""Append a generic action step with optional parameters."""
|
||||||
|
self._steps.append(PlanStep(action=action, variables=params))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_llm_step(
|
||||||
|
self, template_id: str, variables: dict[str, Any] | None = None
|
||||||
|
) -> ExecutionPlanBuilder:
|
||||||
|
"""Append an LLM step referencing a server-side template by ID."""
|
||||||
|
self._steps.append(
|
||||||
|
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
||||||
|
"""Append a step whose input comes from the output of an earlier step."""
|
||||||
|
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
||||||
|
return self
|
||||||
|
|
||||||
|
# ── build ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def build(self) -> ExecutionPlan:
|
||||||
|
"""Validate step references and return the ``ExecutionPlan``.
|
||||||
|
|
||||||
|
Raises ``ValueError`` if any ``data_from_step`` references a
|
||||||
|
non-existent or future step index.
|
||||||
|
"""
|
||||||
|
for i, step in enumerate(self._steps):
|
||||||
|
if step.data_from_step is not None:
|
||||||
|
if not (0 <= step.data_from_step < i):
|
||||||
|
raise ValueError(
|
||||||
|
f"Step {i}: data_from_step={step.data_from_step} must "
|
||||||
|
f"reference a preceding step index in range 0..{i - 1}"
|
||||||
|
)
|
||||||
|
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class PlanCache:
|
||||||
|
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
||||||
|
|
||||||
|
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
||||||
|
The cache also serves as a runtime memoisation layer so that repeated
|
||||||
|
identical intent classifications can skip re-building the plan.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, maxsize: int = 1000) -> None:
|
||||||
|
self._maxsize = maxsize
|
||||||
|
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
||||||
|
|
||||||
|
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
||||||
|
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
||||||
|
if key in self._cache:
|
||||||
|
del self._cache[key] # remove so re-insertion places it at the end
|
||||||
|
elif len(self._cache) >= self._maxsize:
|
||||||
|
self._cache.popitem(last=False) # evict least-recently-used
|
||||||
|
self._cache[key] = plan
|
||||||
|
|
||||||
|
def get_plan(self, key: str) -> ExecutionPlan | None:
|
||||||
|
"""Return the cached plan for *key*, or ``None`` if not present.
|
||||||
|
|
||||||
|
Accessing a plan marks it as most-recently used.
|
||||||
|
"""
|
||||||
|
if key not in self._cache:
|
||||||
|
return None
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
return self._cache[key]
|
||||||
|
|
||||||
|
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
||||||
|
"""Return all cached plans (most-recently used last)."""
|
||||||
|
return list(self._cache.values())
|
||||||
|
|
||||||
|
|
||||||
|
# ── Module-level singletons ───────────────────────────────────────────
|
||||||
|
|
||||||
|
template_registry = PromptTemplateRegistry()
|
||||||
|
plan_cache = PlanCache()
|
||||||
|
|
||||||
|
|
||||||
|
def _register_builtin_templates() -> None:
|
||||||
|
"""Register the built-in server-side prompt templates.
|
||||||
|
|
||||||
|
These strings never leave the server. Clients only receive the IDs.
|
||||||
|
"""
|
||||||
|
_tpls: dict[str, str] = {
|
||||||
|
"tpl_task_agent_default": (
|
||||||
|
"You are a task management assistant. Help the user create, update, "
|
||||||
|
"list, and track tasks. Use correct status values (todo, in_progress, "
|
||||||
|
"done) and priority values (high, medium, low) from the workspace model."
|
||||||
|
),
|
||||||
|
"tpl_checkpoint_agent_default": (
|
||||||
|
"You are a project checkpoint assistant. Help the user create and manage "
|
||||||
|
"milestone checkpoints on their projects. Every checkpoint requires a "
|
||||||
|
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
||||||
|
),
|
||||||
|
"tpl_project_agent_default": (
|
||||||
|
"You are a project management assistant. Help the user create, find, "
|
||||||
|
"update, and archive projects. Projects have a name, an optional client, "
|
||||||
|
"and a status of either active or archived."
|
||||||
|
),
|
||||||
|
"tpl_note_agent_default": (
|
||||||
|
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
||||||
|
"and delete Markdown notes. Notes can optionally be linked to a project."
|
||||||
|
),
|
||||||
|
"tpl_task_extract_from_project": (
|
||||||
|
"Extract all actionable tasks from the provided project context. "
|
||||||
|
"Return a structured list of tasks, each with a title, inferred priority "
|
||||||
|
"(high, medium, or low), suggested status (todo), and a due_date in "
|
||||||
|
"milliseconds where a deadline can be inferred."
|
||||||
|
),
|
||||||
|
"tpl_note_weekly_summary": (
|
||||||
|
"Generate a weekly project summary note from the provided workspace data. "
|
||||||
|
"Include: tasks completed this week, tasks due soon, active projects, "
|
||||||
|
"and upcoming checkpoints. Format the output as clean Markdown."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for tid, text in _tpls.items():
|
||||||
|
template_registry.register(tid, text)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_playbooks() -> None:
|
||||||
|
"""Pre-build and cache the built-in playbooks."""
|
||||||
|
playbooks: list[tuple[str, ExecutionPlan]] = [
|
||||||
|
(
|
||||||
|
"create_tasks_from_project",
|
||||||
|
ExecutionPlanBuilder("project_agent")
|
||||||
|
.add_llm_step(
|
||||||
|
"tpl_task_extract_from_project",
|
||||||
|
{"source": "project_context"},
|
||||||
|
)
|
||||||
|
.add_data_step("create_record", data_from_step=0)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"generate_weekly_note",
|
||||||
|
ExecutionPlanBuilder("note_agent")
|
||||||
|
.add_llm_step(
|
||||||
|
"tpl_note_weekly_summary",
|
||||||
|
{"period": "last_7_days"},
|
||||||
|
)
|
||||||
|
.add_data_step("create_record", data_from_step=0)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for key, plan in playbooks:
|
||||||
|
plan_cache.cache_plan(key, plan)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialise on module load
|
||||||
|
_register_builtin_templates()
|
||||||
|
_load_playbooks()
|
||||||
68
app/core/llm.py
Normal file
68
app/core/llm.py
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
|
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
|
||||||
|
instead of directly constructing a provider-specific class. The model string
|
||||||
|
follows the `LiteLLM model naming convention
|
||||||
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
|
|
||||||
|
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
|
||||||
|
* Anthropic: ``anthropic/claude-3.5-sonnet``
|
||||||
|
* Google: ``gemini/gemini-pro``
|
||||||
|
* Ollama: ``ollama/llama3``
|
||||||
|
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||||
|
|
||||||
|
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
||||||
|
— no code changes required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
||||||
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
*,
|
||||||
|
model: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI:
|
||||||
|
"""Return a LangChain chat model backed by LiteLLM.
|
||||||
|
|
||||||
|
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
|
||||||
|
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
|
||||||
|
``openai`` client transparently when the model string contains a provider
|
||||||
|
prefix (``anthropic/…``, ``gemini/…``, etc.).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model:
|
||||||
|
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
|
||||||
|
temperature:
|
||||||
|
Sampling temperature. ``0`` = deterministic.
|
||||||
|
"""
|
||||||
|
model = model or settings.LLM_MODEL
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=_api_key_for_model(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_router_llm(
|
||||||
|
*,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI:
|
||||||
|
"""Return the lighter model used for intent classification / routing."""
|
||||||
|
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
||||||
168
app/core/orchestrator.py
Normal file
168
app/core/orchestrator.py
Normal file
@@ -0,0 +1,168 @@
|
|||||||
|
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from app.core.agent_registry import AgentRegistry
|
||||||
|
from app.core.llm import get_router_llm
|
||||||
|
from app.core.agent_registry import registry as _default_registry
|
||||||
|
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
||||||
|
|
||||||
|
_FALLBACK_AGENT = "task_agent"
|
||||||
|
|
||||||
|
_CLASSIFY_SYSTEM = (
|
||||||
|
"You are an intent classifier. Given the user message and context, decide "
|
||||||
|
"which agent to route to.\n"
|
||||||
|
"Available agents: {agents}\n"
|
||||||
|
"Respond with just the agent name, nothing else."
|
||||||
|
)
|
||||||
|
|
||||||
|
_SYNTHESIZE_HUMAN = (
|
||||||
|
"Combine the following agent results into one coherent response.\n\n"
|
||||||
|
"Agent results:\n{results}\n\n"
|
||||||
|
"Original message: {message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_llm():
|
||||||
|
return get_router_llm()
|
||||||
|
|
||||||
|
|
||||||
|
async def classify_intent(
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry,
|
||||||
|
) -> str:
|
||||||
|
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
||||||
|
|
||||||
|
Falls back to ``task_agent`` when the registry is empty or the model
|
||||||
|
returns a name that is not registered.
|
||||||
|
"""
|
||||||
|
agents = reg.list_agents()
|
||||||
|
if not agents:
|
||||||
|
return _FALLBACK_AGENT
|
||||||
|
|
||||||
|
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
||||||
|
# Truncate context to keep the classification prompt short
|
||||||
|
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
||||||
|
|
||||||
|
llm = _make_llm()
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[SystemMessage(content=system), HumanMessage(content=human)]
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_name = str(response.content).strip().lower()
|
||||||
|
known = {a["name"] for a in agents}
|
||||||
|
return agent_name if agent_name in known else _FALLBACK_AGENT
|
||||||
|
|
||||||
|
|
||||||
|
async def route_single(
|
||||||
|
agent_name: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry,
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
||||||
|
response_text = await reg.call_agent(agent_name, message, context)
|
||||||
|
return ChatResponse(response=response_text)
|
||||||
|
|
||||||
|
|
||||||
|
async def route_pipeline(
|
||||||
|
agent_names: list[str],
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry,
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""Execute agents sequentially; each agent receives previous results in context.
|
||||||
|
|
||||||
|
A final LLM synthesis call merges all results into one coherent response.
|
||||||
|
"""
|
||||||
|
previous_results: list[str] = []
|
||||||
|
|
||||||
|
for agent_name in agent_names:
|
||||||
|
ctx = {**context, "previous_results": list(previous_results)}
|
||||||
|
result = await reg.call_agent(agent_name, message, ctx)
|
||||||
|
previous_results.append(result)
|
||||||
|
|
||||||
|
results_str = "\n\n".join(
|
||||||
|
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
||||||
|
)
|
||||||
|
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
||||||
|
llm = _make_llm()
|
||||||
|
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
||||||
|
return ChatResponse(response=str(synthesis.content))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
||||||
|
"""Build an ``ExecutionPlan`` for the resolved agent.
|
||||||
|
|
||||||
|
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
||||||
|
If a default template exists for the agent, an LLM step is emitted;
|
||||||
|
otherwise a plain ``handle`` action step is used.
|
||||||
|
"""
|
||||||
|
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
||||||
|
|
||||||
|
template_id = f"tpl_{agent_name}_default"
|
||||||
|
builder = ExecutionPlanBuilder(agent_name)
|
||||||
|
if template_registry.has(template_id):
|
||||||
|
builder.add_llm_step(template_id, {"message": message})
|
||||||
|
else:
|
||||||
|
builder.add_step("handle", {"message": message})
|
||||||
|
return builder.build()
|
||||||
|
|
||||||
|
|
||||||
|
async def orchestrate(
|
||||||
|
request: ChatRequest,
|
||||||
|
reg: AgentRegistry | None = None,
|
||||||
|
) -> ChatResponse | ExecutionPlan:
|
||||||
|
"""Main orchestration entry point.
|
||||||
|
|
||||||
|
* Classifies the user's intent to select an agent.
|
||||||
|
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
||||||
|
``ChatResponse``.
|
||||||
|
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
||||||
|
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
||||||
|
"""
|
||||||
|
if reg is None:
|
||||||
|
reg = _default_registry
|
||||||
|
|
||||||
|
context = request.context.model_dump()
|
||||||
|
agent_name = await classify_intent(request.message, context, reg)
|
||||||
|
|
||||||
|
if request.execution_mode == "direct":
|
||||||
|
return await route_single(agent_name, request.message, context, reg)
|
||||||
|
|
||||||
|
# plan mode — return plan, do not execute
|
||||||
|
return _build_plan(agent_name, request.message)
|
||||||
|
|
||||||
|
|
||||||
|
async def orchestrate_stream(
|
||||||
|
request: ChatRequest,
|
||||||
|
reg: AgentRegistry | None = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Streaming orchestration — yields text chunks then a final JSON frame.
|
||||||
|
|
||||||
|
The final frame is a JSON object:
|
||||||
|
``{"done": true, "response": "...", "actions": []}``.
|
||||||
|
|
||||||
|
Agents do not yet support token-level streaming; the full response is
|
||||||
|
fetched first, then emitted in fixed-size chunks. Token-level streaming
|
||||||
|
will be wired in Step 6 when agents expose ``astream()``.
|
||||||
|
"""
|
||||||
|
if reg is None:
|
||||||
|
reg = _default_registry
|
||||||
|
|
||||||
|
context = request.context.model_dump()
|
||||||
|
agent_name = await classify_intent(request.message, context, reg)
|
||||||
|
response_text = await reg.call_agent(agent_name, request.message, context)
|
||||||
|
|
||||||
|
chunk_size = 50
|
||||||
|
for i in range(0, len(response_text), chunk_size):
|
||||||
|
yield response_text[i : i + chunk_size]
|
||||||
|
|
||||||
|
final = ChatResponse(response=response_text)
|
||||||
|
yield json.dumps({"done": True, **final.model_dump()})
|
||||||
40
app/db.py
Normal file
40
app/db.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Database engine, session factory, and base model.
|
||||||
|
|
||||||
|
All app code uses the async SQLAlchemy API. Alembic migrations use the
|
||||||
|
synchronous psycopg2 URL for the CLI (see alembic/env.py).
|
||||||
|
|
||||||
|
Usage in routes:
|
||||||
|
from app.db import get_session
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
async def my_route(db: AsyncSession = Depends(get_session)):
|
||||||
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
echo=settings.ENV == "dev",
|
||||||
|
)
|
||||||
|
|
||||||
|
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Shared declarative base for all ORM models."""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""FastAPI dependency that yields an async DB session per request."""
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
64
app/main.py
Normal file
64
app/main.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||||
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Startup: initialise DB connection pool and agent registry
|
||||||
|
from app.core.agent_registry import registry # noqa: F401 — triggers module load
|
||||||
|
import app.agents # noqa: F401 — triggers @registry.register decorators
|
||||||
|
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
|
from app.db import engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="Adiuva Cloud API",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
|
||||||
|
# Request flow: TierRateLimit → Sanitizer → CORS → Router
|
||||||
|
# Response flow: Router → CORS → Sanitizer → TierRateLimit
|
||||||
|
app.add_middleware(SanitizerMiddleware)
|
||||||
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
|
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
|
||||||
|
|
||||||
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
|
app.include_router(plans.router, prefix="/api/v1")
|
||||||
|
app.include_router(storage.router, prefix="/api/v1")
|
||||||
|
app.include_router(vectors.router, prefix="/api/v1")
|
||||||
|
app.include_router(backup.router, prefix="/api/v1")
|
||||||
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
|
|
||||||
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
async def health() -> dict:
|
||||||
|
return {"status": "ok", "version": app.version}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
7
app/marketplace/__init__.py
Normal file
7
app/marketplace/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Plugin marketplace package.
|
||||||
|
|
||||||
|
Three service classes introduced in Step 10:
|
||||||
|
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
||||||
|
- ``ReviewQueue`` — approval workflow + security checklist
|
||||||
|
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
||||||
|
"""
|
||||||
212
app/marketplace/plugin_registry.py
Normal file
212
app/marketplace/plugin_registry.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""Plugin catalog registry backed by PostgreSQL.
|
||||||
|
|
||||||
|
Maintains the authoritative list of plugins, their review status, and
|
||||||
|
aggregate install counts. All data is persisted in the ``plugins`` table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import Plugin
|
||||||
|
from app.schemas import PluginListResponse, PluginManifest
|
||||||
|
|
||||||
|
_PAGE_SIZE = 20
|
||||||
|
|
||||||
|
|
||||||
|
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
|
||||||
|
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
|
||||||
|
try:
|
||||||
|
permissions = json.loads(p.permissions) if p.permissions else []
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
permissions = []
|
||||||
|
return PluginManifest(
|
||||||
|
id=p.id,
|
||||||
|
name=p.name,
|
||||||
|
description=p.description,
|
||||||
|
version=p.version,
|
||||||
|
author=p.author_name,
|
||||||
|
permissions=permissions,
|
||||||
|
category=p.category,
|
||||||
|
price_cents=p.price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginRegistry:
|
||||||
|
"""PostgreSQL-backed plugin catalog.
|
||||||
|
|
||||||
|
All methods accept an ``AsyncSession`` parameter so the calling route
|
||||||
|
controls the session lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Queries ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def list_plugins(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
category: str | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
sort: Literal["rating", "installs", "newest"] = "newest",
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Return a page of approved plugins, optionally filtered and sorted."""
|
||||||
|
base = select(Plugin).where(Plugin.status == "approved")
|
||||||
|
|
||||||
|
if category:
|
||||||
|
base = base.where(Plugin.category == category)
|
||||||
|
if query:
|
||||||
|
pattern = f"%{query}%"
|
||||||
|
base = base.where(
|
||||||
|
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count
|
||||||
|
count_q = select(func.count()).select_from(base.subquery())
|
||||||
|
total = (await db.execute(count_q)).scalar_one()
|
||||||
|
|
||||||
|
# Sort
|
||||||
|
if sort == "installs":
|
||||||
|
base = base.order_by(Plugin.install_count.desc())
|
||||||
|
elif sort == "rating":
|
||||||
|
base = base.order_by(Plugin.avg_rating.desc())
|
||||||
|
else: # newest
|
||||||
|
base = base.order_by(Plugin.created_at.desc())
|
||||||
|
|
||||||
|
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
|
||||||
|
rows = (await db.execute(base)).scalars().all()
|
||||||
|
|
||||||
|
return PluginListResponse(
|
||||||
|
plugins=[_plugin_to_manifest(r) for r in rows],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
p = result.scalar_one_or_none()
|
||||||
|
if p is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"manifest": _plugin_to_manifest(p),
|
||||||
|
"status": p.status,
|
||||||
|
"install_count": p.install_count,
|
||||||
|
"avg_rating": p.avg_rating,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Mutations ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def submit_plugin(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
manifest: PluginManifest,
|
||||||
|
package_s3_key: str,
|
||||||
|
) -> str:
|
||||||
|
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
||||||
|
|
||||||
|
Returns the plugin_id. If a plugin with the same id already exists
|
||||||
|
it is overwritten (re-submission after rejection).
|
||||||
|
"""
|
||||||
|
plugin_id = manifest.id
|
||||||
|
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = existing.scalar_one_or_none()
|
||||||
|
|
||||||
|
if row is not None:
|
||||||
|
row.name = manifest.name
|
||||||
|
row.description = manifest.description
|
||||||
|
row.version = manifest.version
|
||||||
|
row.author_name = manifest.author
|
||||||
|
row.category = manifest.category
|
||||||
|
row.price_cents = manifest.price_cents
|
||||||
|
row.permissions = json.dumps(manifest.permissions)
|
||||||
|
row.status = "pending_review"
|
||||||
|
row.s3_package_key = package_s3_key
|
||||||
|
row.rejection_reason = None
|
||||||
|
else:
|
||||||
|
row = Plugin(
|
||||||
|
id=plugin_id,
|
||||||
|
name=manifest.name,
|
||||||
|
description=manifest.description,
|
||||||
|
version=manifest.version,
|
||||||
|
author_name=manifest.author,
|
||||||
|
category=manifest.category,
|
||||||
|
price_cents=manifest.price_cents,
|
||||||
|
permissions=json.dumps(manifest.permissions),
|
||||||
|
status="pending_review",
|
||||||
|
s3_package_key=package_s3_key,
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
await db.commit()
|
||||||
|
return plugin_id
|
||||||
|
|
||||||
|
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'approved'``.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
row.status = "approved"
|
||||||
|
row.rejection_reason = None
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
row.status = "rejected"
|
||||||
|
row.rejection_reason = reason
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.install_count = row.install_count + 1
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Decrement the install count for *plugin_id*, floored at 0."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.install_count = max(0, row.install_count - 1)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
||||||
|
|
||||||
|
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
|
"""Return all entries with status='pending_review'."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Plugin).where(Plugin.status == "pending_review")
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"manifest": _plugin_to_manifest(r),
|
||||||
|
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
registry = PluginRegistry()
|
||||||
125
app/marketplace/plugin_review.py
Normal file
125
app/marketplace/plugin_review.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Plugin review workflow backed by PostgreSQL.
|
||||||
|
|
||||||
|
Manages the approval queue for newly submitted plugins and enforces a
|
||||||
|
security checklist before any plugin is made visible in the marketplace.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_review import review_queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import PluginReview as PluginReviewModel
|
||||||
|
from app.schemas import PluginManifest
|
||||||
|
|
||||||
|
# ── Security policy ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"read:tasks",
|
||||||
|
"write:tasks",
|
||||||
|
"read:projects",
|
||||||
|
"write:projects",
|
||||||
|
"read:notes",
|
||||||
|
"write:notes",
|
||||||
|
"read:checkpoints",
|
||||||
|
"write:checkpoints",
|
||||||
|
"read:calendar",
|
||||||
|
"write:calendar",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_manifest(manifest: PluginManifest) -> None:
|
||||||
|
"""Enforce the plugin security checklist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``ValueError`` on the first violation found. Callers should catch
|
||||||
|
this and return HTTP 422 / reject the submission.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. Plugin id matches ``^[a-z0-9-]+$``
|
||||||
|
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
||||||
|
3. No manifest field contains raw binary data
|
||||||
|
"""
|
||||||
|
if not _PLUGIN_ID_RE.match(manifest.id):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid plugin id format: '{manifest.id}'. "
|
||||||
|
"Only lowercase letters, digits, and hyphens are allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
for perm in manifest.permissions:
|
||||||
|
if perm not in ALLOWED_PERMISSIONS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown permission: '{perm}'. "
|
||||||
|
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name, value in manifest.model_dump().items():
|
||||||
|
if isinstance(value, (bytes, bytearray)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Binary content is not allowed in manifest field '{field_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewQueue:
|
||||||
|
"""Approval queue for pending plugin submissions.
|
||||||
|
|
||||||
|
Delegates status changes to the shared ``PluginRegistry`` singleton.
|
||||||
|
Review records are persisted in the ``plugin_reviews`` table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
|
"""Return all plugins currently awaiting review.
|
||||||
|
|
||||||
|
Each item is ``{plugin_id, manifest, submitted_at}``.
|
||||||
|
"""
|
||||||
|
entries = await registry.get_pending_entries(db)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"plugin_id": e["manifest"].id,
|
||||||
|
"manifest": e["manifest"],
|
||||||
|
"submitted_at": e["submitted_at"],
|
||||||
|
}
|
||||||
|
for e in entries
|
||||||
|
]
|
||||||
|
|
||||||
|
async def submit_review(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
plugin_id: str,
|
||||||
|
reviewer_id: str,
|
||||||
|
decision: Literal["approved", "rejected"],
|
||||||
|
notes: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Record a review decision and update the plugin's status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``KeyError`` if *plugin_id* is not found in the registry.
|
||||||
|
"""
|
||||||
|
if decision == "approved":
|
||||||
|
await registry.approve_plugin(db, plugin_id)
|
||||||
|
else:
|
||||||
|
await registry.reject_plugin(db, plugin_id, reason=notes)
|
||||||
|
|
||||||
|
review = PluginReviewModel(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
reviewer_id=reviewer_id,
|
||||||
|
decision=decision,
|
||||||
|
notes=notes,
|
||||||
|
)
|
||||||
|
db.add(review)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
review_queue = ReviewQueue()
|
||||||
233
app/marketplace/revenue_share.py
Normal file
233
app/marketplace/revenue_share.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
|
||||||
|
|
||||||
|
Records every plugin installation as a revenue event and facilitates
|
||||||
|
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
|
||||||
|
in the ``revenue_events`` table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from sqlalchemy import extract, func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import Plugin, RevenueEvent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Revenue split constants ───────────────────────────────────────────
|
||||||
|
|
||||||
|
DEVELOPER_SHARE: float = 0.70
|
||||||
|
PLATFORM_SHARE: float = 0.30
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueShare:
|
||||||
|
"""Records installation revenue events and coordinates developer payouts.
|
||||||
|
|
||||||
|
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
||||||
|
is not configured, consistent with the rest of the billing layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe_configured() -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe() -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Core operations ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def record_install(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
plugin_id: str,
|
||||||
|
user_id: str,
|
||||||
|
amount_cents: int,
|
||||||
|
) -> None:
|
||||||
|
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
||||||
|
|
||||||
|
For free plugins (``amount_cents == 0``) no payment is initiated but
|
||||||
|
the event is still recorded for analytics.
|
||||||
|
|
||||||
|
For paid plugins the developer receives 70 % via a Stripe Connect
|
||||||
|
destination charge. If Stripe is not configured or the charge fails
|
||||||
|
the installation still succeeds (the event is recorded and the install
|
||||||
|
count is incremented) — a warning is logged for monitoring.
|
||||||
|
"""
|
||||||
|
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
||||||
|
stripe_transfer_id: str | None = None
|
||||||
|
|
||||||
|
if amount_cents > 0 and self._stripe_configured():
|
||||||
|
# Look up the plugin's author Stripe account from the DB
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
plugin_row = result.scalar_one_or_none()
|
||||||
|
developer_stripe_account: str | None = None
|
||||||
|
if plugin_row and plugin_row.author_id:
|
||||||
|
# Future: look up user.stripe_connect_account_id
|
||||||
|
developer_stripe_account = None # no real account yet
|
||||||
|
|
||||||
|
if developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
transfer = s.Transfer.create(
|
||||||
|
amount=developer_share_cents,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Revenue share for plugin {plugin_id}",
|
||||||
|
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
stripe_transfer_id = transfer["id"]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Stripe Connect transfer failed for plugin %s: %s",
|
||||||
|
plugin_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"No Stripe account on file for plugin %s developer; "
|
||||||
|
"skipping transfer.",
|
||||||
|
plugin_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = RevenueEvent(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=user_id,
|
||||||
|
amount_cents=amount_cents,
|
||||||
|
developer_share_cents=developer_share_cents,
|
||||||
|
stripe_transfer_id=stripe_transfer_id,
|
||||||
|
)
|
||||||
|
db.add(event)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await registry.record_install(db, plugin_id)
|
||||||
|
|
||||||
|
async def get_earnings(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
developer_id: str,
|
||||||
|
period: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return aggregated earnings for *developer_id*.
|
||||||
|
|
||||||
|
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
||||||
|
|
||||||
|
Returns::
|
||||||
|
|
||||||
|
{
|
||||||
|
"developer_id": str,
|
||||||
|
"period": str | None,
|
||||||
|
"total_installs": int,
|
||||||
|
"total_revenue_cents": int,
|
||||||
|
"developer_share_cents": int,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Find plugin ids belonging to this developer (by author_name match)
|
||||||
|
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
|
||||||
|
plugin_result = await db.execute(plugin_q)
|
||||||
|
developer_plugin_ids = [row[0] for row in plugin_result.all()]
|
||||||
|
|
||||||
|
if not developer_plugin_ids:
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": 0,
|
||||||
|
"total_revenue_cents": 0,
|
||||||
|
"developer_share_cents": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
query = select(
|
||||||
|
func.count().label("total_installs"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
|
||||||
|
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
|
||||||
|
|
||||||
|
if period:
|
||||||
|
# Filter by YYYY-MM: extract year and month from created_at
|
||||||
|
try:
|
||||||
|
year, month = period.split("-")
|
||||||
|
query = query.where(
|
||||||
|
extract("year", RevenueEvent.created_at) == int(year),
|
||||||
|
extract("month", RevenueEvent.created_at) == int(month),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass # invalid period format — return all
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
row = result.one()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": row.total_installs,
|
||||||
|
"total_revenue_cents": row.total_revenue,
|
||||||
|
"developer_share_cents": row.dev_share,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
|
||||||
|
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
||||||
|
|
||||||
|
Marks processed events with ``paid_at`` timestamp.
|
||||||
|
Stubs gracefully when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
year, month = period.split("-")
|
||||||
|
year_int, month_int = int(year), int(month)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid period format: %s", period)
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(RevenueEvent).where(
|
||||||
|
RevenueEvent.plugin_id == plugin_id,
|
||||||
|
RevenueEvent.paid_at.is_(None),
|
||||||
|
extract("year", RevenueEvent.created_at) == year_int,
|
||||||
|
extract("month", RevenueEvent.created_at) == month_int,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
unpaid = list(result.scalars().all())
|
||||||
|
|
||||||
|
total_dev_share = sum(e.developer_share_cents for e in unpaid)
|
||||||
|
if total_dev_share <= 0 or not unpaid:
|
||||||
|
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._stripe_configured():
|
||||||
|
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
plugin_row = plugin_result.scalar_one_or_none()
|
||||||
|
developer_stripe_account: str | None = None # Future: fetch from DB
|
||||||
|
if plugin_row and developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
s.Transfer.create(
|
||||||
|
amount=total_dev_share,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Payout for plugin {plugin_id} period {period}",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
paid_ts = datetime.now(timezone.utc)
|
||||||
|
for event in unpaid:
|
||||||
|
event.paid_at = paid_ts
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
revenue_share = RevenueShare()
|
||||||
268
app/models.py
Normal file
268
app/models.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
|
Only auth, billing, storage metadata, and marketplace data live here.
|
||||||
|
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||||
|
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||||
|
|
||||||
|
Table inventory:
|
||||||
|
users — account credentials + tier
|
||||||
|
refresh_tokens — hashed refresh token store
|
||||||
|
subscriptions — Stripe subscription records
|
||||||
|
storage_records — S3 blob metadata (no plaintext)
|
||||||
|
backup_metadata — encrypted backup manifests
|
||||||
|
plugins — marketplace plugin catalog
|
||||||
|
plugin_installations — per-user install records
|
||||||
|
plugin_reviews — admin review decisions
|
||||||
|
revenue_events — Stripe Connect 70/30 split ledger
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
Uuid,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Enum types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||||
|
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||||
|
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Models ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
subscription: Mapped[Subscription | None] = relationship(
|
||||||
|
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base):
|
||||||
|
__tablename__ = "refresh_tokens"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(Base):
|
||||||
|
__tablename__ = "subscriptions"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, unique=True, index=True
|
||||||
|
)
|
||||||
|
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||||
|
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecord(Base):
|
||||||
|
__tablename__ = "storage_records"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BackupMetadata(Base):
|
||||||
|
__tablename__ = "backup_metadata"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||||
|
# nullable until developer account system is built
|
||||||
|
author_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||||
|
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
submitted_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
reviews: Mapped[list[PluginReview]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallation(Base):
|
||||||
|
__tablename__ = "plugin_installations"
|
||||||
|
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
installed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginReview(Base):
|
||||||
|
__tablename__ = "plugin_reviews"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
reviewer_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||||
|
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
reviewed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueEvent(Base):
|
||||||
|
__tablename__ = "revenue_events"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||||
157
app/schemas.py
Normal file
157
app/schemas.py
Normal file
@@ -0,0 +1,157 @@
|
|||||||
|
"""Pydantic schemas — API request/response contracts.
|
||||||
|
|
||||||
|
Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
# ── Billing ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
BillingTier = Literal["free", "pro", "power", "team"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AuthTokens(BaseModel):
|
||||||
|
access_token: str
|
||||||
|
refresh_token: str
|
||||||
|
expires_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class UserProfile(BaseModel):
|
||||||
|
id: str
|
||||||
|
email: str
|
||||||
|
tier: BillingTier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Chat ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class ChatContext(BaseModel):
|
||||||
|
user_profile: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
relevant_documents: list[str] = Field(default_factory=list)
|
||||||
|
recent_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class PlanAction(BaseModel):
|
||||||
|
type: Literal[
|
||||||
|
"create_record",
|
||||||
|
"update_record",
|
||||||
|
"delete_record",
|
||||||
|
"index_document",
|
||||||
|
"send_notification",
|
||||||
|
]
|
||||||
|
table: str | None = None
|
||||||
|
data: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ChatRequest(BaseModel):
|
||||||
|
message: str
|
||||||
|
context: ChatContext = Field(default_factory=ChatContext)
|
||||||
|
execution_mode: Literal["direct", "plan"] = "direct"
|
||||||
|
|
||||||
|
|
||||||
|
class ChatResponse(BaseModel):
|
||||||
|
response: str
|
||||||
|
actions: list[PlanAction] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Execution Plans ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PlanStep(BaseModel):
|
||||||
|
action: str
|
||||||
|
prompt_template: str | None = None
|
||||||
|
variables: dict[str, Any] | None = None
|
||||||
|
data_from_step: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionPlan(BaseModel):
|
||||||
|
agent: str
|
||||||
|
steps: list[PlanStep] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Backup ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class BackupMetadata(BaseModel):
|
||||||
|
version: int
|
||||||
|
timestamp: int
|
||||||
|
checksum: str
|
||||||
|
chunk_count: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
||||||
|
|
||||||
|
class StorageRecord(BaseModel):
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordCreate(BaseModel):
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordUpdate(BaseModel):
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
||||||
|
|
||||||
|
class VectorItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
blob: bytes # encrypted vector + metadata — backend never decrypts
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorUpsertRequest(BaseModel):
|
||||||
|
vectors: list[VectorItem]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchRequest(BaseModel):
|
||||||
|
query_blob: bytes # encrypted query — backend never decrypts
|
||||||
|
top_k: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResult(BaseModel):
|
||||||
|
id: str
|
||||||
|
score: float
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResponse(BaseModel):
|
||||||
|
results: list[VectorSearchResult]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin Marketplace ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PluginManifest(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
version: str
|
||||||
|
author: str
|
||||||
|
permissions: list[str]
|
||||||
|
category: str
|
||||||
|
price_cents: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class PluginListResponse(BaseModel):
|
||||||
|
plugins: list[PluginManifest]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallRequest(BaseModel):
|
||||||
|
plugin_id: str
|
||||||
1
app/storage/__init__.py
Normal file
1
app/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
||||||
106
app/storage/blob_store.py
Normal file
106
app/storage/blob_store.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""S3-backed store for E2E-encrypted blobs.
|
||||||
|
|
||||||
|
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
||||||
|
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class BlobStore:
|
||||||
|
"""Thin wrapper around boto3 S3.
|
||||||
|
|
||||||
|
All blobs must be E2E encrypted by the client before upload.
|
||||||
|
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
||||||
|
but cannot decrypt the inner client-side payload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"region_name": settings.S3_REGION,
|
||||||
|
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
|
||||||
|
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
|
||||||
|
}
|
||||||
|
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
|
||||||
|
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
|
||||||
|
return boto3.client("s3", **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _key(user_id: str, table: str, record_id: str) -> str:
|
||||||
|
return f"{user_id}/{table}/{record_id}"
|
||||||
|
|
||||||
|
async def upload(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
table: str,
|
||||||
|
record_id: str,
|
||||||
|
blob: bytes,
|
||||||
|
checksum: str,
|
||||||
|
) -> str:
|
||||||
|
"""Store *blob* in S3 and return the S3 key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Owner of the blob (used as key prefix).
|
||||||
|
table: Logical table name (e.g. ``"tasks"``).
|
||||||
|
record_id: Record UUID.
|
||||||
|
blob: Raw bytes (pre-encrypted by client).
|
||||||
|
checksum: SHA-256 hex digest supplied by the client; stored as
|
||||||
|
object metadata for download-time verification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The S3 key under which the blob was stored.
|
||||||
|
"""
|
||||||
|
key = self._key(user_id, table, record_id)
|
||||||
|
self._client().put_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=key,
|
||||||
|
Body=blob,
|
||||||
|
ServerSideEncryption="AES256", # SSE-S3 at rest
|
||||||
|
Metadata={"checksum": checksum},
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
async def download(self, user_id: str, s3_key: str) -> bytes:
|
||||||
|
"""Retrieve the blob stored at *s3_key*.
|
||||||
|
|
||||||
|
*user_id* is retained in the signature so higher-level code can
|
||||||
|
enforce ownership without re-parsing the key.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
||||||
|
object does not exist.
|
||||||
|
"""
|
||||||
|
response = self._client().get_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
return response["Body"].read()
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, s3_key: str) -> None:
|
||||||
|
"""Delete the object at *s3_key*.
|
||||||
|
|
||||||
|
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
||||||
|
not exist.
|
||||||
|
"""
|
||||||
|
self._client().delete_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
||||||
|
"""Return all S3 keys for a given user + table combination.
|
||||||
|
|
||||||
|
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
||||||
|
"""
|
||||||
|
prefix = f"{user_id}/{table}/"
|
||||||
|
response = self._client().list_objects_v2(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Prefix=prefix,
|
||||||
|
)
|
||||||
|
return [obj["Key"] for obj in response.get("Contents", [])]
|
||||||
32
app/storage/encryption.py
Normal file
32
app/storage/encryption.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Integrity verification only — the backend NEVER decrypts user data."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
||||||
|
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
||||||
|
|
||||||
|
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
||||||
|
timing-based side-channel attacks.
|
||||||
|
"""
|
||||||
|
computed = hashlib.sha256(blob).hexdigest()
|
||||||
|
return hmac.compare_digest(computed, checksum)
|
||||||
|
|
||||||
|
|
||||||
|
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
||||||
|
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
||||||
|
|
||||||
|
Call this before storing or forwarding any client-provided blob.
|
||||||
|
The backend never holds decryption keys — this check only verifies
|
||||||
|
that the opaque bytes arrived intact.
|
||||||
|
"""
|
||||||
|
if not verify_checksum(blob, checksum):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Checksum mismatch: blob integrity check failed",
|
||||||
|
)
|
||||||
205
app/storage/vector_store.py
Normal file
205
app/storage/vector_store.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
||||||
|
|
||||||
|
Vectors are pre-encrypted blobs from the client. The backend stores them
|
||||||
|
alongside a deterministic 32-dim float representation derived from the blob's
|
||||||
|
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
||||||
|
is a known trade-off documented in the backend plan.
|
||||||
|
|
||||||
|
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
||||||
|
``user_id`` payload field on a shared collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pinecone import Pinecone
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.schemas import VectorItem, VectorSearchResult
|
||||||
|
|
||||||
|
_QDRANT_COLLECTION = "adiuva_vectors"
|
||||||
|
|
||||||
|
|
||||||
|
def _blob_to_vector(blob: bytes) -> list[float]:
|
||||||
|
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
||||||
|
|
||||||
|
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
||||||
|
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
||||||
|
semantic meaning on encrypted data.
|
||||||
|
"""
|
||||||
|
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStore:
|
||||||
|
"""Thin wrapper around Pinecone or Qdrant.
|
||||||
|
|
||||||
|
The backend to use is selected at runtime:
|
||||||
|
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
||||||
|
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _use_pinecone(self) -> bool:
|
||||||
|
return bool(settings.PINECONE_API_KEY)
|
||||||
|
|
||||||
|
# ── Pinecone helpers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _pinecone_index(self) -> Any:
|
||||||
|
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
||||||
|
return pc.Index(settings.PINECONE_INDEX)
|
||||||
|
|
||||||
|
# ── Qdrant helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _qdrant_client(self) -> Any:
|
||||||
|
return QdrantClient(
|
||||||
|
url=settings.QDRANT_URL,
|
||||||
|
api_key=settings.QDRANT_API_KEY or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
"""Store encrypted vectors in the backend.
|
||||||
|
|
||||||
|
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
||||||
|
so it can be returned verbatim during search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Used as Pinecone namespace or Qdrant payload field.
|
||||||
|
vectors: List of encrypted vector items from the client.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_upsert(user_id, vectors)
|
||||||
|
else:
|
||||||
|
await self._qdrant_upsert(user_id, vectors)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
query_blob: bytes,
|
||||||
|
top_k: int,
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
"""Query the vector store and return encrypted result blobs.
|
||||||
|
|
||||||
|
The query vector is derived from *query_blob* using the same
|
||||||
|
deterministic mapping as upsert.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Scopes the search to this user's namespace.
|
||||||
|
query_blob: Encrypted query from the client.
|
||||||
|
top_k: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
return await self._pinecone_search(user_id, query_blob, top_k)
|
||||||
|
return await self._qdrant_search(user_id, query_blob, top_k)
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
"""Remove vectors by ID, scoped to *user_id*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Namespace / payload filter to prevent cross-user deletion.
|
||||||
|
vector_ids: List of vector IDs to remove.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_delete(user_id, vector_ids)
|
||||||
|
else:
|
||||||
|
await self._qdrant_delete(user_id, vector_ids)
|
||||||
|
|
||||||
|
# ── Pinecone implementation ───────────────────────────────────────
|
||||||
|
|
||||||
|
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"id": v.id,
|
||||||
|
"values": _blob_to_vector(v.blob),
|
||||||
|
"metadata": {
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
index.upsert(vectors=records, namespace=user_id)
|
||||||
|
|
||||||
|
async def _pinecone_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
response = index.query(
|
||||||
|
vector=query_vector,
|
||||||
|
top_k=top_k,
|
||||||
|
namespace=user_id,
|
||||||
|
include_metadata=True,
|
||||||
|
)
|
||||||
|
results: list[VectorSearchResult] = []
|
||||||
|
for match in response.get("matches", []):
|
||||||
|
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
||||||
|
results.append(
|
||||||
|
VectorSearchResult(
|
||||||
|
id=match["id"],
|
||||||
|
score=match["score"],
|
||||||
|
blob=blob_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
index.delete(ids=vector_ids, namespace=user_id)
|
||||||
|
|
||||||
|
# ── Qdrant implementation ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
points = [
|
||||||
|
PointStruct(
|
||||||
|
id=v.id,
|
||||||
|
vector=_blob_to_vector(v.blob),
|
||||||
|
payload={
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
||||||
|
|
||||||
|
async def _qdrant_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
hits = client.search(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=Filter(
|
||||||
|
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
||||||
|
),
|
||||||
|
limit=top_k,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
VectorSearchResult(
|
||||||
|
id=str(hit.id),
|
||||||
|
score=hit.score,
|
||||||
|
blob=base64.b64decode(hit.payload["blob"]),
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
client.delete(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
points_selector=PointIdsList(points=vector_ids),
|
||||||
|
)
|
||||||
68
docker-compose.yml
Normal file
68
docker-compose.yml
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
services:
|
||||||
|
app:
|
||||||
|
build: .
|
||||||
|
ports:
|
||||||
|
- "8080:8000"
|
||||||
|
env_file:
|
||||||
|
- path: .env
|
||||||
|
required: false
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
db:
|
||||||
|
image: postgres:16-alpine
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: adiuva
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# Optional Redis for future rate-limit or caching needs
|
||||||
|
# redis:
|
||||||
|
# image: redis:7-alpine
|
||||||
|
# restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Local S3-compatible storage (MinIO) ──
|
||||||
|
minio:
|
||||||
|
image: minio/minio:latest
|
||||||
|
command: server /data --console-address ":9001"
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
- "9001:9001"
|
||||||
|
environment:
|
||||||
|
MINIO_ROOT_USER: minioadmin
|
||||||
|
MINIO_ROOT_PASSWORD: minioadmin
|
||||||
|
volumes:
|
||||||
|
- minio_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "mc", "ready", "local"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Local vector store (Qdrant) ──
|
||||||
|
qdrant:
|
||||||
|
image: qdrant/qdrant:latest
|
||||||
|
ports:
|
||||||
|
- "6333:6333"
|
||||||
|
- "6334:6334"
|
||||||
|
volumes:
|
||||||
|
- qdrant_data:/qdrant/storage
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
minio_data:
|
||||||
|
qdrant_data:
|
||||||
27
requirements.txt
Normal file
27
requirements.txt
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
langchain>=0.3.0
|
||||||
|
langchain-openai>=0.3.0
|
||||||
|
litellm>=1.50.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
stripe>=11.0.0
|
||||||
|
boto3>=1.35.0
|
||||||
|
slowapi>=0.1.9
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
alembic>=1.14.0
|
||||||
|
bcrypt>=4.2.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
httpx>=0.28.0
|
||||||
|
websockets>=14.0
|
||||||
|
psycopg2-binary>=2.9.0
|
||||||
|
pytest>=8.0.0
|
||||||
|
pytest-asyncio>=0.24.0
|
||||||
|
aiosqlite>=0.20.0
|
||||||
|
moto[s3]>=5.0.0
|
||||||
|
pinecone>=5.0.0
|
||||||
|
qdrant-client>=1.7.0
|
||||||
|
ruff>=0.8.0
|
||||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
235
tests/conftest.py
Normal file
235
tests/conftest.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""Shared test fixtures for database-backed tests.
|
||||||
|
|
||||||
|
Provides an async SQLite in-memory engine that auto-creates all tables,
|
||||||
|
a per-test session, and a FastAPI ``TestClient`` wired to use it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator, Generator
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from jose import jwt
|
||||||
|
from moto import mock_aws
|
||||||
|
from sqlalchemy import StaticPool, event
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.db import Base, get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import Plugin, Subscription, User
|
||||||
|
|
||||||
|
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
||||||
|
|
||||||
|
TEST_USER_IDS: dict[str, str] = {
|
||||||
|
"free": "00000000-0000-0000-0000-000000000001",
|
||||||
|
"pro": "00000000-0000-0000-0000-000000000002",
|
||||||
|
"power": "00000000-0000-0000-0000-000000000003",
|
||||||
|
"team": "00000000-0000-0000-0000-000000000004",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Async SQLite engine ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
_TEST_ENGINE = create_async_engine(
|
||||||
|
"sqlite+aiosqlite://",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
_TestSessionLocal = async_sessionmaker(
|
||||||
|
_TEST_ENGINE,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Enable foreign key enforcement for SQLite (off by default).
|
||||||
|
@event.listens_for(_TEST_ENGINE.sync_engine, "connect")
|
||||||
|
def _set_sqlite_pragma(dbapi_conn, _connection_record): # noqa: ANN001
|
||||||
|
cursor = dbapi_conn.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def _create_tables():
|
||||||
|
"""Create all tables before each test, seed test users, then drop after."""
|
||||||
|
async with _TEST_ENGINE.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
# Seed one User + Subscription per tier so FK constraints and auth work.
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
for tier, uid in TEST_USER_IDS.items():
|
||||||
|
session.add(User(
|
||||||
|
id=uid,
|
||||||
|
email=f"{tier}@test.com",
|
||||||
|
password_hash="$2b$12$fakehashfortesting000000000000000000000000000",
|
||||||
|
tier=tier,
|
||||||
|
))
|
||||||
|
session.add(Subscription(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=uid,
|
||||||
|
tier=tier,
|
||||||
|
stripe_subscription_id=f"sub_test_{tier}",
|
||||||
|
status="active",
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
yield
|
||||||
|
async with _TEST_ENGINE.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Yield a per-test async DB session."""
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # noqa: ANN001
|
||||||
|
"""FastAPI test client with ``get_session`` overridden to use the test DB."""
|
||||||
|
|
||||||
|
async def _override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _override_get_session
|
||||||
|
with TestClient(app) as c:
|
||||||
|
yield c
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Seed data helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
Plugin(
|
||||||
|
id="plugin-github-sync",
|
||||||
|
name="GitHub Sync",
|
||||||
|
description="Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
version="1.0.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=0,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-slack-notify",
|
||||||
|
name="Slack Notifier",
|
||||||
|
description="Post task and checkpoint updates to Slack channels.",
|
||||||
|
version="1.2.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="communication",
|
||||||
|
price_cents=499,
|
||||||
|
permissions=json.dumps(["read:tasks", "read:checkpoints"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-time-tracker",
|
||||||
|
name="Time Tracker",
|
||||||
|
description="Track time spent on tasks with automatic reporting.",
|
||||||
|
version="0.9.1",
|
||||||
|
author_name="Third Party",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=999,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
|
||||||
|
"""Insert the 3 default approved plugins and return them."""
|
||||||
|
plugins = []
|
||||||
|
for template in _SEED_PLUGINS:
|
||||||
|
p = Plugin(
|
||||||
|
id=template.id,
|
||||||
|
name=template.name,
|
||||||
|
description=template.description,
|
||||||
|
version=template.version,
|
||||||
|
author_name=template.author_name,
|
||||||
|
category=template.category,
|
||||||
|
price_cents=template.price_cents,
|
||||||
|
permissions=template.permissions,
|
||||||
|
status=template.status,
|
||||||
|
s3_package_key=template.s3_package_key,
|
||||||
|
install_count=template.install_count,
|
||||||
|
avg_rating=template.avg_rating,
|
||||||
|
)
|
||||||
|
db_session.add(p)
|
||||||
|
plugins.append(p)
|
||||||
|
await db_session.commit()
|
||||||
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
|
# ── JWT helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def make_jwt(
|
||||||
|
tier: str = "power",
|
||||||
|
user_id: str | None = None,
|
||||||
|
email: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a signed test JWT.
|
||||||
|
|
||||||
|
Uses the fixed ``TEST_USER_IDS`` mapping so the auth middleware can
|
||||||
|
find the corresponding ``Subscription`` row in the test database.
|
||||||
|
"""
|
||||||
|
uid = user_id or TEST_USER_IDS.get(tier, str(uuid.uuid4()))
|
||||||
|
now = int(time.time())
|
||||||
|
payload = {
|
||||||
|
"sub": uid,
|
||||||
|
"email": email or f"{tier}@test.com",
|
||||||
|
"tier": tier,
|
||||||
|
"exp": now + 3600,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]:
|
||||||
|
"""Return an Authorization header dict for the given tier."""
|
||||||
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── S3 mock fixture ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
S3_TEST_BUCKET = "test-bucket"
|
||||||
|
S3_TEST_REGION = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def s3_bucket():
|
||||||
|
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
|
||||||
|
with mock_aws():
|
||||||
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
||||||
|
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
||||||
|
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
|
||||||
|
client = boto3.client("s3", region_name=S3_TEST_REGION)
|
||||||
|
client.create_bucket(Bucket=S3_TEST_BUCKET)
|
||||||
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||||
|
mock_settings.S3_BUCKET = S3_TEST_BUCKET
|
||||||
|
mock_settings.S3_REGION = S3_TEST_REGION
|
||||||
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||||
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||||
|
yield S3_TEST_BUCKET
|
||||||
214
tests/test_agent_registry.py
Normal file
214
tests/test_agent_registry.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Unit tests for the agent registry, base classes, and tool loop."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _StubAgent(ChatAgent):
|
||||||
|
"""Minimal concrete agent for testing."""
|
||||||
|
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "stub"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "A stub agent for tests"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return f"echo: {query}"
|
||||||
|
|
||||||
|
|
||||||
|
class _AnotherAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "another"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Another stub"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return "another"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _fresh_registry():
|
||||||
|
"""Reset the singleton between tests."""
|
||||||
|
AgentRegistry._instance = None
|
||||||
|
yield
|
||||||
|
AgentRegistry._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def reg() -> AgentRegistry:
|
||||||
|
return AgentRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tests ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class TestRegisterAndGet:
|
||||||
|
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
agent = reg.get("stub")
|
||||||
|
assert isinstance(agent, _StubAgent)
|
||||||
|
|
||||||
|
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||||
|
with pytest.raises(KeyError, match="not found"):
|
||||||
|
reg.get("nonexistent")
|
||||||
|
|
||||||
|
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
reg.register(_AnotherAgent)
|
||||||
|
assert reg.get("stub").get_name() == "stub"
|
||||||
|
assert reg.get("another").get_name() == "another"
|
||||||
|
|
||||||
|
|
||||||
|
class TestListAgents:
|
||||||
|
def test_empty(self, reg: AgentRegistry) -> None:
|
||||||
|
assert reg.list_agents() == []
|
||||||
|
|
||||||
|
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
agents = reg.list_agents()
|
||||||
|
assert len(agents) == 1
|
||||||
|
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
||||||
|
|
||||||
|
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
reg.register(_AnotherAgent)
|
||||||
|
names = {a["name"] for a in reg.list_agents()}
|
||||||
|
assert names == {"stub", "another"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestCallAgent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
||||||
|
reg.register(_StubAgent)
|
||||||
|
result = await reg.call_agent("stub", "hello", {})
|
||||||
|
assert result == "echo: hello"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await reg.call_agent("nope", "hi", {})
|
||||||
|
|
||||||
|
|
||||||
|
class TestSingleton:
|
||||||
|
def test_singleton_identity(self) -> None:
|
||||||
|
a = AgentRegistry()
|
||||||
|
b = AgentRegistry()
|
||||||
|
assert a is b
|
||||||
|
|
||||||
|
|
||||||
|
class TestToolLoop:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_tool_calls(self) -> None:
|
||||||
|
"""When the LLM responds without tool calls, return content directly."""
|
||||||
|
agent = _StubAgent()
|
||||||
|
|
||||||
|
ai_msg = MagicMock()
|
||||||
|
ai_msg.content = "final answer"
|
||||||
|
ai_msg.tool_calls = []
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [], [])
|
||||||
|
assert result == "final answer"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_then_answer(self) -> None:
|
||||||
|
"""LLM requests one tool call, gets result, then answers."""
|
||||||
|
agent = _StubAgent()
|
||||||
|
|
||||||
|
# First response: tool call
|
||||||
|
tool_call_msg = MagicMock()
|
||||||
|
tool_call_msg.content = ""
|
||||||
|
tool_call_msg.tool_calls = [
|
||||||
|
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
||||||
|
]
|
||||||
|
|
||||||
|
# Second response: final answer
|
||||||
|
final_msg = MagicMock()
|
||||||
|
final_msg.content = "done"
|
||||||
|
final_msg.tool_calls = []
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm)
|
||||||
|
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||||
|
|
||||||
|
# Mock tool
|
||||||
|
tool = AsyncMock()
|
||||||
|
tool.name = "my_tool"
|
||||||
|
tool.ainvoke = AsyncMock(return_value="tool_result")
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [], [tool])
|
||||||
|
assert result == "done"
|
||||||
|
tool.ainvoke.assert_called_once_with({"x": 1})
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_tool_handled(self) -> None:
|
||||||
|
"""Unknown tool names produce an error message instead of crashing."""
|
||||||
|
agent = _StubAgent()
|
||||||
|
|
||||||
|
tool_call_msg = MagicMock()
|
||||||
|
tool_call_msg.content = ""
|
||||||
|
tool_call_msg.tool_calls = [
|
||||||
|
{"id": "call_1", "name": "missing", "args": {}}
|
||||||
|
]
|
||||||
|
|
||||||
|
final_msg = MagicMock()
|
||||||
|
final_msg.content = "recovered"
|
||||||
|
final_msg.tool_calls = []
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm)
|
||||||
|
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [], [])
|
||||||
|
assert result == "recovered"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_max_iter_reached(self) -> None:
|
||||||
|
"""When max iterations are exhausted, a final no-tools call is made."""
|
||||||
|
agent = _StubAgent()
|
||||||
|
|
||||||
|
# Every response requests a tool call
|
||||||
|
loop_msg = MagicMock()
|
||||||
|
loop_msg.content = ""
|
||||||
|
loop_msg.tool_calls = [
|
||||||
|
{"id": "call_x", "name": "t", "args": {}}
|
||||||
|
]
|
||||||
|
|
||||||
|
final_msg = MagicMock()
|
||||||
|
final_msg.content = "gave up"
|
||||||
|
final_msg.tool_calls = []
|
||||||
|
|
||||||
|
tool = AsyncMock()
|
||||||
|
tool.name = "t"
|
||||||
|
tool.ainvoke = AsyncMock(return_value="ok")
|
||||||
|
|
||||||
|
llm_with_tools = AsyncMock()
|
||||||
|
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
||||||
|
|
||||||
|
llm = AsyncMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||||
|
|
||||||
|
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
||||||
|
assert result == "gave up"
|
||||||
|
assert llm_with_tools.ainvoke.call_count == 2
|
||||||
620
tests/test_agents.py
Normal file
620
tests/test_agents.py
Normal file
@@ -0,0 +1,620 @@
|
|||||||
|
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import app.agents # noqa: F401 — triggers @registry.register decorators
|
||||||
|
from app.agents.checkpoint_agent import CheckpointAgent
|
||||||
|
from app.agents.note_agent import NoteAgent
|
||||||
|
from app.agents.project_agent import ProjectAgent
|
||||||
|
from app.agents.task_agent import TaskAgent
|
||||||
|
from app.core.agent_registry import registry
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_llm(response_text: str) -> MagicMock:
|
||||||
|
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.content = response_text
|
||||||
|
msg.tool_calls = []
|
||||||
|
llm = MagicMock()
|
||||||
|
bound = MagicMock()
|
||||||
|
bound.ainvoke = AsyncMock(return_value=msg)
|
||||||
|
llm.bind_tools = MagicMock(return_value=bound)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=msg)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_llm_with_tool_call(
|
||||||
|
tool_name: str, tool_args: dict[str, Any], final_text: str
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Mock LLM that fires one tool call then returns *final_text*."""
|
||||||
|
tool_msg = MagicMock()
|
||||||
|
tool_msg.content = ""
|
||||||
|
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
|
||||||
|
|
||||||
|
final_msg = MagicMock()
|
||||||
|
final_msg.content = final_text
|
||||||
|
final_msg.tool_calls = []
|
||||||
|
|
||||||
|
bound = MagicMock()
|
||||||
|
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=bound)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
# ── Registration ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentRegistration:
|
||||||
|
def test_all_agents_registered(self) -> None:
|
||||||
|
names = {a["name"] for a in registry.list_agents()}
|
||||||
|
assert {
|
||||||
|
"task_agent", "checkpoint_agent", "project_agent", "note_agent"
|
||||||
|
}.issubset(names)
|
||||||
|
|
||||||
|
def test_registry_returns_correct_types(self) -> None:
|
||||||
|
assert isinstance(registry.get("task_agent"), TaskAgent)
|
||||||
|
assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent)
|
||||||
|
assert isinstance(registry.get("project_agent"), ProjectAgent)
|
||||||
|
assert isinstance(registry.get("note_agent"), NoteAgent)
|
||||||
|
|
||||||
|
def test_descriptions_present(self) -> None:
|
||||||
|
for agent_info in registry.list_agents():
|
||||||
|
assert agent_info["description"], f"Empty description: {agent_info['name']}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TaskAgent ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskAgent:
|
||||||
|
def test_name(self) -> None:
|
||||||
|
assert TaskAgent().get_name() == "task_agent"
|
||||||
|
|
||||||
|
def test_description(self) -> None:
|
||||||
|
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||||
|
|
||||||
|
def test_get_tools_count(self) -> None:
|
||||||
|
assert len(TaskAgent().get_tools()) == 8
|
||||||
|
|
||||||
|
def test_tool_names(self) -> None:
|
||||||
|
names = {t.name for t in TaskAgent().get_tools()}
|
||||||
|
assert names == {
|
||||||
|
"list_tasks",
|
||||||
|
"create_task",
|
||||||
|
"update_task",
|
||||||
|
"delete_task",
|
||||||
|
"list_tasks_due_today",
|
||||||
|
"list_task_comments",
|
||||||
|
"add_task_comment",
|
||||||
|
"delete_task_comment",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_returns_string(self) -> None:
|
||||||
|
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Task created.")
|
||||||
|
result = await TaskAgent().handle("create a task", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_no_tool_calls(self) -> None:
|
||||||
|
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Here are your tasks.")
|
||||||
|
result = await TaskAgent().handle("list my tasks", {})
|
||||||
|
assert result == "Here are your tasks."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_with_create_task_tool_call(self) -> None:
|
||||||
|
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||||
|
"create_task",
|
||||||
|
{"title": "Buy groceries", "priority": "low"},
|
||||||
|
"Task 'Buy groceries' created.",
|
||||||
|
)
|
||||||
|
result = await TaskAgent().handle("add a grocery task", {})
|
||||||
|
assert result == "Task 'Buy groceries' created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_empty_context(self) -> None:
|
||||||
|
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Done.")
|
||||||
|
result = await TaskAgent().handle("help", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_rich_context(self) -> None:
|
||||||
|
context = {
|
||||||
|
"user_profile": {"id": "u1", "tier": "pro"},
|
||||||
|
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
||||||
|
}
|
||||||
|
with patch("app.agents.task_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Tasks listed.")
|
||||||
|
result = await TaskAgent().handle("show tasks", context)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskAgentTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tasks_defaults(self) -> None:
|
||||||
|
from app.agents.task_agent import list_tasks
|
||||||
|
result = await list_tasks.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "tasks"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tasks_with_status_filter(self) -> None:
|
||||||
|
from app.agents.task_agent import list_tasks
|
||||||
|
result = await list_tasks.ainvoke({"status": "done"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["filters"]["status"] == "done"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_task_defaults(self) -> None:
|
||||||
|
from app.agents.task_agent import create_task
|
||||||
|
result = await create_task.ainvoke({"title": "Test task"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["table"] == "tasks"
|
||||||
|
assert data["data"]["title"] == "Test task"
|
||||||
|
assert data["data"]["status"] == "todo"
|
||||||
|
assert data["data"]["priority"] == "medium"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_task_with_all_fields(self) -> None:
|
||||||
|
from app.agents.task_agent import create_task
|
||||||
|
result = await create_task.ainvoke({
|
||||||
|
"title": "Deploy",
|
||||||
|
"priority": "high",
|
||||||
|
"status": "in_progress",
|
||||||
|
"project_id": "p1",
|
||||||
|
"is_ai_suggested": 1,
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["priority"] == "high"
|
||||||
|
assert data["data"]["status"] == "in_progress"
|
||||||
|
assert data["data"]["projectId"] == "p1"
|
||||||
|
assert data["data"]["isAiSuggested"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_task_with_status(self) -> None:
|
||||||
|
from app.agents.task_agent import update_task
|
||||||
|
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "update_record"
|
||||||
|
assert data["data"]["id"] == "t1"
|
||||||
|
assert data["data"]["updates"]["status"] == "done"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_task_empty_updates(self) -> None:
|
||||||
|
from app.agents.task_agent import update_task
|
||||||
|
result = await update_task.ainvoke({"task_id": "t1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["updates"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_task(self) -> None:
|
||||||
|
from app.agents.task_agent import delete_task
|
||||||
|
result = await delete_task.ainvoke({"task_id": "t1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["table"] == "tasks"
|
||||||
|
assert data["data"]["id"] == "t1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tasks_due_today(self) -> None:
|
||||||
|
from app.agents.task_agent import list_tasks_due_today
|
||||||
|
result = await list_tasks_due_today.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list_due_today"
|
||||||
|
assert data["table"] == "tasks"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_task_comments(self) -> None:
|
||||||
|
from app.agents.task_agent import list_task_comments
|
||||||
|
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "taskComments"
|
||||||
|
assert data["filters"]["taskId"] == "t1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_task_comment(self) -> None:
|
||||||
|
from app.agents.task_agent import add_task_comment
|
||||||
|
result = await add_task_comment.ainvoke({
|
||||||
|
"task_id": "t1",
|
||||||
|
"author": "Alice",
|
||||||
|
"content": "Looks good!",
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["table"] == "taskComments"
|
||||||
|
assert data["data"]["taskId"] == "t1"
|
||||||
|
assert data["data"]["author"] == "Alice"
|
||||||
|
assert data["data"]["content"] == "Looks good!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_task_comment(self) -> None:
|
||||||
|
from app.agents.task_agent import delete_task_comment
|
||||||
|
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["table"] == "taskComments"
|
||||||
|
assert data["data"]["id"] == "c1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── CheckpointAgent ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointAgent:
|
||||||
|
def test_name(self) -> None:
|
||||||
|
assert CheckpointAgent().get_name() == "checkpoint_agent"
|
||||||
|
|
||||||
|
def test_description(self) -> None:
|
||||||
|
assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete"
|
||||||
|
|
||||||
|
def test_get_tools_count(self) -> None:
|
||||||
|
assert len(CheckpointAgent().get_tools()) == 4
|
||||||
|
|
||||||
|
def test_tool_names(self) -> None:
|
||||||
|
names = {t.name for t in CheckpointAgent().get_tools()}
|
||||||
|
assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_no_tool_calls(self) -> None:
|
||||||
|
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("No checkpoints found.")
|
||||||
|
result = await CheckpointAgent().handle("list checkpoints", {})
|
||||||
|
assert result == "No checkpoints found."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_with_create_tool_call(self) -> None:
|
||||||
|
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||||
|
"create_checkpoint",
|
||||||
|
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
||||||
|
"Checkpoint 'MVP Launch' created.",
|
||||||
|
)
|
||||||
|
result = await CheckpointAgent().handle("add MVP checkpoint", {})
|
||||||
|
assert result == "Checkpoint 'MVP Launch' created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_empty_context(self) -> None:
|
||||||
|
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Done.")
|
||||||
|
result = await CheckpointAgent().handle("show milestones", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointAgentTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints_no_project(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import list_checkpoints
|
||||||
|
result = await list_checkpoints.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "checkpoints"
|
||||||
|
assert data["filters"]["projectId"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints_with_project(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import list_checkpoints
|
||||||
|
result = await list_checkpoints.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["filters"]["projectId"] == "p1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpoint(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import create_checkpoint
|
||||||
|
result = await create_checkpoint.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"title": "Beta release",
|
||||||
|
"date": 1700000000000,
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["table"] == "checkpoints"
|
||||||
|
assert data["data"]["projectId"] == "p1"
|
||||||
|
assert data["data"]["title"] == "Beta release"
|
||||||
|
assert data["data"]["date"] == 1700000000000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpoint_ai_suggested(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import create_checkpoint
|
||||||
|
result = await create_checkpoint.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"title": "Review",
|
||||||
|
"date": 1700000000000,
|
||||||
|
"is_ai_suggested": 1,
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["isAiSuggested"] == 1
|
||||||
|
assert data["data"]["isApproved"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_checkpoint_approve(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import update_checkpoint
|
||||||
|
result = await update_checkpoint.ainvoke({
|
||||||
|
"checkpoint_id": "c1",
|
||||||
|
"is_approved": 1,
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "update_record"
|
||||||
|
assert data["data"]["id"] == "c1"
|
||||||
|
assert data["data"]["updates"]["isApproved"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_checkpoint_empty_updates(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import update_checkpoint
|
||||||
|
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["updates"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_checkpoint(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import delete_checkpoint
|
||||||
|
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["table"] == "checkpoints"
|
||||||
|
assert data["data"]["id"] == "c1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── ProjectAgent ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectAgent:
|
||||||
|
def test_name(self) -> None:
|
||||||
|
assert ProjectAgent().get_name() == "project_agent"
|
||||||
|
|
||||||
|
def test_description(self) -> None:
|
||||||
|
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
|
||||||
|
|
||||||
|
def test_get_tools_count(self) -> None:
|
||||||
|
assert len(ProjectAgent().get_tools()) == 6
|
||||||
|
|
||||||
|
def test_tool_names(self) -> None:
|
||||||
|
names = {t.name for t in ProjectAgent().get_tools()}
|
||||||
|
assert names == {
|
||||||
|
"list_projects",
|
||||||
|
"list_all_projects",
|
||||||
|
"get_project",
|
||||||
|
"create_project",
|
||||||
|
"update_project",
|
||||||
|
"delete_project",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_no_tool_calls(self) -> None:
|
||||||
|
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
||||||
|
result = await ProjectAgent().handle("show my projects", {})
|
||||||
|
assert result == "Project Alpha is active."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_with_create_project_tool_call(self) -> None:
|
||||||
|
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||||
|
"create_project",
|
||||||
|
{"name": "Pippo"},
|
||||||
|
"Project 'Pippo' created.",
|
||||||
|
)
|
||||||
|
result = await ProjectAgent().handle("create project Pippo", {})
|
||||||
|
assert result == "Project 'Pippo' created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_empty_context(self) -> None:
|
||||||
|
with patch("app.agents.project_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Done.")
|
||||||
|
result = await ProjectAgent().handle("archive old project", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectAgentTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_projects_defaults(self) -> None:
|
||||||
|
from app.agents.project_agent import list_projects
|
||||||
|
result = await list_projects.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "projects"
|
||||||
|
assert data["filters"]["includeArchived"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_projects_include_archived(self) -> None:
|
||||||
|
from app.agents.project_agent import list_projects
|
||||||
|
result = await list_projects.ainvoke({"include_archived": 1})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["filters"]["includeArchived"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_all_projects(self) -> None:
|
||||||
|
from app.agents.project_agent import list_all_projects
|
||||||
|
result = await list_all_projects.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list_all"
|
||||||
|
assert data["table"] == "projects"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project(self) -> None:
|
||||||
|
from app.agents.project_agent import get_project
|
||||||
|
result = await get_project.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "get"
|
||||||
|
assert data["table"] == "projects"
|
||||||
|
assert data["data"]["id"] == "p1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_name_only(self) -> None:
|
||||||
|
from app.agents.project_agent import create_project
|
||||||
|
result = await create_project.ainvoke({"name": "Alpha"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["data"]["name"] == "Alpha"
|
||||||
|
assert data["data"]["clientId"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_with_client(self) -> None:
|
||||||
|
from app.agents.project_agent import create_project
|
||||||
|
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["clientId"] == "cl1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_project_archive(self) -> None:
|
||||||
|
from app.agents.project_agent import update_project
|
||||||
|
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "update_record"
|
||||||
|
assert data["data"]["id"] == "p1"
|
||||||
|
assert data["data"]["updates"]["status"] == "archived"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_project_empty_updates(self) -> None:
|
||||||
|
from app.agents.project_agent import update_project
|
||||||
|
result = await update_project.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["updates"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_project(self) -> None:
|
||||||
|
from app.agents.project_agent import delete_project
|
||||||
|
result = await delete_project.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["data"]["id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── NoteAgent ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoteAgent:
|
||||||
|
def test_name(self) -> None:
|
||||||
|
assert NoteAgent().get_name() == "note_agent"
|
||||||
|
|
||||||
|
def test_description(self) -> None:
|
||||||
|
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
|
||||||
|
|
||||||
|
def test_get_tools_count(self) -> None:
|
||||||
|
assert len(NoteAgent().get_tools()) == 5
|
||||||
|
|
||||||
|
def test_tool_names(self) -> None:
|
||||||
|
names = {t.name for t in NoteAgent().get_tools()}
|
||||||
|
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_no_tool_calls(self) -> None:
|
||||||
|
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Note created.")
|
||||||
|
result = await NoteAgent().handle("create a note", {})
|
||||||
|
assert result == "Note created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_with_create_note_tool_call(self) -> None:
|
||||||
|
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||||
|
"create_note",
|
||||||
|
{"title": "Daily log", "content": "# Today\nAll good."},
|
||||||
|
"Note 'Daily log' created.",
|
||||||
|
)
|
||||||
|
result = await NoteAgent().handle("log today's progress", {})
|
||||||
|
assert result == "Note 'Daily log' created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_empty_context(self) -> None:
|
||||||
|
with patch("app.agents.note_agent.get_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Done.")
|
||||||
|
result = await NoteAgent().handle("show notes", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoteAgentTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_notes_no_project(self) -> None:
|
||||||
|
from app.agents.note_agent import list_notes
|
||||||
|
result = await list_notes.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "notes"
|
||||||
|
assert data["filters"]["projectId"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_notes_with_project(self) -> None:
|
||||||
|
from app.agents.note_agent import list_notes
|
||||||
|
result = await list_notes.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["filters"]["projectId"] == "p1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_note(self) -> None:
|
||||||
|
from app.agents.note_agent import get_note
|
||||||
|
result = await get_note.ainvoke({"note_id": "n1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "get"
|
||||||
|
assert data["table"] == "notes"
|
||||||
|
assert data["data"]["id"] == "n1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_note_minimal(self) -> None:
|
||||||
|
from app.agents.note_agent import create_note
|
||||||
|
result = await create_note.ainvoke({
|
||||||
|
"title": "Daily log",
|
||||||
|
"content": "# Today\nAll good.",
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["table"] == "notes"
|
||||||
|
assert data["data"]["title"] == "Daily log"
|
||||||
|
assert data["data"]["content"] == "# Today\nAll good."
|
||||||
|
assert data["data"]["projectId"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_note_with_project(self) -> None:
|
||||||
|
from app.agents.note_agent import create_note
|
||||||
|
result = await create_note.ainvoke({
|
||||||
|
"title": "Sprint notes",
|
||||||
|
"content": "## Sprint 1",
|
||||||
|
"project_id": "p1",
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["projectId"] == "p1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_note_content_only(self) -> None:
|
||||||
|
from app.agents.note_agent import update_note
|
||||||
|
result = await update_note.ainvoke({
|
||||||
|
"note_id": "n1",
|
||||||
|
"content": "# Updated content",
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "update_record"
|
||||||
|
assert data["data"]["id"] == "n1"
|
||||||
|
assert data["data"]["updates"]["content"] == "# Updated content"
|
||||||
|
assert "title" not in data["data"]["updates"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_note_empty_updates(self) -> None:
|
||||||
|
from app.agents.note_agent import update_note
|
||||||
|
result = await update_note.ainvoke({"note_id": "n1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["updates"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_note(self) -> None:
|
||||||
|
from app.agents.note_agent import delete_note
|
||||||
|
result = await delete_note.ainvoke({"note_id": "n1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["table"] == "notes"
|
||||||
|
assert data["data"]["id"] == "n1"
|
||||||
206
tests/test_auth.py
Normal file
206
tests/test_auth.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""Tests for auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
|
Exercises the full auth lifecycle through the FastAPI TestClient against the
|
||||||
|
in-memory SQLite test database seeded by ``conftest.py``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from tests.conftest import auth_header, TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRegister ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegister:
|
||||||
|
"""POST /api/v1/auth/register"""
|
||||||
|
|
||||||
|
def test_register_success(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "new@example.com", "password": "Str0ngP@ss!"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert "refresh_token" in data
|
||||||
|
assert "expires_at" in data
|
||||||
|
# expires_at should be a future millisecond timestamp
|
||||||
|
assert data["expires_at"] > int(time.time() * 1000)
|
||||||
|
|
||||||
|
def test_register_returns_valid_jwt(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "jwt-check@example.com", "password": "P@ss1234"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
token = resp.json()["access_token"]
|
||||||
|
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
||||||
|
assert payload["email"] == "jwt-check@example.com"
|
||||||
|
assert payload["tier"] == "free"
|
||||||
|
assert "sub" in payload
|
||||||
|
|
||||||
|
def test_register_duplicate_email(self, client) -> None:
|
||||||
|
client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "dupe@example.com", "password": "Pass1234"},
|
||||||
|
)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "dupe@example.com", "password": "Pass5678"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
def test_register_missing_password(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "no-pass@example.com"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_register_missing_email(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"password": "OnlyPass"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestLogin ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogin:
|
||||||
|
"""POST /api/v1/auth/login"""
|
||||||
|
|
||||||
|
def _register(self, client, email="login@example.com", password="MyP@ss123"):
|
||||||
|
client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_login_success(self, client) -> None:
|
||||||
|
self._register(client)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": "login@example.com", "password": "MyP@ss123"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert "refresh_token" in data
|
||||||
|
assert "expires_at" in data
|
||||||
|
|
||||||
|
def test_login_wrong_password(self, client) -> None:
|
||||||
|
self._register(client)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": "login@example.com", "password": "WrongPass!"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_login_unknown_email(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": "ghost@example.com", "password": "Whatever"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRefresh ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRefresh:
|
||||||
|
"""POST /api/v1/auth/refresh"""
|
||||||
|
|
||||||
|
def _register_and_get_tokens(self, client, email="refresh@example.com"):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": email, "password": "RefPass123!"},
|
||||||
|
)
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
def test_refresh_returns_new_tokens(self, client) -> None:
|
||||||
|
tokens = self._register_and_get_tokens(client)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
json={"refresh_token": tokens["refresh_token"]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert "refresh_token" in data
|
||||||
|
# New refresh token should differ from old one (rotation)
|
||||||
|
assert data["refresh_token"] != tokens["refresh_token"]
|
||||||
|
|
||||||
|
def test_refresh_old_token_rejected(self, client) -> None:
|
||||||
|
"""After rotation, the original refresh token must be rejected."""
|
||||||
|
tokens = self._register_and_get_tokens(client, email="rotate@example.com")
|
||||||
|
old_rt = tokens["refresh_token"]
|
||||||
|
|
||||||
|
# First refresh succeeds and rotates the token
|
||||||
|
client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
|
||||||
|
|
||||||
|
# Second attempt with the old token must fail
|
||||||
|
resp = client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_refresh_bogus_token(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
json={"refresh_token": "not-a-real-token"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestMe ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMe:
|
||||||
|
"""GET /api/v1/auth/me"""
|
||||||
|
|
||||||
|
def test_me_with_valid_jwt(self, client) -> None:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["id"] == TEST_USER_IDS["power"]
|
||||||
|
assert data["email"] == "power@test.com"
|
||||||
|
assert data["tier"] == "power"
|
||||||
|
|
||||||
|
def test_me_returns_correct_tier(self, client) -> None:
|
||||||
|
"""Tier comes from the live subscription row, not the JWT claim."""
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=auth_header("free"))
|
||||||
|
assert resp.json()["tier"] == "free"
|
||||||
|
|
||||||
|
def test_me_missing_token(self, client) -> None:
|
||||||
|
resp = client.get("/api/v1/auth/me")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_me_expired_token(self, client) -> None:
|
||||||
|
"""A JWT with ``exp`` in the past must be rejected."""
|
||||||
|
payload = {
|
||||||
|
"sub": TEST_USER_IDS["power"],
|
||||||
|
"email": "power@test.com",
|
||||||
|
"tier": "power",
|
||||||
|
"exp": int(time.time()) - 3600, # 1 hour ago
|
||||||
|
"iat": int(time.time()) - 7200,
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
|
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_me_invalid_signature(self, client) -> None:
|
||||||
|
payload = {
|
||||||
|
"sub": TEST_USER_IDS["power"],
|
||||||
|
"email": "power@test.com",
|
||||||
|
"tier": "power",
|
||||||
|
"exp": int(time.time()) + 3600,
|
||||||
|
"iat": int(time.time()),
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||||
|
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||||
|
assert resp.status_code == 401
|
||||||
243
tests/test_backup.py
Normal file
243
tests/test_backup.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""Tests for backup routes: upload, download, history, delete.
|
||||||
|
|
||||||
|
Exercises the backup lifecycle through the FastAPI TestClient against the
|
||||||
|
in-memory SQLite test database and moto-mocked S3 bucket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
from tests.conftest import auth_header, TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_BLOB = b"encrypted-backup-blob-opaque-bytes"
|
||||||
|
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||||
|
_VERSION = 1
|
||||||
|
_TIMESTAMP = 1700000000000 # arbitrary ms timestamp
|
||||||
|
|
||||||
|
|
||||||
|
def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]:
|
||||||
|
"""Return auth + backup metadata headers."""
|
||||||
|
headers = auth_header(tier)
|
||||||
|
headers["X-Backup-Version"] = str(overrides.get("version", _VERSION))
|
||||||
|
headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP))
|
||||||
|
headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM)
|
||||||
|
headers["Content-Type"] = "application/octet-stream"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821
|
||||||
|
"""Upload a backup blob and return the response."""
|
||||||
|
return client.put(
|
||||||
|
"/api/v1/backup",
|
||||||
|
content=overrides.pop("blob", _BLOB),
|
||||||
|
headers=_backup_headers(tier, **overrides),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestUploadBackup ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadBackup:
|
||||||
|
"""PUT /api/v1/backup"""
|
||||||
|
|
||||||
|
def test_upload_success(self, client, s3_bucket) -> None:
|
||||||
|
resp = _upload(client, tier="power")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
def test_upload_creates_history_entry(self, client, s3_bucket) -> None:
|
||||||
|
_upload(client, tier="power")
|
||||||
|
history = client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header("power")
|
||||||
|
).json()
|
||||||
|
assert len(history) == 1
|
||||||
|
assert history[0]["version"] == _VERSION
|
||||||
|
assert history[0]["timestamp"] == _TIMESTAMP
|
||||||
|
assert history[0]["checksum"] == _CHECKSUM
|
||||||
|
|
||||||
|
def test_upload_bad_checksum(self, client, s3_bucket) -> None:
|
||||||
|
resp = _upload(client, tier="power", checksum="0" * 64)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_upload_free_tier_blocked(self, client, s3_bucket) -> None:
|
||||||
|
"""Free tier has backup_gb=0 → should return 402."""
|
||||||
|
resp = _upload(client, tier="free")
|
||||||
|
assert resp.status_code == 402
|
||||||
|
|
||||||
|
def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None:
|
||||||
|
"""Pro tier has backup_gb=5 → small blob succeeds."""
|
||||||
|
resp = _upload(client, tier="pro")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestDownloadBackup ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDownloadBackup:
|
||||||
|
"""GET /api/v1/backup"""
|
||||||
|
|
||||||
|
def test_download_latest(self, client, s3_bucket) -> None:
|
||||||
|
_upload(client, tier="power")
|
||||||
|
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == _BLOB
|
||||||
|
assert resp.headers["X-Checksum"] == _CHECKSUM
|
||||||
|
assert resp.headers["X-Backup-Version"] == str(_VERSION)
|
||||||
|
|
||||||
|
def test_download_no_backup_returns_404(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None:
|
||||||
|
"""When If-Modified-Since is after the backup timestamp → 304."""
|
||||||
|
_upload(client, tier="power", timestamp=1700000000000)
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/backup",
|
||||||
|
headers={
|
||||||
|
**auth_header("power"),
|
||||||
|
"If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 304
|
||||||
|
|
||||||
|
def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None:
|
||||||
|
"""When If-Modified-Since is before the backup timestamp → serve blob."""
|
||||||
|
_upload(client, tier="power", timestamp=1700000000000)
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/backup",
|
||||||
|
headers={
|
||||||
|
**auth_header("power"),
|
||||||
|
"If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == _BLOB
|
||||||
|
|
||||||
|
def test_download_multiple_returns_latest(self, client, s3_bucket) -> None:
|
||||||
|
"""When multiple backups exist, GET returns the one with the highest timestamp."""
|
||||||
|
_upload(client, tier="power", timestamp=1000)
|
||||||
|
blob2 = b"second-encrypted-backup"
|
||||||
|
checksum2 = hashlib.sha256(blob2).hexdigest()
|
||||||
|
_upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2)
|
||||||
|
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == blob2
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBackupHistory ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackupHistory:
|
||||||
|
"""GET /api/v1/backup/history"""
|
||||||
|
|
||||||
|
def test_history_empty(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.get("/api/v1/backup/history", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
def test_history_returns_entries(self, client, s3_bucket) -> None:
|
||||||
|
_upload(client, tier="power", timestamp=1000)
|
||||||
|
_upload(client, tier="power", timestamp=2000)
|
||||||
|
history = client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header("power")
|
||||||
|
).json()
|
||||||
|
assert len(history) == 2
|
||||||
|
# Ordered by timestamp descending
|
||||||
|
assert history[0]["timestamp"] == 2000
|
||||||
|
assert history[1]["timestamp"] == 1000
|
||||||
|
|
||||||
|
def test_history_isolated_per_user(self, client, s3_bucket) -> None:
|
||||||
|
"""One user's backups should not appear in another user's history."""
|
||||||
|
_upload(client, tier="power")
|
||||||
|
resp = client.get("/api/v1/backup/history", headers=auth_header("team"))
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestDeleteBackup ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteBackup:
|
||||||
|
"""DELETE /api/v1/backup/{backup_id}"""
|
||||||
|
|
||||||
|
def _get_backup_id(self, client, tier="power") -> str:
|
||||||
|
"""Upload a backup and return its DB id from history."""
|
||||||
|
_upload(client, tier=tier)
|
||||||
|
client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header(tier)
|
||||||
|
).json()
|
||||||
|
# History returns BackupMetadata schema which doesn't have `id`.
|
||||||
|
# We need to look it up via a different means.
|
||||||
|
# Since there's only 1 backup, find via history length.
|
||||||
|
# Actually the schema doesn't return id — let's verify via re-download.
|
||||||
|
# We'll use a workaround: upload, then list history to confirm it exists,
|
||||||
|
# then try to delete — but we need the id...
|
||||||
|
# Let's check if history includes an id field.
|
||||||
|
# The schema is: version, timestamp, checksum, chunk_count — no id.
|
||||||
|
# We'll need to query the DB directly or use a known ID.
|
||||||
|
# For testing, we'll search history then use the DB.
|
||||||
|
return None # pragma: no cover — overridden below
|
||||||
|
|
||||||
|
def test_delete_success(self, client, s3_bucket, db_session) -> None:
|
||||||
|
_upload(client, tier="power")
|
||||||
|
|
||||||
|
# Discover the backup_id via direct DB query
|
||||||
|
import asyncio
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models import BackupMetadata
|
||||||
|
|
||||||
|
async def _get_id():
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(BackupMetadata.id).where(
|
||||||
|
BackupMetadata.user_id == TEST_USER_IDS["power"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
|
||||||
|
|
||||||
|
resp = client.delete(
|
||||||
|
f"/api/v1/backup/{backup_id}", headers=auth_header("power")
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
# History should now be empty
|
||||||
|
history = client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header("power")
|
||||||
|
).json()
|
||||||
|
assert history == []
|
||||||
|
|
||||||
|
def test_delete_nonexistent(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.delete(
|
||||||
|
"/api/v1/backup/no-such-id", headers=auth_header("power")
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None:
|
||||||
|
"""Cannot delete another user's backup (ownership check returns 404)."""
|
||||||
|
_upload(client, tier="power")
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models import BackupMetadata
|
||||||
|
|
||||||
|
async def _get_id():
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(BackupMetadata.id).where(
|
||||||
|
BackupMetadata.user_id == TEST_USER_IDS["power"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
|
||||||
|
|
||||||
|
# team user tries to delete power user's backup → 404
|
||||||
|
resp = client.delete(
|
||||||
|
f"/api/v1/backup/{backup_id}", headers=auth_header("team")
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
286
tests/test_execution_plan.py
Normal file
286
tests/test_execution_plan.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.execution_plan import (
|
||||||
|
ExecutionPlanBuilder,
|
||||||
|
PlanCache,
|
||||||
|
PromptTemplateRegistry,
|
||||||
|
plan_cache,
|
||||||
|
template_registry,
|
||||||
|
)
|
||||||
|
from app.schemas import ExecutionPlan
|
||||||
|
|
||||||
|
|
||||||
|
# ── PromptTemplateRegistry ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateRegistry:
|
||||||
|
def test_register_and_get(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_foo", "You are a foo agent.")
|
||||||
|
assert reg.get("tpl_foo") == "You are a foo agent."
|
||||||
|
|
||||||
|
def test_get_unknown_raises_key_error(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
with pytest.raises(KeyError, match="tpl_missing"):
|
||||||
|
reg.get("tpl_missing")
|
||||||
|
|
||||||
|
def test_has_returns_true_for_registered(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_x", "prompt text")
|
||||||
|
assert reg.has("tpl_x") is True
|
||||||
|
|
||||||
|
def test_has_returns_false_for_unregistered(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
assert reg.has("tpl_missing") is False
|
||||||
|
|
||||||
|
def test_list_ids_returns_all_registered_ids(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_a", "a")
|
||||||
|
reg.register("tpl_b", "b")
|
||||||
|
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
|
||||||
|
|
||||||
|
def test_list_ids_does_not_return_prompt_text(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_secret", "top secret prompt")
|
||||||
|
ids = reg.list_ids()
|
||||||
|
assert "top secret prompt" not in ids
|
||||||
|
|
||||||
|
def test_overwrite_existing_template(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_x", "v1")
|
||||||
|
reg.register("tpl_x", "v2")
|
||||||
|
assert reg.get("tpl_x") == "v2"
|
||||||
|
|
||||||
|
def test_empty_registry_has_no_ids(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
assert reg.list_ids() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecutionPlanBuilder:
|
||||||
|
def test_builds_empty_plan(self) -> None:
|
||||||
|
plan = ExecutionPlanBuilder("task_agent").build()
|
||||||
|
assert plan.agent == "task_agent"
|
||||||
|
assert plan.steps == []
|
||||||
|
|
||||||
|
def test_add_step_basic(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_step("create_task", {"priority": "high"})
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert len(plan.steps) == 1
|
||||||
|
assert plan.steps[0].action == "create_task"
|
||||||
|
assert plan.steps[0].variables == {"priority": "high"}
|
||||||
|
assert plan.steps[0].prompt_template is None
|
||||||
|
assert plan.steps[0].data_from_step is None
|
||||||
|
|
||||||
|
def test_add_step_no_params(self) -> None:
|
||||||
|
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
|
||||||
|
assert plan.steps[0].variables is None
|
||||||
|
|
||||||
|
def test_add_llm_step(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_llm_step("tpl_task_default", {"message": "hi"})
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert plan.steps[0].action == "llm"
|
||||||
|
assert plan.steps[0].prompt_template == "tpl_task_default"
|
||||||
|
assert plan.steps[0].variables == {"message": "hi"}
|
||||||
|
|
||||||
|
def test_add_llm_step_no_variables(self) -> None:
|
||||||
|
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
|
||||||
|
assert plan.steps[0].variables is None
|
||||||
|
|
||||||
|
def test_add_data_step(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_step("fetch_data")
|
||||||
|
.add_data_step("transform", data_from_step=0)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert plan.steps[1].action == "transform"
|
||||||
|
assert plan.steps[1].data_from_step == 0
|
||||||
|
|
||||||
|
def test_fluent_chaining_returns_builder(self) -> None:
|
||||||
|
builder = ExecutionPlanBuilder("analytics_agent")
|
||||||
|
result = builder.add_step("a")
|
||||||
|
assert result is builder
|
||||||
|
|
||||||
|
def test_fluent_chain_multiple_steps(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("analytics_agent")
|
||||||
|
.add_llm_step("tpl_analytics_default")
|
||||||
|
.add_step("format_output")
|
||||||
|
.add_data_step("store", data_from_step=0)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert len(plan.steps) == 3
|
||||||
|
|
||||||
|
def test_build_validates_data_from_step_out_of_range(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="data_from_step"):
|
||||||
|
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
|
||||||
|
|
||||||
|
def test_build_validates_data_from_step_self_reference(self) -> None:
|
||||||
|
"""data_from_step=0 on the first step (index 0) is invalid."""
|
||||||
|
with pytest.raises(ValueError, match="data_from_step"):
|
||||||
|
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
|
||||||
|
|
||||||
|
def test_build_validates_data_from_step_negative(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="data_from_step"):
|
||||||
|
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
|
||||||
|
|
||||||
|
def test_valid_data_from_step_at_index_two(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_step("step0")
|
||||||
|
.add_step("step1")
|
||||||
|
.add_data_step("step2", data_from_step=1)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert plan.steps[2].data_from_step == 1
|
||||||
|
|
||||||
|
def test_data_from_step_zero_valid_at_index_one(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_step("step0")
|
||||||
|
.add_data_step("step1", data_from_step=0)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert plan.steps[1].data_from_step == 0
|
||||||
|
|
||||||
|
def test_build_returns_new_plan_each_call(self) -> None:
|
||||||
|
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
|
||||||
|
plan1 = builder.build()
|
||||||
|
plan2 = builder.build()
|
||||||
|
assert plan1 is not plan2
|
||||||
|
assert plan1.steps == plan2.steps
|
||||||
|
|
||||||
|
def test_plan_is_execution_plan_instance(self) -> None:
|
||||||
|
plan = ExecutionPlanBuilder("task_agent").build()
|
||||||
|
assert isinstance(plan, ExecutionPlan)
|
||||||
|
|
||||||
|
|
||||||
|
# ── PlanCache ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPlanCache:
|
||||||
|
def _plan(self, agent: str = "a") -> ExecutionPlan:
|
||||||
|
return ExecutionPlanBuilder(agent).build()
|
||||||
|
|
||||||
|
def test_cache_and_get(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
plan = self._plan()
|
||||||
|
cache.cache_plan("key1", plan)
|
||||||
|
assert cache.get_plan("key1") is plan
|
||||||
|
|
||||||
|
def test_get_missing_returns_none(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
assert cache.get_plan("nonexistent") is None
|
||||||
|
|
||||||
|
def test_get_all_playbooks_empty(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
assert cache.get_all_playbooks() == []
|
||||||
|
|
||||||
|
def test_get_all_playbooks_returns_all_stored(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
p1, p2 = self._plan("a"), self._plan("b")
|
||||||
|
cache.cache_plan("k1", p1)
|
||||||
|
cache.cache_plan("k2", p2)
|
||||||
|
playbooks = cache.get_all_playbooks()
|
||||||
|
assert len(playbooks) == 2
|
||||||
|
assert p1 in playbooks
|
||||||
|
assert p2 in playbooks
|
||||||
|
|
||||||
|
def test_lru_evicts_oldest_entry(self) -> None:
|
||||||
|
cache = PlanCache(maxsize=2)
|
||||||
|
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
||||||
|
cache.cache_plan("k1", p1)
|
||||||
|
cache.cache_plan("k2", p2)
|
||||||
|
cache.cache_plan("k3", p3) # k1 should be evicted
|
||||||
|
assert cache.get_plan("k1") is None
|
||||||
|
assert cache.get_plan("k2") is p2
|
||||||
|
assert cache.get_plan("k3") is p3
|
||||||
|
|
||||||
|
def test_lru_access_updates_recency(self) -> None:
|
||||||
|
cache = PlanCache(maxsize=2)
|
||||||
|
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
||||||
|
cache.cache_plan("k1", p1)
|
||||||
|
cache.cache_plan("k2", p2)
|
||||||
|
cache.get_plan("k1") # k1 is now most-recently used
|
||||||
|
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
|
||||||
|
assert cache.get_plan("k1") is p1
|
||||||
|
assert cache.get_plan("k2") is None
|
||||||
|
assert cache.get_plan("k3") is p3
|
||||||
|
|
||||||
|
def test_overwrite_existing_key(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
p1, p2 = self._plan("a"), self._plan("b")
|
||||||
|
cache.cache_plan("same_key", p1)
|
||||||
|
cache.cache_plan("same_key", p2)
|
||||||
|
assert cache.get_plan("same_key") is p2
|
||||||
|
assert len(cache.get_all_playbooks()) == 1
|
||||||
|
|
||||||
|
def test_overwrite_does_not_consume_capacity(self) -> None:
|
||||||
|
cache = PlanCache(maxsize=2)
|
||||||
|
p1, p2 = self._plan("a"), self._plan("b")
|
||||||
|
cache.cache_plan("k1", p1)
|
||||||
|
cache.cache_plan("k1", p2) # overwrite, not a new slot
|
||||||
|
cache.cache_plan("k2", p1) # should fit without eviction
|
||||||
|
assert cache.get_plan("k1") is p2
|
||||||
|
assert cache.get_plan("k2") is p1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Module-level singletons ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestModuleSingletons:
|
||||||
|
def test_template_registry_has_all_agent_defaults(self) -> None:
|
||||||
|
for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"):
|
||||||
|
assert template_registry.has(f"tpl_{agent}_default"), (
|
||||||
|
f"Missing template: tpl_{agent}_default"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_template_registry_has_operation_templates(self) -> None:
|
||||||
|
assert template_registry.has("tpl_task_extract_from_project")
|
||||||
|
assert template_registry.has("tpl_note_weekly_summary")
|
||||||
|
|
||||||
|
def test_template_registry_get_returns_non_empty_string(self) -> None:
|
||||||
|
text = template_registry.get("tpl_task_agent_default")
|
||||||
|
assert isinstance(text, str)
|
||||||
|
assert len(text) > 0
|
||||||
|
|
||||||
|
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
|
||||||
|
assert len(plan_cache.get_all_playbooks()) >= 2
|
||||||
|
|
||||||
|
def test_playbook_create_tasks_from_project(self) -> None:
|
||||||
|
plan = plan_cache.get_plan("create_tasks_from_project")
|
||||||
|
assert plan is not None
|
||||||
|
assert plan.agent == "project_agent"
|
||||||
|
assert len(plan.steps) == 2
|
||||||
|
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
|
||||||
|
assert plan.steps[1].data_from_step == 0
|
||||||
|
|
||||||
|
def test_playbook_generate_weekly_note(self) -> None:
|
||||||
|
plan = plan_cache.get_plan("generate_weekly_note")
|
||||||
|
assert plan is not None
|
||||||
|
assert plan.agent == "note_agent"
|
||||||
|
assert len(plan.steps) == 2
|
||||||
|
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
|
||||||
|
assert plan.steps[1].data_from_step == 0
|
||||||
|
|
||||||
|
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
|
||||||
|
"""Plans must not embed prompt text — only template IDs."""
|
||||||
|
for plan in plan_cache.get_all_playbooks():
|
||||||
|
for step in plan.steps:
|
||||||
|
if step.prompt_template is not None:
|
||||||
|
assert step.prompt_template.startswith("tpl_"), (
|
||||||
|
f"prompt_template looks like raw text: {step.prompt_template!r}"
|
||||||
|
)
|
||||||
322
tests/test_middleware.py
Normal file
322
tests/test_middleware.py
Normal file
@@ -0,0 +1,322 @@
|
|||||||
|
"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer.
|
||||||
|
|
||||||
|
Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT).
|
||||||
|
Rate limit: use unique user UUIDs per test so windows are independent;
|
||||||
|
the free-tier threshold (20 req/min) is exercised directly.
|
||||||
|
Sanitizer: the orchestrator is mocked to inject controlled prompt
|
||||||
|
fragments, and the chat endpoint response body is inspected.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.schemas import ChatResponse
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Autouse: redirect all DB access to the in-memory SQLite test engine.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
"""Route all get_session calls to the test SQLite session."""
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
_CHAT_BODY = {
|
||||||
|
"message": "hello",
|
||||||
|
"context": {
|
||||||
|
"user_profile": {},
|
||||||
|
"relevant_documents": [],
|
||||||
|
"recent_tasks": [],
|
||||||
|
"conversation_history": [],
|
||||||
|
},
|
||||||
|
"execution_mode": "direct",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_jwt(
|
||||||
|
*,
|
||||||
|
user_id: str | None = None,
|
||||||
|
email: str = "test@example.com",
|
||||||
|
tier: str = "free",
|
||||||
|
exp_offset: int = 3600,
|
||||||
|
secret: str | None = None,
|
||||||
|
include_sub: bool = True,
|
||||||
|
) -> str:
|
||||||
|
"""Mint a test JWT signed with the configured (or custom) secret."""
|
||||||
|
uid = user_id or str(uuid.uuid4())
|
||||||
|
now = int(time.time())
|
||||||
|
payload: dict = {
|
||||||
|
"email": email,
|
||||||
|
"tier": tier,
|
||||||
|
"exp": now + exp_offset,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
if include_sub:
|
||||||
|
payload["sub"] = uid
|
||||||
|
key = secret or settings.JWT_SECRET
|
||||||
|
return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def _auth_header(token: str) -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {token}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Auth middleware
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestAuthMiddleware:
|
||||||
|
"""Tests exercised via GET /api/v1/auth/me."""
|
||||||
|
|
||||||
|
def test_valid_token_returns_profile(self) -> None:
|
||||||
|
# Use the seeded pro user so the subscription lookup returns 'pro'.
|
||||||
|
uid = TEST_USER_IDS["pro"]
|
||||||
|
token = _make_jwt(user_id=uid, email="pro@test.com", tier="pro")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["id"] == uid
|
||||||
|
assert data["email"] == "pro@test.com"
|
||||||
|
assert data["tier"] == "pro"
|
||||||
|
|
||||||
|
def test_missing_token_returns_401(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_expired_token_returns_401(self) -> None:
|
||||||
|
token = _make_jwt(exp_offset=-1) # already expired
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_wrong_signature_returns_401(self) -> None:
|
||||||
|
token = _make_jwt(secret="totally-wrong-secret")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_missing_sub_claim_returns_401(self) -> None:
|
||||||
|
token = _make_jwt(include_sub=False)
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_malformed_token_returns_401(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"}
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Rate limiter middleware
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRateLimitMiddleware:
|
||||||
|
"""Each test uses a fresh unique user_id so windows never collide."""
|
||||||
|
|
||||||
|
def _unique_token(self, tier: str = "free") -> str:
|
||||||
|
return _make_jwt(user_id=str(uuid.uuid4()), tier=tier)
|
||||||
|
|
||||||
|
def test_free_tier_allows_up_to_20_requests(self) -> None:
|
||||||
|
token = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(20):
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_free_tier_blocks_21st_request(self) -> None:
|
||||||
|
token = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(20):
|
||||||
|
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 429
|
||||||
|
|
||||||
|
def test_429_includes_retry_after_header(self) -> None:
|
||||||
|
token = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(20):
|
||||||
|
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 429
|
||||||
|
assert "retry-after" in resp.headers
|
||||||
|
retry_after = int(resp.headers["retry-after"])
|
||||||
|
assert retry_after >= 1
|
||||||
|
|
||||||
|
def test_429_response_has_detail_field(self) -> None:
|
||||||
|
token = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(20):
|
||||||
|
client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 429
|
||||||
|
assert "detail" in resp.json()
|
||||||
|
|
||||||
|
def test_pro_tier_allows_60_requests(self) -> None:
|
||||||
|
token = self._unique_token("pro")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Sample: first 60 succeed, 61st is blocked.
|
||||||
|
for _ in range(60):
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
|
assert resp.status_code == 429
|
||||||
|
|
||||||
|
def test_independent_users_have_separate_windows(self) -> None:
|
||||||
|
token_a = self._unique_token("free")
|
||||||
|
token_b = self._unique_token("free")
|
||||||
|
with TestClient(app) as client:
|
||||||
|
# Exhaust user A's quota.
|
||||||
|
for _ in range(20):
|
||||||
|
client.get("/api/v1/auth/me", headers=_auth_header(token_a))
|
||||||
|
assert (
|
||||||
|
client.get(
|
||||||
|
"/api/v1/auth/me", headers=_auth_header(token_a)
|
||||||
|
).status_code
|
||||||
|
== 429
|
||||||
|
)
|
||||||
|
# User B's quota is untouched.
|
||||||
|
resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b))
|
||||||
|
assert resp_b.status_code == 200
|
||||||
|
|
||||||
|
def test_exempt_path_register_never_rate_limited(self) -> None:
|
||||||
|
"""POST /auth/register is exempt — 25 calls should never return 429."""
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for i in range(25):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"},
|
||||||
|
)
|
||||||
|
# 201 on first, 409 on email collision — but never 429.
|
||||||
|
assert resp.status_code != 429
|
||||||
|
|
||||||
|
def test_exempt_path_login_never_rate_limited(self) -> None:
|
||||||
|
"""POST /auth/login is exempt — multiple failed attempts are not rate-limited."""
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(25):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": "nosuchuser@example.com", "password": "wrong"},
|
||||||
|
)
|
||||||
|
assert resp.status_code != 429
|
||||||
|
|
||||||
|
def test_exempt_path_health_never_rate_limited(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
for _ in range(25):
|
||||||
|
resp = client.get("/api/v1/health")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Sanitizer middleware
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestSanitizerMiddleware:
|
||||||
|
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
|
||||||
|
|
||||||
|
_CHAT_PATH = "/api/v1/chat"
|
||||||
|
|
||||||
|
def _token(self) -> str:
|
||||||
|
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
||||||
|
|
||||||
|
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
||||||
|
mock_response = ChatResponse(response=response_text, actions=[])
|
||||||
|
with patch(
|
||||||
|
"app.api.routes.chat.orchestrate",
|
||||||
|
new_callable=AsyncMock,
|
||||||
|
return_value=mock_response,
|
||||||
|
):
|
||||||
|
resp = client.post(
|
||||||
|
self._CHAT_PATH,
|
||||||
|
json=_CHAT_BODY,
|
||||||
|
headers=_auth_header(self._token()),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
def test_clean_response_passes_through_unchanged(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(client, "Sure, I created the task for you.")
|
||||||
|
assert data["response"] == "Sure, I created the task for you."
|
||||||
|
|
||||||
|
def test_strips_system_prompt_opener(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, "You are an intent classifier. Route to task_agent."
|
||||||
|
)
|
||||||
|
assert "You are" not in data["response"]
|
||||||
|
assert "[REDACTED]" in data["response"]
|
||||||
|
|
||||||
|
def test_strips_known_fingerprint(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, "Respond with just the agent name and nothing else."
|
||||||
|
)
|
||||||
|
assert data["response"] == "[REDACTED]"
|
||||||
|
|
||||||
|
def test_strips_tool_schema_fragment(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, 'Here is the schema: {"type": "function", "name": "foo"}'
|
||||||
|
)
|
||||||
|
assert '"type": "function"' not in data["response"]
|
||||||
|
|
||||||
|
def test_strips_reasoning_tag(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, "<thinking>I should route this to calendar_agent</thinking>Done."
|
||||||
|
)
|
||||||
|
assert "<thinking>" not in data["response"]
|
||||||
|
assert "[REDACTED]" in data["response"]
|
||||||
|
|
||||||
|
def test_strips_available_agents_fragment(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(
|
||||||
|
client, "Available agents: task_agent, calendar_agent"
|
||||||
|
)
|
||||||
|
assert "[REDACTED]" in data["response"]
|
||||||
|
|
||||||
|
def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None:
|
||||||
|
"""GET /api/v1/plans/playbook should pass through the sanitizer untouched."""
|
||||||
|
token = self._token()
|
||||||
|
with TestClient(app) as client:
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/plans/playbook",
|
||||||
|
headers=_auth_header(token),
|
||||||
|
)
|
||||||
|
# The sanitizer should not interfere — just check it returns something
|
||||||
|
# (200 or whatever the route returns; we only care it's not broken).
|
||||||
|
assert resp.status_code in (200, 401, 403, 404)
|
||||||
|
|
||||||
|
def test_sanitizer_preserves_empty_response(self) -> None:
|
||||||
|
with TestClient(app) as client:
|
||||||
|
data = self._post_chat(client, "")
|
||||||
|
assert data["response"] == ""
|
||||||
348
tests/test_orchestrator.py
Normal file
348
tests/test_orchestrator.py
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
"""Integration tests for the orchestrator module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||||
|
from app.core.orchestrator import (
|
||||||
|
classify_intent,
|
||||||
|
orchestrate,
|
||||||
|
orchestrate_stream,
|
||||||
|
route_pipeline,
|
||||||
|
route_single,
|
||||||
|
)
|
||||||
|
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stub agents ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _TaskAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "task_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages tasks: create, update, list, suggest"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return f"task: {query}"
|
||||||
|
|
||||||
|
|
||||||
|
class _CalendarAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "calendar_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Calendar management: events, conflicts, scheduling"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return f"calendar: {query}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_llm(response_text: str) -> MagicMock:
|
||||||
|
"""Return a mock LLM that always produces *response_text*."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.content = response_text
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=msg)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _fresh_registry():
|
||||||
|
"""Reset the AgentRegistry singleton between tests."""
|
||||||
|
AgentRegistry._instance = None
|
||||||
|
yield
|
||||||
|
AgentRegistry._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def reg() -> AgentRegistry:
|
||||||
|
r = AgentRegistry()
|
||||||
|
r.register(_TaskAgent)
|
||||||
|
r.register(_CalendarAgent)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
# ── classify_intent ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassifyIntent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
result = await classify_intent("add a task", {}, reg)
|
||||||
|
assert result == "task_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||||
|
result = await classify_intent("schedule a meeting", {}, reg)
|
||||||
|
assert result == "calendar_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("nonexistent_agent")
|
||||||
|
result = await classify_intent("do something", {}, reg)
|
||||||
|
assert result == "task_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
|
||||||
|
empty_reg = AgentRegistry()
|
||||||
|
# No LLM should be instantiated — early return path
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
result = await classify_intent("anything", {}, empty_reg)
|
||||||
|
mock_cls.assert_not_called()
|
||||||
|
assert result == "task_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm(" task_agent \n")
|
||||||
|
result = await classify_intent("create task", {}, reg)
|
||||||
|
assert result == "task_agent"
|
||||||
|
|
||||||
|
|
||||||
|
# ── route_single ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouteSingle:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||||
|
result = await route_single("task_agent", "create a task", {}, reg)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
|
||||||
|
result = await route_single("task_agent", "create a task", {}, reg)
|
||||||
|
assert result.response == "task: create a task"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await route_single("nonexistent", "hello", {}, reg)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
|
||||||
|
result = await route_single("task_agent", "hi", {}, reg)
|
||||||
|
assert result.actions == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── route_pipeline ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoutePipeline:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("synthesized result")
|
||||||
|
result = await route_pipeline(
|
||||||
|
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||||
|
)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("synthesized result")
|
||||||
|
result = await route_pipeline(
|
||||||
|
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||||
|
)
|
||||||
|
assert result.response == "synthesized result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_passes_previous_results_to_subsequent_agents(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Each agent after the first should receive prior outputs in context."""
|
||||||
|
received_contexts: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
class _CapturingAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "capture"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "captures context for testing"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
received_contexts.append(dict(context))
|
||||||
|
return "captured"
|
||||||
|
|
||||||
|
reg.register(_CapturingAgent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("done")
|
||||||
|
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
|
||||||
|
|
||||||
|
# The second agent (capture) must have received previous results
|
||||||
|
assert len(received_contexts) == 1
|
||||||
|
assert "previous_results" in received_contexts[0]
|
||||||
|
assert received_contexts[0]["previous_results"] == ["task: hi"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("single result")
|
||||||
|
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
|
||||||
|
assert result.response == "single result"
|
||||||
|
|
||||||
|
|
||||||
|
# ── orchestrate ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestrate:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_mode_returns_chat_response(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
assert result.response == "task: add a task"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_mode_returns_execution_plan(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="plan my tasks", execution_mode="plan")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ExecutionPlan)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_mode_agent_matches_classified(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||||
|
request = ChatRequest(
|
||||||
|
message="schedule something", execution_mode="plan"
|
||||||
|
)
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ExecutionPlan)
|
||||||
|
assert result.agent == "calendar_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ExecutionPlan)
|
||||||
|
assert len(result.steps) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_mode_template_id_contains_agent_name(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ExecutionPlan)
|
||||||
|
assert result.steps[0].prompt_template is not None
|
||||||
|
assert "task_agent" in result.steps[0].prompt_template
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_execution_mode_is_direct(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
# execution_mode defaults to "direct"
|
||||||
|
request = ChatRequest(message="help me")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
|
||||||
|
|
||||||
|
# ── orchestrate_stream ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestrateStream:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
assert len(chunks) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_last_chunk_is_final_json_frame(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
|
last = json.loads(chunks[-1])
|
||||||
|
assert last["done"] is True
|
||||||
|
assert "response" in last
|
||||||
|
assert "actions" in last
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_final_frame_response_matches_agent_output(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="create a task", execution_mode="direct")
|
||||||
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
|
final = json.loads(chunks[-1])
|
||||||
|
assert final["response"] == "task: create a task"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_chunks_before_final_frame(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator._make_llm") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(
|
||||||
|
message="x" * 200, execution_mode="direct"
|
||||||
|
) # long enough to produce multiple chunks
|
||||||
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
|
# All but the last chunk should be plain text (not valid final JSON)
|
||||||
|
non_final = chunks[:-1]
|
||||||
|
for chunk in non_final:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(chunk)
|
||||||
|
assert parsed.get("done") is not True
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass # plain text chunk — expected
|
||||||
400
tests/test_plugins.py
Normal file
400
tests/test_plugins.py
Normal file
@@ -0,0 +1,400 @@
|
|||||||
|
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
|
||||||
|
- ReviewQueue: pending queue, review decisions, manifest security checklist
|
||||||
|
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
|
||||||
|
- Route integration: tier gate, list/get/install/uninstall via TestClient
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import PluginRegistry
|
||||||
|
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
|
||||||
|
from app.marketplace.revenue_share import RevenueShare
|
||||||
|
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
|
||||||
|
from app.schemas import PluginManifest
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
def _fresh_manifest(
|
||||||
|
plugin_id: str | None = None,
|
||||||
|
category: str = "productivity",
|
||||||
|
price_cents: int = 0,
|
||||||
|
permissions: list[str] | None = None,
|
||||||
|
) -> PluginManifest:
|
||||||
|
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
|
||||||
|
return PluginManifest(
|
||||||
|
id=pid,
|
||||||
|
name=f"Plugin {pid}",
|
||||||
|
description=f"Description for {pid}",
|
||||||
|
version="1.0.0",
|
||||||
|
author="test-author",
|
||||||
|
permissions=permissions or ["read:tasks"],
|
||||||
|
category=category,
|
||||||
|
price_cents=price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# PluginRegistry (DB-backed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginRegistry:
|
||||||
|
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def reg(self) -> PluginRegistry:
|
||||||
|
return PluginRegistry()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_seed_plugins_are_listed(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
result = await reg.list_plugins(db_session)
|
||||||
|
assert result.total == 3
|
||||||
|
assert all(p.id.startswith("plugin-") for p in result.plugins)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_approved_only(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
|
||||||
|
result = await reg.list_plugins(db_session)
|
||||||
|
ids = [p.id for p in result.plugins]
|
||||||
|
assert manifest.id not in ids # still pending
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_filter_by_category(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
result = await reg.list_plugins(db_session, category="communication")
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.plugins[0].id == "plugin-slack-notify"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_filter_by_query(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
result = await reg.list_plugins(db_session, query="time")
|
||||||
|
assert result.total == 1
|
||||||
|
assert result.plugins[0].id == "plugin-time-tracker"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_sort_by_installs(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await reg.record_install(db_session, "plugin-slack-notify")
|
||||||
|
await reg.record_install(db_session, "plugin-slack-notify")
|
||||||
|
result = await reg.list_plugins(db_session, sort="installs")
|
||||||
|
assert result.plugins[0].id == "plugin-slack-notify"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_plugin_found(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["manifest"].id == "plugin-github-sync"
|
||||||
|
assert "install_count" in entry
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_plugin_not_found(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
entry = await reg.get_plugin(db_session, "no-such-plugin")
|
||||||
|
assert entry is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_sets_pending(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
assert plugin_id == manifest.id
|
||||||
|
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "pending_review"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_approve_makes_visible(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
await reg.approve_plugin(db_session, manifest.id)
|
||||||
|
result = await reg.list_plugins(db_session)
|
||||||
|
assert manifest.id in [p.id for p in result.plugins]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_reject_stores_reason(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
|
||||||
|
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "rejected"
|
||||||
|
assert row.rejection_reason == "Unsafe permissions"
|
||||||
|
listed = await reg.list_plugins(db_session)
|
||||||
|
assert manifest.id not in [p.id for p in listed.plugins]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_approve_unknown_raises_key_error(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await reg.approve_plugin(db_session, "ghost-plugin")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_increments_count(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await reg.record_install(db_session, "plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_uninstall_decrements_count(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await reg.record_install(db_session, "plugin-github-sync")
|
||||||
|
await reg.record_install(db_session, "plugin-github-sync")
|
||||||
|
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_uninstall_floors_at_zero(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# ReviewQueue (DB-backed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestReviewQueue:
|
||||||
|
@pytest.fixture
|
||||||
|
def reg(self) -> PluginRegistry:
|
||||||
|
return PluginRegistry()
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def queue(self) -> ReviewQueue:
|
||||||
|
return ReviewQueue()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_pending_returns_submitted_plugins(
|
||||||
|
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
pending = await queue.get_pending(db_session)
|
||||||
|
assert any(p["plugin_id"] == manifest.id for p in pending)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_review_approved(
|
||||||
|
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
|
||||||
|
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "approved"
|
||||||
|
# Check review row was persisted
|
||||||
|
review_result = await db_session.execute(
|
||||||
|
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
|
||||||
|
)
|
||||||
|
review = review_result.scalar_one()
|
||||||
|
assert review.decision == "approved"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_submit_review_rejected(
|
||||||
|
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
manifest = _fresh_manifest()
|
||||||
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
|
await queue.submit_review(
|
||||||
|
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
|
||||||
|
)
|
||||||
|
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "rejected"
|
||||||
|
|
||||||
|
def test_validate_manifest_ok(self) -> None:
|
||||||
|
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
|
||||||
|
validate_manifest(manifest) # should not raise
|
||||||
|
|
||||||
|
def test_validate_manifest_unknown_permission(self) -> None:
|
||||||
|
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
|
||||||
|
with pytest.raises(ValueError, match="Unknown permission"):
|
||||||
|
validate_manifest(manifest)
|
||||||
|
|
||||||
|
def test_validate_manifest_invalid_id_format(self) -> None:
|
||||||
|
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
|
||||||
|
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||||
|
validate_manifest(manifest)
|
||||||
|
|
||||||
|
def test_validate_manifest_id_with_uppercase(self) -> None:
|
||||||
|
manifest = _fresh_manifest(plugin_id="UpperCase")
|
||||||
|
with pytest.raises(ValueError, match="Invalid plugin id format"):
|
||||||
|
validate_manifest(manifest)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# RevenueShare (DB-backed)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestRevenueShare:
|
||||||
|
@pytest.fixture
|
||||||
|
def rs(self) -> RevenueShare:
|
||||||
|
return RevenueShare()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_free_plugin(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
|
||||||
|
)
|
||||||
|
event = result.scalar_one()
|
||||||
|
assert event.developer_share_cents == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_paid_plugin_no_stripe(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await rs.record_install(
|
||||||
|
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
|
||||||
|
)
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
|
||||||
|
)
|
||||||
|
event = result.scalar_one()
|
||||||
|
assert event.amount_cents == 499
|
||||||
|
assert event.developer_share_cents == int(499 * 0.70)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_record_install_increments_registry_count(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
reg = PluginRegistry()
|
||||||
|
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
|
assert entry is not None
|
||||||
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_earnings_empty(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
result = await rs.get_earnings(db_session, "unknown-dev")
|
||||||
|
assert result["total_installs"] == 0
|
||||||
|
assert result["total_revenue_cents"] == 0
|
||||||
|
assert result["developer_share_cents"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_earnings_aggregates(
|
||||||
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
|
||||||
|
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
|
||||||
|
result = await rs.get_earnings(db_session, "Adiuva")
|
||||||
|
assert result["total_installs"] == 2
|
||||||
|
assert result["total_revenue_cents"] == 998
|
||||||
|
assert result["developer_share_cents"] == int(499 * 0.70) * 2
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Route integration tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPluginRoutes:
|
||||||
|
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
|
||||||
|
assert resp.status_code == 403
|
||||||
|
|
||||||
|
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "plugins" in data
|
||||||
|
assert data["total"] == 3
|
||||||
|
|
||||||
|
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
def test_get_plugin_found(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["plugin"]["id"] == "plugin-github-sync"
|
||||||
|
assert "install_count" in data
|
||||||
|
|
||||||
|
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_install_plugin_free(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
|
json={"plugin_id": "plugin-github-sync"},
|
||||||
|
headers=auth_header(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["ok"] is True
|
||||||
|
assert "download_url" in data
|
||||||
|
|
||||||
|
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/plugins/ghost/install",
|
||||||
|
json={"plugin_id": "ghost"},
|
||||||
|
headers=auth_header(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.delete(
|
||||||
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
|
headers=auth_header(),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json()["ok"] is True
|
||||||
|
|
||||||
|
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
|
json={"plugin_id": "plugin-github-sync"},
|
||||||
|
headers=auth_header("free"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 403
|
||||||
562
tests/test_storage.py
Normal file
562
tests/test_storage.py
Normal file
@@ -0,0 +1,562 @@
|
|||||||
|
"""Tests for the storage layer: encryption, BlobStore, VectorStore, and storage routes."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import pytest
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
from app.storage.encryption import reject_if_tampered, verify_checksum
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.vector_store import VectorStore, _blob_to_vector
|
||||||
|
from app.schemas import VectorItem, VectorSearchResult
|
||||||
|
from tests.conftest import auth_header, S3_TEST_BUCKET
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_BLOB = b"encrypted-payload-opaque-to-server"
|
||||||
|
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||||
|
_BUCKET = S3_TEST_BUCKET
|
||||||
|
_REGION = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
def _pinecone_mock():
|
||||||
|
"""Return a mock Pinecone index with realistic return shapes."""
|
||||||
|
mock_index = MagicMock()
|
||||||
|
mock_index.query.return_value = {
|
||||||
|
"matches": [
|
||||||
|
{
|
||||||
|
"id": "v1",
|
||||||
|
"score": 0.95,
|
||||||
|
"metadata": {
|
||||||
|
"blob": base64.b64encode(b"result-blob").decode(),
|
||||||
|
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
|
||||||
|
"user_id": "u1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_pc = MagicMock()
|
||||||
|
mock_pc.return_value.Index.return_value = mock_index
|
||||||
|
return mock_pc, mock_index
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestEncryption ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncryption:
|
||||||
|
def test_verify_checksum_correct(self) -> None:
|
||||||
|
assert verify_checksum(_BLOB, _CHECKSUM) is True
|
||||||
|
|
||||||
|
def test_verify_checksum_wrong(self) -> None:
|
||||||
|
assert verify_checksum(_BLOB, "0" * 64) is False
|
||||||
|
|
||||||
|
def test_verify_checksum_empty_checksum(self) -> None:
|
||||||
|
assert verify_checksum(_BLOB, "") is False
|
||||||
|
|
||||||
|
def test_verify_checksum_empty_blob(self) -> None:
|
||||||
|
expected = hashlib.sha256(b"").hexdigest()
|
||||||
|
assert verify_checksum(b"", expected) is True
|
||||||
|
|
||||||
|
def test_verify_checksum_tampered_blob(self) -> None:
|
||||||
|
tampered = _BLOB + b"\x00"
|
||||||
|
assert verify_checksum(tampered, _CHECKSUM) is False
|
||||||
|
|
||||||
|
def test_reject_if_tampered_passes_when_valid(self) -> None:
|
||||||
|
# Should not raise
|
||||||
|
reject_if_tampered(_BLOB, _CHECKSUM)
|
||||||
|
|
||||||
|
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
reject_if_tampered(_BLOB, "bad" * 20)
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
reject_if_tampered(_BLOB, "bad" * 20)
|
||||||
|
assert "checksum" in exc_info.value.detail.lower()
|
||||||
|
|
||||||
|
def test_checksum_is_sha256_hex(self) -> None:
|
||||||
|
cs = hashlib.sha256(_BLOB).hexdigest()
|
||||||
|
assert len(cs) == 64
|
||||||
|
assert all(c in "0123456789abcdef" for c in cs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBlobStore ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlobStore:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
assert key == "u1/tasks/r1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
# Verify by downloading — no exception means object exists
|
||||||
|
retrieved = await store.download("u1", "u1/tasks/r1")
|
||||||
|
assert retrieved == _BLOB
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
|
||||||
|
result = await store.download("u1", "u1/notes/n1")
|
||||||
|
assert result == b"note-data"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_removes_object(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.delete("u1", "u1/tasks/r1")
|
||||||
|
with pytest.raises(ClientError) as exc_info:
|
||||||
|
await store.download("u1", "u1/tasks/r1")
|
||||||
|
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
# Delete a key that never existed — should not raise
|
||||||
|
await store.delete("u1", "u1/tasks/nonexistent")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
|
||||||
|
keys = await store.list_keys("u1", "tasks")
|
||||||
|
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
|
||||||
|
keys = await store.list_keys("u1", "tasks")
|
||||||
|
assert "u1/notes/n1" not in keys
|
||||||
|
assert "u1/tasks/r1" in keys
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
keys_u1 = await store.list_keys("u1", "tasks")
|
||||||
|
assert "u2/tasks/r1" not in keys_u1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
keys = await store.list_keys("u1", "tasks")
|
||||||
|
assert keys == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
# Verify S3 metadata was set — check via head_object
|
||||||
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||||
|
mock_settings.S3_BUCKET = _BUCKET
|
||||||
|
mock_settings.S3_REGION = _REGION
|
||||||
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||||
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||||
|
client = boto3.client("s3", region_name=_REGION)
|
||||||
|
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||||
|
assert response.get("ServerSideEncryption") == "AES256"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
client = boto3.client("s3", region_name=_REGION)
|
||||||
|
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||||
|
assert response["Metadata"]["checksum"] == _CHECKSUM
|
||||||
|
|
||||||
|
|
||||||
|
# ── _blob_to_vector helper ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlobToVector:
|
||||||
|
def test_returns_32_floats(self) -> None:
|
||||||
|
v = _blob_to_vector(b"test")
|
||||||
|
assert len(v) == 32
|
||||||
|
|
||||||
|
def test_all_values_in_range(self) -> None:
|
||||||
|
v = _blob_to_vector(b"test")
|
||||||
|
assert all(-1.0 <= x <= 1.0 for x in v)
|
||||||
|
|
||||||
|
def test_deterministic(self) -> None:
|
||||||
|
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
|
||||||
|
|
||||||
|
def test_different_blobs_different_vectors(self) -> None:
|
||||||
|
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestVectorStorePinecone ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorStorePinecone:
|
||||||
|
def _store(self) -> VectorStore:
|
||||||
|
store = VectorStore()
|
||||||
|
store._use_pinecone = lambda: True # type: ignore[method-assign]
|
||||||
|
return store
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_calls_index_upsert(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
mock_index.upsert.assert_called_once()
|
||||||
|
call_kwargs = mock_index.upsert.call_args[1]
|
||||||
|
assert call_kwargs.get("namespace") == "u1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
|
||||||
|
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_calls_index_query(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query-blob", top_k=5)
|
||||||
|
mock_index.query.assert_called_once()
|
||||||
|
query_kwargs = mock_index.query.call_args[1]
|
||||||
|
assert query_kwargs.get("namespace") == "u1"
|
||||||
|
assert query_kwargs.get("top_k") == 5
|
||||||
|
assert query_kwargs.get("include_metadata") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_returns_vector_search_results(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
results = await store.search("u1", b"query", top_k=10)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], VectorSearchResult)
|
||||||
|
assert results[0].id == "v1"
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
assert results[0].blob == b"result-blob"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_uses_derived_query_vector(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query-blob", top_k=3)
|
||||||
|
expected_vector = _blob_to_vector(b"query-blob")
|
||||||
|
actual_vector = mock_index.query.call_args[1].get("vector")
|
||||||
|
assert actual_vector == expected_vector
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_calls_index_delete(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
await store.delete("u1", ["v1", "v2"])
|
||||||
|
mock_index.delete.assert_called_once()
|
||||||
|
delete_kwargs = mock_index.delete.call_args[1]
|
||||||
|
assert delete_kwargs.get("namespace") == "u1"
|
||||||
|
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorStoreQdrant:
|
||||||
|
def _store(self) -> VectorStore:
|
||||||
|
store = VectorStore()
|
||||||
|
store._use_pinecone = lambda: False # type: ignore[method-assign]
|
||||||
|
return store
|
||||||
|
|
||||||
|
def _qdrant_mock(self) -> MagicMock:
|
||||||
|
mock_hit = MagicMock()
|
||||||
|
mock_hit.id = "v1"
|
||||||
|
mock_hit.score = 0.88
|
||||||
|
mock_hit.payload = {
|
||||||
|
"blob": base64.b64encode(b"qdrant-result").decode(),
|
||||||
|
"user_id": "u1",
|
||||||
|
}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.search.return_value = [mock_hit]
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_calls_client_upsert(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
mock_client.upsert.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_uses_correct_collection(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
call_kwargs = mock_client.upsert.call_args[1]
|
||||||
|
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_calls_client_search(self) -> None:
|
||||||
|
mock_client = self._qdrant_mock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query", top_k=5)
|
||||||
|
mock_client.search.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_passes_limit(self) -> None:
|
||||||
|
mock_client = self._qdrant_mock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query", top_k=7)
|
||||||
|
call_kwargs = mock_client.search.call_args[1]
|
||||||
|
assert call_kwargs.get("limit") == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_returns_vector_search_results(self) -> None:
|
||||||
|
mock_client = self._qdrant_mock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
results = await store.search("u1", b"query", top_k=5)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], VectorSearchResult)
|
||||||
|
assert results[0].id == "v1"
|
||||||
|
assert results[0].score == 0.88
|
||||||
|
assert results[0].blob == b"qdrant-result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_calls_client_delete(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.delete("u1", ["v1", "v2"])
|
||||||
|
mock_client.delete.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_uses_correct_collection(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.delete("u1", ["v1"])
|
||||||
|
call_kwargs = mock_client.delete.call_args[1]
|
||||||
|
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestStorageRoutes (integration) ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageRoutes:
|
||||||
|
"""Integration tests for POST/GET/PUT/DELETE /api/v1/storage/records.
|
||||||
|
|
||||||
|
Pydantic v2 converts JSON string → bytes via ``str.encode('utf-8')``.
|
||||||
|
So "hello" in JSON becomes ``b"hello"`` on the server. We use plain
|
||||||
|
ASCII strings as blob values and compute checksums accordingly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_BLOB_STR = "encrypted-payload-opaque-to-server"
|
||||||
|
_BLOB_BYTES = _BLOB_STR.encode()
|
||||||
|
_BLOB_CHECKSUM = hashlib.sha256(_BLOB_BYTES).hexdigest()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_payload(cls, blob_str: str | None = None) -> dict:
|
||||||
|
blob_str = blob_str or cls._BLOB_STR
|
||||||
|
checksum = hashlib.sha256(blob_str.encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"table": "tasks",
|
||||||
|
"blob": blob_str,
|
||||||
|
"checksum": checksum,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_record(self, client, tier="power", blob_str=None):
|
||||||
|
payload = self._create_payload(blob_str)
|
||||||
|
return client.post(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
json=payload,
|
||||||
|
headers=auth_header(tier),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Create ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_create_record(self, client, s3_bucket) -> None:
|
||||||
|
resp = self._create_record(client)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert "created_at" in data
|
||||||
|
|
||||||
|
def test_create_record_bad_checksum(self, client, s3_bucket) -> None:
|
||||||
|
payload = {
|
||||||
|
"table": "tasks",
|
||||||
|
"blob": self._BLOB_STR,
|
||||||
|
"checksum": "0" * 64,
|
||||||
|
}
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
json=payload,
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_create_record_free_tier_blocked(self, client, s3_bucket) -> None:
|
||||||
|
"""Free tier has cloud_storage_gb=0 → 402."""
|
||||||
|
resp = self._create_record(client, tier="free")
|
||||||
|
assert resp.status_code == 402
|
||||||
|
|
||||||
|
def test_create_record_pro_tier_allowed(self, client, s3_bucket) -> None:
|
||||||
|
"""Pro tier has cloud_storage_gb=5 → succeeds for small blob."""
|
||||||
|
resp = self._create_record(client, tier="pro")
|
||||||
|
assert resp.status_code == 201
|
||||||
|
|
||||||
|
# ── List ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_list_records(self, client, s3_bucket) -> None:
|
||||||
|
self._create_record(client)
|
||||||
|
self._create_record(client, blob_str="second-blob")
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 2
|
||||||
|
# Each entry has metadata, no blob bytes
|
||||||
|
for item in data:
|
||||||
|
assert "id" in item
|
||||||
|
assert "table" in item
|
||||||
|
assert "checksum" in item
|
||||||
|
assert "blob" not in item
|
||||||
|
|
||||||
|
def test_list_records_filter_by_table(self, client, s3_bucket) -> None:
|
||||||
|
self._create_record(client)
|
||||||
|
# Create in a different table
|
||||||
|
note_blob = "note-blob"
|
||||||
|
payload = {
|
||||||
|
"table": "notes",
|
||||||
|
"blob": note_blob,
|
||||||
|
"checksum": hashlib.sha256(note_blob.encode()).hexdigest(),
|
||||||
|
}
|
||||||
|
client.post(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
json=payload,
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records?table=notes",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["table"] == "notes"
|
||||||
|
|
||||||
|
def test_list_records_isolated_per_user(self, client, s3_bucket) -> None:
|
||||||
|
"""One user's records should not appear in another user's list."""
|
||||||
|
self._create_record(client, tier="power")
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
headers=auth_header("team"),
|
||||||
|
)
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
# ── Download ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_download_record(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
resp = client.get(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == self._BLOB_BYTES
|
||||||
|
assert resp.headers["X-Checksum"] == self._BLOB_CHECKSUM
|
||||||
|
|
||||||
|
def test_download_record_not_found(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records/nonexistent-id",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
# ── Update ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_update_record(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
new_blob_str = "updated-encrypted-payload"
|
||||||
|
new_checksum = hashlib.sha256(new_blob_str.encode()).hexdigest()
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
json={"blob": new_blob_str, "checksum": new_checksum},
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
# Verify download returns the updated blob
|
||||||
|
dl = client.get(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert dl.content == new_blob_str.encode()
|
||||||
|
|
||||||
|
def test_update_record_bad_checksum(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
json={"blob": "some-data", "checksum": "0" * 64},
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
# ── Delete ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_delete_record(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
resp = client.delete(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
# Subsequent GET should return 404
|
||||||
|
dl = client.get(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert dl.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_record_not_found(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.delete(
|
||||||
|
"/api/v1/storage/records/nonexistent",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
Reference in New Issue
Block a user