Compare commits

..

18 Commits

Author SHA1 Message Date
7f278c6f63 complete backend plan 2026-03-03 16:09:13 +01:00
8bfce9da00 Refactor LLM instantiation across agents and orchestrator
- Replaced direct instantiation of ChatOpenAI with a centralized get_llm function in CheckpointAgent, NoteAgent, ProjectAgent, and TaskAgent.
- Introduced a new llm.py module to handle LLM model instantiation and API key management.
- Updated settings.py to include LLM_MODEL and LLM_ROUTER_MODEL configurations.
- Modified orchestrator.py to use get_router_llm for intent classification.
- Updated requirements.txt to include litellm for LLM management.
- Adjusted tests to mock get_llm instead of ChatOpenAI directly.
2026-03-03 15:46:44 +01:00
480e7ac5bd Step 13 - completed 2026-03-03 15:14:04 +01:00
d0b303e745 Step 12 - completed 2026-03-03 14:53:34 +01:00
5d485b3665 step 12 2026-03-03 12:39:32 +01:00
9787befd4a step 11 complete: billing service and tier manager
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 22:41:35 +01:00
8f7bc25611 step 10 complete: plugin marketplace with catalog, review workflow, and revenue split
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 22:32:44 +01:00
3e07fff958 step 9 complete: auth middleware, tier-aware rate limiter, and response sanitizer
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 22:18:17 +01:00
4c4df7335a auto deploy
Some checks failed
Deploy to Proxmox Docker / Deploy (push) Failing after 2m11s
2026-03-02 17:41:23 +01:00
c8ef7b119b Refactor tests for execution plan and add comprehensive storage tests
- Updated `TestModuleSingletons` in `test_execution_plan.py` to reflect new agent templates and playbook names.
- Changed assertions in playbook tests to match updated templates and agents.
- Introduced `test_storage.py` to cover the storage layer, including encryption, BlobStore, and VectorStore functionalities.
- Added tests for S3 interactions, ensuring upload, download, delete, and list operations work as expected.
- Implemented mock tests for Pinecone and Qdrant vector stores to validate upsert, search, and delete operations.
2026-03-02 15:36:09 +01:00
35dd9ac86f step 8 complete: REST + WebSocket API routes for chat, plans, storage, vectors, backup, plugins, billing
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 15:33:57 +01:00
e72d72f4f6 step 6 complete: four specialized agents, all registered and tested
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 13:18:53 +01:00
14d1a7351d step 5 complete: execution plan builder, template registry, and LRU plan cache
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 13:13:02 +01:00
68955d2fc2 step 4 complete: intelligent routing with single-agent and pipeline modes
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-02 13:03:54 +01:00
864dfdc4e6 add .gitignore 2026-03-02 00:06:21 +01:00
0d16729036 step 3 complete: pluggable agent framework
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-02 00:03:42 +01:00
82669d3704 step 2 complete: all request/response models defined and validated
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-03-01 23:56:32 +01:00
4d0917f5df step 1 complete: runnable FastAPI skeleton
- Full directory structure with all __init__.py stubs
- requirements.txt with all pinned dependencies
- app/config/settings.py (BaseSettings, env-based)
- app/main.py (CORS, lifespan, /api/v1/health)
- Dockerfile (multi-stage, Python 3.12-slim, non-root user)
- docker-compose.yml (app + postgres:16 with healthcheck)
- .env.example
- BACKEND_PLAN.md: mark step 1 done, add one-step-at-a-time rule

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-01 23:51:37 +01:00
68 changed files with 9825 additions and 95 deletions

28
.env.example Normal file
View File

@@ -0,0 +1,28 @@
# ── Application ──────────────────────────────────────────────────────────────
ENV=dev
# ── Database ──────────────────────────────────────────────────────────────────
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
# ── Auth ──────────────────────────────────────────────────────────────────────
JWT_SECRET=replace-with-a-long-random-secret
JWT_ALGORITHM=HS256
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
# ── OpenAI ────────────────────────────────────────────────────────────────────
OPENAI_API_KEY=sk-...
# ── Stripe ────────────────────────────────────────────────────────────────────
STRIPE_SECRET_KEY=sk_test_...
STRIPE_WEBHOOK_SECRET=whsec_...
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
S3_BUCKET=adiuva-backups
S3_REGION=us-east-1
AWS_ACCESS_KEY_ID=AKIA...
AWS_SECRET_ACCESS_KEY=...
# ── CORS ──────────────────────────────────────────────────────────────────────
# Comma-separated list parsed by Settings (override default if needed)
# CORS_ORIGINS=["app://.","http://localhost:3000"]

View File

@@ -0,0 +1,96 @@
name: Test & Deploy API
run-name: ${{ gitea.ref_name }} → Docker LXC
on:
push:
branches: [main]
tags: ['v*']
pull_request:
branches: [main]
jobs:
# ── 1. Run tests in an isolated Python container ──────────────────
test:
runs-on: ubuntu-latest
container:
image: python:3.12-slim
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Install Dependencies
run: pip install --no-cache-dir -r requirements.txt
- name: Run Linter
run: ruff check app/ tests/
- name: Run Tests
run: pytest tests/ -v --tb=short
# ── 2. Deploy to Docker LXC (only main branch & tags) ─────────────
deploy:
needs: test
runs-on: ubuntu-latest
if: gitea.event_name == 'push'
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Sync to deploy directory
run: |
DEPLOY_DIR="/opt/adiuva-api"
mkdir -p "$DEPLOY_DIR"
# Sync source, preserve .env and volumes
cp -rf app/ alembic/ alembic.ini Dockerfile docker-compose.yml requirements.txt "$DEPLOY_DIR/"
- name: Build & restart services
run: |
cd /opt/adiuva-api
docker compose up -d --build --remove-orphans
- name: Run database migrations
run: |
cd /opt/adiuva-api
docker compose exec -T app alembic upgrade head
- name: Verify deployment
run: |
echo "Waiting for app to be ready..."
sleep 5
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8000/api/v1/health)
if [ "$HTTP_CODE" -eq 200 ]; then
echo "✅ API is healthy (HTTP ${HTTP_CODE})"
else
echo "❌ Health check failed (HTTP ${HTTP_CODE})"
docker compose -f /opt/adiuva-api/docker-compose.yml logs app --tail=50
exit 1
fi
- name: Create Gitea Release (tags only)
if: startsWith(gitea.ref, 'refs/tags/')
run: |
GITEA_URL="http://10.0.0.119:3000"
TAG="${GITHUB_REF_NAME}"
REPO="${GITHUB_REPOSITORY}"
TOKEN="${{ gitea.token }}"
RELEASE_ID=$(curl -sf \
-H "Authorization: token ${TOKEN}" \
"${GITEA_URL}/api/v1/repos/${REPO}/releases/tags/${TAG}" \
| grep -o '"id":[0-9]*' | head -1 | cut -d: -f2)
if [ -z "$RELEASE_ID" ]; then
curl -sf \
-X POST \
-H "Authorization: token ${TOKEN}" \
-H "Content-Type: application/json" \
-d "{\"tag_name\":\"${TAG}\",\"name\":\"Adiuva API ${TAG}\",\"body\":\"Deployed to Docker LXC\"}" \
"${GITEA_URL}/api/v1/repos/${REPO}/releases"
echo "✅ Release ${TAG} created"
else
echo " Release ${TAG} already exists (ID: ${RELEASE_ID})"
fi

64
.github/workflows/ci.yml vendored Normal file
View File

@@ -0,0 +1,64 @@
name: CI
on:
push:
branches: [main]
pull_request:
branches: [main]
jobs:
lint:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Install ruff
run: pip install ruff>=0.8.0
- name: Ruff check
run: ruff check .
- name: Ruff format check
run: ruff format --check .
test:
name: Test
runs-on: ubuntu-latest
needs: lint
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: "3.12"
- name: Cache pip
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
restore-keys: ${{ runner.os }}-pip-
- name: Install dependencies
run: pip install -r requirements.txt
- name: Run tests
run: pytest -v --tb=short
docker:
name: Docker Build
runs-on: ubuntu-latest
needs: test
steps:
- uses: actions/checkout@v4
- name: Build image
run: docker build -t adiuva-api:ci .
- name: Verify gunicorn installed
run: docker run --rm adiuva-api:ci gunicorn --version

33
.gitignore vendored Normal file
View File

@@ -0,0 +1,33 @@
# Python
__pycache__/
*.py[cod]
*.egg-info/
dist/
build/
# Virtual environment
.venv/
venv/
env/
# Environment variables
.env
# IDE
.vscode/
.idea/
# Testing / coverage
.pytest_cache/
htmlcov/
.coverage
# Docker
*.log
# OS
.DS_Store
Thumbs.db
# Claude Code
.claude/

View File

@@ -2,8 +2,8 @@
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with. > **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
> >
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, and backup blob storage. > The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
> The backend NEVER persists user data. It receives context in requests, uses it for orchestration, and discards it. > The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
--- ---
@@ -20,7 +20,7 @@ adiuva-api/
│ │ ├── orchestrator.py # LLM-based intent router │ │ ├── orchestrator.py # LLM-based intent router
│ │ ├── execution_plan.py # Plan builder + cache │ │ ├── execution_plan.py # Plan builder + cache
│ │ └── plugin_loader.py # Dynamic agent loading │ │ └── plugin_loader.py # Dynamic agent loading
│ ├── agents/ │ ├── agents/ # Chat agents (proprietary logic + prompts)
│ │ ├── __init__.py # Auto-registers all agents │ │ ├── __init__.py # Auto-registers all agents
│ │ ├── task_agent.py │ │ ├── task_agent.py
│ │ ├── calendar_agent.py │ │ ├── calendar_agent.py
@@ -32,7 +32,10 @@ adiuva-api/
│ │ │ ├── __init__.py │ │ │ ├── __init__.py
│ │ │ ├── chat.py # POST /chat + WS /chat/stream │ │ │ ├── chat.py # POST /chat + WS /chat/stream
│ │ │ ├── plans.py # GET /plans/playbook │ │ │ ├── plans.py # GET /plans/playbook
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
│ │ │ ├── vectors.py # Upsert/search cloud vector store
│ │ │ ├── backup.py # PUT/GET /backup │ │ │ ├── backup.py # PUT/GET /backup
│ │ │ ├── plugins.py # Plugin marketplace
│ │ │ ├── auth.py # Register/login/refresh │ │ │ ├── auth.py # Register/login/refresh
│ │ │ └── billing.py # Checkout/webhook/subscription │ │ │ └── billing.py # Checkout/webhook/subscription
│ │ └── middleware/ │ │ └── middleware/
@@ -40,6 +43,16 @@ adiuva-api/
│ │ ├── auth.py # JWT validation │ │ ├── auth.py # JWT validation
│ │ ├── rate_limit.py # Tier-aware rate limiting │ │ ├── rate_limit.py # Tier-aware rate limiting
│ │ └── sanitizer.py # Strip prompt metadata from responses │ │ └── sanitizer.py # Strip prompt metadata from responses
│ ├── storage/
│ │ ├── __init__.py
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
│ │ └── encryption.py # Integrity verification only — NO decryption
│ ├── marketplace/
│ │ ├── __init__.py
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
│ │ ├── plugin_review.py # Review queue + approval workflow
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
│ ├── billing/ │ ├── billing/
│ │ ├── __init__.py │ │ ├── __init__.py
│ │ ├── stripe_service.py # Stripe checkout + webhooks │ │ ├── stripe_service.py # Stripe checkout + webhooks
@@ -53,8 +66,10 @@ adiuva-api/
│ ├── test_orchestrator.py │ ├── test_orchestrator.py
│ ├── test_agents.py │ ├── test_agents.py
│ ├── test_auth.py │ ├── test_auth.py
── test_backup.py ── test_backup.py
├── alembic/ # DB migrations (auth/billing tables only) │ ├── test_storage.py
│ └── test_plugins.py
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
│ ├── alembic.ini │ ├── alembic.ini
│ └── versions/ │ └── versions/
├── requirements.txt ├── requirements.txt
@@ -68,9 +83,9 @@ adiuva-api/
## Step-by-Step Implementation ## Step-by-Step Implementation
### Step 1 — Project scaffolding ### Step 1 — Project scaffolding
- [ ] Initialize repo with the directory structure above - [x] Initialize repo with the directory structure above
- [ ] Write `requirements.txt`: - [x] Write `requirements.txt`:
``` ```
fastapi>=0.115.0 fastapi>=0.115.0
uvicorn[standard]>=0.34.0 uvicorn[standard]>=0.34.0
@@ -91,29 +106,40 @@ adiuva-api/
pytest>=8.0.0 pytest>=8.0.0
pytest-asyncio>=0.24.0 pytest-asyncio>=0.24.0
``` ```
- [ ] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1` - [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
- [ ] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod) - [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
- [ ] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user - [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
- [ ] Write `docker-compose.yml`: app, postgres:16, optional redis - [x] Write `docker-compose.yml`: app, postgres:16, optional redis
- [ ] Write `.env.example` - [x] Write `.env.example`
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes). - **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
### Step 2 — Pydantic schemas (API contracts) ### Step 2 — Pydantic schemas (API contracts)
- [ ] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo): - [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']` - `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]` - `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
- `ChatResponse`: `response: str`, `actions: list[PlanAction]` - `ChatResponse`: `response: str`, `actions: list[PlanAction]`
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification']`, `table: str | None`, `data: dict | None` - `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]` - `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None` - `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int` - `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']` - `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int` - `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier` - `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
- `PluginInstallRequest`: `plugin_id: str`
- **Outcome:** All request/response models defined and validated. - **Outcome:** All request/response models defined and validated.
### Step 3 — Agent Registry + base classes ### Step 3 — Agent Registry + base classes
- [ ] `app/core/agent_registry.py`: - [x] `app/core/agent_registry.py`:
- `BaseAgent(ABC)`: - `BaseAgent(ABC)`:
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]` - `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
- Abstract `get_name() -> str`, `get_description() -> str` - Abstract `get_name() -> str`, `get_description() -> str`
@@ -127,11 +153,11 @@ adiuva-api/
- `get(name) -> ChatAgent` - `get(name) -> ChatAgent`
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt - `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
- `async call_agent(name, query, context) -> str` — for inter-agent calls - `async call_agent(name, query, context) -> str` — for inter-agent calls
- [ ] Unit tests: register, get, list, call_agent with mock - [x] Unit tests: register, get, list, call_agent with mock
- **Outcome:** Pluggable agent framework. - **Outcome:** Pluggable agent framework.
### Step 4 — Orchestrator ### Step 4 — Orchestrator
- [ ] `app/core/orchestrator.py`: - [x] `app/core/orchestrator.py`:
- `async classify_intent(message, context, registry) -> str`: - `async classify_intent(message, context, registry) -> str`:
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name." - System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
- Uses gpt-4o-mini via LangChain for low latency - Uses gpt-4o-mini via LangChain for low latency
@@ -146,16 +172,17 @@ adiuva-api/
- Final synthesis via LLM: "Summarize these agent results into a coherent response" - Final synthesis via LLM: "Summarize these agent results into a coherent response"
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`: - `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
- Main entry point - Main entry point
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
- Classifies intent - Classifies intent
- If `execution_mode == 'direct'`: route + return response - If `execution_mode == 'direct'`: route + return response
- If `execution_mode == 'plan'`: route + return execution plan with template IDs - If `execution_mode == 'plan'`: route + return execution plan with template IDs
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`: - `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
- Same as orchestrate but yields tokens for WebSocket streaming - Same as orchestrate but yields tokens for WebSocket streaming
- [ ] Integration tests with mocked LLM and mocked agents - [x] Integration tests with mocked LLM and mocked agents
- **Outcome:** Intelligent routing with single-agent and pipeline modes. - **Outcome:** Intelligent routing with single-agent and pipeline modes.
### Step 5 — Execution Plan generator ### Step 5 — Execution Plan generator
- [ ] `app/core/execution_plan.py`: - [x] `app/core/execution_plan.py`:
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs. - `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
- `ExecutionPlanBuilder`: - `ExecutionPlanBuilder`:
- `add_step(action, params) -> self` - `add_step(action, params) -> self`
@@ -168,32 +195,52 @@ adiuva-api/
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report") - Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server. - **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
### Step 6 — Chat Agents ### Step 6 — Chat Agents
- [ ] `app/agents/task_agent.py` — `@registry.register`: - [x] `app/agents/task_agent.py` — `@registry.register`:
- Description: "Manages tasks: create, update, list, suggest" - Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
- Tools: `create_task(title, description, priority, due_date)`, `update_task(id, updates)`, `list_tasks(filters)`, `suggest_tasks(notes_context)` - Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
- System prompt: PM-oriented, validates task structure, infers priority from context - status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
- `handle()`: LLM + tool loop via `_tool_loop()`, returns response text + list of actions performed - Accepts flexible context; sentinel `-1` for optional integer update fields
- [ ] `app/agents/calendar_agent.py` — `@registry.register`: - [x] `app/agents/checkpoint_agent.py` — `@registry.register`:
- Description: "Calendar management: events, conflicts, scheduling" - Description: "Manages project checkpoints (milestones): list, create, update, delete"
- Tools: `list_events(date_range)`, `detect_conflicts(events)`, `suggest_reschedule(conflict)` - Tools (4): `list_checkpoints(project_id)`, `create_checkpoint(project_id, title, date, is_ai_suggested, is_approved)`, `update_checkpoint(checkpoint_id, ...)`, `delete_checkpoint(checkpoint_id)`
- Works with event metadata passed in context (never raw calendar data stored) - `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
- [ ] `app/agents/email_agent.py` — `@registry.register`: - [x] `app/agents/project_agent.py` — `@registry.register`:
- Description: "Email analysis: classify, extract actions, draft responses" - Description: "Manages projects: list, get, create, update, archive, delete"
- Tools: `classify_email(metadata)`, `extract_action_items(metadata)`, `draft_response(thread_context)` - Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
- Only processes metadata sent by client — never raw email bodies - status: `active|archived`; prefers archive over deletion (docstring guard on delete)
- [ ] `app/agents/analytics_agent.py` — `@registry.register`: - [x] `app/agents/note_agent.py` — `@registry.register`:
- Description: "Workspace analytics: metrics, reports, trends" - Description: "Manages notes: list, get, create, update, delete"
- Tools: `calculate_metrics(task_data)`, `generate_report(period, data)`, `trend_analysis(data_points)` - Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
- Crunches numbers from context, returns structured insights - content is Markdown; `get_note` should be called before update to preserve existing content
- [ ] `app/agents/__init__.py`: imports all agent modules to trigger `@registry.register` decorators - [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
- [ ] Unit tests per agent with mocked LLM - [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
- **Outcome:** Four specialized agents, all registered and tested. - **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Checkpoints, Projects, Notes), all registered and tested.
### Step 7 — API Routes ### Step 7 — Storage Layer ✅
- [x] `app/storage/blob_store.py`:
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
- [x] `app/storage/vector_store.py`:
- `VectorStore`: `async upsert`, `async search`, `async delete`
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
- ANN on encrypted data: known accuracy trade-off, documented
- [x] `app/storage/encryption.py`:
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
- Backend NEVER holds decryption keys
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
#### 7a — Chat endpoint ### Step 8 — API Routes ✅
- [ ] `app/api/routes/chat.py`:
#### 8a — Chat endpoint
- [x] `app/api/routes/chat.py`:
- `POST /api/v1/chat`: - `POST /api/v1/chat`:
- Request: `ChatRequest` - Request: `ChatRequest`
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()` - Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
@@ -204,49 +251,94 @@ adiuva-api/
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}` - Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
- Heartbeat ping every 30s to keep connection alive - Heartbeat ping every 30s to keep connection alive
#### 7b — Plans endpoint #### 8b — Plans endpoint
- [ ] `app/api/routes/plans.py`: - [x] `app/api/routes/plans.py`:
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier - `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan - `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
#### 7c — Backup endpoint #### 8c — Storage endpoint (cloud records)
- [ ] `app/api/routes/backup.py`: - [x] `app/api/routes/storage.py`:
- `POST /api/v1/storage/records`: Create encrypted record
- Request: `StorageRecordCreate`
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
- Response: `{id: str, created_at: int}`
- `GET /api/v1/storage/records`: List record metadata (no blobs)
- Query params: `table: str`, `page: int`, `limit: int`
- Response: `list[{id, table, checksum, created_at, updated_at}]`
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
- Response: blob bytes + `X-Checksum` header
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
- Request: `StorageRecordUpdate`
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
#### 8d — Vectors endpoint (cloud vector store)
- [x] `app/api/routes/vectors.py`:
- `POST /api/v1/storage/vectors/upsert`:
- Request: `VectorUpsertRequest`
- Verifies checksums, delegates to `VectorStore.upsert()`
- Response: `{upserted: int}`
- `POST /api/v1/storage/vectors/search`:
- Request: `VectorSearchRequest`
- Delegates to `VectorStore.search()`
- Response: `VectorSearchResponse`
- `DELETE /api/v1/storage/vectors`:
- Request: `{ids: list[str]}`
#### 8e — Backup endpoint
- [x] `app/api/routes/backup.py`:
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits: - `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
- Free: 0 (no backup) - Free: 0 (no backup)
- Pro: 5 GB - Pro: 5 GB
- Power: 50 GB - Power: 25 GB
- Team: unlimited - Team: unlimited
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`. - `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs). - `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup. - `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
#### 7dAuth endpoint #### 8fPlugins endpoint
- [ ] `app/api/routes/auth.py`: - [x] `app/api/routes/plugins.py`:
- `GET /api/v1/plugins`:
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
- Response: `PluginListResponse`
- Available from Power tier and above
- `GET /api/v1/plugins/{id}`:
- Response: `PluginManifest` + ratings + install count
- `POST /api/v1/plugins/{id}/install`:
- Request: `PluginInstallRequest`
- Records installation for the user (billing tracking, analytics)
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
- `DELETE /api/v1/plugins/{id}/install`:
- Unregisters installation
#### 8g — Auth endpoint
- [x] `app/api/routes/auth.py`:
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens` - `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens` - `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens` - `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT - `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
#### 7e — Billing endpoint #### 8h — Billing endpoint
- [ ] `app/api/routes/billing.py`: - [x] `app/api/routes/billing.py`:
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL - `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle) - `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
- `GET /api/v1/billing/subscription`: Returns current subscription info - `GET /api/v1/billing/subscription`: Returns current subscription info
- `DELETE /api/v1/billing/subscription`: Cancels subscription - `DELETE /api/v1/billing/subscription`: Cancels subscription
- **Outcome:** Complete REST + WebSocket API. - **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
### Step 8 — Middleware ### Step 9 — Middleware
#### 8a — Auth middleware #### 9a — Auth middleware
- [ ] `app/api/middleware/auth.py`: - [x] `app/api/middleware/auth.py`:
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile` - FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
- Validates JWT signature, expiry, extracts `user_id` and `tier` - Validates JWT signature, expiry, extracts `user_id` and `tier`
- Raises `401` on invalid/expired token - Raises `401` on invalid/expired token
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook` - Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
#### 8b — Rate limiter #### 9b — Rate limiter
- [ ] `app/api/middleware/rate_limit.py`: - [x] `app/api/middleware/rate_limit.py`:
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)` - Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
- Tier-based limits: - Tier-based limits:
- Free: 20 req/min - Free: 20 req/min
@@ -255,8 +347,8 @@ adiuva-api/
- Team: 200 req/seat/min - Team: 200 req/seat/min
- Custom 429 response with `Retry-After` header - Custom 429 response with `Retry-After` header
#### 8c — Sanitizer #### 9c — Sanitizer
- [ ] `app/api/middleware/sanitizer.py`: - [x] `app/api/middleware/sanitizer.py`:
- Response middleware that scans response bodies - Response middleware that scans response bodies
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata - Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
- Pattern-based detection + exact match against known prompt fingerprints - Pattern-based detection + exact match against known prompt fingerprints
@@ -264,46 +356,113 @@ adiuva-api/
- **Outcome:** Secure, rate-limited API with prompt IP protection. - **Outcome:** Secure, rate-limited API with prompt IP protection.
### Step 9Billing & Tier management ### Step 10Plugin Marketplace ✅
- [ ] `app/billing/stripe_service.py`: - [x] `app/marketplace/plugin_registry.py`:
- `PluginRegistry`:
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
- `async get_plugin(plugin_id) -> PluginManifest | None`
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
- `async reject_plugin(plugin_id, reason: str) -> None`
- [x] `app/marketplace/plugin_review.py`:
- `ReviewQueue`:
- `async get_pending() -> list[dict]`
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
- [x] `app/marketplace/revenue_share.py`:
- `RevenueShare`:
- `async record_install(plugin_id, user_id, amount_cents) -> None`
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
- `async get_earnings(developer_id, period) -> dict`
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
### Step 11 — Billing & Tier management ✅
- [x] `app/billing/stripe_service.py`:
- `create_checkout_session(user_id, tier) -> str` - `create_checkout_session(user_id, tier) -> str`
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` - `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
- `get_subscription(user_id) -> dict | None` - `get_subscription(user_id) -> dict | None`
- `cancel_subscription(user_id) -> None` - `cancel_subscription(user_id) -> None`
- [ ] `app/billing/tier_manager.py`: - [x] `app/billing/tier_manager.py`:
- `TierManager`: - `TierManager`:
- Feature matrix: - Feature matrix:
```python ```python
FEATURES = { FEATURES = {
'free': {'agents': 3, 'batch': False, 'providers': 1, 'backup_gb': 0}, 'free': {
'pro': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': 5}, 'agents': 3,
'power': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': 50, 'byok': True}, 'batch_active': 2,
'team': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': -1, 'sso': True}, 'cloud_storage_gb': 0,
'backup_gb': 0,
'providers': 1,
'batch_builder': False,
'plugin_marketplace': False,
'sso': False,
},
'pro': {
'agents': -1, # unlimited
'batch_active': 10,
'cloud_storage_gb': 5,
'backup_gb': 5,
'providers': -1,
'batch_builder': False,
'plugin_marketplace': False,
'sso': False,
},
'power': {
'agents': -1,
'batch_active': -1, # unlimited
'cloud_storage_gb': 25,
'backup_gb': 25,
'providers': -1,
'batch_builder': True,
'plugin_marketplace': True,
'sso': False,
},
'team': {
'agents': -1,
'batch_active': -1,
'cloud_storage_gb': -1,
'backup_gb': -1,
'providers': -1,
'batch_builder': True,
'plugin_marketplace': True,
'sso': True,
},
} }
``` ```
- `get_tier(user_id) -> BillingTier` - `get_tier(user_id) -> BillingTier`
- `check_feature(user_id, feature) -> bool` - `check_feature(user_id, feature) -> bool`
- `get_rate_limit(tier) -> int` - `get_rate_limit(tier) -> int`
- **Outcome:** Stripe integration with tier-based feature gating. - `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
- [x] `app/billing/__init__.py`: exports `stripe_service` and `tier_manager` singletons
- [x] `app/api/routes/billing.py`: refactored to delegate to `StripeService`
- [x] `app/api/routes/storage.py` and `backup.py`: `_check_quota` now delegates to `tier_manager.enforce_quota` / `enforce_backup_quota`
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
### Step 10 — Database (auth/billing only) ### Step 12 — Database (auth/billing/marketplace only)
- [ ] PostgreSQL schema via Alembic: - [x] PostgreSQL schema via Alembic:
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at` - `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at` - `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at` - `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at` - `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
- [ ] Initial Alembic migration - `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
- [ ] SQLAlchemy models in `app/models.py` - `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
- **Outcome:** Auth and billing persistence. Zero user data stored. - `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
- [x] Initial Alembic migration
- [x] SQLAlchemy models in `app/models.py`
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
### Step 11 — Testing & deployment ### Step 13 — Testing & deployment
- [ ] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed) - [x] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
- [ ] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode - [x] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
- [ ] `tests/test_agents.py`: each agent with mocked tools - [x] `tests/test_agents.py`: each agent with mocked tools
- [ ] `tests/test_auth.py`: register → login → access protected → refresh → expired token - [x] `tests/test_auth.py`: register → login → access protected → refresh → expired token
- [ ] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement - [x] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
- [ ] `Dockerfile` optimized for production (gunicorn + uvicorn workers) - [x] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
- [ ] GitHub Actions CI: lint (ruff), test (pytest), build Docker image - [x] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
- [x] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
- [x] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
- **Outcome:** Fully tested, deployable backend. - **Outcome:** Fully tested, deployable backend.
--- ---
@@ -320,10 +479,22 @@ adiuva-api/
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON | | WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` | | GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` | | GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` | | PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
| GET | `/api/v1/backup` | JWT | — | Binary blob | | GET | `/api/v1/backup` | JWT | — | Binary blob |
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` | | GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` | | DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` | | POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` | | POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info | | GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
@@ -339,20 +510,24 @@ adiuva-api/
| Framework | FastAPI + Uvicorn | | Framework | FastAPI + Uvicorn |
| LLM | LangChain + langchain-openai | | LLM | LangChain + langchain-openai |
| Auth | PyJWT + bcrypt + OAuth2 | | Auth | PyJWT + bcrypt + OAuth2 |
| Billing | stripe-python | | Billing | stripe-python + Stripe Connect |
| Storage | boto3 (S3) | | Blob storage | boto3 (S3) |
| Vector store | Pinecone or Qdrant (configurable) |
| Database | PostgreSQL + SQLAlchemy + Alembic | | Database | PostgreSQL + SQLAlchemy + Alembic |
| Rate limiting | slowapi | | Rate limiting | slowapi |
| Testing | pytest + pytest-asyncio + httpx | | Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
| Deployment | Docker → fly.io / Railway / AWS ECS | | Deployment | Docker → fly.io / Railway / AWS ECS |
--- ---
## Development Rules ## Development Rules
1. **NEVER persist user data.** The DB stores only auth, billing, and backup metadata. User context arrives in requests and is discarded after processing. 1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. 2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
3. **Stateless request handling.** No server-side session state. All context comes from the client + JWT. 3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
4. **Type hints everywhere.** All functions have full type annotations. 4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
5. **Test every agent.** Each chat agent has unit tests with mocked LLM responses. 5. **Type hints everywhere.** All functions have full type annotations.
6. **Structured logging.** JSON logs with request ID correlation. 6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
7. **Structured logging.** JSON logs with request ID correlation.
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.

39
Dockerfile Normal file
View File

@@ -0,0 +1,39 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY requirements.txt .
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
# Non-root user
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
# Copy installed packages from builder
COPY --from=builder /install /usr/local
# Copy application source
COPY app/ app/
# Copy Alembic migration files
COPY alembic/ alembic/
COPY alembic.ini .
# Ensure appuser owns the working directory
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "4", \
"--timeout", "120"]

793
README.md Normal file
View File

@@ -0,0 +1,793 @@
# Adiuva Cloud API
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
---
## Table of Contents
- [Overview](#overview)
- [Architecture](#architecture)
- [Key Features](#key-features)
- [Tech Stack](#tech-stack)
- [Getting Started](#getting-started)
- [Docker Deployment](#docker-deployment)
- [Environment Variables](#environment-variables)
- [API Reference](#api-reference)
- [Data Model](#data-model)
- [AI Agent System](#ai-agent-system)
- [Orchestration & Execution Plans](#orchestration--execution-plans)
- [Middleware](#middleware)
- [Storage Layer](#storage-layer)
- [Billing & Tiers](#billing--tiers)
- [Plugin Marketplace](#plugin-marketplace)
- [Testing](#testing)
- [Project Structure](#project-structure)
- [License](#license)
---
## Overview
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
### Design Principles
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
---
## Architecture
```
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
│ Desktop App │────▶│ │
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
└──────────────┘ │ │
│ ┌──────────────────┐ ┌────────────────────────────┐ │
│ │ Auth Routes │ │ Chat Routes │ │
│ │ Billing Routes │ │ ↓ │ │
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
│ │ Backup Routes │ │ ↓ classify intent │ │
│ │ Plugin Routes │ │ Agent Registry │ │
│ │ Vector Routes │ │ ↓ │ │
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
│ │ (GPT-4o + LangChain) │ │
│ └────────────────────────────┘ │
└────────────────────────────────────────────────────────┘
│ │ │
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
│ Billing, │ │ backups) │ │ (Vectors) │
│ Metadata) │ └───────────────┘ └────────────────┘
└────────────┘
┌────────▼───┐
│ Stripe │
│ (Billing, │
│ Connect) │
└────────────┘
```
---
## Key Features
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Checkpoints (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
---
## Tech Stack
| Package | Version | Purpose |
|---|---|---|
| `fastapi` | ≥ 0.115.0 | Web framework |
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
| `gunicorn` | ≥ 22.0.0 | Production process manager |
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
| `boto3` | ≥ 1.35.0 | AWS S3 client |
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
| `alembic` | ≥ 1.14.0 | Database migration management |
| `bcrypt` | ≥ 4.2.0 | Password hashing |
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
| `websockets` | ≥ 14.0 | WebSocket protocol support |
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
| `pytest` | ≥ 8.0.0 | Test framework |
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
| `ruff` | ≥ 0.8.0 | Linter and formatter |
---
## Getting Started
### Prerequisites
- Python 3.12+
- PostgreSQL 16+
- An OpenAI API key (for LLM features)
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
- AWS credentials (optional — needed for S3 storage in production)
### Installation
```bash
# Clone the repository
git clone <repo-url> && cd adiuva-api
# Create a virtual environment
python -m venv .venv && source .venv/bin/activate
# Install dependencies
pip install -r requirements.txt
# Configure environment
cp .env.example .env
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
```
### Database Setup
```bash
# Start PostgreSQL (or use the Docker Compose database)
docker compose up db -d
# Run migrations
alembic upgrade head
```
### Run the Development Server
```bash
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
---
## Docker Deployment
### Quick Start
```bash
docker compose up --build
```
This starts two services:
- **app** — FastAPI server on port `8000`
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
The compose file also includes optional services for fully local deployments:
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
### Dockerfile Details
The Dockerfile uses a multi-stage build:
1. **Builder stage** — Installs Python dependencies into a virtual environment.
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
```bash
# Production command (run by the container)
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
```
---
## Homelab / Self-Hosted Deployment
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
### 1. Start all services
```bash
docker compose up -d
```
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
### 2. Create the MinIO bucket
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
```bash
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
docker compose exec minio mc mb local/adiuva
```
### 3. Configure your `.env`
```bash
# Database (uses the compose PostgreSQL)
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
# S3 → MinIO
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=http://minio:9000
AWS_ACCESS_KEY_ID=minioadmin
AWS_SECRET_ACCESS_KEY=minioadmin
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
QDRANT_URL=http://qdrant:6333
QDRANT_API_KEY=
PINECONE_API_KEY=
# Billing — leave empty to stub (no Stripe needed)
STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=
# LLM — the only external service
OPENAI_API_KEY=sk-...
LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# Auth
JWT_SECRET=your-secret-here
ENV=dev
```
### 4. Run migrations
```bash
docker compose exec app alembic upgrade head
```
### What runs where
| Service | Runs on | Port | Notes |
|---|---|---|---|
| FastAPI app | Docker | 8000 | API server |
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
| Stripe | — | — | Stubbed when keys are empty |
| OpenAI / LLM | Cloud | — | Only external dependency |
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
---
## Environment Variables
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
| Variable | Type | Default | Description |
|---|---|---|---|
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
| `S3_REGION` | `str` | `us-east-1` | AWS region |
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
---
## API Reference
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
### Health
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
### Auth
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
### Chat
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
### Plans
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
### Storage (Cloud Records)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
### Vectors (Cloud Vector Store)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
### Backup
| Method | Path | Auth | Description |
|---|---|---|---|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
### Plugins (Marketplace)
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
### Billing
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
---
## Data Model
9 tables managed by Alembic migrations. Source: `app/models.py`
### Tables
| Table | Primary Key | Key Columns | Purpose |
|---|---|---|---|
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
### Enum Types
| Enum | Values |
|---|---|
| `billing_tier` | `free`, `pro`, `power`, `team` |
| `plugin_status` | `pending_review`, `approved`, `rejected` |
| `review_decision` | `approved`, `rejected` |
### Migrations
| Version | Description |
|---|---|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
---
## AI Agent System
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
### Architecture
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
### Registered Agents
| Agent | Registry Name | Tools | Description |
|---|---|---|---|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
| **CheckpointAgent** | `checkpoint_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_checkpoints`, `create_checkpoint`, `update_checkpoint`, `delete_checkpoint` |
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
### Switching LLM Providers
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
```bash
# OpenAI (default)
LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# Anthropic
LLM_MODEL=anthropic/claude-3.5-sonnet
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
# Google Gemini
LLM_MODEL=gemini/gemini-pro
LLM_ROUTER_MODEL=gemini/gemini-flash
# Local Ollama
LLM_MODEL=ollama/llama3
LLM_ROUTER_MODEL=ollama/llama3
# AWS Bedrock
LLM_MODEL=bedrock/anthropic.claude-v2
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
```
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
---
## Orchestration & Execution Plans
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
### Orchestrator
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
### Execution Plans
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
### Built-in Templates (6)
`tpl_task_agent_default`, `tpl_checkpoint_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
### Built-in Playbooks (2)
| Playbook | Description |
|---|---|
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
---
## Middleware
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
### JWT Authentication
Source: `app/api/middleware/auth.py`
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
- Falls back to `free` when no subscription row exists.
- Raises `401 Unauthorized` on invalid or expired tokens.
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
### Tier-Based Rate Limiter
Source: `app/api/middleware/rate_limit.py`
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
- Per-user 60-second window sized by subscription tier:
| Tier | Requests / Minute |
|---|---|
| Free | 20 |
| Pro | 60 |
| Power | 120 |
| Team | 200 |
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
- **Exempt paths:** register, login, webhook, health
### Response Sanitizer
Source: `app/api/middleware/sanitizer.py`
- Runs only on `/api/v1/chat` endpoints.
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
- Logs sanitization events as `WARNING`.
- Binary responses (storage, backup) are never touched.
---
## Storage Layer
### Blob Store
Source: `app/storage/blob_store.py`
- S3-backed storage for E2E encrypted blobs.
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
- The backend **never inspects or decrypts blob content**.
### Vector Store
Source: `app/storage/vector_store.py`
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
- Methods: `upsert()`, `search()`, `delete()`
### Encryption Utilities
Source: `app/storage/encryption.py`
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
- **No decryption key ever reaches the backend.**
---
## Billing & Tiers
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
### Feature Matrix
| Feature | Free | Pro | Power | Team |
|---|---|---|---|---|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
| Batch Active | 2 | 10 | Unlimited | Unlimited |
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
| Batch Builder | — | — | ✓ | ✓ |
| Plugin Marketplace | — | — | ✓ | ✓ |
| SSO | — | — | — | ✓ |
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
### Stripe Integration
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
### Tier Manager
- `get_tier(user_id)` — Returns the user's current billing tier.
- `check_feature(tier, feature)` — Boolean feature gate check.
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
---
## Plugin Marketplace
Source: `app/marketplace/`
### Plugin Registry
- PostgreSQL-backed catalog of submitted and approved plugins.
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
### Review Queue
- Automated security checklist before human review:
- Plugin ID must match `^[a-z0-9-]+$`
- Permissions must be from the allowed set only
- No binary blobs in the manifest
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:checkpoints`, `write:checkpoints`, `read:calendar`, `write:calendar`
- `get_pending(db)` — Lists plugins awaiting review.
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
### Revenue Sharing
- **70% developer / 30% platform** split on all paid plugin sales.
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
- Gracefully stubs transfers when Stripe is not configured.
### Seed Plugins
| Plugin | Category | Price |
|---|---|---|
| GitHub Sync | Productivity | Free |
| Slack Notifier | Communication | €4.99 |
| Time Tracker | Productivity | €9.99 |
---
## Testing
### Running Tests
```bash
# Run all tests
pytest
# Run a specific test file
pytest tests/test_auth.py
# Run with verbose output
pytest -v
```
### Test Infrastructure
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
- **No external dependencies** — all tests run fully offline.
### Test Coverage
| File | Coverage |
|---|---|
| `test_auth.py` | Register, login, token access, refresh, expiration |
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
---
## Project Structure
```
adiuva-api/
├── alembic.ini # Alembic configuration
├── BACKEND_PLAN.md # Architecture & design decisions
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
├── Dockerfile # Multi-stage production build
├── requirements.txt # Python dependencies
├── alembic/ # Database migrations
│ ├── env.py # Alembic environment config
│ ├── script.py.mako # Migration template
│ └── versions/
│ ├── 001_initial_schema.py # Tables, indexes, FKs
│ └── 002_seed_plugins.py # Seed marketplace plugins
├── app/ # Application source
│ ├── main.py # FastAPI app factory, middleware, routes
│ ├── db.py # Async SQLAlchemy engine & session
│ ├── models.py # SQLAlchemy ORM models (9 tables)
│ ├── schemas.py # Pydantic request/response schemas
│ │
│ ├── config/
│ │ └── settings.py # Pydantic Settings (env vars)
│ │
│ ├── agents/ # LLM-powered domain agents
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
│ │ ├── project_agent.py # Project lifecycle (6 tools)
│ │ ├── checkpoint_agent.py # Milestones (4 tools)
│ │ └── note_agent.py # Markdown notes (5 tools)
│ │
│ ├── core/ # Orchestration engine
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
│ │ ├── orchestrator.py # Intent classification & routing
│ │ └── execution_plan.py # Plan builder, templates, cache
│ │
│ ├── api/ # HTTP layer
│ │ ├── deps.py # Shared FastAPI dependencies
│ │ ├── middleware/
│ │ │ ├── auth.py # JWT validation, live tier lookup
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
│ │ │ └── sanitizer.py # Prompt IP leak protection
│ │ └── routes/
│ │ ├── auth.py # Register, login, refresh, me
│ │ ├── chat.py # Chat + WebSocket streaming
│ │ ├── plans.py # Execution plan playbooks
│ │ ├── storage.py # E2E encrypted record CRUD
│ │ ├── vectors.py # Vector upsert, search, delete
│ │ ├── backup.py # Encrypted backup management
│ │ ├── plugins.py # Marketplace browse & install
│ │ └── billing.py # Stripe checkout & webhooks
│ │
│ ├── storage/ # Storage backends
│ │ ├── blob_store.py # S3 blob storage
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
│ │ └── encryption.py # Checksum verification utilities
│ │
│ ├── billing/ # Subscription management
│ │ ├── stripe_service.py # Stripe API integration
│ │ └── tier_manager.py # Feature matrix & quota enforcement
│ │
│ └── marketplace/ # Plugin ecosystem
│ ├── plugin_registry.py # Catalog CRUD & search
│ ├── plugin_review.py # Security checklist & review queue
│ └── revenue_share.py # 70/30 split & Stripe Connect
└── tests/ # Test suite
├── conftest.py # Fixtures: DB, S3, auth, seeds
├── test_auth.py
├── test_orchestrator.py
├── test_agents.py
├── test_storage.py
├── test_backup.py
├── test_plugins.py
├── test_agent_registry.py
├── test_execution_plan.py
└── test_middleware.py
```
---
## License
*To be determined.*

47
alembic.ini Normal file
View File

@@ -0,0 +1,47 @@
# Alembic configuration file.
# The async app uses postgresql+asyncpg:// at runtime.
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
[alembic]
script_location = alembic
prepend_sys_path = .
version_path_separator = os
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
sqlalchemy.url = driver://user:pass@localhost/dbname
[post_write_hooks]
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

93
alembic/env.py Normal file
View File

@@ -0,0 +1,93 @@
"""Alembic migration environment — async-compatible.
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
env var by replacing the driver prefix.
Run migrations with:
alembic upgrade head
"""
from __future__ import annotations
import asyncio
import os
import re
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlalchemy.ext.asyncio import create_async_engine
# Alembic Config object (gives access to alembic.ini values).
config = context.config
# Set up Python logging from alembic.ini.
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Import the Base so that Alembic can detect model changes for --autogenerate.
from app.models import Base # noqa: E402
target_metadata = Base.metadata
def _sync_url(async_url: str) -> str:
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
def _get_url() -> str:
db_url = os.environ.get("DATABASE_URL", "")
if not db_url:
# Fall back to settings if env var not set directly.
from app.config.settings import settings # noqa: PLC0415
db_url = settings.DATABASE_URL
return _sync_url(db_url)
def run_migrations_offline() -> None:
"""Emit SQL without a live DB connection."""
url = _get_url()
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
def do_run_migrations(connection): # type: ignore[no-untyped-def]
context.configure(
connection=connection,
target_metadata=target_metadata,
compare_type=True,
)
with context.begin_transaction():
context.run_migrations()
async def run_migrations_online_async() -> None:
"""Run migrations against a live DB using the async engine."""
async_url = os.environ.get("DATABASE_URL", "")
if not async_url:
from app.config.settings import settings # noqa: PLC0415
async_url = settings.DATABASE_URL
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
async with connectable.connect() as connection:
await connection.run_sync(do_run_migrations)
await connectable.dispose()
def run_migrations_online() -> None:
asyncio.run(run_migrations_online_async())
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

28
alembic/script.py.mako Normal file
View File

@@ -0,0 +1,28 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,202 @@
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
Revision ID: 001
Revises:
Create Date: 2026-03-02
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "001"
down_revision: Union[str, None] = None
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── Enum types ────────────────────────────────────────────────────────
billing_tier = postgresql.ENUM(
"free", "pro", "power", "team", name="billing_tier", create_type=False
)
plugin_status = postgresql.ENUM(
"pending_review", "approved", "rejected", name="plugin_status", create_type=False
)
review_decision = postgresql.ENUM(
"approved", "rejected", name="review_decision", create_type=False
)
for enum in (billing_tier, plugin_status, review_decision):
enum.create(op.get_bind(), checkfirst=True)
# ── users ─────────────────────────────────────────────────────────────
op.create_table(
"users",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("email", sa.String(255), nullable=False),
sa.Column("password_hash", sa.String(255), nullable=False),
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("email"),
)
op.create_index("ix_users_email", "users", ["email"])
# ── refresh_tokens ────────────────────────────────────────────────────
op.create_table(
"refresh_tokens",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("token_hash", sa.String(64), nullable=False),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("token_hash"),
)
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
# ── subscriptions ─────────────────────────────────────────────────────
op.create_table(
"subscriptions",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
sa.Column("tier", sa.Enum("free", "pro", "power", "team", name="billing_tier"), nullable=False, server_default="free"),
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("user_id"),
)
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
# ── storage_records ───────────────────────────────────────────────────
op.create_table(
"storage_records",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("table_name", sa.String(100), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
# ── backup_metadata ───────────────────────────────────────────────────
op.create_table(
"backup_metadata",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column("timestamp", sa.BigInteger, nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
# ── plugins ───────────────────────────────────────────────────────────
op.create_table(
"plugins",
sa.Column("id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("description", sa.Text, nullable=False, server_default=""),
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
sa.Column("category", sa.String(100), nullable=False, server_default=""),
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
sa.Column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status"), nullable=False, server_default="pending_review"),
sa.Column("s3_package_key", sa.String(500), nullable=True),
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
sa.Column("rejection_reason", sa.Text, nullable=True),
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
)
# ── plugin_installations ──────────────────────────────────────────────
op.create_table(
"plugin_installations",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
)
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
# ── plugin_reviews ────────────────────────────────────────────────────
op.create_table(
"plugin_reviews",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("decision", sa.Enum("approved", "rejected", name="review_decision"), nullable=False),
sa.Column("notes", sa.Text, nullable=True),
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
)
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
# ── revenue_events ────────────────────────────────────────────────────
op.create_table(
"revenue_events",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
def downgrade() -> None:
op.drop_table("revenue_events")
op.drop_table("plugin_reviews")
op.drop_table("plugin_installations")
op.drop_table("plugins")
op.drop_table("backup_metadata")
op.drop_table("storage_records")
op.drop_table("subscriptions")
op.drop_table("refresh_tokens")
op.drop_table("users")
op.execute("DROP TYPE IF EXISTS review_decision")
op.execute("DROP TYPE IF EXISTS plugin_status")
op.execute("DROP TYPE IF EXISTS billing_tier")

View File

@@ -0,0 +1,92 @@
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
Revision ID: 002
Revises: 001
Create Date: 2026-03-03
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_SEED_PLUGINS = [
{
"id": "plugin-github-sync",
"name": "GitHub Sync",
"description": "Sync tasks with GitHub Issues and pull requests.",
"version": "1.0.0",
"author_name": "Adiuva",
"category": "productivity",
"price_cents": 0,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-slack-notify",
"name": "Slack Notifier",
"description": "Post task and checkpoint updates to Slack channels.",
"version": "1.2.0",
"author_name": "Adiuva",
"category": "communication",
"price_cents": 499,
"permissions": json.dumps(["read:tasks", "read:checkpoints"]),
"status": "approved",
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-time-tracker",
"name": "Time Tracker",
"description": "Track time spent on tasks with automatic reporting.",
"version": "0.9.1",
"author_name": "Third Party",
"category": "productivity",
"price_cents": 999,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
]
def upgrade() -> None:
plugins = sa.table(
"plugins",
sa.column("id", sa.String),
sa.column("name", sa.String),
sa.column("description", sa.Text),
sa.column("version", sa.String),
sa.column("author_name", sa.String),
sa.column("category", sa.String),
sa.column("price_cents", sa.Integer),
sa.column("permissions", sa.Text),
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
sa.column("s3_package_key", sa.String),
sa.column("install_count", sa.Integer),
sa.column("avg_rating", sa.Float),
)
op.bulk_insert(plugins, _SEED_PLUGINS)
def downgrade() -> None:
op.execute(
"DELETE FROM plugins WHERE id IN ("
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
")"
)

0
app/__init__.py Normal file
View File

5
app/agents/__init__.py Normal file
View File

@@ -0,0 +1,5 @@
"""Import all agent modules to trigger @registry.register decorators."""
from app.agents import checkpoint_agent, note_agent, project_agent, task_agent
__all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"]

View File

@@ -0,0 +1,121 @@
"""Checkpoint agent — project milestone management (list, create, update, delete)."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
_SYSTEM_PROMPT = (
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n"
" - is_approved: 0 until the user explicitly confirms; then 1\n"
" - For update_checkpoint, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all checkpoints across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_checkpoints(project_id: str = "") -> str:
"""List checkpoints. Provide project_id to scope to a specific project."""
return json.dumps({
"action": "list",
"table": "checkpoints",
"filters": {"projectId": project_id or None},
})
@tool
async def create_checkpoint(
project_id: str,
title: str,
date: int,
is_ai_suggested: int = 0,
is_approved: int = 0,
) -> str:
"""Create a project checkpoint (milestone).
project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone
date: Unix timestamp in milliseconds
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
is_approved: 0 until the user confirms
"""
return json.dumps({
"action": "create_record",
"table": "checkpoints",
"data": {
"projectId": project_id,
"title": title,
"date": date,
"isAiSuggested": is_ai_suggested,
"isApproved": is_approved,
},
})
@tool
async def update_checkpoint(
checkpoint_id: str,
title: str = "",
date: int = -1,
is_approved: int = -1,
) -> str:
"""Update a checkpoint. Only pass fields that should change.
checkpoint_id: UUID of the checkpoint (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp)
is_approved: -1 means unchanged; 0 or 1 sets the approval state
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
if is_approved != -1:
updates["isApproved"] = is_approved
return json.dumps({
"action": "update_record",
"table": "checkpoints",
"data": {"id": checkpoint_id, "updates": updates},
})
@tool
async def delete_checkpoint(checkpoint_id: str) -> str:
"""Delete a checkpoint permanently by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "checkpoints",
"data": {"id": checkpoint_id},
})
@registry.register
class CheckpointAgent(ChatAgent):
def get_name(self) -> str:
return "checkpoint_agent"
def get_description(self) -> str:
return "Manages project checkpoints (milestones): list, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

122
app/agents/note_agent.py Normal file
View File

@@ -0,0 +1,122 @@
"""Note agent — Markdown note management (list, get, create, update, delete)."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
return json.dumps({
"action": "list",
"table": "notes",
"filters": {"projectId": project_id or None},
})
@tool
async def get_note(note_id: str) -> str:
"""Fetch a single note by its UUID to read its full Markdown content."""
return json.dumps({
"action": "get",
"table": "notes",
"data": {"id": note_id},
})
@tool
async def create_note(
title: str,
content: str,
project_id: str = "",
) -> str:
"""Create a new note.
title: note heading (required)
content: Markdown body text (required)
project_id: optional UUID linking this note to a project
"""
return json.dumps({
"action": "create_record",
"table": "notes",
"data": {
"title": title,
"content": content,
"projectId": project_id or None,
},
})
@tool
async def update_note(
note_id: str,
title: str = "",
content: str = "",
) -> str:
"""Update an existing note. Only pass fields that should change.
note_id: UUID of the note (required)
If you need to preserve existing content, call get_note first.
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if content:
updates["content"] = content
return json.dumps({
"action": "update_record",
"table": "notes",
"data": {"id": note_id, "updates": updates},
})
@tool
async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "notes",
"data": {"id": note_id},
})
@registry.register
class NoteAgent(ChatAgent):
def get_name(self) -> str:
return "note_agent"
def get_description(self) -> str:
return "Manages notes: list, get, create, update, delete"
def get_tools(self) -> list[Any]:
return [list_notes, get_note, create_note, update_note, delete_note]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

157
app/agents/project_agent.py Normal file
View File

@@ -0,0 +1,157 @@
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool
async def list_projects(
client_id: str = "",
include_archived: int = 0,
) -> str:
"""List projects, optionally filtered by client_id.
include_archived: 1 to include archived projects, 0 for active only (default).
"""
return json.dumps({
"action": "list",
"table": "projects",
"filters": {
"clientId": client_id or None,
"includeArchived": bool(include_archived),
},
})
@tool
async def list_all_projects() -> str:
"""List every project regardless of client or status.
Use only when the user wants a complete cross-client overview.
"""
return json.dumps({
"action": "list_all",
"table": "projects",
})
@tool
async def get_project(project_id: str) -> str:
"""Fetch a single project by its UUID."""
return json.dumps({
"action": "get",
"table": "projects",
"data": {"id": project_id},
})
@tool
async def create_project(
name: str,
client_id: str = "",
) -> str:
"""Create a new project.
name: human-readable project name (required)
client_id: optional UUID of the owning client
"""
return json.dumps({
"action": "create_record",
"table": "projects",
"data": {
"name": name,
"clientId": client_id or None,
},
})
@tool
async def update_project(
project_id: str,
name: str = "",
client_id: str = "",
status: str = "",
ai_summary: str = "",
) -> str:
"""Update a project. Only pass fields that should change.
project_id: UUID of the project (required)
status: active | archived
ai_summary: AI-generated summary text (populate only when explicitly requested)
"""
updates: dict[str, Any] = {}
if name:
updates["name"] = name
if client_id:
updates["clientId"] = client_id
if status:
updates["status"] = status
if ai_summary:
updates["aiSummary"] = ai_summary
return json.dumps({
"action": "update_record",
"table": "projects",
"data": {"id": project_id, "updates": updates},
})
@tool
async def delete_project(project_id: str) -> str:
"""Permanently delete a project and orphan its tasks.
IMPORTANT: prefer update_project(status='archived') unless the user
has explicitly confirmed they want permanent deletion.
"""
return json.dumps({
"action": "delete_record",
"table": "projects",
"data": {"id": project_id},
})
@registry.register
class ProjectAgent(ChatAgent):
def get_name(self) -> str:
return "project_agent"
def get_description(self) -> str:
return "Manages projects: list, get, create, update, archive, delete"
def get_tools(self) -> list[Any]:
return [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

228
app/agents/task_agent.py Normal file
View File

@@ -0,0 +1,228 @@
"""Task agent — full CRUD for tasks and task comments."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from langchain_core.tools import tool
from app.core.agent_registry import ChatAgent, registry
from app.core.llm import get_llm
_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ────────────────────────────────────────────────────────
@tool
async def list_tasks(
project_id: str = "",
status: str = "",
search: str = "",
order_by: str = "",
) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
a search string, or an order_by field name (dueDate|priority|createdAt)."""
return json.dumps({
"action": "list",
"table": "tasks",
"filters": {
"projectId": project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
},
})
@tool
async def create_task(
title: str,
description: str = "",
status: str = "todo",
priority: str = "medium",
assignees: str = "[]",
due_date: int = 0,
project_id: str = "",
is_ai_suggested: int = 0,
is_approved: int = 0,
) -> str:
"""Create a new task.
title: task title (required)
description: optional details
status: todo | in_progress | done (default: todo)
priority: high | medium | low (default: medium)
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
is_approved: 0 until the user confirms; 1 when confirmed
"""
return json.dumps({
"action": "create_record",
"table": "tasks",
"data": {
"title": title,
"description": description or None,
"status": status,
"priority": priority,
"assignee": assignees,
"dueDate": due_date or None,
"projectId": project_id or None,
"isAiSuggested": is_ai_suggested,
"isApproved": is_approved,
},
})
@tool
async def update_task(
task_id: str,
title: str = "",
description: str = "",
status: str = "",
priority: str = "",
assignees: str = "",
due_date: int = -1,
project_id: str = "",
is_approved: int = -1,
) -> str:
"""Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
is_approved: -1 means unchanged; 0 or 1 sets the value
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if description:
updates["description"] = description
if status:
updates["status"] = status
if priority:
updates["priority"] = priority
if assignees:
updates["assignee"] = assignees
if due_date != -1:
updates["dueDate"] = due_date or None
if project_id:
updates["projectId"] = project_id
if is_approved != -1:
updates["isApproved"] = is_approved
return json.dumps({
"action": "update_record",
"table": "tasks",
"data": {"id": task_id, "updates": updates},
})
@tool
async def delete_task(task_id: str) -> str:
"""Delete a task permanently by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "tasks",
"data": {"id": task_id},
})
@tool
async def list_tasks_due_today() -> str:
"""List all tasks whose due date falls on today's date."""
return json.dumps({
"action": "list_due_today",
"table": "tasks",
})
# ── Task comment tools ────────────────────────────────────────────────
@tool
async def list_task_comments(task_id: str) -> str:
"""List all comments on a task by its UUID."""
return json.dumps({
"action": "list",
"table": "taskComments",
"filters": {"taskId": task_id},
})
@tool
async def add_task_comment(task_id: str, author: str, content: str) -> str:
"""Add a comment to a task.
task_id: UUID of the task to comment on
author: name or ID of the comment author
content: comment text
"""
return json.dumps({
"action": "create_record",
"table": "taskComments",
"data": {
"taskId": task_id,
"author": author,
"content": content,
},
})
@tool
async def delete_task_comment(comment_id: str) -> str:
"""Delete a task comment by its UUID."""
return json.dumps({
"action": "delete_record",
"table": "taskComments",
"data": {"id": comment_id},
})
# ── Agent ─────────────────────────────────────────────────────────────
@registry.register
class TaskAgent(ChatAgent):
def get_name(self) -> str:
return "task_agent"
def get_description(self) -> str:
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
def get_tools(self) -> list[Any]:
return [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]
async def handle(self, query: str, context: dict[str, Any]) -> str:
llm = get_llm()
messages = [
SystemMessage(content=_SYSTEM_PROMPT),
HumanMessage(
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
),
]
return await self._tool_loop(llm, messages, self.get_tools())

0
app/api/__init__.py Normal file
View File

14
app/api/deps.py Normal file
View File

@@ -0,0 +1,14 @@
"""Shared FastAPI dependencies.
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
(the canonical location per Step 9). This module re-exports them so that all
existing route imports (``from app.api.deps import get_current_user``) continue
to work without modification.
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
instead of reading it from the JWT payload.
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
__all__ = ["get_current_user", "oauth2_scheme"]

View File

@@ -0,0 +1,19 @@
"""API middleware package.
Exports the three middleware components introduced in Step 9:
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
- Sanitizer: ``SanitizerMiddleware``
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
from app.api.middleware.sanitizer import SanitizerMiddleware
__all__ = [
"get_current_user",
"oauth2_scheme",
"TierRateLimitMiddleware",
"limiter",
"SanitizerMiddleware",
]

View File

@@ -0,0 +1,65 @@
"""Auth middleware — JWT validation dependency.
``get_current_user`` is the FastAPI dependency used by all protected routes.
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
from the ``subscriptions`` table so that tier changes take effect immediately
without requiring token re-issue.
Exempt routes (no JWT required):
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.db import get_session
from app.schemas import UserProfile
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
The JWT is used for identity and expiry only. The tier is fetched live
from the ``subscriptions`` table so that upgrades/downgrades take effect
immediately. Falls back to ``'free'`` when no subscription row exists.
Raises HTTP 401 on any invalid or expired token.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
# Live tier lookup — subscription row is the authoritative source.
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str = result.scalar_one_or_none() or "free"
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]

View File

@@ -0,0 +1,129 @@
"""Tier-aware rate limiting middleware.
Uses a per-user sliding-window counter (in-process, no Redis required).
The ``slowapi`` Limiter is also exported for optional route-level decoration.
Limits (requests per minute):
- free: 20
- pro: 60
- power: 120
- team: 200
Exempt paths bypass the limiter entirely:
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
- GET /api/v1/health
"""
from __future__ import annotations
import json
import time
from collections import defaultdict
from fastapi import Request, Response
from jose import JWTError, jwt
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.config.settings import settings
_TIER_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/billing/webhook",
"/api/v1/health",
}
)
def _get_user_id_from_jwt(request: Request) -> str:
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return get_remote_address(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
return payload.get("sub") or get_remote_address(request)
except JWTError:
return get_remote_address(request)
# Exported Limiter instance — available for optional route-level decoration.
limiter = Limiter(key_func=_get_user_id_from_jwt)
class TierRateLimitMiddleware(BaseHTTPMiddleware):
"""Sliding-window rate limiter applied globally across all non-exempt routes.
Each authenticated user gets their own 60-second window sized by tier.
Unauthenticated requests pass through (the auth dependency will reject them
with 401 before the route handler runs).
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
# user_id → list of request timestamps (float, seconds since epoch)
self._window: dict[str, list[float]] = defaultdict(list)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
if request.url.path in _EXEMPT_PATHS:
return await call_next(request)
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return await call_next(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str = payload.get("sub") or get_remote_address(request)
tier: str = payload.get("tier", "free")
except JWTError:
return await call_next(request)
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
now = time.monotonic()
window_start = now - 60.0
# Slide the window: discard timestamps older than 60 seconds.
timestamps = [t for t in self._window[user_id] if t > window_start]
if len(timestamps) >= limit:
retry_after = max(1, int(60 - (now - min(timestamps))))
return Response(
content=json.dumps(
{
"detail": (
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
f"Retry in {retry_after}s."
)
}
),
status_code=429,
headers={
"Retry-After": str(retry_after),
"Content-Type": "application/json",
},
)
timestamps.append(now)
self._window[user_id] = timestamps
return await call_next(request)

View File

@@ -0,0 +1,139 @@
"""Response sanitizer middleware.
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
that could reveal server-side prompt IP:
- System prompt openers ("You are a/an/the …")
- Agent routing metadata ("Available agents:", "intent classifier", …)
- LangChain tool schema fragments (``"type": "function"``)
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints
Binary responses (storage blobs, backup data) are never touched — the
middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified.
"""
from __future__ import annotations
import json
import logging
import re
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Detection patterns — order matters: fingerprints checked first (exact),
# then compiled regexes.
# ---------------------------------------------------------------------------
_FINGERPRINTS: tuple[str, ...] = (
"You are an intent classifier",
"Respond with just the agent name",
"Summarize these agent results",
"Available agents:",
"route to:",
)
_PATTERNS: tuple[re.Pattern[str], ...] = (
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
re.compile(r"Available agents\s*:", re.IGNORECASE),
re.compile(r"\bintent classifier\b", re.IGNORECASE),
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
re.compile(r"route\s+to\s*:", re.IGNORECASE),
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
)
def _sanitize_text(text: str) -> tuple[str, bool]:
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
Returns ``(cleaned_text, was_changed)``.
"""
# Fingerprint check — if any exact phrase is present, redact the whole string.
for fp in _FINGERPRINTS:
if fp in text:
return "[REDACTED]", True
changed = False
for pattern in _PATTERNS:
new_text, n = pattern.subn("[REDACTED]", text)
if n:
text = new_text
changed = True
return text, changed
class SanitizerMiddleware(BaseHTTPMiddleware):
"""Strip prompt IP from /api/v1/chat JSON responses."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
response: Response = await call_next(request)
# Only process chat endpoint responses.
if not request.url.path.startswith("/api/v1/chat"):
return response
# Read body — collect streaming chunks.
body_bytes = b""
async for chunk in response.body_iterator:
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
try:
body = json.loads(body_bytes.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
if not isinstance(body, dict):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Walk top-level string fields and sanitise.
sanitised_fields: list[str] = []
for key, value in body.items():
if isinstance(value, str):
cleaned, changed = _sanitize_text(value)
if changed:
body[key] = cleaned
sanitised_fields.append(key)
if sanitised_fields:
logger.warning(
"Sanitizer redacted prompt fragments",
extra={
"path": request.url.path,
"fields": sanitised_fields,
},
)
new_body = json.dumps(body).encode("utf-8")
headers = dict(response.headers)
headers["content-length"] = str(len(new_body))
return Response(
content=new_body,
status_code=response.status_code,
headers=headers,
media_type="application/json",
)

View File

197
app/api/routes/auth.py Normal file
View File

@@ -0,0 +1,197 @@
"""Auth routes: register, login, refresh, me.
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
SHA-256 hashes so plaintext never reaches the DB.
"""
from __future__ import annotations
import hashlib
import time
import uuid
from datetime import datetime, timedelta, timezone
import bcrypt
from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.config.settings import settings
from app.db import get_session
from app.models import RefreshToken, User
from app.schemas import AuthTokens, UserProfile
router = APIRouter(prefix="/auth", tags=["auth"])
# ── Internal helpers ─────────────────────────────────────────────────
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def _verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def _hash_token(plain_token: str) -> str:
"""SHA-256 of the plain refresh token string."""
return hashlib.sha256(plain_token.encode()).hexdigest()
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"""Return (signed JWT, expires_at_ms)."""
now = int(time.time())
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
payload = {
"sub": user_id,
"email": email,
"tier": tier,
"exp": exp,
"iat": now,
}
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
return token, exp * 1000 # ms for client
# ── Request bodies ────────────────────────────────────────────────────
class _RegisterRequest(BaseModel):
email: str
password: str
class _LoginRequest(BaseModel):
email: str
password: str
class _RefreshRequest(BaseModel):
refresh_token: str
# ── Routes ────────────────────────────────────────────────────────────
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
async def register(
body: _RegisterRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Create a new account and return JWT tokens."""
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
user = User(
id=str(uuid.uuid4()),
email=body.email,
password_hash=_hash_password(body.password),
tier="free",
)
db.add(user)
await db.flush() # get user.id without committing
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/login", response_model=AuthTokens)
async def login(
body: _LoginRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate credentials and return JWT tokens."""
result = await db.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not _verify_password(body.password, user.password_hash):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/refresh", response_model=AuthTokens)
async def refresh(
body: _RefreshRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Rotate a refresh token and return a new token pair."""
token_hash = _hash_token(body.refresh_token)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
rt = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
# Rotate: delete old token, issue new one.
await db.delete(rt)
user_result = await db.execute(select(User).where(User.id == rt.user_id))
user = user_result.scalar_one_or_none()
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
plain_token = str(uuid.uuid4())
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=new_expires,
)
db.add(new_rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.get("/me", response_model=UserProfile)
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
"""Return the profile for the authenticated user."""
return current_user

171
app/api/routes/backup.py Normal file
View File

@@ -0,0 +1,171 @@
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
PostgreSQL ``backup_metadata`` table.
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
treating "history" as a ``{backup_id}`` path parameter.
"""
from __future__ import annotations
import uuid
from email.utils import parsedate_to_datetime
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import BackupMetadata as BackupMetadataModel
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/backup", tags=["backup"])
_blob_store = BlobStore()
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total backup bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
BackupMetadataModel.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_backup_quota(
user: UserProfile, size_bytes: int, db: AsyncSession
) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
current = await _current_backup_bytes(user.id, db)
tier_manager.enforce_backup_quota(
user.tier, current_bytes=current, additional_bytes=size_bytes
)
@router.put("")
async def upload_backup(
request: Request,
x_backup_version: int = Header(..., alias="X-Backup-Version"),
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Upload an E2E-encrypted backup blob.
Metadata is passed via custom headers; the raw body is the encrypted blob.
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
await _check_backup_quota(current_user, len(blob), db)
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
)
row = BackupMetadataModel(
id=str(uuid.uuid4()),
user_id=current_user.id,
s3_key=s3_key,
version=x_backup_version,
timestamp=x_backup_timestamp,
checksum=x_backup_checksum,
size_bytes=len(blob),
)
db.add(row)
await db.commit()
return {"ok": True}
@router.get("/history", response_model=list[BackupMetadata])
async def backup_history(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[BackupMetadata]:
"""Return backup metadata records for the authenticated user (no blob bytes)."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
)
rows = result.scalars().all()
return [
BackupMetadata(
version=r.version,
timestamp=r.timestamp,
checksum=r.checksum,
chunk_count=1,
)
for r in rows
]
@router.get("")
async def download_backup(
request: Request,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
.limit(1)
)
latest = result.scalar_one_or_none()
if latest is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
ims_header = request.headers.get("If-Modified-Since")
if ims_header:
try:
ims_dt = parsedate_to_datetime(ims_header)
ims_ms = int(ims_dt.timestamp() * 1000)
if latest.timestamp <= ims_ms:
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
except Exception:
pass # malformed header — ignore and serve the blob
blob = await _blob_store.download(current_user.id, latest.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={
"X-Backup-Version": str(latest.version),
"X-Backup-Timestamp": str(latest.timestamp),
"X-Checksum": latest.checksum,
},
)
@router.delete("/{backup_id}", response_model=dict)
async def delete_backup(
backup_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a specific backup by ID."""
result = await db.execute(
select(BackupMetadataModel).where(
BackupMetadataModel.id == backup_id,
BackupMetadataModel.user_id == current_user.id,
)
)
target = result.scalar_one_or_none()
if target is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
await _blob_store.delete(current_user.id, target.s3_key)
await db.delete(target)
await db.commit()
return {"ok": True}

85
app/api/routes/billing.py Normal file
View File

@@ -0,0 +1,85 @@
"""Billing routes: Stripe checkout, webhook, subscription management.
Business logic lives in ``app.billing.stripe_service.StripeService``.
The route layer handles HTTP concerns (request parsing, response shaping)
and delegates everything else to the service singleton.
"""
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.stripe_service import stripe_service
from app.db import get_session
from app.schemas import BillingTier, UserProfile
router = APIRouter(prefix="/billing", tags=["billing"])
# ── Request bodies ─────────────────────────────────────────────────────
class _CheckoutRequest(BaseModel):
tier: BillingTier
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/checkout", response_model=dict)
async def create_checkout(
body: _CheckoutRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, str]:
"""Create a Stripe checkout session for a tier upgrade.
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
"""
url = stripe_service.create_checkout_session(current_user.id, body.tier)
return {"checkout_url": url}
@router.post("/webhook", response_model=dict)
async def stripe_webhook(
request: Request,
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Handle Stripe webhook events.
No JWT auth — authenticated via Stripe signature verification instead.
Returns 200 immediately when Stripe is not configured (local dev).
"""
payload = await request.body()
await stripe_service.handle_webhook(payload, stripe_signature, db)
return {"ok": True}
@router.get("/subscription", response_model=dict)
async def get_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Return the current subscription info for the authenticated user."""
sub = await stripe_service.get_subscription(current_user.id, db)
if sub is None:
return {
"tier": current_user.tier,
"status": "free",
"stripe_subscription_id": None,
"current_period_end": None,
}
return sub
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
async def cancel_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Cancel the active subscription."""
await stripe_service.cancel_subscription(current_user.id, db)
return {"ok": True}

78
app/api/routes/chat.py Normal file
View File

@@ -0,0 +1,78 @@
"""Chat routes: POST /chat and WebSocket /chat/stream."""
from __future__ import annotations
import asyncio
import json
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
from fastapi.responses import JSONResponse
from jose import JWTError, jwt
from app.api.deps import get_current_user
from app.config.settings import settings
from app.core.orchestrator import orchestrate, orchestrate_stream
from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"])
_HEARTBEAT_INTERVAL = 30 # seconds
@router.post("")
async def chat(
body: ChatRequest,
current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse:
"""Route a chat message through the orchestrator.
Returns ``ChatResponse`` for ``execution_mode='direct'``,
or ``ExecutionPlan`` for ``execution_mode='plan'``.
"""
result = await orchestrate(body)
return JSONResponse(content=result.model_dump())
@router.websocket("/stream")
async def chat_stream(websocket: WebSocket) -> None:
"""Streaming chat via WebSocket.
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
Protocol:
1. Client sends ``ChatRequest`` as the first JSON text frame.
2. Server streams response text chunks.
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
4. Server pings every 30 s to keep the connection alive.
"""
# Authenticate before accepting the connection
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
user_id: str | None = payload.get("sub")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008) # 1008 = Policy Violation
return
await websocket.accept()
try:
raw = await websocket.receive_text()
body = ChatRequest.model_validate_json(raw)
async def _heartbeat() -> None:
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"ping": True}))
heartbeat_task = asyncio.create_task(_heartbeat())
try:
async for chunk in orchestrate_stream(body):
await websocket.send_text(chunk)
finally:
heartbeat_task.cancel()
except WebSocketDisconnect:
pass

37
app/api/routes/plans.py Normal file
View File

@@ -0,0 +1,37 @@
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
from __future__ import annotations
from fastapi import APIRouter, Depends, HTTPException, status
from app.api.deps import get_current_user
from app.core.execution_plan import plan_cache
from app.schemas import ExecutionPlan, UserProfile
router = APIRouter(prefix="/plans", tags=["plans"])
@router.get("/playbook", response_model=list[ExecutionPlan])
async def list_playbooks(
current_user: UserProfile = Depends(get_current_user),
) -> list[ExecutionPlan]:
"""Return all cached execution plan playbooks for the authenticated user.
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
"""
return plan_cache.get_all_playbooks()
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
async def get_playbook(
plan_id: str,
current_user: UserProfile = Depends(get_current_user),
) -> ExecutionPlan:
"""Return a specific execution plan playbook by ID."""
plan = plan_cache.get_plan(plan_id)
if plan is None:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Plan not found: {plan_id}",
)
return plan

148
app/api/routes/plugins.py Normal file
View File

@@ -0,0 +1,148 @@
"""Plugins routes: browse and install plugins from the marketplace.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
"""
from __future__ import annotations
from typing import Any, Literal
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db import get_session
from app.marketplace.plugin_registry import registry
from app.marketplace.revenue_share import revenue_share
from app.models import PluginInstallation, PluginReview as PluginReviewModel
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
router = APIRouter(prefix="/plugins", tags=["plugins"])
# ── Tier gate ─────────────────────────────────────────────────────────
def _require_plugin_tier(user: UserProfile) -> None:
"""Raise HTTP 403 for users below Power tier."""
if user.tier not in ("power", "team"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Plugin marketplace requires Power tier or above",
)
# ── Local detail schema ────────────────────────────────────────────────
class _PluginDetail(BaseModel):
plugin: PluginManifest
install_count: int
ratings: list[Any]
# ── Routes ────────────────────────────────────────────────────────────
@router.get("", response_model=PluginListResponse)
async def list_plugins(
category: str | None = Query(default=None),
q: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> PluginListResponse:
"""Browse the plugin marketplace. Requires Power tier or above."""
_require_plugin_tier(current_user)
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
@router.get("/{plugin_id}", response_model=_PluginDetail)
async def get_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _PluginDetail:
"""Get full plugin details including install count. Requires Power tier or above."""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Fetch review ratings for this plugin
review_result = await db.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
)
reviews = review_result.scalars().all()
ratings = [
{
"reviewer_id": r.reviewer_id,
"decision": r.decision,
"notes": r.notes,
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
}
for r in reviews
]
return _PluginDetail(
plugin=entry["manifest"],
install_count=entry["install_count"],
ratings=ratings,
)
@router.post("/{plugin_id}/install", response_model=dict)
async def install_plugin(
plugin_id: str,
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
Requires Power tier or above.
"""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Record the installation in plugin_installations
installation = PluginInstallation(
plugin_id=plugin_id,
user_id=current_user.id,
)
db.add(installation)
await db.flush()
await revenue_share.record_install(
db,
plugin_id=plugin_id,
user_id=current_user.id,
amount_cents=entry["manifest"].price_cents,
)
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
return {"ok": True, "download_url": download_url}
@router.delete("/{plugin_id}/install", response_model=dict)
async def uninstall_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unregister a plugin installation."""
result = await db.execute(
select(PluginInstallation).where(
PluginInstallation.plugin_id == plugin_id,
PluginInstallation.user_id == current_user.id,
)
)
installation = result.scalar_one_or_none()
if installation is not None:
await db.delete(installation)
await db.commit()
await registry.record_uninstall(db, plugin_id)
return {"ok": True}

196
app/api/routes/storage.py Normal file
View File

@@ -0,0 +1,196 @@
"""Storage routes: CRUD for E2E-encrypted cloud records.
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
PostgreSQL ``storage_records`` table.
"""
from __future__ import annotations
import uuid
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import StorageRecord
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/storage", tags=["storage"])
_blob_store = BlobStore()
# ── Local response schemas ─────────────────────────────────────────────
class _CreateResponse(BaseModel):
id: str
created_at: int
class _RecordMeta(BaseModel):
id: str
table: str
checksum: str
created_at: int
updated_at: int
# ── Helpers ────────────────────────────────────────────────────────────
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
StorageRecord.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
current = await _current_usage_bytes(user.id, db)
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
async def _get_record_for_user(
record_id: str, user_id: str, db: AsyncSession
) -> StorageRecord:
"""Look up a record and verify ownership. Returns 404 on mismatch
to prevent user enumeration attacks."""
result = await db.execute(
select(StorageRecord).where(
StorageRecord.id == record_id, StorageRecord.user_id == user_id
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
return record
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
async def create_record(
body: StorageRecordCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
await _check_quota(current_user, len(body.blob), db)
record_id = str(uuid.uuid4())
s3_key = await _blob_store.upload(
current_user.id, body.table, record_id, body.blob, body.checksum
)
record = StorageRecord(
id=record_id,
user_id=current_user.id,
table_name=body.table,
s3_key=s3_key,
checksum=body.checksum,
size_bytes=len(body.blob),
)
db.add(record)
await db.commit()
await db.refresh(record)
created_at_ms = int(record.created_at.timestamp() * 1000)
return _CreateResponse(id=record_id, created_at=created_at_ms)
@router.get("/records", response_model=list[_RecordMeta])
async def list_records(
table: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=50, ge=1, le=200),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[_RecordMeta]:
"""List record metadata for the authenticated user. Blob bytes are never returned."""
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
if table is not None:
query = query.where(StorageRecord.table_name == table)
query = query.offset((page - 1) * limit).limit(limit)
result = await db.execute(query)
rows = result.scalars().all()
return [
_RecordMeta(
id=r.id,
table=r.table_name,
checksum=r.checksum,
created_at=int(r.created_at.timestamp() * 1000),
updated_at=int(r.updated_at.timestamp() * 1000),
)
for r in rows
]
@router.get("/records/{record_id}")
async def download_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
record = await _get_record_for_user(record_id, current_user.id, db)
blob = await _blob_store.download(current_user.id, record.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={"X-Checksum": record.checksum},
)
@router.put("/records/{record_id}", response_model=dict)
async def update_record(
record_id: str,
body: StorageRecordUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Replace the blob for an existing record. Verifies checksum before storing."""
record = await _get_record_for_user(record_id, current_user.id, db)
reject_if_tampered(body.blob, body.checksum)
delta = len(body.blob) - record.size_bytes
if delta > 0:
await _check_quota(current_user, delta, db)
s3_key = await _blob_store.upload(
current_user.id, record.table_name, record_id, body.blob, body.checksum
)
record.s3_key = s3_key
record.checksum = body.checksum
record.size_bytes = len(body.blob)
await db.commit()
return {"ok": True}
@router.delete("/records/{record_id}", response_model=dict)
async def delete_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a record and its S3 blob."""
record = await _get_record_for_user(record_id, current_user.id, db)
await _blob_store.delete(current_user.id, record.s3_key)
await db.delete(record)
await db.commit()
return {"ok": True}

56
app/api/routes/vectors.py Normal file
View File

@@ -0,0 +1,56 @@
"""Vectors routes: upsert, search, and delete cloud vector store entries."""
from __future__ import annotations
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.schemas import (
UserProfile,
VectorSearchRequest,
VectorSearchResponse,
VectorUpsertRequest,
)
from app.storage.encryption import reject_if_tampered
from app.storage.vector_store import VectorStore
router = APIRouter(prefix="/storage", tags=["vectors"])
_vector_store = VectorStore()
class _VectorDeleteRequest(BaseModel):
ids: list[str]
@router.post("/vectors/upsert", response_model=dict)
async def upsert_vectors(
body: VectorUpsertRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, int]:
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
for item in body.vectors:
reject_if_tampered(item.blob, item.checksum)
await _vector_store.upsert(current_user.id, body.vectors)
return {"upserted": len(body.vectors)}
@router.post("/vectors/search", response_model=VectorSearchResponse)
async def search_vectors(
body: VectorSearchRequest,
current_user: UserProfile = Depends(get_current_user),
) -> VectorSearchResponse:
"""Search the user-scoped vector namespace with an encrypted query blob."""
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
return VectorSearchResponse(results=results)
@router.delete("/vectors", response_model=dict)
async def delete_vectors(
body: _VectorDeleteRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, bool]:
"""Delete vectors by ID, scoped to the authenticated user."""
await _vector_store.delete(current_user.id, body.ids)
return {"ok": True}

4
app/billing/__init__.py Normal file
View File

@@ -0,0 +1,4 @@
from app.billing.stripe_service import stripe_service
from app.billing.tier_manager import tier_manager
__all__ = ["stripe_service", "tier_manager"]

View File

@@ -0,0 +1,256 @@
"""Stripe service: checkout sessions, webhook handling, subscription management.
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
configured, enabling local development without live credentials.
"""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
# Stripe price IDs per tier — replace with real IDs in production .env
TIER_PRICE_IDS: dict[str, str] = {
"pro": "price_pro_monthly",
"power": "price_power_monthly",
"team": "price_team_monthly",
}
class StripeService:
"""Wraps all Stripe interactions and owns subscription persistence."""
# ── Internal helpers ────────────────────────────────────────────────
def _configured(self) -> bool:
return bool(settings.STRIPE_SECRET_KEY)
def _client(self) -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Public API ──────────────────────────────────────────────────────
def create_checkout_session(
self,
user_id: str,
tier: str,
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
cancel_url: str = "https://app.adiuva.app/billing/cancel",
) -> str:
"""Create a Stripe checkout session and return the URL.
Returns a stub URL when Stripe is not configured.
Raises ``HTTP 400`` for the free tier or an unknown tier.
"""
if tier == "free":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Cannot create a checkout session for the free tier",
)
price_id = TIER_PRICE_IDS.get(tier)
if not price_id:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Unknown tier: {tier}",
)
if not self._configured():
return "https://stripe.com/stub-checkout"
s = self._client()
session = s.checkout.Session.create(
payment_method_types=["card"],
mode="subscription",
line_items=[{"price": price_id, "quantity": 1}],
success_url=success_url,
cancel_url=cancel_url,
metadata={"user_id": user_id, "tier": tier},
)
return session.url
async def handle_webhook(
self,
payload: bytes,
sig_header: str,
db: AsyncSession,
) -> None:
"""Process a Stripe webhook event.
Verifies the signature, then dispatches on event type.
Raises ``HTTP 400`` on signature mismatch.
No-ops when Stripe is not configured.
"""
if not self._configured():
return
try:
s = self._client()
event = s.Webhook.construct_event(
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
)
except stripe_lib.error.SignatureVerificationError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid Stripe signature",
)
event_type: str = event["type"]
data: dict[str, Any] = event["data"]["object"]
if event_type == "checkout.session.completed":
user_id = data.get("metadata", {}).get("user_id")
tier = data.get("metadata", {}).get("tier", "free")
sub_id = data.get("subscription")
period_end_ts = data.get("current_period_end")
period_end = (
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
if period_end_ts
else None
)
if user_id:
await self._upsert_subscription(
db, user_id, sub_id, tier, "active", period_end
)
elif event_type == "customer.subscription.updated":
sub_id = data.get("id")
new_status = data.get("status", "active")
period_end_ts = data.get("current_period_end")
period_end = (
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
if period_end_ts
else None
)
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, status=new_status, current_period_end=period_end
)
elif event_type == "customer.subscription.deleted":
sub_id = data.get("id")
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, tier="free", status="canceled"
)
elif event_type == "invoice.payment_failed":
sub_id = data.get("subscription")
if sub_id:
await self._update_subscription_by_stripe_id(
db, sub_id, status="past_due"
)
await db.commit()
async def get_subscription(
self, user_id: str, db: AsyncSession
) -> dict[str, Any] | None:
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None:
return None
return {
"tier": sub.tier,
"stripe_subscription_id": sub.stripe_subscription_id,
"status": sub.status,
"current_period_end": (
int(sub.current_period_end.timestamp() * 1000)
if sub.current_period_end
else None
),
}
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
"""Cancel the user's Stripe subscription and downgrade them to free.
Raises ``HTTP 404`` when no active subscription exists.
"""
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None or not sub.stripe_subscription_id:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="No active subscription found",
)
if self._configured():
s = self._client()
s.Subscription.cancel(sub.stripe_subscription_id)
sub.tier = "free"
sub.status = "canceled"
await db.commit()
# ── Private DB helpers ───────────────────────────────────────────────
async def _upsert_subscription(
self,
db: AsyncSession,
user_id: str,
stripe_subscription_id: str | None,
tier: str,
sub_status: str,
current_period_end: datetime | None,
) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id)
)
sub = result.scalar_one_or_none()
if sub is None:
sub = Subscription(user_id=user_id)
db.add(sub)
sub.stripe_subscription_id = stripe_subscription_id
sub.tier = tier
sub.status = sub_status
sub.current_period_end = current_period_end
async def _update_subscription_by_stripe_id(
self,
db: AsyncSession,
stripe_subscription_id: str,
*,
tier: str | None = None,
status: str | None = None,
current_period_end: datetime | None = None,
) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription).where(
Subscription.stripe_subscription_id == stripe_subscription_id
)
)
sub = result.scalar_one_or_none()
if sub is None:
return
if tier is not None:
sub.tier = tier
if status is not None:
sub.status = status
if current_period_end is not None:
sub.current_period_end = current_period_end
# Module-level singleton shared across the app.
stripe_service = StripeService()

189
app/billing/tier_manager.py Normal file
View File

@@ -0,0 +1,189 @@
"""Tier manager: feature matrix and quota enforcement.
``TierManager`` is the single source of truth for what each billing tier
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
Quota-enforcement helpers take ``tier`` directly — the caller already has it
from ``current_user.tier`` (provided by ``get_current_user``).
"""
from __future__ import annotations
from typing import Any
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas import BillingTier
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
FEATURES: dict[str, dict[str, Any]] = {
"free": {
"agents": 3,
"batch_active": 2,
"cloud_storage_gb": 0,
"backup_gb": 0,
"providers": 1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"pro": {
"agents": -1, # unlimited
"batch_active": 10,
"cloud_storage_gb": 5,
"backup_gb": 5,
"providers": -1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"power": {
"agents": -1,
"batch_active": -1, # unlimited
"cloud_storage_gb": 25,
"backup_gb": 25,
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": False,
},
"team": {
"agents": -1,
"batch_active": -1,
"cloud_storage_gb": -1, # unlimited
"backup_gb": -1, # unlimited
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": True,
},
}
# Requests-per-minute limit per tier.
RATE_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
class TierManager:
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
# ── Tier lookup ─────────────────────────────────────────────────────
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
"""Return the current billing tier for ``user_id`` from the DB.
Falls back to ``'free'`` when no subscription row exists.
"""
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str | None = result.scalar_one_or_none()
if tier is None or tier not in FEATURES:
return "free"
return tier # type: ignore[return-value]
# ── Feature access ───────────────────────────────────────────────────
def check_feature(self, tier: BillingTier, feature: str) -> bool:
"""Return ``True`` if ``tier`` has ``feature`` enabled.
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
"""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if value is None:
return False
if isinstance(value, bool):
return value
return value != 0
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
if not self.check_feature(tier, feature):
detail = (
f"Feature '{feature}' requires {tier_name} tier or above."
if tier_name
else f"Feature '{feature}' is not available on your current tier."
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
# ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int:
"""Return the requests-per-minute limit for ``tier``."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# ── Storage quota ────────────────────────────────────────────────────
def enforce_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
``tier`` is the caller's current tier (from ``current_user.tier``).
``current_bytes`` is the total bytes already stored (queried by caller).
"""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Cloud storage is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
def enforce_backup_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
def check_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False
if limit_gb == -1:
return True
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
# Module-level singleton shared across the app.
tier_manager = TierManager()

0
app/config/__init__.py Normal file
View File

40
app/config/settings.py Normal file
View File

@@ -0,0 +1,40 @@
from typing import Literal
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
JWT_SECRET: str = "change-me-in-production"
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = ""
S3_BUCKET: str = ""
S3_REGION: str = "us-east-1"
S3_ENDPOINT_URL: str = ""
AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = ""
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
OPENAI_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o"
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
ENV: Literal["dev", "prod"] = "dev"
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
settings = Settings()

0
app/core/__init__.py Normal file
View File

137
app/core/agent_registry.py Normal file
View File

@@ -0,0 +1,137 @@
"""Agent Registry — base classes and singleton registry for chat agents."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
"""Common base for all agents."""
def __init__(
self,
user_id: str = "",
shared_memory: dict[str, Any] | None = None,
vector_store_context: list[str] | None = None,
) -> None:
self.user_id = user_id
self.shared_memory: dict[str, Any] = shared_memory or {}
self.vector_store_context: list[str] = vector_store_context or []
@abstractmethod
def get_name(self) -> str: ...
@abstractmethod
def get_description(self) -> str: ...
@property
def skills(self) -> list[str]:
"""Override in subclasses to advertise capabilities."""
return []
class ChatAgent(BaseAgent):
"""Base class for LLM-powered chat agents."""
@abstractmethod
async def handle(self, query: str, context: dict[str, Any]) -> str:
"""Process a user query and return a text response."""
...
@abstractmethod
def get_tools(self) -> list[Any]:
"""Return LangChain tool definitions available to this agent."""
...
async def _tool_loop(
self,
llm: Any,
messages: list[Any],
tools: list[Any],
max_iter: int = 5,
) -> str:
"""Shared tool-calling loop.
Binds *tools* to *llm*, invokes iteratively until the model stops
requesting tool calls or *max_iter* is reached, and returns the
final text response.
"""
from langchain_core.messages import AIMessage, ToolMessage
llm_with_tools = llm.bind_tools(tools) if tools else llm
for _ in range(max_iter):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return str(response.content)
# Execute each requested tool call
tool_map = {t.name: t for t in tools}
for call in response.tool_calls:
tool_fn = tool_map.get(call["name"])
if tool_fn is None:
result = f"Unknown tool: {call['name']}"
else:
result = await tool_fn.ainvoke(call["args"])
messages.append(
ToolMessage(content=str(result), tool_call_id=call["id"])
)
# Exhausted iterations — ask model for a final answer without tools
response = await llm.ainvoke(messages)
return str(response.content)
class AgentRegistry:
"""Singleton registry for ChatAgent subclasses."""
_instance: AgentRegistry | None = None
def __init__(self) -> None:
self._agents: dict[str, type[ChatAgent]] = {}
def __new__(cls) -> AgentRegistry:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._agents = {}
return cls._instance
# ── public API ───────────────────────────────────────────────────
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
"""Class decorator — registers an agent by its name."""
instance = agent_class()
name = instance.get_name()
self._agents[name] = agent_class
return agent_class
def get(self, name: str) -> ChatAgent:
"""Return a fresh instance of the named agent."""
cls = self._agents.get(name)
if cls is None:
raise KeyError(f"Agent not found: {name}")
return cls()
def list_agents(self) -> list[dict[str, str]]:
"""Return ``[{name, description}]`` for the orchestrator prompt."""
result: list[dict[str, str]] = []
for cls in self._agents.values():
inst = cls()
result.append(
{"name": inst.get_name(), "description": inst.get_description()}
)
return result
async def call_agent(
self, name: str, query: str, context: dict[str, Any]
) -> str:
"""Instantiate the named agent and call its ``handle`` method."""
agent = self.get(name)
return await agent.handle(query, context)
# Module-level singleton
registry = AgentRegistry()

222
app/core/execution_plan.py Normal file
View File

@@ -0,0 +1,222 @@
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
from __future__ import annotations
from collections import OrderedDict
from typing import Any
from app.schemas import ExecutionPlan, PlanStep
# ── Prompt Template Registry ──────────────────────────────────────────
class PromptTemplateRegistry:
"""Server-side store mapping template IDs to prompt text.
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
The actual prompt text is resolved here on the server, keeping prompt IP
out of API responses.
"""
def __init__(self) -> None:
self._templates: dict[str, str] = {}
def register(self, template_id: str, prompt_text: str) -> None:
self._templates[template_id] = prompt_text
def get(self, template_id: str) -> str:
"""Resolve a template ID to its prompt text.
Raises ``KeyError`` if the template is not registered.
"""
text = self._templates.get(template_id)
if text is None:
raise KeyError(f"Template not found: {template_id!r}")
return text
def has(self, template_id: str) -> bool:
return template_id in self._templates
def list_ids(self) -> list[str]:
"""Return all registered template IDs (never the text)."""
return list(self._templates.keys())
# ── Execution Plan Builder ────────────────────────────────────────────
class ExecutionPlanBuilder:
"""Fluent builder for ``ExecutionPlan`` objects.
Example::
plan = (
ExecutionPlanBuilder("task_agent")
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
.add_data_step("create_record", data_from_step=0)
.build()
)
"""
def __init__(self, agent: str) -> None:
self._agent = agent
self._steps: list[PlanStep] = []
# ── step adders ──────────────────────────────────────────────────
def add_step(
self, action: str, params: dict[str, Any] | None = None
) -> ExecutionPlanBuilder:
"""Append a generic action step with optional parameters."""
self._steps.append(PlanStep(action=action, variables=params))
return self
def add_llm_step(
self, template_id: str, variables: dict[str, Any] | None = None
) -> ExecutionPlanBuilder:
"""Append an LLM step referencing a server-side template by ID."""
self._steps.append(
PlanStep(action="llm", prompt_template=template_id, variables=variables)
)
return self
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
"""Append a step whose input comes from the output of an earlier step."""
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
return self
# ── build ────────────────────────────────────────────────────────
def build(self) -> ExecutionPlan:
"""Validate step references and return the ``ExecutionPlan``.
Raises ``ValueError`` if any ``data_from_step`` references a
non-existent or future step index.
"""
for i, step in enumerate(self._steps):
if step.data_from_step is not None:
if not (0 <= step.data_from_step < i):
raise ValueError(
f"Step {i}: data_from_step={step.data_from_step} must "
f"reference a preceding step index in range 0..{i - 1}"
)
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
class PlanCache:
"""In-memory LRU cache for ``ExecutionPlan`` objects.
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
The cache also serves as a runtime memoisation layer so that repeated
identical intent classifications can skip re-building the plan.
"""
def __init__(self, maxsize: int = 1000) -> None:
self._maxsize = maxsize
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
if key in self._cache:
del self._cache[key] # remove so re-insertion places it at the end
elif len(self._cache) >= self._maxsize:
self._cache.popitem(last=False) # evict least-recently-used
self._cache[key] = plan
def get_plan(self, key: str) -> ExecutionPlan | None:
"""Return the cached plan for *key*, or ``None`` if not present.
Accessing a plan marks it as most-recently used.
"""
if key not in self._cache:
return None
self._cache.move_to_end(key)
return self._cache[key]
def get_all_playbooks(self) -> list[ExecutionPlan]:
"""Return all cached plans (most-recently used last)."""
return list(self._cache.values())
# ── Module-level singletons ───────────────────────────────────────────
template_registry = PromptTemplateRegistry()
plan_cache = PlanCache()
def _register_builtin_templates() -> None:
"""Register the built-in server-side prompt templates.
These strings never leave the server. Clients only receive the IDs.
"""
_tpls: dict[str, str] = {
"tpl_task_agent_default": (
"You are a task management assistant. Help the user create, update, "
"list, and track tasks. Use correct status values (todo, in_progress, "
"done) and priority values (high, medium, low) from the workspace model."
),
"tpl_checkpoint_agent_default": (
"You are a project checkpoint assistant. Help the user create and manage "
"milestone checkpoints on their projects. Every checkpoint requires a "
"project_id and a date expressed as a Unix timestamp in milliseconds."
),
"tpl_project_agent_default": (
"You are a project management assistant. Help the user create, find, "
"update, and archive projects. Projects have a name, an optional client, "
"and a status of either active or archived."
),
"tpl_note_agent_default": (
"You are a note-taking assistant. Help the user create, retrieve, update, "
"and delete Markdown notes. Notes can optionally be linked to a project."
),
"tpl_task_extract_from_project": (
"Extract all actionable tasks from the provided project context. "
"Return a structured list of tasks, each with a title, inferred priority "
"(high, medium, or low), suggested status (todo), and a due_date in "
"milliseconds where a deadline can be inferred."
),
"tpl_note_weekly_summary": (
"Generate a weekly project summary note from the provided workspace data. "
"Include: tasks completed this week, tasks due soon, active projects, "
"and upcoming checkpoints. Format the output as clean Markdown."
),
}
for tid, text in _tpls.items():
template_registry.register(tid, text)
def _load_playbooks() -> None:
"""Pre-build and cache the built-in playbooks."""
playbooks: list[tuple[str, ExecutionPlan]] = [
(
"create_tasks_from_project",
ExecutionPlanBuilder("project_agent")
.add_llm_step(
"tpl_task_extract_from_project",
{"source": "project_context"},
)
.add_data_step("create_record", data_from_step=0)
.build(),
),
(
"generate_weekly_note",
ExecutionPlanBuilder("note_agent")
.add_llm_step(
"tpl_note_weekly_summary",
{"period": "last_7_days"},
)
.add_data_step("create_record", data_from_step=0)
.build(),
),
]
for key, plan in playbooks:
plan_cache.cache_plan(key, plan)
# Initialise on module load
_register_builtin_templates()
_load_playbooks()

68
app/core/llm.py Normal file
View File

@@ -0,0 +1,68 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_:
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
* Anthropic: ``anthropic/claude-3.5-sonnet``
* Google: ``gemini/gemini-pro``
* Ollama: ``ollama/llama3``
* Bedrock: ``bedrock/anthropic.claude-v2``
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
— no code changes required.
"""
from __future__ import annotations
from langchain_openai import ChatOpenAI
from litellm import get_supported_openai_params # noqa: F401 validates install
from app.config.settings import settings
def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string."""
if model.startswith("anthropic/"):
return getattr(settings, "ANTHROPIC_API_KEY", None) or None
if model.startswith("gemini/") or model.startswith("google/"):
return getattr(settings, "GOOGLE_API_KEY", None) or None
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
) -> ChatOpenAI:
"""Return a LangChain chat model backed by LiteLLM.
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
``openai`` client transparently when the model string contains a provider
prefix (``anthropic/…``, ``gemini/…``, etc.).
Parameters
----------
model:
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
temperature:
Sampling temperature. ``0`` = deterministic.
"""
model = model or settings.LLM_MODEL
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
)
def get_router_llm(
*,
temperature: float = 0,
) -> ChatOpenAI:
"""Return the lighter model used for intent classification / routing."""
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)

168
app/core/orchestrator.py Normal file
View File

@@ -0,0 +1,168 @@
"""Orchestrator — LLM-based intent router and agent pipeline."""
from __future__ import annotations
import json
from typing import Any, AsyncGenerator
from langchain_core.messages import HumanMessage, SystemMessage
from app.core.agent_registry import AgentRegistry
from app.core.llm import get_router_llm
from app.core.agent_registry import registry as _default_registry
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
_FALLBACK_AGENT = "task_agent"
_CLASSIFY_SYSTEM = (
"You are an intent classifier. Given the user message and context, decide "
"which agent to route to.\n"
"Available agents: {agents}\n"
"Respond with just the agent name, nothing else."
)
_SYNTHESIZE_HUMAN = (
"Combine the following agent results into one coherent response.\n\n"
"Agent results:\n{results}\n\n"
"Original message: {message}"
)
def _make_llm():
return get_router_llm()
async def classify_intent(
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> str:
"""Use gpt-4o-mini to classify intent and return the matching agent name.
Falls back to ``task_agent`` when the registry is empty or the model
returns a name that is not registered.
"""
agents = reg.list_agents()
if not agents:
return _FALLBACK_AGENT
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
# Truncate context to keep the classification prompt short
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
llm = _make_llm()
response = await llm.ainvoke(
[SystemMessage(content=system), HumanMessage(content=human)]
)
agent_name = str(response.content).strip().lower()
known = {a["name"] for a in agents}
return agent_name if agent_name in known else _FALLBACK_AGENT
async def route_single(
agent_name: str,
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
response_text = await reg.call_agent(agent_name, message, context)
return ChatResponse(response=response_text)
async def route_pipeline(
agent_names: list[str],
message: str,
context: dict[str, Any],
reg: AgentRegistry,
) -> ChatResponse:
"""Execute agents sequentially; each agent receives previous results in context.
A final LLM synthesis call merges all results into one coherent response.
"""
previous_results: list[str] = []
for agent_name in agent_names:
ctx = {**context, "previous_results": list(previous_results)}
result = await reg.call_agent(agent_name, message, ctx)
previous_results.append(result)
results_str = "\n\n".join(
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
)
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
llm = _make_llm()
synthesis = await llm.ainvoke([HumanMessage(content=human)])
return ChatResponse(response=str(synthesis.content))
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
"""Build an ``ExecutionPlan`` for the resolved agent.
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
If a default template exists for the agent, an LLM step is emitted;
otherwise a plain ``handle`` action step is used.
"""
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
template_id = f"tpl_{agent_name}_default"
builder = ExecutionPlanBuilder(agent_name)
if template_registry.has(template_id):
builder.add_llm_step(template_id, {"message": message})
else:
builder.add_step("handle", {"message": message})
return builder.build()
async def orchestrate(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> ChatResponse | ExecutionPlan:
"""Main orchestration entry point.
* Classifies the user's intent to select an agent.
* ``execution_mode == 'direct'``: routes to the agent and returns a
``ChatResponse``.
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
resolved agent and a template-ID-only step (prompt IP stays server-side).
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
if request.execution_mode == "direct":
return await route_single(agent_name, request.message, context, reg)
# plan mode — return plan, do not execute
return _build_plan(agent_name, request.message)
async def orchestrate_stream(
request: ChatRequest,
reg: AgentRegistry | None = None,
) -> AsyncGenerator[str, None]:
"""Streaming orchestration — yields text chunks then a final JSON frame.
The final frame is a JSON object:
``{"done": true, "response": "...", "actions": []}``.
Agents do not yet support token-level streaming; the full response is
fetched first, then emitted in fixed-size chunks. Token-level streaming
will be wired in Step 6 when agents expose ``astream()``.
"""
if reg is None:
reg = _default_registry
context = request.context.model_dump()
agent_name = await classify_intent(request.message, context, reg)
response_text = await reg.call_agent(agent_name, request.message, context)
chunk_size = 50
for i in range(0, len(response_text), chunk_size):
yield response_text[i : i + chunk_size]
final = ChatResponse(response=response_text)
yield json.dumps({"done": True, **final.model_dump()})

40
app/db.py Normal file
View File

@@ -0,0 +1,40 @@
"""Database engine, session factory, and base model.
All app code uses the async SQLAlchemy API. Alembic migrations use the
synchronous psycopg2 URL for the CLI (see alembic/env.py).
Usage in routes:
from app.db import get_session
from sqlalchemy.ext.asyncio import AsyncSession
async def my_route(db: AsyncSession = Depends(get_session)):
result = await db.execute(select(User).where(User.email == email))
user = result.scalar_one_or_none()
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from app.config.settings import settings
engine = create_async_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
echo=settings.ENV == "dev",
)
async_session = async_sessionmaker(engine, expire_on_commit=False)
class Base(DeclarativeBase):
"""Shared declarative base for all ORM models."""
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency that yields an async DB session per request."""
async with async_session() as session:
yield session

64
app/main.py Normal file
View File

@@ -0,0 +1,64 @@
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.middleware.rate_limit import TierRateLimitMiddleware
from app.api.middleware.sanitizer import SanitizerMiddleware
from app.config.settings import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: initialise DB connection pool and agent registry
from app.core.agent_registry import registry # noqa: F401 — triggers module load
import app.agents # noqa: F401 — triggers @registry.register decorators
yield
# Shutdown: dispose SQLAlchemy connection pool
from app.db import engine
await engine.dispose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Cloud API",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
# Request flow: TierRateLimit → Sanitizer → CORS → Router
# Response flow: Router → CORS → Sanitizer → TierRateLimit
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(plans.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1")
app.include_router(plugins.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "version": app.version}
return app
app = create_app()

View File

@@ -0,0 +1,7 @@
"""Plugin marketplace package.
Three service classes introduced in Step 10:
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
- ``ReviewQueue`` — approval workflow + security checklist
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
"""

View File

@@ -0,0 +1,212 @@
"""Plugin catalog registry backed by PostgreSQL.
Maintains the authoritative list of plugins, their review status, and
aggregate install counts. All data is persisted in the ``plugins`` table.
Module-level singleton::
from app.marketplace.plugin_registry import registry
"""
from __future__ import annotations
import json
from typing import Any, Literal
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Plugin
from app.schemas import PluginListResponse, PluginManifest
_PAGE_SIZE = 20
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
try:
permissions = json.loads(p.permissions) if p.permissions else []
except (json.JSONDecodeError, TypeError):
permissions = []
return PluginManifest(
id=p.id,
name=p.name,
description=p.description,
version=p.version,
author=p.author_name,
permissions=permissions,
category=p.category,
price_cents=p.price_cents,
)
class PluginRegistry:
"""PostgreSQL-backed plugin catalog.
All methods accept an ``AsyncSession`` parameter so the calling route
controls the session lifecycle.
"""
# ── Queries ──────────────────────────────────────────────────────
async def list_plugins(
self,
db: AsyncSession,
category: str | None = None,
query: str | None = None,
page: int = 1,
sort: Literal["rating", "installs", "newest"] = "newest",
) -> PluginListResponse:
"""Return a page of approved plugins, optionally filtered and sorted."""
base = select(Plugin).where(Plugin.status == "approved")
if category:
base = base.where(Plugin.category == category)
if query:
pattern = f"%{query}%"
base = base.where(
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
)
# Count
count_q = select(func.count()).select_from(base.subquery())
total = (await db.execute(count_q)).scalar_one()
# Sort
if sort == "installs":
base = base.order_by(Plugin.install_count.desc())
elif sort == "rating":
base = base.order_by(Plugin.avg_rating.desc())
else: # newest
base = base.order_by(Plugin.created_at.desc())
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
rows = (await db.execute(base)).scalars().all()
return PluginListResponse(
plugins=[_plugin_to_manifest(r) for r in rows],
total=total,
page=page,
)
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
p = result.scalar_one_or_none()
if p is None:
return None
return {
"manifest": _plugin_to_manifest(p),
"status": p.status,
"install_count": p.install_count,
"avg_rating": p.avg_rating,
}
# ── Mutations ────────────────────────────────────────────────────
async def submit_plugin(
self,
db: AsyncSession,
manifest: PluginManifest,
package_s3_key: str,
) -> str:
"""Add *manifest* to the catalog with ``status='pending_review'``.
Returns the plugin_id. If a plugin with the same id already exists
it is overwritten (re-submission after rejection).
"""
plugin_id = manifest.id
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = existing.scalar_one_or_none()
if row is not None:
row.name = manifest.name
row.description = manifest.description
row.version = manifest.version
row.author_name = manifest.author
row.category = manifest.category
row.price_cents = manifest.price_cents
row.permissions = json.dumps(manifest.permissions)
row.status = "pending_review"
row.s3_package_key = package_s3_key
row.rejection_reason = None
else:
row = Plugin(
id=plugin_id,
name=manifest.name,
description=manifest.description,
version=manifest.version,
author_name=manifest.author,
category=manifest.category,
price_cents=manifest.price_cents,
permissions=json.dumps(manifest.permissions),
status="pending_review",
s3_package_key=package_s3_key,
install_count=0,
avg_rating=0.0,
)
db.add(row)
await db.commit()
return plugin_id
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
"""Set *plugin_id* status to ``'approved'``.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "approved"
row.rejection_reason = None
await db.commit()
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "rejected"
row.rejection_reason = reason
await db.commit()
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
"""Increment the install count for *plugin_id* (no-op if not found)."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = row.install_count + 1
await db.commit()
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
"""Decrement the install count for *plugin_id*, floored at 0."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = max(0, row.install_count - 1)
await db.commit()
# ── Internal helpers used by ReviewQueue ─────────────────────────
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review'."""
result = await db.execute(
select(Plugin).where(Plugin.status == "pending_review")
)
rows = result.scalars().all()
return [
{
"manifest": _plugin_to_manifest(r),
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
}
for r in rows
]
# Module-level singleton
registry = PluginRegistry()

View File

@@ -0,0 +1,125 @@
"""Plugin review workflow backed by PostgreSQL.
Manages the approval queue for newly submitted plugins and enforces a
security checklist before any plugin is made visible in the marketplace.
Module-level singleton::
from app.marketplace.plugin_review import review_queue
"""
from __future__ import annotations
import re
from typing import Any, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import registry
from app.models import PluginReview as PluginReviewModel
from app.schemas import PluginManifest
# ── Security policy ───────────────────────────────────────────────────
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
{
"read:tasks",
"write:tasks",
"read:projects",
"write:projects",
"read:notes",
"write:notes",
"read:checkpoints",
"write:checkpoints",
"read:calendar",
"write:calendar",
}
)
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
def validate_manifest(manifest: PluginManifest) -> None:
"""Enforce the plugin security checklist.
Raises:
``ValueError`` on the first violation found. Callers should catch
this and return HTTP 422 / reject the submission.
Checks:
1. Plugin id matches ``^[a-z0-9-]+$``
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
3. No manifest field contains raw binary data
"""
if not _PLUGIN_ID_RE.match(manifest.id):
raise ValueError(
f"Invalid plugin id format: '{manifest.id}'. "
"Only lowercase letters, digits, and hyphens are allowed."
)
for perm in manifest.permissions:
if perm not in ALLOWED_PERMISSIONS:
raise ValueError(
f"Unknown permission: '{perm}'. "
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
)
for field_name, value in manifest.model_dump().items():
if isinstance(value, (bytes, bytearray)):
raise ValueError(
f"Binary content is not allowed in manifest field '{field_name}'."
)
class ReviewQueue:
"""Approval queue for pending plugin submissions.
Delegates status changes to the shared ``PluginRegistry`` singleton.
Review records are persisted in the ``plugin_reviews`` table.
"""
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all plugins currently awaiting review.
Each item is ``{plugin_id, manifest, submitted_at}``.
"""
entries = await registry.get_pending_entries(db)
return [
{
"plugin_id": e["manifest"].id,
"manifest": e["manifest"],
"submitted_at": e["submitted_at"],
}
for e in entries
]
async def submit_review(
self,
db: AsyncSession,
plugin_id: str,
reviewer_id: str,
decision: Literal["approved", "rejected"],
notes: str = "",
) -> None:
"""Record a review decision and update the plugin's status.
Raises:
``KeyError`` if *plugin_id* is not found in the registry.
"""
if decision == "approved":
await registry.approve_plugin(db, plugin_id)
else:
await registry.reject_plugin(db, plugin_id, reason=notes)
review = PluginReviewModel(
plugin_id=plugin_id,
reviewer_id=reviewer_id,
decision=decision,
notes=notes,
)
db.add(review)
await db.commit()
# Module-level singleton
review_queue = ReviewQueue()

View File

@@ -0,0 +1,233 @@
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
Records every plugin installation as a revenue event and facilitates
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
in the ``revenue_events`` table.
Module-level singleton::
from app.marketplace.revenue_share import revenue_share
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from sqlalchemy import extract, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
from app.models import Plugin, RevenueEvent
logger = logging.getLogger(__name__)
# ── Revenue split constants ───────────────────────────────────────────
DEVELOPER_SHARE: float = 0.70
PLATFORM_SHARE: float = 0.30
class RevenueShare:
"""Records installation revenue events and coordinates developer payouts.
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
is not configured, consistent with the rest of the billing layer.
"""
# ── Helpers ──────────────────────────────────────────────────────
@staticmethod
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
@staticmethod
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Core operations ──────────────────────────────────────────────
async def record_install(
self,
db: AsyncSession,
plugin_id: str,
user_id: str,
amount_cents: int,
) -> None:
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
For free plugins (``amount_cents == 0``) no payment is initiated but
the event is still recorded for analytics.
For paid plugins the developer receives 70 % via a Stripe Connect
destination charge. If Stripe is not configured or the charge fails
the installation still succeeds (the event is recorded and the install
count is incremented) — a warning is logged for monitoring.
"""
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
stripe_transfer_id: str | None = None
if amount_cents > 0 and self._stripe_configured():
# Look up the plugin's author Stripe account from the DB
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = result.scalar_one_or_none()
developer_stripe_account: str | None = None
if plugin_row and plugin_row.author_id:
# Future: look up user.stripe_connect_account_id
developer_stripe_account = None # no real account yet
if developer_stripe_account:
try:
s = self._stripe()
transfer = s.Transfer.create(
amount=developer_share_cents,
currency="eur",
destination=developer_stripe_account,
description=f"Revenue share for plugin {plugin_id}",
metadata={"plugin_id": plugin_id, "user_id": user_id},
)
stripe_transfer_id = transfer["id"]
except Exception as exc:
logger.warning(
"Stripe Connect transfer failed for plugin %s: %s",
plugin_id,
exc,
)
else:
logger.debug(
"No Stripe account on file for plugin %s developer; "
"skipping transfer.",
plugin_id,
)
event = RevenueEvent(
plugin_id=plugin_id,
user_id=user_id,
amount_cents=amount_cents,
developer_share_cents=developer_share_cents,
stripe_transfer_id=stripe_transfer_id,
)
db.add(event)
await db.commit()
await registry.record_install(db, plugin_id)
async def get_earnings(
self,
db: AsyncSession,
developer_id: str,
period: str | None = None,
) -> dict[str, Any]:
"""Return aggregated earnings for *developer_id*.
``period`` is an optional ``YYYY-MM`` string to restrict the window.
Returns::
{
"developer_id": str,
"period": str | None,
"total_installs": int,
"total_revenue_cents": int,
"developer_share_cents": int,
}
"""
# Find plugin ids belonging to this developer (by author_name match)
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
plugin_result = await db.execute(plugin_q)
developer_plugin_ids = [row[0] for row in plugin_result.all()]
if not developer_plugin_ids:
return {
"developer_id": developer_id,
"period": period,
"total_installs": 0,
"total_revenue_cents": 0,
"developer_share_cents": 0,
}
query = select(
func.count().label("total_installs"),
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
if period:
# Filter by YYYY-MM: extract year and month from created_at
try:
year, month = period.split("-")
query = query.where(
extract("year", RevenueEvent.created_at) == int(year),
extract("month", RevenueEvent.created_at) == int(month),
)
except ValueError:
pass # invalid period format — return all
result = await db.execute(query)
row = result.one()
return {
"developer_id": developer_id,
"period": period,
"total_installs": row.total_installs,
"total_revenue_cents": row.total_revenue,
"developer_share_cents": row.dev_share,
}
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
Marks processed events with ``paid_at`` timestamp.
Stubs gracefully when Stripe is not configured.
"""
try:
year, month = period.split("-")
year_int, month_int = int(year), int(month)
except ValueError:
logger.warning("Invalid period format: %s", period)
return
result = await db.execute(
select(RevenueEvent).where(
RevenueEvent.plugin_id == plugin_id,
RevenueEvent.paid_at.is_(None),
extract("year", RevenueEvent.created_at) == year_int,
extract("month", RevenueEvent.created_at) == month_int,
)
)
unpaid = list(result.scalars().all())
total_dev_share = sum(e.developer_share_cents for e in unpaid)
if total_dev_share <= 0 or not unpaid:
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
return
if self._stripe_configured():
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = plugin_result.scalar_one_or_none()
developer_stripe_account: str | None = None # Future: fetch from DB
if plugin_row and developer_stripe_account:
try:
s = self._stripe()
s.Transfer.create(
amount=total_dev_share,
currency="eur",
destination=developer_stripe_account,
description=f"Payout for plugin {plugin_id} period {period}",
)
except Exception as exc:
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
return
paid_ts = datetime.now(timezone.utc)
for event in unpaid:
event.paid_at = paid_ts
await db.commit()
# Module-level singleton
revenue_share = RevenueShare()

269
app/models.py Normal file
View File

@@ -0,0 +1,269 @@
"""SQLAlchemy ORM models for all persistent tables.
Only auth, billing, storage metadata, and marketplace data live here.
User content (notes, tasks, etc.) is NEVER persisted server-side —
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
Table inventory:
users — account credentials + tier
refresh_tokens — hashed refresh token store
subscriptions — Stripe subscription records
storage_records — S3 blob metadata (no plaintext)
backup_metadata — encrypted backup manifests
plugins — marketplace plugin catalog
plugin_installations — per-user install records
plugin_reviews — admin review decisions
revenue_events — Stripe Connect 70/30 split ledger
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
String,
Text,
UniqueConstraint,
Uuid,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from app.db import Base
# ── Helpers ──────────────────────────────────────────────────────────────
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> datetime:
return datetime.now(timezone.utc)
# ── Enum types ────────────────────────────────────────────────────────────
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
# ── Models ────────────────────────────────────────────────────────────────
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
subscription: Mapped[Subscription | None] = relationship(
back_populates="user", uselist=False, cascade="all, delete-orphan"
)
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="refresh_tokens")
class Subscription(Base):
__tablename__ = "subscriptions"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, unique=True, index=True
)
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="subscription")
class StorageRecord(Base):
__tablename__ = "storage_records"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class BackupMetadata(Base):
__tablename__ = "backup_metadata"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class Plugin(Base):
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
# nullable until developer account system is built
author_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
submitted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
installations: Mapped[list[PluginInstallation]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
reviews: Mapped[list[PluginReview]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
revenue_events: Mapped[list[RevenueEvent]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
class PluginInstallation(Base):
__tablename__ = "plugin_installations"
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="installations")
class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
reviewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
class RevenueEvent(Base):
__tablename__ = "revenue_events"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")

157
app/schemas.py Normal file
View File

@@ -0,0 +1,157 @@
"""Pydantic schemas — API request/response contracts.
Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
"""
from __future__ import annotations
from typing import Any, Literal
from pydantic import BaseModel, Field
# ── Billing ──────────────────────────────────────────────────────────
BillingTier = Literal["free", "pro", "power", "team"]
# ── Auth ─────────────────────────────────────────────────────────────
class AuthTokens(BaseModel):
access_token: str
refresh_token: str
expires_at: int
class UserProfile(BaseModel):
id: str
email: str
tier: BillingTier
# ── Chat ─────────────────────────────────────────────────────────────
class ChatContext(BaseModel):
user_profile: dict[str, Any] = Field(default_factory=dict)
relevant_documents: list[str] = Field(default_factory=list)
recent_tasks: list[dict[str, Any]] = Field(default_factory=list)
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
class PlanAction(BaseModel):
type: Literal[
"create_record",
"update_record",
"delete_record",
"index_document",
"send_notification",
]
table: str | None = None
data: dict[str, Any] | None = None
class ChatRequest(BaseModel):
message: str
context: ChatContext = Field(default_factory=ChatContext)
execution_mode: Literal["direct", "plan"] = "direct"
class ChatResponse(BaseModel):
response: str
actions: list[PlanAction] = Field(default_factory=list)
# ── Execution Plans ──────────────────────────────────────────────────
class PlanStep(BaseModel):
action: str
prompt_template: str | None = None
variables: dict[str, Any] | None = None
data_from_step: int | None = None
class ExecutionPlan(BaseModel):
agent: str
steps: list[PlanStep] = Field(default_factory=list)
# ── Backup ───────────────────────────────────────────────────────────
class BackupMetadata(BaseModel):
version: int
timestamp: int
checksum: str
chunk_count: int
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
class StorageRecord(BaseModel):
id: str
user_id: str
table: str
blob: bytes
checksum: str
created_at: int
updated_at: int
class StorageRecordCreate(BaseModel):
table: str
blob: bytes
checksum: str
class StorageRecordUpdate(BaseModel):
blob: bytes
checksum: str
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
class VectorItem(BaseModel):
id: str
blob: bytes # encrypted vector + metadata — backend never decrypts
checksum: str
class VectorUpsertRequest(BaseModel):
vectors: list[VectorItem]
class VectorSearchRequest(BaseModel):
query_blob: bytes # encrypted query — backend never decrypts
top_k: int = 10
class VectorSearchResult(BaseModel):
id: str
score: float
blob: bytes
class VectorSearchResponse(BaseModel):
results: list[VectorSearchResult]
# ── Plugin Marketplace ────────────────────────────────────────────────
class PluginManifest(BaseModel):
id: str
name: str
description: str
version: str
author: str
permissions: list[str]
category: str
price_cents: int = 0
class PluginListResponse(BaseModel):
plugins: list[PluginManifest]
total: int
page: int
class PluginInstallRequest(BaseModel):
plugin_id: str

1
app/storage/__init__.py Normal file
View File

@@ -0,0 +1 @@
"""Cloud storage layer — E2E encrypted blobs and vectors."""

107
app/storage/blob_store.py Normal file
View File

@@ -0,0 +1,107 @@
"""S3-backed store for E2E-encrypted blobs.
Keys are structured as ``{user_id}/{table}/{record_id}``.
The backend never inspects blob content — it stores and retrieves opaque bytes.
"""
from __future__ import annotations
from typing import Any
import boto3
from botocore.exceptions import ClientError
from app.config.settings import settings
class BlobStore:
"""Thin wrapper around boto3 S3.
All blobs must be E2E encrypted by the client before upload.
The backend adds SSE-S3 as an extra layer of at-rest encryption
but cannot decrypt the inner client-side payload.
"""
def _client(self) -> Any:
kwargs: dict[str, Any] = {
"region_name": settings.S3_REGION,
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
}
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
return boto3.client("s3", **kwargs)
@staticmethod
def _key(user_id: str, table: str, record_id: str) -> str:
return f"{user_id}/{table}/{record_id}"
async def upload(
self,
user_id: str,
table: str,
record_id: str,
blob: bytes,
checksum: str,
) -> str:
"""Store *blob* in S3 and return the S3 key.
Args:
user_id: Owner of the blob (used as key prefix).
table: Logical table name (e.g. ``"tasks"``).
record_id: Record UUID.
blob: Raw bytes (pre-encrypted by client).
checksum: SHA-256 hex digest supplied by the client; stored as
object metadata for download-time verification.
Returns:
The S3 key under which the blob was stored.
"""
key = self._key(user_id, table, record_id)
self._client().put_object(
Bucket=settings.S3_BUCKET,
Key=key,
Body=blob,
ServerSideEncryption="AES256", # SSE-S3 at rest
Metadata={"checksum": checksum},
)
return key
async def download(self, user_id: str, s3_key: str) -> bytes:
"""Retrieve the blob stored at *s3_key*.
*user_id* is retained in the signature so higher-level code can
enforce ownership without re-parsing the key.
Raises:
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
object does not exist.
"""
response = self._client().get_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
return response["Body"].read()
async def delete(self, user_id: str, s3_key: str) -> None:
"""Delete the object at *s3_key*.
S3 ``delete_object`` is idempotent — it succeeds even if the key does
not exist.
"""
self._client().delete_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
async def list_keys(self, user_id: str, table: str) -> list[str]:
"""Return all S3 keys for a given user + table combination.
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
"""
prefix = f"{user_id}/{table}/"
response = self._client().list_objects_v2(
Bucket=settings.S3_BUCKET,
Prefix=prefix,
)
return [obj["Key"] for obj in response.get("Contents", [])]

32
app/storage/encryption.py Normal file
View File

@@ -0,0 +1,32 @@
"""Integrity verification only — the backend NEVER decrypts user data."""
from __future__ import annotations
import hashlib
import hmac
from fastapi import HTTPException
def verify_checksum(blob: bytes, checksum: str) -> bool:
"""Return ``True`` if SHA-256(blob) matches *checksum*.
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
timing-based side-channel attacks.
"""
computed = hashlib.sha256(blob).hexdigest()
return hmac.compare_digest(computed, checksum)
def reject_if_tampered(blob: bytes, checksum: str) -> None:
"""Raise ``HTTP 400`` if the blob does not match its checksum.
Call this before storing or forwarding any client-provided blob.
The backend never holds decryption keys — this check only verifies
that the opaque bytes arrived intact.
"""
if not verify_checksum(blob, checksum):
raise HTTPException(
status_code=400,
detail="Checksum mismatch: blob integrity check failed",
)

205
app/storage/vector_store.py Normal file
View File

@@ -0,0 +1,205 @@
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
Vectors are pre-encrypted blobs from the client. The backend stores them
alongside a deterministic 32-dim float representation derived from the blob's
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
is a known trade-off documented in the backend plan.
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
``user_id`` payload field on a shared collection.
"""
from __future__ import annotations
import base64
import hashlib
from typing import Any
from pinecone import Pinecone
from qdrant_client import QdrantClient
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
from app.config.settings import settings
from app.schemas import VectorItem, VectorSearchResult
_QDRANT_COLLECTION = "adiuva_vectors"
def _blob_to_vector(blob: bytes) -> list[float]:
"""Derive a 32-dim float vector from *blob* for storage purposes only.
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
normalises each byte to the range [-1.0, 1.0]. This vector carries no
semantic meaning on encrypted data.
"""
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
class VectorStore:
"""Thin wrapper around Pinecone or Qdrant.
The backend to use is selected at runtime:
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
"""
def _use_pinecone(self) -> bool:
return bool(settings.PINECONE_API_KEY)
# ── Pinecone helpers ──────────────────────────────────────────────
def _pinecone_index(self) -> Any:
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
return pc.Index(settings.PINECONE_INDEX)
# ── Qdrant helpers ────────────────────────────────────────────────
def _qdrant_client(self) -> Any:
return QdrantClient(
url=settings.QDRANT_URL,
api_key=settings.QDRANT_API_KEY or None,
)
# ── Public API ────────────────────────────────────────────────────
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
"""Store encrypted vectors in the backend.
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
so it can be returned verbatim during search.
Args:
user_id: Used as Pinecone namespace or Qdrant payload field.
vectors: List of encrypted vector items from the client.
"""
if self._use_pinecone():
await self._pinecone_upsert(user_id, vectors)
else:
await self._qdrant_upsert(user_id, vectors)
async def search(
self,
user_id: str,
query_blob: bytes,
top_k: int,
) -> list[VectorSearchResult]:
"""Query the vector store and return encrypted result blobs.
The query vector is derived from *query_blob* using the same
deterministic mapping as upsert.
Args:
user_id: Scopes the search to this user's namespace.
query_blob: Encrypted query from the client.
top_k: Maximum number of results to return.
Returns:
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
"""
if self._use_pinecone():
return await self._pinecone_search(user_id, query_blob, top_k)
return await self._qdrant_search(user_id, query_blob, top_k)
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
"""Remove vectors by ID, scoped to *user_id*.
Args:
user_id: Namespace / payload filter to prevent cross-user deletion.
vector_ids: List of vector IDs to remove.
"""
if self._use_pinecone():
await self._pinecone_delete(user_id, vector_ids)
else:
await self._qdrant_delete(user_id, vector_ids)
# ── Pinecone implementation ───────────────────────────────────────
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
index = self._pinecone_index()
records = [
{
"id": v.id,
"values": _blob_to_vector(v.blob),
"metadata": {
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
}
for v in vectors
]
index.upsert(vectors=records, namespace=user_id)
async def _pinecone_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
index = self._pinecone_index()
query_vector = _blob_to_vector(query_blob)
response = index.query(
vector=query_vector,
top_k=top_k,
namespace=user_id,
include_metadata=True,
)
results: list[VectorSearchResult] = []
for match in response.get("matches", []):
blob_bytes = base64.b64decode(match["metadata"]["blob"])
results.append(
VectorSearchResult(
id=match["id"],
score=match["score"],
blob=blob_bytes,
)
)
return results
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
index = self._pinecone_index()
index.delete(ids=vector_ids, namespace=user_id)
# ── Qdrant implementation ─────────────────────────────────────────
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
client = self._qdrant_client()
points = [
PointStruct(
id=v.id,
vector=_blob_to_vector(v.blob),
payload={
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
)
for v in vectors
]
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
async def _qdrant_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
client = self._qdrant_client()
query_vector = _blob_to_vector(query_blob)
hits = client.search(
collection_name=_QDRANT_COLLECTION,
query_vector=query_vector,
query_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=top_k,
)
return [
VectorSearchResult(
id=str(hit.id),
score=hit.score,
blob=base64.b64decode(hit.payload["blob"]),
)
for hit in hits
]
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
client = self._qdrant_client()
client.delete(
collection_name=_QDRANT_COLLECTION,
points_selector=PointIdsList(points=vector_ids),
)

69
docker-compose.yml Normal file
View File

@@ -0,0 +1,69 @@
version: "3.9"
services:
app:
build: .
ports:
- "8000:8000"
env_file:
- .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
depends_on:
db:
condition: service_healthy
restart: unless-stopped
db:
image: postgres:16-alpine
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: adiuva
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
# Optional Redis for future rate-limit or caching needs
# redis:
# image: redis:7-alpine
# restart: unless-stopped
# ── Local S3-compatible storage (MinIO) ──
minio:
image: minio/minio:latest
command: server /data --console-address ":9001"
ports:
- "9000:9000"
- "9001:9001"
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin
volumes:
- minio_data:/data
healthcheck:
test: ["CMD", "mc", "ready", "local"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
# ── Local vector store (Qdrant) ──
qdrant:
image: qdrant/qdrant:latest
ports:
- "6333:6333"
- "6334:6334"
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
volumes:
postgres_data:
minio_data:
qdrant_data:

27
requirements.txt Normal file
View File

@@ -0,0 +1,27 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
langchain>=0.3.0
langchain-openai>=0.3.0
litellm>=1.50.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
stripe>=11.0.0
boto3>=1.35.0
slowapi>=0.1.9
sqlalchemy>=2.0.0
asyncpg>=0.30.0
alembic>=1.14.0
bcrypt>=4.2.0
python-dotenv>=1.0.0
httpx>=0.28.0
websockets>=14.0
psycopg2-binary>=2.9.0
pytest>=8.0.0
pytest-asyncio>=0.24.0
aiosqlite>=0.20.0
moto[s3]>=5.0.0
pinecone>=5.0.0
qdrant-client>=1.7.0
ruff>=0.8.0

0
tests/__init__.py Normal file
View File

236
tests/conftest.py Normal file
View File

@@ -0,0 +1,236 @@
"""Shared test fixtures for database-backed tests.
Provides an async SQLite in-memory engine that auto-creates all tables,
a per-test session, and a FastAPI ``TestClient`` wired to use it.
"""
from __future__ import annotations
import hashlib
import json
import os
import time
import uuid
from collections.abc import AsyncGenerator, Generator
from unittest.mock import patch
import boto3
import pytest
import pytest_asyncio
from fastapi.testclient import TestClient
from jose import jwt
from moto import mock_aws
from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from app.config.settings import settings
from app.db import Base, get_session
from app.main import app
from app.models import Plugin, Subscription, User
# ── Fixed test user IDs (one per tier) ───────────────────────────────
TEST_USER_IDS: dict[str, str] = {
"free": "00000000-0000-0000-0000-000000000001",
"pro": "00000000-0000-0000-0000-000000000002",
"power": "00000000-0000-0000-0000-000000000003",
"team": "00000000-0000-0000-0000-000000000004",
}
# ── Async SQLite engine ──────────────────────────────────────────────
_TEST_ENGINE = create_async_engine(
"sqlite+aiosqlite://",
connect_args={"check_same_thread": False},
poolclass=StaticPool,
)
_TestSessionLocal = async_sessionmaker(
_TEST_ENGINE,
expire_on_commit=False,
)
# Enable foreign key enforcement for SQLite (off by default).
@event.listens_for(_TEST_ENGINE.sync_engine, "connect")
def _set_sqlite_pragma(dbapi_conn, _connection_record): # noqa: ANN001
cursor = dbapi_conn.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest_asyncio.fixture(autouse=True)
async def _create_tables():
"""Create all tables before each test, seed test users, then drop after."""
async with _TEST_ENGINE.begin() as conn:
await conn.run_sync(Base.metadata.create_all)
# Seed one User + Subscription per tier so FK constraints and auth work.
async with _TestSessionLocal() as session:
for tier, uid in TEST_USER_IDS.items():
session.add(User(
id=uid,
email=f"{tier}@test.com",
password_hash="$2b$12$fakehashfortesting000000000000000000000000000",
tier=tier,
))
session.add(Subscription(
id=str(uuid.uuid4()),
user_id=uid,
tier=tier,
stripe_subscription_id=f"sub_test_{tier}",
status="active",
))
await session.commit()
yield
async with _TEST_ENGINE.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
@pytest_asyncio.fixture
async def db_session() -> AsyncGenerator[AsyncSession, None]:
"""Yield a per-test async DB session."""
async with _TestSessionLocal() as session:
yield session
@pytest.fixture
def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # noqa: ANN001
"""FastAPI test client with ``get_session`` overridden to use the test DB."""
async def _override_get_session() -> AsyncGenerator[AsyncSession, None]:
yield db_session
app.dependency_overrides[get_session] = _override_get_session
with TestClient(app) as c:
yield c
app.dependency_overrides.pop(get_session, None)
# ── Seed data helpers ────────────────────────────────────────────────
_SEED_PLUGINS = [
Plugin(
id="plugin-github-sync",
name="GitHub Sync",
description="Sync tasks with GitHub Issues and pull requests.",
version="1.0.0",
author_name="Adiuva",
category="productivity",
price_cents=0,
permissions=json.dumps(["read:tasks", "write:tasks"]),
status="approved",
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
install_count=0,
avg_rating=0.0,
),
Plugin(
id="plugin-slack-notify",
name="Slack Notifier",
description="Post task and checkpoint updates to Slack channels.",
version="1.2.0",
author_name="Adiuva",
category="communication",
price_cents=499,
permissions=json.dumps(["read:tasks", "read:checkpoints"]),
status="approved",
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
install_count=0,
avg_rating=0.0,
),
Plugin(
id="plugin-time-tracker",
name="Time Tracker",
description="Track time spent on tasks with automatic reporting.",
version="0.9.1",
author_name="Third Party",
category="productivity",
price_cents=999,
permissions=json.dumps(["read:tasks", "write:tasks"]),
status="approved",
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
install_count=0,
avg_rating=0.0,
),
]
@pytest_asyncio.fixture
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
"""Insert the 3 default approved plugins and return them."""
plugins = []
for template in _SEED_PLUGINS:
p = Plugin(
id=template.id,
name=template.name,
description=template.description,
version=template.version,
author_name=template.author_name,
category=template.category,
price_cents=template.price_cents,
permissions=template.permissions,
status=template.status,
s3_package_key=template.s3_package_key,
install_count=template.install_count,
avg_rating=template.avg_rating,
)
db_session.add(p)
plugins.append(p)
await db_session.commit()
return plugins
# ── JWT helpers ──────────────────────────────────────────────────────
def make_jwt(
tier: str = "power",
user_id: str | None = None,
email: str | None = None,
) -> str:
"""Create a signed test JWT.
Uses the fixed ``TEST_USER_IDS`` mapping so the auth middleware can
find the corresponding ``Subscription`` row in the test database.
"""
uid = user_id or TEST_USER_IDS.get(tier, str(uuid.uuid4()))
now = int(time.time())
payload = {
"sub": uid,
"email": email or f"{tier}@test.com",
"tier": tier,
"exp": now + 3600,
"iat": now,
}
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]:
"""Return an Authorization header dict for the given tier."""
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
# ── S3 mock fixture ──────────────────────────────────────────────────
S3_TEST_BUCKET = "test-bucket"
S3_TEST_REGION = "us-east-1"
@pytest.fixture
def s3_bucket():
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
with mock_aws():
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
client = boto3.client("s3", region_name=S3_TEST_REGION)
client.create_bucket(Bucket=S3_TEST_BUCKET)
with patch("app.storage.blob_store.settings") as mock_settings:
mock_settings.S3_BUCKET = S3_TEST_BUCKET
mock_settings.S3_REGION = S3_TEST_REGION
mock_settings.AWS_ACCESS_KEY_ID = "testing"
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
yield S3_TEST_BUCKET

View File

@@ -0,0 +1,214 @@
"""Unit tests for the agent registry, base classes, and tool loop."""
from __future__ import annotations
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
from app.core.agent_registry import AgentRegistry, ChatAgent
# ── Helpers ──────────────────────────────────────────────────────────
class _StubAgent(ChatAgent):
"""Minimal concrete agent for testing."""
def get_name(self) -> str:
return "stub"
def get_description(self) -> str:
return "A stub agent for tests"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"echo: {query}"
class _AnotherAgent(ChatAgent):
def get_name(self) -> str:
return "another"
def get_description(self) -> str:
return "Another stub"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return "another"
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _fresh_registry():
"""Reset the singleton between tests."""
AgentRegistry._instance = None
yield
AgentRegistry._instance = None
@pytest.fixture()
def reg() -> AgentRegistry:
return AgentRegistry()
# ── Tests ────────────────────────────────────────────────────────────
class TestRegisterAndGet:
def test_register_decorator(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
agent = reg.get("stub")
assert isinstance(agent, _StubAgent)
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError, match="not found"):
reg.get("nonexistent")
def test_register_multiple(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
reg.register(_AnotherAgent)
assert reg.get("stub").get_name() == "stub"
assert reg.get("another").get_name() == "another"
class TestListAgents:
def test_empty(self, reg: AgentRegistry) -> None:
assert reg.list_agents() == []
def test_list_after_register(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
agents = reg.list_agents()
assert len(agents) == 1
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
def test_list_multiple(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
reg.register(_AnotherAgent)
names = {a["name"] for a in reg.list_agents()}
assert names == {"stub", "another"}
class TestCallAgent:
@pytest.mark.asyncio
async def test_call_agent(self, reg: AgentRegistry) -> None:
reg.register(_StubAgent)
result = await reg.call_agent("stub", "hello", {})
assert result == "echo: hello"
@pytest.mark.asyncio
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError):
await reg.call_agent("nope", "hi", {})
class TestSingleton:
def test_singleton_identity(self) -> None:
a = AgentRegistry()
b = AgentRegistry()
assert a is b
class TestToolLoop:
@pytest.mark.asyncio
async def test_no_tool_calls(self) -> None:
"""When the LLM responds without tool calls, return content directly."""
agent = _StubAgent()
ai_msg = MagicMock()
ai_msg.content = "final answer"
ai_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(return_value=ai_msg)
result = await agent._tool_loop(llm, [], [])
assert result == "final answer"
@pytest.mark.asyncio
async def test_tool_call_then_answer(self) -> None:
"""LLM requests one tool call, gets result, then answers."""
agent = _StubAgent()
# First response: tool call
tool_call_msg = MagicMock()
tool_call_msg.content = ""
tool_call_msg.tool_calls = [
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
]
# Second response: final answer
final_msg = MagicMock()
final_msg.content = "done"
final_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
# Mock tool
tool = AsyncMock()
tool.name = "my_tool"
tool.ainvoke = AsyncMock(return_value="tool_result")
result = await agent._tool_loop(llm, [], [tool])
assert result == "done"
tool.ainvoke.assert_called_once_with({"x": 1})
@pytest.mark.asyncio
async def test_unknown_tool_handled(self) -> None:
"""Unknown tool names produce an error message instead of crashing."""
agent = _StubAgent()
tool_call_msg = MagicMock()
tool_call_msg.content = ""
tool_call_msg.tool_calls = [
{"id": "call_1", "name": "missing", "args": {}}
]
final_msg = MagicMock()
final_msg.content = "recovered"
final_msg.tool_calls = []
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm)
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
result = await agent._tool_loop(llm, [], [])
assert result == "recovered"
@pytest.mark.asyncio
async def test_max_iter_reached(self) -> None:
"""When max iterations are exhausted, a final no-tools call is made."""
agent = _StubAgent()
# Every response requests a tool call
loop_msg = MagicMock()
loop_msg.content = ""
loop_msg.tool_calls = [
{"id": "call_x", "name": "t", "args": {}}
]
final_msg = MagicMock()
final_msg.content = "gave up"
final_msg.tool_calls = []
tool = AsyncMock()
tool.name = "t"
tool.ainvoke = AsyncMock(return_value="ok")
llm_with_tools = AsyncMock()
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
llm = AsyncMock()
llm.bind_tools = MagicMock(return_value=llm_with_tools)
llm.ainvoke = AsyncMock(return_value=final_msg)
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
assert result == "gave up"
assert llm_with_tools.ainvoke.call_count == 2

620
tests/test_agents.py Normal file
View File

@@ -0,0 +1,620 @@
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import app.agents # noqa: F401 — triggers @registry.register decorators
from app.agents.checkpoint_agent import CheckpointAgent
from app.agents.note_agent import NoteAgent
from app.agents.project_agent import ProjectAgent
from app.agents.task_agent import TaskAgent
from app.core.agent_registry import registry
# ── Helpers ──────────────────────────────────────────────────────────
def _mock_llm(response_text: str) -> MagicMock:
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
msg = MagicMock()
msg.content = response_text
msg.tool_calls = []
llm = MagicMock()
bound = MagicMock()
bound.ainvoke = AsyncMock(return_value=msg)
llm.bind_tools = MagicMock(return_value=bound)
llm.ainvoke = AsyncMock(return_value=msg)
return llm
def _mock_llm_with_tool_call(
tool_name: str, tool_args: dict[str, Any], final_text: str
) -> MagicMock:
"""Mock LLM that fires one tool call then returns *final_text*."""
tool_msg = MagicMock()
tool_msg.content = ""
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
final_msg = MagicMock()
final_msg.content = final_text
final_msg.tool_calls = []
bound = MagicMock()
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
llm = MagicMock()
llm.bind_tools = MagicMock(return_value=bound)
llm.ainvoke = AsyncMock(return_value=final_msg)
return llm
# ── Registration ──────────────────────────────────────────────────────
class TestAgentRegistration:
def test_all_agents_registered(self) -> None:
names = {a["name"] for a in registry.list_agents()}
assert {
"task_agent", "checkpoint_agent", "project_agent", "note_agent"
}.issubset(names)
def test_registry_returns_correct_types(self) -> None:
assert isinstance(registry.get("task_agent"), TaskAgent)
assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent)
assert isinstance(registry.get("project_agent"), ProjectAgent)
assert isinstance(registry.get("note_agent"), NoteAgent)
def test_descriptions_present(self) -> None:
for agent_info in registry.list_agents():
assert agent_info["description"], f"Empty description: {agent_info['name']}"
# ── TaskAgent ─────────────────────────────────────────────────────────
class TestTaskAgent:
def test_name(self) -> None:
assert TaskAgent().get_name() == "task_agent"
def test_description(self) -> None:
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
def test_get_tools_count(self) -> None:
assert len(TaskAgent().get_tools()) == 8
def test_tool_names(self) -> None:
names = {t.name for t in TaskAgent().get_tools()}
assert names == {
"list_tasks",
"create_task",
"update_task",
"delete_task",
"list_tasks_due_today",
"list_task_comments",
"add_task_comment",
"delete_task_comment",
}
@pytest.mark.asyncio
async def test_handle_returns_string(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Task created.")
result = await TaskAgent().handle("create a task", {})
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Here are your tasks.")
result = await TaskAgent().handle("list my tasks", {})
assert result == "Here are your tasks."
@pytest.mark.asyncio
async def test_handle_with_create_task_tool_call(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_task",
{"title": "Buy groceries", "priority": "low"},
"Task 'Buy groceries' created.",
)
result = await TaskAgent().handle("add a grocery task", {})
assert result == "Task 'Buy groceries' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await TaskAgent().handle("help", {})
assert isinstance(result, str)
@pytest.mark.asyncio
async def test_handle_accepts_rich_context(self) -> None:
context = {
"user_profile": {"id": "u1", "tier": "pro"},
"recent_tasks": [{"id": "t1", "title": "Old task"}],
}
with patch("app.agents.task_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Tasks listed.")
result = await TaskAgent().handle("show tasks", context)
assert isinstance(result, str)
class TestTaskAgentTools:
@pytest.mark.asyncio
async def test_list_tasks_defaults(self) -> None:
from app.agents.task_agent import list_tasks
result = await list_tasks.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "tasks"
@pytest.mark.asyncio
async def test_list_tasks_with_status_filter(self) -> None:
from app.agents.task_agent import list_tasks
result = await list_tasks.ainvoke({"status": "done"})
data = json.loads(result)
assert data["filters"]["status"] == "done"
@pytest.mark.asyncio
async def test_create_task_defaults(self) -> None:
from app.agents.task_agent import create_task
result = await create_task.ainvoke({"title": "Test task"})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["table"] == "tasks"
assert data["data"]["title"] == "Test task"
assert data["data"]["status"] == "todo"
assert data["data"]["priority"] == "medium"
@pytest.mark.asyncio
async def test_create_task_with_all_fields(self) -> None:
from app.agents.task_agent import create_task
result = await create_task.ainvoke({
"title": "Deploy",
"priority": "high",
"status": "in_progress",
"project_id": "p1",
"is_ai_suggested": 1,
})
data = json.loads(result)
assert data["data"]["priority"] == "high"
assert data["data"]["status"] == "in_progress"
assert data["data"]["projectId"] == "p1"
assert data["data"]["isAiSuggested"] == 1
@pytest.mark.asyncio
async def test_update_task_with_status(self) -> None:
from app.agents.task_agent import update_task
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "t1"
assert data["data"]["updates"]["status"] == "done"
@pytest.mark.asyncio
async def test_update_task_empty_updates(self) -> None:
from app.agents.task_agent import update_task
result = await update_task.ainvoke({"task_id": "t1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_task(self) -> None:
from app.agents.task_agent import delete_task
result = await delete_task.ainvoke({"task_id": "t1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "tasks"
assert data["data"]["id"] == "t1"
@pytest.mark.asyncio
async def test_list_tasks_due_today(self) -> None:
from app.agents.task_agent import list_tasks_due_today
result = await list_tasks_due_today.ainvoke({})
data = json.loads(result)
assert data["action"] == "list_due_today"
assert data["table"] == "tasks"
@pytest.mark.asyncio
async def test_list_task_comments(self) -> None:
from app.agents.task_agent import list_task_comments
result = await list_task_comments.ainvoke({"task_id": "t1"})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "taskComments"
assert data["filters"]["taskId"] == "t1"
@pytest.mark.asyncio
async def test_add_task_comment(self) -> None:
from app.agents.task_agent import add_task_comment
result = await add_task_comment.ainvoke({
"task_id": "t1",
"author": "Alice",
"content": "Looks good!",
})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["table"] == "taskComments"
assert data["data"]["taskId"] == "t1"
assert data["data"]["author"] == "Alice"
assert data["data"]["content"] == "Looks good!"
@pytest.mark.asyncio
async def test_delete_task_comment(self) -> None:
from app.agents.task_agent import delete_task_comment
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "taskComments"
assert data["data"]["id"] == "c1"
# ── CheckpointAgent ───────────────────────────────────────────────────
class TestCheckpointAgent:
def test_name(self) -> None:
assert CheckpointAgent().get_name() == "checkpoint_agent"
def test_description(self) -> None:
assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete"
def test_get_tools_count(self) -> None:
assert len(CheckpointAgent().get_tools()) == 4
def test_tool_names(self) -> None:
names = {t.name for t in CheckpointAgent().get_tools()}
assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("No checkpoints found.")
result = await CheckpointAgent().handle("list checkpoints", {})
assert result == "No checkpoints found."
@pytest.mark.asyncio
async def test_handle_with_create_tool_call(self) -> None:
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_checkpoint",
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
"Checkpoint 'MVP Launch' created.",
)
result = await CheckpointAgent().handle("add MVP checkpoint", {})
assert result == "Checkpoint 'MVP Launch' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.checkpoint_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await CheckpointAgent().handle("show milestones", {})
assert isinstance(result, str)
class TestCheckpointAgentTools:
@pytest.mark.asyncio
async def test_list_checkpoints_no_project(self) -> None:
from app.agents.checkpoint_agent import list_checkpoints
result = await list_checkpoints.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "checkpoints"
assert data["filters"]["projectId"] is None
@pytest.mark.asyncio
async def test_list_checkpoints_with_project(self) -> None:
from app.agents.checkpoint_agent import list_checkpoints
result = await list_checkpoints.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["filters"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_create_checkpoint(self) -> None:
from app.agents.checkpoint_agent import create_checkpoint
result = await create_checkpoint.ainvoke({
"project_id": "p1",
"title": "Beta release",
"date": 1700000000000,
})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["table"] == "checkpoints"
assert data["data"]["projectId"] == "p1"
assert data["data"]["title"] == "Beta release"
assert data["data"]["date"] == 1700000000000
@pytest.mark.asyncio
async def test_create_checkpoint_ai_suggested(self) -> None:
from app.agents.checkpoint_agent import create_checkpoint
result = await create_checkpoint.ainvoke({
"project_id": "p1",
"title": "Review",
"date": 1700000000000,
"is_ai_suggested": 1,
})
data = json.loads(result)
assert data["data"]["isAiSuggested"] == 1
assert data["data"]["isApproved"] == 0
@pytest.mark.asyncio
async def test_update_checkpoint_approve(self) -> None:
from app.agents.checkpoint_agent import update_checkpoint
result = await update_checkpoint.ainvoke({
"checkpoint_id": "c1",
"is_approved": 1,
})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "c1"
assert data["data"]["updates"]["isApproved"] == 1
@pytest.mark.asyncio
async def test_update_checkpoint_empty_updates(self) -> None:
from app.agents.checkpoint_agent import update_checkpoint
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_checkpoint(self) -> None:
from app.agents.checkpoint_agent import delete_checkpoint
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "checkpoints"
assert data["data"]["id"] == "c1"
# ── ProjectAgent ──────────────────────────────────────────────────────
class TestProjectAgent:
def test_name(self) -> None:
assert ProjectAgent().get_name() == "project_agent"
def test_description(self) -> None:
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
def test_get_tools_count(self) -> None:
assert len(ProjectAgent().get_tools()) == 6
def test_tool_names(self) -> None:
names = {t.name for t in ProjectAgent().get_tools()}
assert names == {
"list_projects",
"list_all_projects",
"get_project",
"create_project",
"update_project",
"delete_project",
}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Project Alpha is active.")
result = await ProjectAgent().handle("show my projects", {})
assert result == "Project Alpha is active."
@pytest.mark.asyncio
async def test_handle_with_create_project_tool_call(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_project",
{"name": "Pippo"},
"Project 'Pippo' created.",
)
result = await ProjectAgent().handle("create project Pippo", {})
assert result == "Project 'Pippo' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.project_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await ProjectAgent().handle("archive old project", {})
assert isinstance(result, str)
class TestProjectAgentTools:
@pytest.mark.asyncio
async def test_list_projects_defaults(self) -> None:
from app.agents.project_agent import list_projects
result = await list_projects.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "projects"
assert data["filters"]["includeArchived"] is False
@pytest.mark.asyncio
async def test_list_projects_include_archived(self) -> None:
from app.agents.project_agent import list_projects
result = await list_projects.ainvoke({"include_archived": 1})
data = json.loads(result)
assert data["filters"]["includeArchived"] is True
@pytest.mark.asyncio
async def test_list_all_projects(self) -> None:
from app.agents.project_agent import list_all_projects
result = await list_all_projects.ainvoke({})
data = json.loads(result)
assert data["action"] == "list_all"
assert data["table"] == "projects"
@pytest.mark.asyncio
async def test_get_project(self) -> None:
from app.agents.project_agent import get_project
result = await get_project.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["action"] == "get"
assert data["table"] == "projects"
assert data["data"]["id"] == "p1"
@pytest.mark.asyncio
async def test_create_project_name_only(self) -> None:
from app.agents.project_agent import create_project
result = await create_project.ainvoke({"name": "Alpha"})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["data"]["name"] == "Alpha"
assert data["data"]["clientId"] is None
@pytest.mark.asyncio
async def test_create_project_with_client(self) -> None:
from app.agents.project_agent import create_project
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
data = json.loads(result)
assert data["data"]["clientId"] == "cl1"
@pytest.mark.asyncio
async def test_update_project_archive(self) -> None:
from app.agents.project_agent import update_project
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "p1"
assert data["data"]["updates"]["status"] == "archived"
@pytest.mark.asyncio
async def test_update_project_empty_updates(self) -> None:
from app.agents.project_agent import update_project
result = await update_project.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_project(self) -> None:
from app.agents.project_agent import delete_project
result = await delete_project.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["data"]["id"] == "p1"
# ── NoteAgent ─────────────────────────────────────────────────────────
class TestNoteAgent:
def test_name(self) -> None:
assert NoteAgent().get_name() == "note_agent"
def test_description(self) -> None:
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
def test_get_tools_count(self) -> None:
assert len(NoteAgent().get_tools()) == 5
def test_tool_names(self) -> None:
names = {t.name for t in NoteAgent().get_tools()}
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
@pytest.mark.asyncio
async def test_handle_no_tool_calls(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Note created.")
result = await NoteAgent().handle("create a note", {})
assert result == "Note created."
@pytest.mark.asyncio
async def test_handle_with_create_note_tool_call(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm_with_tool_call(
"create_note",
{"title": "Daily log", "content": "# Today\nAll good."},
"Note 'Daily log' created.",
)
result = await NoteAgent().handle("log today's progress", {})
assert result == "Note 'Daily log' created."
@pytest.mark.asyncio
async def test_handle_accepts_empty_context(self) -> None:
with patch("app.agents.note_agent.get_llm") as mock_cls:
mock_cls.return_value = _mock_llm("Done.")
result = await NoteAgent().handle("show notes", {})
assert isinstance(result, str)
class TestNoteAgentTools:
@pytest.mark.asyncio
async def test_list_notes_no_project(self) -> None:
from app.agents.note_agent import list_notes
result = await list_notes.ainvoke({})
data = json.loads(result)
assert data["action"] == "list"
assert data["table"] == "notes"
assert data["filters"]["projectId"] is None
@pytest.mark.asyncio
async def test_list_notes_with_project(self) -> None:
from app.agents.note_agent import list_notes
result = await list_notes.ainvoke({"project_id": "p1"})
data = json.loads(result)
assert data["filters"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_get_note(self) -> None:
from app.agents.note_agent import get_note
result = await get_note.ainvoke({"note_id": "n1"})
data = json.loads(result)
assert data["action"] == "get"
assert data["table"] == "notes"
assert data["data"]["id"] == "n1"
@pytest.mark.asyncio
async def test_create_note_minimal(self) -> None:
from app.agents.note_agent import create_note
result = await create_note.ainvoke({
"title": "Daily log",
"content": "# Today\nAll good.",
})
data = json.loads(result)
assert data["action"] == "create_record"
assert data["table"] == "notes"
assert data["data"]["title"] == "Daily log"
assert data["data"]["content"] == "# Today\nAll good."
assert data["data"]["projectId"] is None
@pytest.mark.asyncio
async def test_create_note_with_project(self) -> None:
from app.agents.note_agent import create_note
result = await create_note.ainvoke({
"title": "Sprint notes",
"content": "## Sprint 1",
"project_id": "p1",
})
data = json.loads(result)
assert data["data"]["projectId"] == "p1"
@pytest.mark.asyncio
async def test_update_note_content_only(self) -> None:
from app.agents.note_agent import update_note
result = await update_note.ainvoke({
"note_id": "n1",
"content": "# Updated content",
})
data = json.loads(result)
assert data["action"] == "update_record"
assert data["data"]["id"] == "n1"
assert data["data"]["updates"]["content"] == "# Updated content"
assert "title" not in data["data"]["updates"]
@pytest.mark.asyncio
async def test_update_note_empty_updates(self) -> None:
from app.agents.note_agent import update_note
result = await update_note.ainvoke({"note_id": "n1"})
data = json.loads(result)
assert data["data"]["updates"] == {}
@pytest.mark.asyncio
async def test_delete_note(self) -> None:
from app.agents.note_agent import delete_note
result = await delete_note.ainvoke({"note_id": "n1"})
data = json.loads(result)
assert data["action"] == "delete_record"
assert data["table"] == "notes"
assert data["data"]["id"] == "n1"

207
tests/test_auth.py Normal file
View File

@@ -0,0 +1,207 @@
"""Tests for auth routes: register, login, refresh, me.
Exercises the full auth lifecycle through the FastAPI TestClient against the
in-memory SQLite test database seeded by ``conftest.py``.
"""
from __future__ import annotations
import time
import pytest
from jose import jwt
from app.config.settings import settings
from tests.conftest import auth_header, make_jwt, TEST_USER_IDS
# ── TestRegister ──────────────────────────────────────────────────────
class TestRegister:
"""POST /api/v1/auth/register"""
def test_register_success(self, client) -> None:
resp = client.post(
"/api/v1/auth/register",
json={"email": "new@example.com", "password": "Str0ngP@ss!"},
)
assert resp.status_code == 201
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert "expires_at" in data
# expires_at should be a future millisecond timestamp
assert data["expires_at"] > int(time.time() * 1000)
def test_register_returns_valid_jwt(self, client) -> None:
resp = client.post(
"/api/v1/auth/register",
json={"email": "jwt-check@example.com", "password": "P@ss1234"},
)
assert resp.status_code == 201
token = resp.json()["access_token"]
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
assert payload["email"] == "jwt-check@example.com"
assert payload["tier"] == "free"
assert "sub" in payload
def test_register_duplicate_email(self, client) -> None:
client.post(
"/api/v1/auth/register",
json={"email": "dupe@example.com", "password": "Pass1234"},
)
resp = client.post(
"/api/v1/auth/register",
json={"email": "dupe@example.com", "password": "Pass5678"},
)
assert resp.status_code == 409
def test_register_missing_password(self, client) -> None:
resp = client.post(
"/api/v1/auth/register",
json={"email": "no-pass@example.com"},
)
assert resp.status_code == 422
def test_register_missing_email(self, client) -> None:
resp = client.post(
"/api/v1/auth/register",
json={"password": "OnlyPass"},
)
assert resp.status_code == 422
# ── TestLogin ─────────────────────────────────────────────────────────
class TestLogin:
"""POST /api/v1/auth/login"""
def _register(self, client, email="login@example.com", password="MyP@ss123"):
client.post(
"/api/v1/auth/register",
json={"email": email, "password": password},
)
def test_login_success(self, client) -> None:
self._register(client)
resp = client.post(
"/api/v1/auth/login",
json={"email": "login@example.com", "password": "MyP@ss123"},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
assert "expires_at" in data
def test_login_wrong_password(self, client) -> None:
self._register(client)
resp = client.post(
"/api/v1/auth/login",
json={"email": "login@example.com", "password": "WrongPass!"},
)
assert resp.status_code == 401
def test_login_unknown_email(self, client) -> None:
resp = client.post(
"/api/v1/auth/login",
json={"email": "ghost@example.com", "password": "Whatever"},
)
assert resp.status_code == 401
# ── TestRefresh ───────────────────────────────────────────────────────
class TestRefresh:
"""POST /api/v1/auth/refresh"""
def _register_and_get_tokens(self, client, email="refresh@example.com"):
resp = client.post(
"/api/v1/auth/register",
json={"email": email, "password": "RefPass123!"},
)
return resp.json()
def test_refresh_returns_new_tokens(self, client) -> None:
tokens = self._register_and_get_tokens(client)
resp = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": tokens["refresh_token"]},
)
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data
assert "refresh_token" in data
# New refresh token should differ from old one (rotation)
assert data["refresh_token"] != tokens["refresh_token"]
def test_refresh_old_token_rejected(self, client) -> None:
"""After rotation, the original refresh token must be rejected."""
tokens = self._register_and_get_tokens(client, email="rotate@example.com")
old_rt = tokens["refresh_token"]
# First refresh succeeds and rotates the token
client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
# Second attempt with the old token must fail
resp = client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
assert resp.status_code == 401
def test_refresh_bogus_token(self, client) -> None:
resp = client.post(
"/api/v1/auth/refresh",
json={"refresh_token": "not-a-real-token"},
)
assert resp.status_code == 401
# ── TestMe ────────────────────────────────────────────────────────────
class TestMe:
"""GET /api/v1/auth/me"""
def test_me_with_valid_jwt(self, client) -> None:
resp = client.get("/api/v1/auth/me", headers=auth_header("power"))
assert resp.status_code == 200
data = resp.json()
assert data["id"] == TEST_USER_IDS["power"]
assert data["email"] == "power@test.com"
assert data["tier"] == "power"
def test_me_returns_correct_tier(self, client) -> None:
"""Tier comes from the live subscription row, not the JWT claim."""
resp = client.get("/api/v1/auth/me", headers=auth_header("free"))
assert resp.json()["tier"] == "free"
def test_me_missing_token(self, client) -> None:
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
def test_me_expired_token(self, client) -> None:
"""A JWT with ``exp`` in the past must be rejected."""
payload = {
"sub": TEST_USER_IDS["power"],
"email": "power@test.com",
"tier": "power",
"exp": int(time.time()) - 3600, # 1 hour ago
"iat": int(time.time()) - 7200,
}
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 401
def test_me_invalid_signature(self, client) -> None:
payload = {
"sub": TEST_USER_IDS["power"],
"email": "power@test.com",
"tier": "power",
"exp": int(time.time()) + 3600,
"iat": int(time.time()),
}
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 401

244
tests/test_backup.py Normal file
View File

@@ -0,0 +1,244 @@
"""Tests for backup routes: upload, download, history, delete.
Exercises the backup lifecycle through the FastAPI TestClient against the
in-memory SQLite test database and moto-mocked S3 bucket.
"""
from __future__ import annotations
import hashlib
import pytest
from tests.conftest import auth_header, TEST_USER_IDS
# ── Helpers ───────────────────────────────────────────────────────────
_BLOB = b"encrypted-backup-blob-opaque-bytes"
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
_VERSION = 1
_TIMESTAMP = 1700000000000 # arbitrary ms timestamp
def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]:
"""Return auth + backup metadata headers."""
headers = auth_header(tier)
headers["X-Backup-Version"] = str(overrides.get("version", _VERSION))
headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP))
headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM)
headers["Content-Type"] = "application/octet-stream"
return headers
def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821
"""Upload a backup blob and return the response."""
return client.put(
"/api/v1/backup",
content=overrides.pop("blob", _BLOB),
headers=_backup_headers(tier, **overrides),
)
# ── TestUploadBackup ──────────────────────────────────────────────────
class TestUploadBackup:
"""PUT /api/v1/backup"""
def test_upload_success(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power")
assert resp.status_code == 200
assert resp.json() == {"ok": True}
def test_upload_creates_history_entry(self, client, s3_bucket) -> None:
_upload(client, tier="power")
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 1
assert history[0]["version"] == _VERSION
assert history[0]["timestamp"] == _TIMESTAMP
assert history[0]["checksum"] == _CHECKSUM
def test_upload_bad_checksum(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power", checksum="0" * 64)
assert resp.status_code == 400
def test_upload_free_tier_blocked(self, client, s3_bucket) -> None:
"""Free tier has backup_gb=0 → should return 402."""
resp = _upload(client, tier="free")
assert resp.status_code == 402
def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None:
"""Pro tier has backup_gb=5 → small blob succeeds."""
resp = _upload(client, tier="pro")
assert resp.status_code == 200
# ── TestDownloadBackup ────────────────────────────────────────────────
class TestDownloadBackup:
"""GET /api/v1/backup"""
def test_download_latest(self, client, s3_bucket) -> None:
_upload(client, tier="power")
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == _BLOB
assert resp.headers["X-Checksum"] == _CHECKSUM
assert resp.headers["X-Backup-Version"] == str(_VERSION)
def test_download_no_backup_returns_404(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 404
def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None:
"""When If-Modified-Since is after the backup timestamp → 304."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT",
},
)
assert resp.status_code == 304
def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None:
"""When If-Modified-Since is before the backup timestamp → serve blob."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT",
},
)
assert resp.status_code == 200
assert resp.content == _BLOB
def test_download_multiple_returns_latest(self, client, s3_bucket) -> None:
"""When multiple backups exist, GET returns the one with the highest timestamp."""
_upload(client, tier="power", timestamp=1000)
blob2 = b"second-encrypted-backup"
checksum2 = hashlib.sha256(blob2).hexdigest()
_upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2)
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == blob2
# ── TestBackupHistory ─────────────────────────────────────────────────
class TestBackupHistory:
"""GET /api/v1/backup/history"""
def test_history_empty(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup/history", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.json() == []
def test_history_returns_entries(self, client, s3_bucket) -> None:
_upload(client, tier="power", timestamp=1000)
_upload(client, tier="power", timestamp=2000)
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 2
# Ordered by timestamp descending
assert history[0]["timestamp"] == 2000
assert history[1]["timestamp"] == 1000
def test_history_isolated_per_user(self, client, s3_bucket) -> None:
"""One user's backups should not appear in another user's history."""
_upload(client, tier="power")
resp = client.get("/api/v1/backup/history", headers=auth_header("team"))
assert resp.json() == []
# ── TestDeleteBackup ──────────────────────────────────────────────────
class TestDeleteBackup:
"""DELETE /api/v1/backup/{backup_id}"""
def _get_backup_id(self, client, tier="power") -> str:
"""Upload a backup and return its DB id from history."""
_upload(client, tier=tier)
history = client.get(
"/api/v1/backup/history", headers=auth_header(tier)
).json()
# History returns BackupMetadata schema which doesn't have `id`.
# We need to look it up via a different means.
# Since there's only 1 backup, find via history length.
# Actually the schema doesn't return id — let's verify via re-download.
# We'll use a workaround: upload, then list history to confirm it exists,
# then try to delete — but we need the id...
# Let's check if history includes an id field.
# The schema is: version, timestamp, checksum, chunk_count — no id.
# We'll need to query the DB directly or use a known ID.
# For testing, we'll search history then use the DB.
return None # pragma: no cover — overridden below
def test_delete_success(self, client, s3_bucket, db_session) -> None:
_upload(client, tier="power")
# Discover the backup_id via direct DB query
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("power")
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# History should now be empty
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert history == []
def test_delete_nonexistent(self, client, s3_bucket) -> None:
resp = client.delete(
"/api/v1/backup/no-such-id", headers=auth_header("power")
)
assert resp.status_code == 404
def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None:
"""Cannot delete another user's backup (ownership check returns 404)."""
_upload(client, tier="power")
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
# team user tries to delete power user's backup → 404
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("team")
)
assert resp.status_code == 404

View File

@@ -0,0 +1,286 @@
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
from __future__ import annotations
import pytest
from app.core.execution_plan import (
ExecutionPlanBuilder,
PlanCache,
PromptTemplateRegistry,
plan_cache,
template_registry,
)
from app.schemas import ExecutionPlan
# ── PromptTemplateRegistry ────────────────────────────────────────────
class TestPromptTemplateRegistry:
def test_register_and_get(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_foo", "You are a foo agent.")
assert reg.get("tpl_foo") == "You are a foo agent."
def test_get_unknown_raises_key_error(self) -> None:
reg = PromptTemplateRegistry()
with pytest.raises(KeyError, match="tpl_missing"):
reg.get("tpl_missing")
def test_has_returns_true_for_registered(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_x", "prompt text")
assert reg.has("tpl_x") is True
def test_has_returns_false_for_unregistered(self) -> None:
reg = PromptTemplateRegistry()
assert reg.has("tpl_missing") is False
def test_list_ids_returns_all_registered_ids(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_a", "a")
reg.register("tpl_b", "b")
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
def test_list_ids_does_not_return_prompt_text(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_secret", "top secret prompt")
ids = reg.list_ids()
assert "top secret prompt" not in ids
def test_overwrite_existing_template(self) -> None:
reg = PromptTemplateRegistry()
reg.register("tpl_x", "v1")
reg.register("tpl_x", "v2")
assert reg.get("tpl_x") == "v2"
def test_empty_registry_has_no_ids(self) -> None:
reg = PromptTemplateRegistry()
assert reg.list_ids() == []
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
class TestExecutionPlanBuilder:
def test_builds_empty_plan(self) -> None:
plan = ExecutionPlanBuilder("task_agent").build()
assert plan.agent == "task_agent"
assert plan.steps == []
def test_add_step_basic(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("create_task", {"priority": "high"})
.build()
)
assert len(plan.steps) == 1
assert plan.steps[0].action == "create_task"
assert plan.steps[0].variables == {"priority": "high"}
assert plan.steps[0].prompt_template is None
assert plan.steps[0].data_from_step is None
def test_add_step_no_params(self) -> None:
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
assert plan.steps[0].variables is None
def test_add_llm_step(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_llm_step("tpl_task_default", {"message": "hi"})
.build()
)
assert plan.steps[0].action == "llm"
assert plan.steps[0].prompt_template == "tpl_task_default"
assert plan.steps[0].variables == {"message": "hi"}
def test_add_llm_step_no_variables(self) -> None:
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
assert plan.steps[0].variables is None
def test_add_data_step(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("fetch_data")
.add_data_step("transform", data_from_step=0)
.build()
)
assert plan.steps[1].action == "transform"
assert plan.steps[1].data_from_step == 0
def test_fluent_chaining_returns_builder(self) -> None:
builder = ExecutionPlanBuilder("analytics_agent")
result = builder.add_step("a")
assert result is builder
def test_fluent_chain_multiple_steps(self) -> None:
plan = (
ExecutionPlanBuilder("analytics_agent")
.add_llm_step("tpl_analytics_default")
.add_step("format_output")
.add_data_step("store", data_from_step=0)
.build()
)
assert len(plan.steps) == 3
def test_build_validates_data_from_step_out_of_range(self) -> None:
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
def test_build_validates_data_from_step_self_reference(self) -> None:
"""data_from_step=0 on the first step (index 0) is invalid."""
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
def test_build_validates_data_from_step_negative(self) -> None:
with pytest.raises(ValueError, match="data_from_step"):
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
def test_valid_data_from_step_at_index_two(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("step0")
.add_step("step1")
.add_data_step("step2", data_from_step=1)
.build()
)
assert plan.steps[2].data_from_step == 1
def test_data_from_step_zero_valid_at_index_one(self) -> None:
plan = (
ExecutionPlanBuilder("task_agent")
.add_step("step0")
.add_data_step("step1", data_from_step=0)
.build()
)
assert plan.steps[1].data_from_step == 0
def test_build_returns_new_plan_each_call(self) -> None:
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
plan1 = builder.build()
plan2 = builder.build()
assert plan1 is not plan2
assert plan1.steps == plan2.steps
def test_plan_is_execution_plan_instance(self) -> None:
plan = ExecutionPlanBuilder("task_agent").build()
assert isinstance(plan, ExecutionPlan)
# ── PlanCache ─────────────────────────────────────────────────────────
class TestPlanCache:
def _plan(self, agent: str = "a") -> ExecutionPlan:
return ExecutionPlanBuilder(agent).build()
def test_cache_and_get(self) -> None:
cache = PlanCache()
plan = self._plan()
cache.cache_plan("key1", plan)
assert cache.get_plan("key1") is plan
def test_get_missing_returns_none(self) -> None:
cache = PlanCache()
assert cache.get_plan("nonexistent") is None
def test_get_all_playbooks_empty(self) -> None:
cache = PlanCache()
assert cache.get_all_playbooks() == []
def test_get_all_playbooks_returns_all_stored(self) -> None:
cache = PlanCache()
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
playbooks = cache.get_all_playbooks()
assert len(playbooks) == 2
assert p1 in playbooks
assert p2 in playbooks
def test_lru_evicts_oldest_entry(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
cache.cache_plan("k3", p3) # k1 should be evicted
assert cache.get_plan("k1") is None
assert cache.get_plan("k2") is p2
assert cache.get_plan("k3") is p3
def test_lru_access_updates_recency(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
cache.cache_plan("k1", p1)
cache.cache_plan("k2", p2)
cache.get_plan("k1") # k1 is now most-recently used
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
assert cache.get_plan("k1") is p1
assert cache.get_plan("k2") is None
assert cache.get_plan("k3") is p3
def test_overwrite_existing_key(self) -> None:
cache = PlanCache()
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("same_key", p1)
cache.cache_plan("same_key", p2)
assert cache.get_plan("same_key") is p2
assert len(cache.get_all_playbooks()) == 1
def test_overwrite_does_not_consume_capacity(self) -> None:
cache = PlanCache(maxsize=2)
p1, p2 = self._plan("a"), self._plan("b")
cache.cache_plan("k1", p1)
cache.cache_plan("k1", p2) # overwrite, not a new slot
cache.cache_plan("k2", p1) # should fit without eviction
assert cache.get_plan("k1") is p2
assert cache.get_plan("k2") is p1
# ── Module-level singletons ───────────────────────────────────────────
class TestModuleSingletons:
def test_template_registry_has_all_agent_defaults(self) -> None:
for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"):
assert template_registry.has(f"tpl_{agent}_default"), (
f"Missing template: tpl_{agent}_default"
)
def test_template_registry_has_operation_templates(self) -> None:
assert template_registry.has("tpl_task_extract_from_project")
assert template_registry.has("tpl_note_weekly_summary")
def test_template_registry_get_returns_non_empty_string(self) -> None:
text = template_registry.get("tpl_task_agent_default")
assert isinstance(text, str)
assert len(text) > 0
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
assert len(plan_cache.get_all_playbooks()) >= 2
def test_playbook_create_tasks_from_project(self) -> None:
plan = plan_cache.get_plan("create_tasks_from_project")
assert plan is not None
assert plan.agent == "project_agent"
assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
assert plan.steps[1].data_from_step == 0
def test_playbook_generate_weekly_note(self) -> None:
plan = plan_cache.get_plan("generate_weekly_note")
assert plan is not None
assert plan.agent == "note_agent"
assert len(plan.steps) == 2
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
assert plan.steps[1].data_from_step == 0
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
"""Plans must not embed prompt text — only template IDs."""
for plan in plan_cache.get_all_playbooks():
for step in plan.steps:
if step.prompt_template is not None:
assert step.prompt_template.startswith("tpl_"), (
f"prompt_template looks like raw text: {step.prompt_template!r}"
)

322
tests/test_middleware.py Normal file
View File

@@ -0,0 +1,322 @@
"""Tests for Step 9 middleware: auth, rate limiting, and sanitizer.
Auth tests: validated via GET /api/v1/auth/me (requires a Bearer JWT).
Rate limit: use unique user UUIDs per test so windows are independent;
the free-tier threshold (20 req/min) is exercised directly.
Sanitizer: the orchestrator is mocked to inject controlled prompt
fragments, and the chat endpoint response body is inspected.
"""
from __future__ import annotations
import time
import uuid
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from jose import jwt
from app.config.settings import settings
from app.db import get_session
from app.main import app
from app.schemas import ChatResponse
from tests.conftest import TEST_USER_IDS
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
# ---------------------------------------------------------------------------
# Autouse: redirect all DB access to the in-memory SQLite test engine.
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _override_db(db_session):
"""Route all get_session calls to the test SQLite session."""
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
_CHAT_BODY = {
"message": "hello",
"context": {
"user_profile": {},
"relevant_documents": [],
"recent_tasks": [],
"conversation_history": [],
},
"execution_mode": "direct",
}
def _make_jwt(
*,
user_id: str | None = None,
email: str = "test@example.com",
tier: str = "free",
exp_offset: int = 3600,
secret: str | None = None,
include_sub: bool = True,
) -> str:
"""Mint a test JWT signed with the configured (or custom) secret."""
uid = user_id or str(uuid.uuid4())
now = int(time.time())
payload: dict = {
"email": email,
"tier": tier,
"exp": now + exp_offset,
"iat": now,
}
if include_sub:
payload["sub"] = uid
key = secret or settings.JWT_SECRET
return jwt.encode(payload, key, algorithm=settings.JWT_ALGORITHM)
def _auth_header(token: str) -> dict[str, str]:
return {"Authorization": f"Bearer {token}"}
# ---------------------------------------------------------------------------
# Auth middleware
# ---------------------------------------------------------------------------
class TestAuthMiddleware:
"""Tests exercised via GET /api/v1/auth/me."""
def test_valid_token_returns_profile(self) -> None:
# Use the seeded pro user so the subscription lookup returns 'pro'.
uid = TEST_USER_IDS["pro"]
token = _make_jwt(user_id=uid, email="pro@test.com", tier="pro")
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
data = resp.json()
assert data["id"] == uid
assert data["email"] == "pro@test.com"
assert data["tier"] == "pro"
def test_missing_token_returns_401(self) -> None:
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me")
assert resp.status_code == 401
def test_expired_token_returns_401(self) -> None:
token = _make_jwt(exp_offset=-1) # already expired
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 401
def test_wrong_signature_returns_401(self) -> None:
token = _make_jwt(secret="totally-wrong-secret")
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 401
def test_missing_sub_claim_returns_401(self) -> None:
token = _make_jwt(include_sub=False)
with TestClient(app) as client:
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 401
def test_malformed_token_returns_401(self) -> None:
with TestClient(app) as client:
resp = client.get(
"/api/v1/auth/me", headers={"Authorization": "Bearer not.a.jwt"}
)
assert resp.status_code == 401
# ---------------------------------------------------------------------------
# Rate limiter middleware
# ---------------------------------------------------------------------------
class TestRateLimitMiddleware:
"""Each test uses a fresh unique user_id so windows never collide."""
def _unique_token(self, tier: str = "free") -> str:
return _make_jwt(user_id=str(uuid.uuid4()), tier=tier)
def test_free_tier_allows_up_to_20_requests(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
def test_free_tier_blocks_21st_request(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token))
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
def test_429_includes_retry_after_header(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token))
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
assert "retry-after" in resp.headers
retry_after = int(resp.headers["retry-after"])
assert retry_after >= 1
def test_429_response_has_detail_field(self) -> None:
token = self._unique_token("free")
with TestClient(app) as client:
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token))
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
assert "detail" in resp.json()
def test_pro_tier_allows_60_requests(self) -> None:
token = self._unique_token("pro")
with TestClient(app) as client:
# Sample: first 60 succeed, 61st is blocked.
for _ in range(60):
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 200
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
assert resp.status_code == 429
def test_independent_users_have_separate_windows(self) -> None:
token_a = self._unique_token("free")
token_b = self._unique_token("free")
with TestClient(app) as client:
# Exhaust user A's quota.
for _ in range(20):
client.get("/api/v1/auth/me", headers=_auth_header(token_a))
assert (
client.get(
"/api/v1/auth/me", headers=_auth_header(token_a)
).status_code
== 429
)
# User B's quota is untouched.
resp_b = client.get("/api/v1/auth/me", headers=_auth_header(token_b))
assert resp_b.status_code == 200
def test_exempt_path_register_never_rate_limited(self) -> None:
"""POST /auth/register is exempt — 25 calls should never return 429."""
with TestClient(app) as client:
for i in range(25):
resp = client.post(
"/api/v1/auth/register",
json={"email": f"user{i}_{uuid.uuid4()}@example.com", "password": "pw"},
)
# 201 on first, 409 on email collision — but never 429.
assert resp.status_code != 429
def test_exempt_path_login_never_rate_limited(self) -> None:
"""POST /auth/login is exempt — multiple failed attempts are not rate-limited."""
with TestClient(app) as client:
for _ in range(25):
resp = client.post(
"/api/v1/auth/login",
json={"email": "nosuchuser@example.com", "password": "wrong"},
)
assert resp.status_code != 429
def test_exempt_path_health_never_rate_limited(self) -> None:
with TestClient(app) as client:
for _ in range(25):
resp = client.get("/api/v1/health")
assert resp.status_code == 200
# ---------------------------------------------------------------------------
# Sanitizer middleware
# ---------------------------------------------------------------------------
class TestSanitizerMiddleware:
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
_CHAT_PATH = "/api/v1/chat"
def _token(self) -> str:
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
def _post_chat(self, client: TestClient, response_text: str) -> dict:
mock_response = ChatResponse(response=response_text, actions=[])
with patch(
"app.api.routes.chat.orchestrate",
new_callable=AsyncMock,
return_value=mock_response,
):
resp = client.post(
self._CHAT_PATH,
json=_CHAT_BODY,
headers=_auth_header(self._token()),
)
assert resp.status_code == 200
return resp.json()
def test_clean_response_passes_through_unchanged(self) -> None:
with TestClient(app) as client:
data = self._post_chat(client, "Sure, I created the task for you.")
assert data["response"] == "Sure, I created the task for you."
def test_strips_system_prompt_opener(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "You are an intent classifier. Route to task_agent."
)
assert "You are" not in data["response"]
assert "[REDACTED]" in data["response"]
def test_strips_known_fingerprint(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "Respond with just the agent name and nothing else."
)
assert data["response"] == "[REDACTED]"
def test_strips_tool_schema_fragment(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, 'Here is the schema: {"type": "function", "name": "foo"}'
)
assert '"type": "function"' not in data["response"]
def test_strips_reasoning_tag(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "<thinking>I should route this to calendar_agent</thinking>Done."
)
assert "<thinking>" not in data["response"]
assert "[REDACTED]" in data["response"]
def test_strips_available_agents_fragment(self) -> None:
with TestClient(app) as client:
data = self._post_chat(
client, "Available agents: task_agent, calendar_agent"
)
assert "[REDACTED]" in data["response"]
def test_sanitizer_does_not_activate_for_non_chat_path(self) -> None:
"""GET /api/v1/plans/playbook should pass through the sanitizer untouched."""
token = self._token()
with TestClient(app) as client:
resp = client.get(
"/api/v1/plans/playbook",
headers=_auth_header(token),
)
# The sanitizer should not interfere — just check it returns something
# (200 or whatever the route returns; we only care it's not broken).
assert resp.status_code in (200, 401, 403, 404)
def test_sanitizer_preserves_empty_response(self) -> None:
with TestClient(app) as client:
data = self._post_chat(client, "")
assert data["response"] == ""

348
tests/test_orchestrator.py Normal file
View File

@@ -0,0 +1,348 @@
"""Integration tests for the orchestrator module."""
from __future__ import annotations
import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from app.core.agent_registry import AgentRegistry, ChatAgent
from app.core.orchestrator import (
classify_intent,
orchestrate,
orchestrate_stream,
route_pipeline,
route_single,
)
from app.schemas import ChatContext, ChatRequest, ChatResponse, ExecutionPlan
# ── Stub agents ──────────────────────────────────────────────────────
class _TaskAgent(ChatAgent):
def get_name(self) -> str:
return "task_agent"
def get_description(self) -> str:
return "Manages tasks: create, update, list, suggest"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"task: {query}"
class _CalendarAgent(ChatAgent):
def get_name(self) -> str:
return "calendar_agent"
def get_description(self) -> str:
return "Calendar management: events, conflicts, scheduling"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
return f"calendar: {query}"
# ── Helpers ──────────────────────────────────────────────────────────
def _mock_llm(response_text: str) -> MagicMock:
"""Return a mock LLM that always produces *response_text*."""
msg = MagicMock()
msg.content = response_text
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=msg)
return llm
# ── Fixtures ─────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _fresh_registry():
"""Reset the AgentRegistry singleton between tests."""
AgentRegistry._instance = None
yield
AgentRegistry._instance = None
@pytest.fixture()
def reg() -> AgentRegistry:
r = AgentRegistry()
r.register(_TaskAgent)
r.register(_CalendarAgent)
return r
# ── classify_intent ───────────────────────────────────────────────────
class TestClassifyIntent:
@pytest.mark.asyncio
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
result = await classify_intent("add a task", {}, reg)
assert result == "task_agent"
@pytest.mark.asyncio
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("calendar_agent")
result = await classify_intent("schedule a meeting", {}, reg)
assert result == "calendar_agent"
@pytest.mark.asyncio
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("nonexistent_agent")
result = await classify_intent("do something", {}, reg)
assert result == "task_agent"
@pytest.mark.asyncio
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
empty_reg = AgentRegistry()
# No LLM should be instantiated — early return path
with patch("app.core.orchestrator._make_llm") as mock_cls:
result = await classify_intent("anything", {}, empty_reg)
mock_cls.assert_not_called()
assert result == "task_agent"
@pytest.mark.asyncio
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm(" task_agent \n")
result = await classify_intent("create task", {}, reg)
assert result == "task_agent"
# ── route_single ─────────────────────────────────────────────────────
class TestRouteSingle:
@pytest.mark.asyncio
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "create a task", {}, reg)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "create a task", {}, reg)
assert result.response == "task: create a task"
@pytest.mark.asyncio
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
with pytest.raises(KeyError):
await route_single("nonexistent", "hello", {}, reg)
@pytest.mark.asyncio
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
result = await route_single("task_agent", "hi", {}, reg)
assert result.actions == []
# ── route_pipeline ────────────────────────────────────────────────────
class TestRoutePipeline:
@pytest.mark.asyncio
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("synthesized result")
result = await route_pipeline(
["task_agent", "calendar_agent"], "plan my week", {}, reg
)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("synthesized result")
result = await route_pipeline(
["task_agent", "calendar_agent"], "plan my week", {}, reg
)
assert result.response == "synthesized result"
@pytest.mark.asyncio
async def test_passes_previous_results_to_subsequent_agents(
self, reg: AgentRegistry
) -> None:
"""Each agent after the first should receive prior outputs in context."""
received_contexts: list[dict[str, Any]] = []
class _CapturingAgent(ChatAgent):
def get_name(self) -> str:
return "capture"
def get_description(self) -> str:
return "captures context for testing"
def get_tools(self) -> list[Any]:
return []
async def handle(self, query: str, context: dict[str, Any]) -> str:
received_contexts.append(dict(context))
return "captured"
reg.register(_CapturingAgent)
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("done")
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
# The second agent (capture) must have received previous results
assert len(received_contexts) == 1
assert "previous_results" in received_contexts[0]
assert received_contexts[0]["previous_results"] == ["task: hi"]
@pytest.mark.asyncio
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("single result")
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
assert result.response == "single result"
# ── orchestrate ───────────────────────────────────────────────────────
class TestOrchestrate:
@pytest.mark.asyncio
async def test_direct_mode_returns_chat_response(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
@pytest.mark.asyncio
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
assert result.response == "task: add a task"
@pytest.mark.asyncio
async def test_plan_mode_returns_execution_plan(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan my tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
@pytest.mark.asyncio
async def test_plan_mode_agent_matches_classified(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("calendar_agent")
request = ChatRequest(
message="schedule something", execution_mode="plan"
)
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert result.agent == "calendar_agent"
@pytest.mark.asyncio
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert len(result.steps) >= 1
@pytest.mark.asyncio
async def test_plan_mode_template_id_contains_agent_name(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="plan tasks", execution_mode="plan")
result = await orchestrate(request, reg)
assert isinstance(result, ExecutionPlan)
assert result.steps[0].prompt_template is not None
assert "task_agent" in result.steps[0].prompt_template
@pytest.mark.asyncio
async def test_default_execution_mode_is_direct(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
# execution_mode defaults to "direct"
request = ChatRequest(message="help me")
result = await orchestrate(request, reg)
assert isinstance(result, ChatResponse)
# ── orchestrate_stream ────────────────────────────────────────────────
class TestOrchestrateStream:
@pytest.mark.asyncio
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
assert len(chunks) >= 1
@pytest.mark.asyncio
async def test_last_chunk_is_final_json_frame(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="add a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
last = json.loads(chunks[-1])
assert last["done"] is True
assert "response" in last
assert "actions" in last
@pytest.mark.asyncio
async def test_final_frame_response_matches_agent_output(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(message="create a task", execution_mode="direct")
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
final = json.loads(chunks[-1])
assert final["response"] == "task: create a task"
@pytest.mark.asyncio
async def test_text_chunks_before_final_frame(
self, reg: AgentRegistry
) -> None:
with patch("app.core.orchestrator._make_llm") as mock_cls:
mock_cls.return_value = _mock_llm("task_agent")
request = ChatRequest(
message="x" * 200, execution_mode="direct"
) # long enough to produce multiple chunks
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
# All but the last chunk should be plain text (not valid final JSON)
non_final = chunks[:-1]
for chunk in non_final:
try:
parsed = json.loads(chunk)
assert parsed.get("done") is not True
except json.JSONDecodeError:
pass # plain text chunk — expected

402
tests/test_plugins.py Normal file
View File

@@ -0,0 +1,402 @@
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
Covers:
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
- ReviewQueue: pending queue, review decisions, manifest security checklist
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
- Route integration: tier gate, list/get/install/uninstall via TestClient
"""
from __future__ import annotations
import json
import uuid
import pytest
import pytest_asyncio
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import PluginRegistry
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
from app.marketplace.revenue_share import RevenueShare
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
from app.schemas import PluginManifest
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fresh_manifest(
plugin_id: str | None = None,
category: str = "productivity",
price_cents: int = 0,
permissions: list[str] | None = None,
) -> PluginManifest:
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
return PluginManifest(
id=pid,
name=f"Plugin {pid}",
description=f"Description for {pid}",
version="1.0.0",
author="test-author",
permissions=permissions or ["read:tasks"],
category=category,
price_cents=price_cents,
)
# ---------------------------------------------------------------------------
# PluginRegistry (DB-backed)
# ---------------------------------------------------------------------------
class TestPluginRegistry:
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.mark.asyncio
async def test_seed_plugins_are_listed(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session)
assert result.total == 3
assert all(p.id.startswith("plugin-") for p in result.plugins)
@pytest.mark.asyncio
async def test_list_approved_only(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
result = await reg.list_plugins(db_session)
ids = [p.id for p in result.plugins]
assert manifest.id not in ids # still pending
@pytest.mark.asyncio
async def test_list_filter_by_category(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, category="communication")
assert result.total == 1
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_list_filter_by_query(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, query="time")
assert result.total == 1
assert result.plugins[0].id == "plugin-time-tracker"
@pytest.mark.asyncio
async def test_list_sort_by_installs(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-slack-notify")
await reg.record_install(db_session, "plugin-slack-notify")
result = await reg.list_plugins(db_session, sort="installs")
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_get_plugin_found(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["manifest"].id == "plugin-github-sync"
assert "install_count" in entry
@pytest.mark.asyncio
async def test_get_plugin_not_found(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
entry = await reg.get_plugin(db_session, "no-such-plugin")
assert entry is None
@pytest.mark.asyncio
async def test_submit_sets_pending(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
assert plugin_id == manifest.id
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one()
assert row.status == "pending_review"
@pytest.mark.asyncio
async def test_approve_makes_visible(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.approve_plugin(db_session, manifest.id)
result = await reg.list_plugins(db_session)
assert manifest.id in [p.id for p in result.plugins]
@pytest.mark.asyncio
async def test_reject_stores_reason(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
assert row.rejection_reason == "Unsafe permissions"
listed = await reg.list_plugins(db_session)
assert manifest.id not in [p.id for p in listed.plugins]
@pytest.mark.asyncio
async def test_approve_unknown_raises_key_error(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
with pytest.raises(KeyError):
await reg.approve_plugin(db_session, "ghost-plugin")
@pytest.mark.asyncio
async def test_record_install_increments_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_decrements_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_floors_at_zero(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 0
# ---------------------------------------------------------------------------
# ReviewQueue (DB-backed)
# ---------------------------------------------------------------------------
class TestReviewQueue:
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.fixture
def queue(self) -> ReviewQueue:
return ReviewQueue()
@pytest.mark.asyncio
async def test_get_pending_returns_submitted_plugins(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
pending = await queue.get_pending(db_session)
assert any(p["plugin_id"] == manifest.id for p in pending)
@pytest.mark.asyncio
async def test_submit_review_approved(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "approved"
# Check review row was persisted
review_result = await db_session.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
)
review = review_result.scalar_one()
assert review.decision == "approved"
@pytest.mark.asyncio
async def test_submit_review_rejected(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
)
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
def test_validate_manifest_ok(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
validate_manifest(manifest) # should not raise
def test_validate_manifest_unknown_permission(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
with pytest.raises(ValueError, match="Unknown permission"):
validate_manifest(manifest)
def test_validate_manifest_invalid_id_format(self) -> None:
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
def test_validate_manifest_id_with_uppercase(self) -> None:
manifest = _fresh_manifest(plugin_id="UpperCase")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
# ---------------------------------------------------------------------------
# RevenueShare (DB-backed)
# ---------------------------------------------------------------------------
class TestRevenueShare:
@pytest.fixture
def rs(self) -> RevenueShare:
return RevenueShare()
@pytest.mark.asyncio
async def test_record_install_free_plugin(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
)
event = result.scalar_one()
assert event.developer_share_cents == 0
@pytest.mark.asyncio
async def test_record_install_paid_plugin_no_stripe(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
)
event = result.scalar_one()
assert event.amount_cents == 499
assert event.developer_share_cents == int(499 * 0.70)
@pytest.mark.asyncio
async def test_record_install_increments_registry_count(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
reg = PluginRegistry()
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_get_earnings_empty(
self, rs: RevenueShare, db_session: AsyncSession
) -> None:
result = await rs.get_earnings(db_session, "unknown-dev")
assert result["total_installs"] == 0
assert result["total_revenue_cents"] == 0
assert result["developer_share_cents"] == 0
@pytest.mark.asyncio
async def test_get_earnings_aggregates(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
result = await rs.get_earnings(db_session, "Adiuva")
assert result["total_installs"] == 2
assert result["total_revenue_cents"] == 998
assert result["developer_share_cents"] == int(499 * 0.70) * 2
# ---------------------------------------------------------------------------
# Route integration tests
# ---------------------------------------------------------------------------
class TestPluginRoutes:
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
assert resp.status_code == 403
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
assert resp.status_code == 403
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
assert resp.status_code == 200
data = resp.json()
assert "plugins" in data
assert data["total"] == 3
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
assert resp.status_code == 200
def test_get_plugin_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
assert resp.status_code == 200
data = resp.json()
assert data["plugin"]["id"] == "plugin-github-sync"
assert "install_count" in data
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
assert resp.status_code == 404
def test_install_plugin_free(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header(),
)
assert resp.status_code == 200
data = resp.json()
assert data["ok"] is True
assert "download_url" in data
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/ghost/install",
json={"plugin_id": "ghost"},
headers=auth_header(),
)
assert resp.status_code == 404
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
resp = client.delete(
"/api/v1/plugins/plugin-github-sync/install",
headers=auth_header(),
)
assert resp.status_code == 200
assert resp.json()["ok"] is True
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header("free"),
)
assert resp.status_code == 403

562
tests/test_storage.py Normal file
View File

@@ -0,0 +1,562 @@
"""Tests for the storage layer: encryption, BlobStore, VectorStore, and storage routes."""
from __future__ import annotations
import base64
import hashlib
from unittest.mock import MagicMock, patch
import boto3
import pytest
from botocore.exceptions import ClientError
from app.storage.encryption import reject_if_tampered, verify_checksum
from app.storage.blob_store import BlobStore
from app.storage.vector_store import VectorStore, _blob_to_vector
from app.schemas import VectorItem, VectorSearchResult
from tests.conftest import auth_header, S3_TEST_BUCKET
# ── Helpers ───────────────────────────────────────────────────────────
_BLOB = b"encrypted-payload-opaque-to-server"
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
_BUCKET = S3_TEST_BUCKET
_REGION = "us-east-1"
def _pinecone_mock():
"""Return a mock Pinecone index with realistic return shapes."""
mock_index = MagicMock()
mock_index.query.return_value = {
"matches": [
{
"id": "v1",
"score": 0.95,
"metadata": {
"blob": base64.b64encode(b"result-blob").decode(),
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
"user_id": "u1",
},
}
]
}
mock_pc = MagicMock()
mock_pc.return_value.Index.return_value = mock_index
return mock_pc, mock_index
# ── TestEncryption ────────────────────────────────────────────────────
class TestEncryption:
def test_verify_checksum_correct(self) -> None:
assert verify_checksum(_BLOB, _CHECKSUM) is True
def test_verify_checksum_wrong(self) -> None:
assert verify_checksum(_BLOB, "0" * 64) is False
def test_verify_checksum_empty_checksum(self) -> None:
assert verify_checksum(_BLOB, "") is False
def test_verify_checksum_empty_blob(self) -> None:
expected = hashlib.sha256(b"").hexdigest()
assert verify_checksum(b"", expected) is True
def test_verify_checksum_tampered_blob(self) -> None:
tampered = _BLOB + b"\x00"
assert verify_checksum(tampered, _CHECKSUM) is False
def test_reject_if_tampered_passes_when_valid(self) -> None:
# Should not raise
reject_if_tampered(_BLOB, _CHECKSUM)
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
reject_if_tampered(_BLOB, "bad" * 20)
assert exc_info.value.status_code == 400
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
from fastapi import HTTPException
with pytest.raises(HTTPException) as exc_info:
reject_if_tampered(_BLOB, "bad" * 20)
assert "checksum" in exc_info.value.detail.lower()
def test_checksum_is_sha256_hex(self) -> None:
cs = hashlib.sha256(_BLOB).hexdigest()
assert len(cs) == 64
assert all(c in "0123456789abcdef" for c in cs)
# ── TestBlobStore ─────────────────────────────────────────────────────
class TestBlobStore:
@pytest.mark.asyncio
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
store = BlobStore()
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
assert key == "u1/tasks/r1"
@pytest.mark.asyncio
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
# Verify by downloading — no exception means object exists
retrieved = await store.download("u1", "u1/tasks/r1")
assert retrieved == _BLOB
@pytest.mark.asyncio
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
result = await store.download("u1", "u1/notes/n1")
assert result == b"note-data"
@pytest.mark.asyncio
async def test_delete_removes_object(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.delete("u1", "u1/tasks/r1")
with pytest.raises(ClientError) as exc_info:
await store.download("u1", "u1/tasks/r1")
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
@pytest.mark.asyncio
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
store = BlobStore()
# Delete a key that never existed — should not raise
await store.delete("u1", "u1/tasks/nonexistent")
@pytest.mark.asyncio
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
keys = await store.list_keys("u1", "tasks")
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
@pytest.mark.asyncio
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
keys = await store.list_keys("u1", "tasks")
assert "u1/notes/n1" not in keys
assert "u1/tasks/r1" in keys
@pytest.mark.asyncio
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
keys_u1 = await store.list_keys("u1", "tasks")
assert "u2/tasks/r1" not in keys_u1
@pytest.mark.asyncio
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
store = BlobStore()
keys = await store.list_keys("u1", "tasks")
assert keys == []
@pytest.mark.asyncio
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
# Verify S3 metadata was set — check via head_object
with patch("app.storage.blob_store.settings") as mock_settings:
mock_settings.S3_BUCKET = _BUCKET
mock_settings.S3_REGION = _REGION
mock_settings.AWS_ACCESS_KEY_ID = "testing"
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
client = boto3.client("s3", region_name=_REGION)
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
assert response.get("ServerSideEncryption") == "AES256"
@pytest.mark.asyncio
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
store = BlobStore()
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
client = boto3.client("s3", region_name=_REGION)
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
assert response["Metadata"]["checksum"] == _CHECKSUM
# ── _blob_to_vector helper ────────────────────────────────────────────
class TestBlobToVector:
def test_returns_32_floats(self) -> None:
v = _blob_to_vector(b"test")
assert len(v) == 32
def test_all_values_in_range(self) -> None:
v = _blob_to_vector(b"test")
assert all(-1.0 <= x <= 1.0 for x in v)
def test_deterministic(self) -> None:
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
def test_different_blobs_different_vectors(self) -> None:
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
# ── TestVectorStorePinecone ───────────────────────────────────────────
class TestVectorStorePinecone:
def _store(self) -> VectorStore:
store = VectorStore()
store._use_pinecone = lambda: True # type: ignore[method-assign]
return store
@pytest.mark.asyncio
async def test_upsert_calls_index_upsert(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
await store.upsert("u1", items)
mock_index.upsert.assert_called_once()
call_kwargs = mock_index.upsert.call_args[1]
assert call_kwargs.get("namespace") == "u1"
@pytest.mark.asyncio
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
await store.upsert("u1", items)
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
@pytest.mark.asyncio
async def test_search_calls_index_query(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.search("u1", b"query-blob", top_k=5)
mock_index.query.assert_called_once()
query_kwargs = mock_index.query.call_args[1]
assert query_kwargs.get("namespace") == "u1"
assert query_kwargs.get("top_k") == 5
assert query_kwargs.get("include_metadata") is True
@pytest.mark.asyncio
async def test_search_returns_vector_search_results(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
results = await store.search("u1", b"query", top_k=10)
assert len(results) == 1
assert isinstance(results[0], VectorSearchResult)
assert results[0].id == "v1"
assert results[0].score == 0.95
assert results[0].blob == b"result-blob"
@pytest.mark.asyncio
async def test_search_uses_derived_query_vector(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.search("u1", b"query-blob", top_k=3)
expected_vector = _blob_to_vector(b"query-blob")
actual_vector = mock_index.query.call_args[1].get("vector")
assert actual_vector == expected_vector
@pytest.mark.asyncio
async def test_delete_calls_index_delete(self) -> None:
mock_pc, mock_index = _pinecone_mock()
with patch("app.storage.vector_store.Pinecone", mock_pc):
store = self._store()
await store.delete("u1", ["v1", "v2"])
mock_index.delete.assert_called_once()
delete_kwargs = mock_index.delete.call_args[1]
assert delete_kwargs.get("namespace") == "u1"
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
class TestVectorStoreQdrant:
def _store(self) -> VectorStore:
store = VectorStore()
store._use_pinecone = lambda: False # type: ignore[method-assign]
return store
def _qdrant_mock(self) -> MagicMock:
mock_hit = MagicMock()
mock_hit.id = "v1"
mock_hit.score = 0.88
mock_hit.payload = {
"blob": base64.b64encode(b"qdrant-result").decode(),
"user_id": "u1",
}
mock_client = MagicMock()
mock_client.search.return_value = [mock_hit]
return mock_client
@pytest.mark.asyncio
async def test_upsert_calls_client_upsert(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
await store.upsert("u1", items)
mock_client.upsert.assert_called_once()
@pytest.mark.asyncio
async def test_upsert_uses_correct_collection(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
await store.upsert("u1", items)
call_kwargs = mock_client.upsert.call_args[1]
assert call_kwargs["collection_name"] == "adiuva_vectors"
@pytest.mark.asyncio
async def test_search_calls_client_search(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.search("u1", b"query", top_k=5)
mock_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_search_passes_limit(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.search("u1", b"query", top_k=7)
call_kwargs = mock_client.search.call_args[1]
assert call_kwargs.get("limit") == 7
@pytest.mark.asyncio
async def test_search_returns_vector_search_results(self) -> None:
mock_client = self._qdrant_mock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
results = await store.search("u1", b"query", top_k=5)
assert len(results) == 1
assert isinstance(results[0], VectorSearchResult)
assert results[0].id == "v1"
assert results[0].score == 0.88
assert results[0].blob == b"qdrant-result"
@pytest.mark.asyncio
async def test_delete_calls_client_delete(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.delete("u1", ["v1", "v2"])
mock_client.delete.assert_called_once()
@pytest.mark.asyncio
async def test_delete_uses_correct_collection(self) -> None:
mock_client = MagicMock()
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
store = self._store()
await store.delete("u1", ["v1"])
call_kwargs = mock_client.delete.call_args[1]
assert call_kwargs["collection_name"] == "adiuva_vectors"
# ── TestStorageRoutes (integration) ───────────────────────────────────
class TestStorageRoutes:
"""Integration tests for POST/GET/PUT/DELETE /api/v1/storage/records.
Pydantic v2 converts JSON string → bytes via ``str.encode('utf-8')``.
So "hello" in JSON becomes ``b"hello"`` on the server. We use plain
ASCII strings as blob values and compute checksums accordingly.
"""
_BLOB_STR = "encrypted-payload-opaque-to-server"
_BLOB_BYTES = _BLOB_STR.encode()
_BLOB_CHECKSUM = hashlib.sha256(_BLOB_BYTES).hexdigest()
@classmethod
def _create_payload(cls, blob_str: str | None = None) -> dict:
blob_str = blob_str or cls._BLOB_STR
checksum = hashlib.sha256(blob_str.encode()).hexdigest()
return {
"table": "tasks",
"blob": blob_str,
"checksum": checksum,
}
def _create_record(self, client, tier="power", blob_str=None):
payload = self._create_payload(blob_str)
return client.post(
"/api/v1/storage/records",
json=payload,
headers=auth_header(tier),
)
# ── Create ────────────────────────────────────────────────────────
def test_create_record(self, client, s3_bucket) -> None:
resp = self._create_record(client)
assert resp.status_code == 201
data = resp.json()
assert "id" in data
assert "created_at" in data
def test_create_record_bad_checksum(self, client, s3_bucket) -> None:
payload = {
"table": "tasks",
"blob": self._BLOB_STR,
"checksum": "0" * 64,
}
resp = client.post(
"/api/v1/storage/records",
json=payload,
headers=auth_header("power"),
)
assert resp.status_code == 400
def test_create_record_free_tier_blocked(self, client, s3_bucket) -> None:
"""Free tier has cloud_storage_gb=0 → 402."""
resp = self._create_record(client, tier="free")
assert resp.status_code == 402
def test_create_record_pro_tier_allowed(self, client, s3_bucket) -> None:
"""Pro tier has cloud_storage_gb=5 → succeeds for small blob."""
resp = self._create_record(client, tier="pro")
assert resp.status_code == 201
# ── List ──────────────────────────────────────────────────────────
def test_list_records(self, client, s3_bucket) -> None:
self._create_record(client)
self._create_record(client, blob_str="second-blob")
resp = client.get(
"/api/v1/storage/records",
headers=auth_header("power"),
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 2
# Each entry has metadata, no blob bytes
for item in data:
assert "id" in item
assert "table" in item
assert "checksum" in item
assert "blob" not in item
def test_list_records_filter_by_table(self, client, s3_bucket) -> None:
self._create_record(client)
# Create in a different table
note_blob = "note-blob"
payload = {
"table": "notes",
"blob": note_blob,
"checksum": hashlib.sha256(note_blob.encode()).hexdigest(),
}
client.post(
"/api/v1/storage/records",
json=payload,
headers=auth_header("power"),
)
resp = client.get(
"/api/v1/storage/records?table=notes",
headers=auth_header("power"),
)
assert resp.status_code == 200
data = resp.json()
assert len(data) == 1
assert data[0]["table"] == "notes"
def test_list_records_isolated_per_user(self, client, s3_bucket) -> None:
"""One user's records should not appear in another user's list."""
self._create_record(client, tier="power")
resp = client.get(
"/api/v1/storage/records",
headers=auth_header("team"),
)
assert resp.json() == []
# ── Download ──────────────────────────────────────────────────────
def test_download_record(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
resp = client.get(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert resp.status_code == 200
assert resp.content == self._BLOB_BYTES
assert resp.headers["X-Checksum"] == self._BLOB_CHECKSUM
def test_download_record_not_found(self, client, s3_bucket) -> None:
resp = client.get(
"/api/v1/storage/records/nonexistent-id",
headers=auth_header("power"),
)
assert resp.status_code == 404
# ── Update ────────────────────────────────────────────────────────
def test_update_record(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
new_blob_str = "updated-encrypted-payload"
new_checksum = hashlib.sha256(new_blob_str.encode()).hexdigest()
resp = client.put(
f"/api/v1/storage/records/{record_id}",
json={"blob": new_blob_str, "checksum": new_checksum},
headers=auth_header("power"),
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# Verify download returns the updated blob
dl = client.get(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert dl.content == new_blob_str.encode()
def test_update_record_bad_checksum(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
resp = client.put(
f"/api/v1/storage/records/{record_id}",
json={"blob": "some-data", "checksum": "0" * 64},
headers=auth_header("power"),
)
assert resp.status_code == 400
# ── Delete ────────────────────────────────────────────────────────
def test_delete_record(self, client, s3_bucket) -> None:
create_resp = self._create_record(client)
record_id = create_resp.json()["id"]
resp = client.delete(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# Subsequent GET should return 404
dl = client.get(
f"/api/v1/storage/records/{record_id}",
headers=auth_header("power"),
)
assert dl.status_code == 404
def test_delete_record_not_found(self, client, s3_bucket) -> None:
resp = client.delete(
"/api/v1/storage/records/nonexistent",
headers=auth_header("power"),
)
assert resp.status_code == 404