Compare commits
10 Commits
71fd1a0a7c
...
4c4df7335a
| Author | SHA1 | Date | |
|---|---|---|---|
| 4c4df7335a | |||
| c8ef7b119b | |||
| 35dd9ac86f | |||
| e72d72f4f6 | |||
| 14d1a7351d | |||
| 68955d2fc2 | |||
| 864dfdc4e6 | |||
| 0d16729036 | |||
| 82669d3704 | |||
| 4d0917f5df |
28
.env.example
Normal file
28
.env.example
Normal file
@@ -0,0 +1,28 @@
|
||||
# ── Application ──────────────────────────────────────────────────────────────
|
||||
ENV=dev
|
||||
|
||||
# ── Database ──────────────────────────────────────────────────────────────────
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||
JWT_SECRET=replace-with-a-long-random-secret
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# ── OpenAI ────────────────────────────────────────────────────────────────────
|
||||
OPENAI_API_KEY=sk-...
|
||||
|
||||
# ── Stripe ────────────────────────────────────────────────────────────────────
|
||||
STRIPE_SECRET_KEY=sk_test_...
|
||||
STRIPE_WEBHOOK_SECRET=whsec_...
|
||||
|
||||
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
||||
S3_BUCKET=adiuva-backups
|
||||
S3_REGION=us-east-1
|
||||
AWS_ACCESS_KEY_ID=AKIA...
|
||||
AWS_SECRET_ACCESS_KEY=...
|
||||
|
||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||
# Comma-separated list parsed by Settings (override default if needed)
|
||||
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||
21
.gitea/workflows/deploy.yaml
Normal file
21
.gitea/workflows/deploy.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
name: Deploy to Proxmox Docker
|
||||
run-name: Deploying ${{ gitea.sha }}
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main # O il nome del tuo branch principale
|
||||
|
||||
jobs:
|
||||
Deploy:
|
||||
runs-on: ubuntu-latest # Questo dipende dalle label che hai dato al tuo act_runner
|
||||
steps:
|
||||
- name: Deploying via SSH
|
||||
uses: appleboy/ssh-action@v1.0.0
|
||||
with:
|
||||
host: ${{ secrets.SSH_HOST }}
|
||||
username: ${{ secrets.SSH_USER }}
|
||||
key: ${{ secrets.SSH_KEY }}
|
||||
script: |
|
||||
cd /opt/adiuva-api
|
||||
git pull origin main
|
||||
docker compose up -d --build
|
||||
33
.gitignore
vendored
Normal file
33
.gitignore
vendored
Normal file
@@ -0,0 +1,33 @@
|
||||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*.egg-info/
|
||||
dist/
|
||||
build/
|
||||
|
||||
# Virtual environment
|
||||
.venv/
|
||||
venv/
|
||||
env/
|
||||
|
||||
# Environment variables
|
||||
.env
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
|
||||
# Testing / coverage
|
||||
.pytest_cache/
|
||||
htmlcov/
|
||||
.coverage
|
||||
|
||||
# Docker
|
||||
*.log
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
Thumbs.db
|
||||
|
||||
# Claude Code
|
||||
.claude/
|
||||
334
BACKEND_PLAN.md
334
BACKEND_PLAN.md
@@ -2,8 +2,8 @@
|
||||
|
||||
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
||||
>
|
||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, and backup blob storage.
|
||||
> The backend NEVER persists user data. It receives context in requests, uses it for orchestration, and discards it.
|
||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
|
||||
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
|
||||
|
||||
---
|
||||
|
||||
@@ -20,7 +20,7 @@ adiuva-api/
|
||||
│ │ ├── orchestrator.py # LLM-based intent router
|
||||
│ │ ├── execution_plan.py # Plan builder + cache
|
||||
│ │ └── plugin_loader.py # Dynamic agent loading
|
||||
│ ├── agents/
|
||||
│ ├── agents/ # Chat agents (proprietary logic + prompts)
|
||||
│ │ ├── __init__.py # Auto-registers all agents
|
||||
│ │ ├── task_agent.py
|
||||
│ │ ├── calendar_agent.py
|
||||
@@ -32,7 +32,10 @@ adiuva-api/
|
||||
│ │ │ ├── __init__.py
|
||||
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
||||
│ │ │ ├── plans.py # GET /plans/playbook
|
||||
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
|
||||
│ │ │ ├── vectors.py # Upsert/search cloud vector store
|
||||
│ │ │ ├── backup.py # PUT/GET /backup
|
||||
│ │ │ ├── plugins.py # Plugin marketplace
|
||||
│ │ │ ├── auth.py # Register/login/refresh
|
||||
│ │ │ └── billing.py # Checkout/webhook/subscription
|
||||
│ │ └── middleware/
|
||||
@@ -40,6 +43,16 @@ adiuva-api/
|
||||
│ │ ├── auth.py # JWT validation
|
||||
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
||||
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
||||
│ ├── storage/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
|
||||
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
|
||||
│ │ └── encryption.py # Integrity verification only — NO decryption
|
||||
│ ├── marketplace/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
|
||||
│ │ ├── plugin_review.py # Review queue + approval workflow
|
||||
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
|
||||
│ ├── billing/
|
||||
│ │ ├── __init__.py
|
||||
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
||||
@@ -53,8 +66,10 @@ adiuva-api/
|
||||
│ ├── test_orchestrator.py
|
||||
│ ├── test_agents.py
|
||||
│ ├── test_auth.py
|
||||
│ └── test_backup.py
|
||||
├── alembic/ # DB migrations (auth/billing tables only)
|
||||
│ ├── test_backup.py
|
||||
│ ├── test_storage.py
|
||||
│ └── test_plugins.py
|
||||
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
|
||||
│ ├── alembic.ini
|
||||
│ └── versions/
|
||||
├── requirements.txt
|
||||
@@ -68,9 +83,9 @@ adiuva-api/
|
||||
|
||||
## Step-by-Step Implementation
|
||||
|
||||
### Step 1 — Project scaffolding
|
||||
- [ ] Initialize repo with the directory structure above
|
||||
- [ ] Write `requirements.txt`:
|
||||
### Step 1 — Project scaffolding ✅
|
||||
- [x] Initialize repo with the directory structure above
|
||||
- [x] Write `requirements.txt`:
|
||||
```
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.34.0
|
||||
@@ -91,29 +106,40 @@ adiuva-api/
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.24.0
|
||||
```
|
||||
- [ ] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
||||
- [ ] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod)
|
||||
- [ ] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
||||
- [ ] Write `docker-compose.yml`: app, postgres:16, optional redis
|
||||
- [ ] Write `.env.example`
|
||||
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
||||
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
||||
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
||||
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
||||
- [x] Write `.env.example`
|
||||
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
|
||||
|
||||
### Step 2 — Pydantic schemas (API contracts)
|
||||
- [ ] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
||||
### Step 2 — Pydantic schemas (API contracts) ✅
|
||||
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
||||
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
||||
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
||||
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification']`, `table: str | None`, `data: dict | None`
|
||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
|
||||
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
||||
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
||||
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
||||
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
||||
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
||||
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
|
||||
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
|
||||
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
|
||||
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
|
||||
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
|
||||
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
|
||||
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
|
||||
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
|
||||
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
|
||||
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
|
||||
- `PluginInstallRequest`: `plugin_id: str`
|
||||
- **Outcome:** All request/response models defined and validated.
|
||||
|
||||
### Step 3 — Agent Registry + base classes
|
||||
- [ ] `app/core/agent_registry.py`:
|
||||
### Step 3 — Agent Registry + base classes ✅
|
||||
- [x] `app/core/agent_registry.py`:
|
||||
- `BaseAgent(ABC)`:
|
||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
||||
@@ -127,11 +153,11 @@ adiuva-api/
|
||||
- `get(name) -> ChatAgent`
|
||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
||||
- [ ] Unit tests: register, get, list, call_agent with mock
|
||||
- [x] Unit tests: register, get, list, call_agent with mock
|
||||
- **Outcome:** Pluggable agent framework.
|
||||
|
||||
### Step 4 — Orchestrator
|
||||
- [ ] `app/core/orchestrator.py`:
|
||||
### Step 4 — Orchestrator ✅
|
||||
- [x] `app/core/orchestrator.py`:
|
||||
- `async classify_intent(message, context, registry) -> str`:
|
||||
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
||||
- Uses gpt-4o-mini via LangChain for low latency
|
||||
@@ -146,16 +172,17 @@ adiuva-api/
|
||||
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
||||
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
||||
- Main entry point
|
||||
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
|
||||
- Classifies intent
|
||||
- If `execution_mode == 'direct'`: route + return response
|
||||
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
||||
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
||||
- Same as orchestrate but yields tokens for WebSocket streaming
|
||||
- [ ] Integration tests with mocked LLM and mocked agents
|
||||
- [x] Integration tests with mocked LLM and mocked agents
|
||||
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
||||
|
||||
### Step 5 — Execution Plan generator
|
||||
- [ ] `app/core/execution_plan.py`:
|
||||
### Step 5 — Execution Plan generator ✅
|
||||
- [x] `app/core/execution_plan.py`:
|
||||
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
||||
- `ExecutionPlanBuilder`:
|
||||
- `add_step(action, params) -> self`
|
||||
@@ -168,32 +195,52 @@ adiuva-api/
|
||||
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
||||
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
||||
|
||||
### Step 6 — Chat Agents
|
||||
- [ ] `app/agents/task_agent.py` — `@registry.register`:
|
||||
- Description: "Manages tasks: create, update, list, suggest"
|
||||
- Tools: `create_task(title, description, priority, due_date)`, `update_task(id, updates)`, `list_tasks(filters)`, `suggest_tasks(notes_context)`
|
||||
- System prompt: PM-oriented, validates task structure, infers priority from context
|
||||
- `handle()`: LLM + tool loop via `_tool_loop()`, returns response text + list of actions performed
|
||||
- [ ] `app/agents/calendar_agent.py` — `@registry.register`:
|
||||
- Description: "Calendar management: events, conflicts, scheduling"
|
||||
- Tools: `list_events(date_range)`, `detect_conflicts(events)`, `suggest_reschedule(conflict)`
|
||||
- Works with event metadata passed in context (never raw calendar data stored)
|
||||
- [ ] `app/agents/email_agent.py` — `@registry.register`:
|
||||
- Description: "Email analysis: classify, extract actions, draft responses"
|
||||
- Tools: `classify_email(metadata)`, `extract_action_items(metadata)`, `draft_response(thread_context)`
|
||||
- Only processes metadata sent by client — never raw email bodies
|
||||
- [ ] `app/agents/analytics_agent.py` — `@registry.register`:
|
||||
- Description: "Workspace analytics: metrics, reports, trends"
|
||||
- Tools: `calculate_metrics(task_data)`, `generate_report(period, data)`, `trend_analysis(data_points)`
|
||||
- Crunches numbers from context, returns structured insights
|
||||
- [ ] `app/agents/__init__.py`: imports all agent modules to trigger `@registry.register` decorators
|
||||
- [ ] Unit tests per agent with mocked LLM
|
||||
- **Outcome:** Four specialized agents, all registered and tested.
|
||||
### Step 6 — Chat Agents ✅
|
||||
- [x] `app/agents/task_agent.py` — `@registry.register`:
|
||||
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
|
||||
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
|
||||
- Accepts flexible context; sentinel `-1` for optional integer update fields
|
||||
- [x] `app/agents/checkpoint_agent.py` — `@registry.register`:
|
||||
- Description: "Manages project checkpoints (milestones): list, create, update, delete"
|
||||
- Tools (4): `list_checkpoints(project_id)`, `create_checkpoint(project_id, title, date, is_ai_suggested, is_approved)`, `update_checkpoint(checkpoint_id, ...)`, `delete_checkpoint(checkpoint_id)`
|
||||
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
|
||||
- [x] `app/agents/project_agent.py` — `@registry.register`:
|
||||
- Description: "Manages projects: list, get, create, update, archive, delete"
|
||||
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
|
||||
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
|
||||
- [x] `app/agents/note_agent.py` — `@registry.register`:
|
||||
- Description: "Manages notes: list, get, create, update, delete"
|
||||
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
|
||||
- content is Markdown; `get_note` should be called before update to preserve existing content
|
||||
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
|
||||
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
|
||||
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Checkpoints, Projects, Notes), all registered and tested.
|
||||
|
||||
### Step 7 — API Routes
|
||||
### Step 7 — Storage Layer ✅
|
||||
- [x] `app/storage/blob_store.py`:
|
||||
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
|
||||
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
|
||||
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
|
||||
- [x] `app/storage/vector_store.py`:
|
||||
- `VectorStore`: `async upsert`, `async search`, `async delete`
|
||||
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
|
||||
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
|
||||
- ANN on encrypted data: known accuracy trade-off, documented
|
||||
- [x] `app/storage/encryption.py`:
|
||||
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
|
||||
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
|
||||
- Backend NEVER holds decryption keys
|
||||
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
|
||||
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
||||
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
|
||||
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
|
||||
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
|
||||
|
||||
#### 7a — Chat endpoint
|
||||
- [ ] `app/api/routes/chat.py`:
|
||||
### Step 8 — API Routes ✅
|
||||
|
||||
#### 8a — Chat endpoint
|
||||
- [x] `app/api/routes/chat.py`:
|
||||
- `POST /api/v1/chat`:
|
||||
- Request: `ChatRequest`
|
||||
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
||||
@@ -204,48 +251,93 @@ adiuva-api/
|
||||
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
||||
- Heartbeat ping every 30s to keep connection alive
|
||||
|
||||
#### 7b — Plans endpoint
|
||||
- [ ] `app/api/routes/plans.py`:
|
||||
#### 8b — Plans endpoint
|
||||
- [x] `app/api/routes/plans.py`:
|
||||
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
||||
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
||||
|
||||
#### 7c — Backup endpoint
|
||||
- [ ] `app/api/routes/backup.py`:
|
||||
#### 8c — Storage endpoint (cloud records)
|
||||
- [x] `app/api/routes/storage.py`:
|
||||
- `POST /api/v1/storage/records`: Create encrypted record
|
||||
- Request: `StorageRecordCreate`
|
||||
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
|
||||
- Response: `{id: str, created_at: int}`
|
||||
- `GET /api/v1/storage/records`: List record metadata (no blobs)
|
||||
- Query params: `table: str`, `page: int`, `limit: int`
|
||||
- Response: `list[{id, table, checksum, created_at, updated_at}]`
|
||||
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
|
||||
- Response: blob bytes + `X-Checksum` header
|
||||
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
|
||||
- Request: `StorageRecordUpdate`
|
||||
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
|
||||
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
|
||||
|
||||
#### 8d — Vectors endpoint (cloud vector store)
|
||||
- [x] `app/api/routes/vectors.py`:
|
||||
- `POST /api/v1/storage/vectors/upsert`:
|
||||
- Request: `VectorUpsertRequest`
|
||||
- Verifies checksums, delegates to `VectorStore.upsert()`
|
||||
- Response: `{upserted: int}`
|
||||
- `POST /api/v1/storage/vectors/search`:
|
||||
- Request: `VectorSearchRequest`
|
||||
- Delegates to `VectorStore.search()`
|
||||
- Response: `VectorSearchResponse`
|
||||
- `DELETE /api/v1/storage/vectors`:
|
||||
- Request: `{ids: list[str]}`
|
||||
|
||||
#### 8e — Backup endpoint
|
||||
- [x] `app/api/routes/backup.py`:
|
||||
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
||||
- Free: 0 (no backup)
|
||||
- Pro: 5 GB
|
||||
- Power: 50 GB
|
||||
- Power: 25 GB
|
||||
- Team: unlimited
|
||||
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
||||
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
||||
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
||||
|
||||
#### 7d — Auth endpoint
|
||||
- [ ] `app/api/routes/auth.py`:
|
||||
#### 8f — Plugins endpoint
|
||||
- [x] `app/api/routes/plugins.py`:
|
||||
- `GET /api/v1/plugins`:
|
||||
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
|
||||
- Response: `PluginListResponse`
|
||||
- Available from Power tier and above
|
||||
- `GET /api/v1/plugins/{id}`:
|
||||
- Response: `PluginManifest` + ratings + install count
|
||||
- `POST /api/v1/plugins/{id}/install`:
|
||||
- Request: `PluginInstallRequest`
|
||||
- Records installation for the user (billing tracking, analytics)
|
||||
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
|
||||
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
|
||||
- `DELETE /api/v1/plugins/{id}/install`:
|
||||
- Unregisters installation
|
||||
|
||||
#### 8g — Auth endpoint
|
||||
- [x] `app/api/routes/auth.py`:
|
||||
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
||||
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
||||
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
||||
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
||||
|
||||
#### 7e — Billing endpoint
|
||||
- [ ] `app/api/routes/billing.py`:
|
||||
#### 8h — Billing endpoint
|
||||
- [x] `app/api/routes/billing.py`:
|
||||
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
||||
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
||||
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
||||
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
||||
|
||||
- **Outcome:** Complete REST + WebSocket API.
|
||||
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
|
||||
|
||||
### Step 8 — Middleware
|
||||
### Step 9 — Middleware
|
||||
|
||||
#### 8a — Auth middleware
|
||||
#### 9a — Auth middleware
|
||||
- [ ] `app/api/middleware/auth.py`:
|
||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
||||
- Raises `401` on invalid/expired token
|
||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||
|
||||
#### 8b — Rate limiter
|
||||
#### 9b — Rate limiter
|
||||
- [ ] `app/api/middleware/rate_limit.py`:
|
||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
||||
- Tier-based limits:
|
||||
@@ -255,7 +347,7 @@ adiuva-api/
|
||||
- Team: 200 req/seat/min
|
||||
- Custom 429 response with `Retry-After` header
|
||||
|
||||
#### 8c — Sanitizer
|
||||
#### 9c — Sanitizer
|
||||
- [ ] `app/api/middleware/sanitizer.py`:
|
||||
- Response middleware that scans response bodies
|
||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
||||
@@ -264,7 +356,27 @@ adiuva-api/
|
||||
|
||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
||||
|
||||
### Step 9 — Billing & Tier management
|
||||
### Step 10 — Plugin Marketplace
|
||||
- [ ] `app/marketplace/plugin_registry.py`:
|
||||
- `PluginRegistry`:
|
||||
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
||||
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
||||
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
||||
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
||||
- `async reject_plugin(plugin_id, reason: str) -> None`
|
||||
- [ ] `app/marketplace/plugin_review.py`:
|
||||
- `ReviewQueue`:
|
||||
- `async get_pending() -> list[dict]`
|
||||
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
||||
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
||||
- [ ] `app/marketplace/revenue_share.py`:
|
||||
- `RevenueShare`:
|
||||
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
||||
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
||||
- `async get_earnings(developer_id, period) -> dict`
|
||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
||||
|
||||
### Step 11 — Billing & Tier management
|
||||
- [ ] `app/billing/stripe_service.py`:
|
||||
- `create_checkout_session(user_id, tier) -> str`
|
||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
||||
@@ -275,33 +387,77 @@ adiuva-api/
|
||||
- Feature matrix:
|
||||
```python
|
||||
FEATURES = {
|
||||
'free': {'agents': 3, 'batch': False, 'providers': 1, 'backup_gb': 0},
|
||||
'pro': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': 5},
|
||||
'power': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': 50, 'byok': True},
|
||||
'team': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': -1, 'sso': True},
|
||||
'free': {
|
||||
'agents': 3,
|
||||
'batch_active': 2,
|
||||
'cloud_storage_gb': 0,
|
||||
'backup_gb': 0,
|
||||
'providers': 1,
|
||||
'batch_builder': False,
|
||||
'plugin_marketplace': False,
|
||||
'sso': False,
|
||||
},
|
||||
'pro': {
|
||||
'agents': -1, # unlimited
|
||||
'batch_active': 10,
|
||||
'cloud_storage_gb': 5,
|
||||
'backup_gb': 5,
|
||||
'providers': -1,
|
||||
'batch_builder': False,
|
||||
'plugin_marketplace': False,
|
||||
'sso': False,
|
||||
},
|
||||
'power': {
|
||||
'agents': -1,
|
||||
'batch_active': -1, # unlimited
|
||||
'cloud_storage_gb': 25,
|
||||
'backup_gb': 25,
|
||||
'providers': -1,
|
||||
'batch_builder': True,
|
||||
'plugin_marketplace': True,
|
||||
'sso': False,
|
||||
},
|
||||
'team': {
|
||||
'agents': -1,
|
||||
'batch_active': -1,
|
||||
'cloud_storage_gb': -1,
|
||||
'backup_gb': -1,
|
||||
'providers': -1,
|
||||
'batch_builder': True,
|
||||
'plugin_marketplace': True,
|
||||
'sso': True,
|
||||
},
|
||||
}
|
||||
```
|
||||
- `get_tier(user_id) -> BillingTier`
|
||||
- `check_feature(user_id, feature) -> bool`
|
||||
- `get_rate_limit(tier) -> int`
|
||||
- **Outcome:** Stripe integration with tier-based feature gating.
|
||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
||||
|
||||
### Step 10 — Database (auth/billing only)
|
||||
### Step 12 — Database (auth/billing/marketplace only)
|
||||
- [ ] PostgreSQL schema via Alembic:
|
||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
||||
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
||||
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
|
||||
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
|
||||
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
||||
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
||||
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
||||
- [ ] Initial Alembic migration
|
||||
- [ ] SQLAlchemy models in `app/models.py`
|
||||
- **Outcome:** Auth and billing persistence. Zero user data stored.
|
||||
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
||||
|
||||
### Step 11 — Testing & deployment
|
||||
- [ ] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed)
|
||||
### Step 13 — Testing & deployment
|
||||
- [ ] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
|
||||
- [ ] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
||||
- [ ] `tests/test_agents.py`: each agent with mocked tools
|
||||
- [ ] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
||||
- [ ] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
||||
- [ ] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
|
||||
- [ ] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
|
||||
- [ ] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
||||
- [ ] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
||||
- **Outcome:** Fully tested, deployable backend.
|
||||
@@ -320,10 +476,22 @@ adiuva-api/
|
||||
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
||||
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
||||
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
||||
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
|
||||
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
|
||||
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
|
||||
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
|
||||
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
|
||||
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
|
||||
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
|
||||
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
|
||||
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
||||
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
||||
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
||||
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
||||
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
|
||||
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
|
||||
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
|
||||
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
|
||||
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
||||
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
||||
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
||||
@@ -339,20 +507,24 @@ adiuva-api/
|
||||
| Framework | FastAPI + Uvicorn |
|
||||
| LLM | LangChain + langchain-openai |
|
||||
| Auth | PyJWT + bcrypt + OAuth2 |
|
||||
| Billing | stripe-python |
|
||||
| Storage | boto3 (S3) |
|
||||
| Billing | stripe-python + Stripe Connect |
|
||||
| Blob storage | boto3 (S3) |
|
||||
| Vector store | Pinecone or Qdrant (configurable) |
|
||||
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
||||
| Rate limiting | slowapi |
|
||||
| Testing | pytest + pytest-asyncio + httpx |
|
||||
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
|
||||
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
||||
|
||||
---
|
||||
|
||||
## Development Rules
|
||||
|
||||
1. **NEVER persist user data.** The DB stores only auth, billing, and backup metadata. User context arrives in requests and is discarded after processing.
|
||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending.
|
||||
3. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
||||
4. **Type hints everywhere.** All functions have full type annotations.
|
||||
5. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
||||
6. **Structured logging.** JSON logs with request ID correlation.
|
||||
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
|
||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
|
||||
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
|
||||
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
||||
5. **Type hints everywhere.** All functions have full type annotations.
|
||||
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
||||
7. **Structured logging.** JSON logs with request ID correlation.
|
||||
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
|
||||
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
||||
|
||||
31
Dockerfile
Normal file
31
Dockerfile
Normal file
@@ -0,0 +1,31 @@
|
||||
# ── builder ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY requirements.txt .
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||
|
||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
# Non-root user
|
||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Copy installed packages from builder
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
# Copy application source
|
||||
COPY app/ app/
|
||||
|
||||
# Ensure appuser owns the working directory
|
||||
RUN chown -R appuser:appgroup /app
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
|
||||
0
app/__init__.py
Normal file
0
app/__init__.py
Normal file
5
app/agents/__init__.py
Normal file
5
app/agents/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Import all agent modules to trigger @registry.register decorators."""
|
||||
|
||||
from app.agents import checkpoint_agent, note_agent, project_agent, task_agent
|
||||
|
||||
__all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"]
|
||||
122
app/agents/checkpoint_agent.py
Normal file
122
app/agents/checkpoint_agent.py
Normal file
@@ -0,0 +1,122 @@
|
||||
"""Checkpoint agent — project milestone management (list, create, update, delete)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
|
||||
"track progress on a project — they are not calendar events.\n\n"
|
||||
"Rules:\n"
|
||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||
" - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n"
|
||||
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
||||
" - For update_checkpoint, use -1 for integer fields you do not want to change\n"
|
||||
" - Listing without a project_id returns all checkpoints across projects\n"
|
||||
" - Always echo the title and formatted date in your confirmation."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_checkpoints(project_id: str = "") -> str:
|
||||
"""List checkpoints. Provide project_id to scope to a specific project."""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "checkpoints",
|
||||
"filters": {"projectId": project_id or None},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def create_checkpoint(
|
||||
project_id: str,
|
||||
title: str,
|
||||
date: int,
|
||||
is_ai_suggested: int = 0,
|
||||
is_approved: int = 0,
|
||||
) -> str:
|
||||
"""Create a project checkpoint (milestone).
|
||||
project_id: REQUIRED UUID of the parent project
|
||||
title: descriptive name for the milestone
|
||||
date: Unix timestamp in milliseconds
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
is_approved: 0 until the user confirms
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "checkpoints",
|
||||
"data": {
|
||||
"projectId": project_id,
|
||||
"title": title,
|
||||
"date": date,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
"isApproved": is_approved,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def update_checkpoint(
|
||||
checkpoint_id: str,
|
||||
title: str = "",
|
||||
date: int = -1,
|
||||
is_approved: int = -1,
|
||||
) -> str:
|
||||
"""Update a checkpoint. Only pass fields that should change.
|
||||
checkpoint_id: UUID of the checkpoint (required)
|
||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if date != -1:
|
||||
updates["date"] = date
|
||||
if is_approved != -1:
|
||||
updates["isApproved"] = is_approved
|
||||
return json.dumps({
|
||||
"action": "update_record",
|
||||
"table": "checkpoints",
|
||||
"data": {"id": checkpoint_id, "updates": updates},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_checkpoint(checkpoint_id: str) -> str:
|
||||
"""Delete a checkpoint permanently by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "checkpoints",
|
||||
"data": {"id": checkpoint_id},
|
||||
})
|
||||
|
||||
|
||||
@registry.register
|
||||
class CheckpointAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "checkpoint_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages project checkpoints (milestones): list, create, update, delete"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||
messages = [
|
||||
SystemMessage(content=_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||
),
|
||||
]
|
||||
return await self._tool_loop(llm, messages, self.get_tools())
|
||||
123
app/agents/note_agent.py
Normal file
123
app/agents/note_agent.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||
"and delete Markdown notes in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - content is always Markdown; preserve formatting when updating\n"
|
||||
" - project_id is optional; link a note to a project when mentioned\n"
|
||||
" - When updating, call get_note first if you need to read existing content\n"
|
||||
" before appending or replacing sections\n"
|
||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||
" when the user is working within a specific project\n"
|
||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||
" is already in the note (retrieved via get_note)."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_notes(project_id: str = "") -> str:
|
||||
"""List notes, optionally scoped to a project by project_id."""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "notes",
|
||||
"filters": {"projectId": project_id or None},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def get_note(note_id: str) -> str:
|
||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||
return json.dumps({
|
||||
"action": "get",
|
||||
"table": "notes",
|
||||
"data": {"id": note_id},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def create_note(
|
||||
title: str,
|
||||
content: str,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new note.
|
||||
title: note heading (required)
|
||||
content: Markdown body text (required)
|
||||
project_id: optional UUID linking this note to a project
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "notes",
|
||||
"data": {
|
||||
"title": title,
|
||||
"content": content,
|
||||
"projectId": project_id or None,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def update_note(
|
||||
note_id: str,
|
||||
title: str = "",
|
||||
content: str = "",
|
||||
) -> str:
|
||||
"""Update an existing note. Only pass fields that should change.
|
||||
note_id: UUID of the note (required)
|
||||
If you need to preserve existing content, call get_note first.
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if content:
|
||||
updates["content"] = content
|
||||
return json.dumps({
|
||||
"action": "update_record",
|
||||
"table": "notes",
|
||||
"data": {"id": note_id, "updates": updates},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_note(note_id: str) -> str:
|
||||
"""Delete a note permanently by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "notes",
|
||||
"data": {"id": note_id},
|
||||
})
|
||||
|
||||
|
||||
@registry.register
|
||||
class NoteAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "note_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages notes: list, get, create, update, delete"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return [list_notes, get_note, create_note, update_note, delete_note]
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||
messages = [
|
||||
SystemMessage(content=_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||
),
|
||||
]
|
||||
return await self._tool_loop(llm, messages, self.get_tools())
|
||||
158
app/agents/project_agent.py
Normal file
158
app/agents/project_agent.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a project management assistant. You help users create, find,\n"
|
||||
"update, and archive projects in their workspace.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: active, archived\n"
|
||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||
" derive it from context data — do not fabricate content\n"
|
||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||
" user wants a complete cross-client view including archived projects\n"
|
||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||
" list_projects if you only have a project name\n"
|
||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||
" only call delete_project when the user explicitly confirms deletion."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_projects(
|
||||
client_id: str = "",
|
||||
include_archived: int = 0,
|
||||
) -> str:
|
||||
"""List projects, optionally filtered by client_id.
|
||||
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "projects",
|
||||
"filters": {
|
||||
"clientId": client_id or None,
|
||||
"includeArchived": bool(include_archived),
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def list_all_projects() -> str:
|
||||
"""List every project regardless of client or status.
|
||||
Use only when the user wants a complete cross-client overview.
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "list_all",
|
||||
"table": "projects",
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def get_project(project_id: str) -> str:
|
||||
"""Fetch a single project by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "get",
|
||||
"table": "projects",
|
||||
"data": {"id": project_id},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def create_project(
|
||||
name: str,
|
||||
client_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new project.
|
||||
name: human-readable project name (required)
|
||||
client_id: optional UUID of the owning client
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "projects",
|
||||
"data": {
|
||||
"name": name,
|
||||
"clientId": client_id or None,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def update_project(
|
||||
project_id: str,
|
||||
name: str = "",
|
||||
client_id: str = "",
|
||||
status: str = "",
|
||||
ai_summary: str = "",
|
||||
) -> str:
|
||||
"""Update a project. Only pass fields that should change.
|
||||
project_id: UUID of the project (required)
|
||||
status: active | archived
|
||||
ai_summary: AI-generated summary text (populate only when explicitly requested)
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if name:
|
||||
updates["name"] = name
|
||||
if client_id:
|
||||
updates["clientId"] = client_id
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if ai_summary:
|
||||
updates["aiSummary"] = ai_summary
|
||||
return json.dumps({
|
||||
"action": "update_record",
|
||||
"table": "projects",
|
||||
"data": {"id": project_id, "updates": updates},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_project(project_id: str) -> str:
|
||||
"""Permanently delete a project and orphan its tasks.
|
||||
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||
has explicitly confirmed they want permanent deletion.
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "projects",
|
||||
"data": {"id": project_id},
|
||||
})
|
||||
|
||||
|
||||
@registry.register
|
||||
class ProjectAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "project_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages projects: list, get, create, update, archive, delete"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return [
|
||||
list_projects,
|
||||
list_all_projects,
|
||||
get_project,
|
||||
create_project,
|
||||
update_project,
|
||||
delete_project,
|
||||
]
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||
messages = [
|
||||
SystemMessage(content=_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||
),
|
||||
]
|
||||
return await self._tool_loop(llm, messages, self.get_tools())
|
||||
229
app/agents/task_agent.py
Normal file
229
app/agents/task_agent.py
Normal file
@@ -0,0 +1,229 @@
|
||||
"""Task agent — full CRUD for tasks and task comments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_core.tools import tool
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_registry import ChatAgent, registry
|
||||
|
||||
_SYSTEM_PROMPT = (
|
||||
"You are a task management assistant for a project workspace.\n"
|
||||
"You create, update, list, and track tasks and their comments.\n\n"
|
||||
"Rules:\n"
|
||||
" - status must be one of: todo, in_progress, done\n"
|
||||
" - priority must be one of: high, medium, low\n"
|
||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||
" - project_id is optional; link to a project when the user mentions one\n"
|
||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||
" did not explicitly request; 0 otherwise\n"
|
||||
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||
" - Always confirm the action in plain, user-friendly language."
|
||||
)
|
||||
|
||||
|
||||
# ── Task tools ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
search: str = "",
|
||||
order_by: str = "",
|
||||
) -> str:
|
||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "tasks",
|
||||
"filters": {
|
||||
"projectId": project_id or None,
|
||||
"status": status or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def create_task(
|
||||
title: str,
|
||||
description: str = "",
|
||||
status: str = "todo",
|
||||
priority: str = "medium",
|
||||
assignees: str = "[]",
|
||||
due_date: int = 0,
|
||||
project_id: str = "",
|
||||
is_ai_suggested: int = 0,
|
||||
is_approved: int = 0,
|
||||
) -> str:
|
||||
"""Create a new task.
|
||||
title: task title (required)
|
||||
description: optional details
|
||||
status: todo | in_progress | done (default: todo)
|
||||
priority: high | medium | low (default: medium)
|
||||
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||
project_id: optional UUID of the parent project
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
is_approved: 0 until the user confirms; 1 when confirmed
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "tasks",
|
||||
"data": {
|
||||
"title": title,
|
||||
"description": description or None,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"assignee": assignees,
|
||||
"dueDate": due_date or None,
|
||||
"projectId": project_id or None,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
"isApproved": is_approved,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignees: str = "",
|
||||
due_date: int = -1,
|
||||
project_id: str = "",
|
||||
is_approved: int = -1,
|
||||
) -> str:
|
||||
"""Update fields on an existing task. Only pass fields you want to change.
|
||||
task_id: the task's UUID (required)
|
||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||
is_approved: -1 means unchanged; 0 or 1 sets the value
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if description:
|
||||
updates["description"] = description
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if priority:
|
||||
updates["priority"] = priority
|
||||
if assignees:
|
||||
updates["assignee"] = assignees
|
||||
if due_date != -1:
|
||||
updates["dueDate"] = due_date or None
|
||||
if project_id:
|
||||
updates["projectId"] = project_id
|
||||
if is_approved != -1:
|
||||
updates["isApproved"] = is_approved
|
||||
return json.dumps({
|
||||
"action": "update_record",
|
||||
"table": "tasks",
|
||||
"data": {"id": task_id, "updates": updates},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task(task_id: str) -> str:
|
||||
"""Delete a task permanently by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "tasks",
|
||||
"data": {"id": task_id},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks_due_today() -> str:
|
||||
"""List all tasks whose due date falls on today's date."""
|
||||
return json.dumps({
|
||||
"action": "list_due_today",
|
||||
"table": "tasks",
|
||||
})
|
||||
|
||||
|
||||
# ── Task comment tools ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_task_comments(task_id: str) -> str:
|
||||
"""List all comments on a task by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "list",
|
||||
"table": "taskComments",
|
||||
"filters": {"taskId": task_id},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||
"""Add a comment to a task.
|
||||
task_id: UUID of the task to comment on
|
||||
author: name or ID of the comment author
|
||||
content: comment text
|
||||
"""
|
||||
return json.dumps({
|
||||
"action": "create_record",
|
||||
"table": "taskComments",
|
||||
"data": {
|
||||
"taskId": task_id,
|
||||
"author": author,
|
||||
"content": content,
|
||||
},
|
||||
})
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task_comment(comment_id: str) -> str:
|
||||
"""Delete a task comment by its UUID."""
|
||||
return json.dumps({
|
||||
"action": "delete_record",
|
||||
"table": "taskComments",
|
||||
"data": {"id": comment_id},
|
||||
})
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@registry.register
|
||||
class TaskAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "task_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return [
|
||||
list_tasks,
|
||||
create_task,
|
||||
update_task,
|
||||
delete_task,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
add_task_comment,
|
||||
delete_task_comment,
|
||||
]
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||
messages = [
|
||||
SystemMessage(content=_SYSTEM_PROMPT),
|
||||
HumanMessage(
|
||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||
),
|
||||
]
|
||||
return await self._tool_loop(llm, messages, self.get_tools())
|
||||
0
app/api/__init__.py
Normal file
0
app/api/__init__.py
Normal file
46
app/api/deps.py
Normal file
46
app/api/deps.py
Normal file
@@ -0,0 +1,46 @@
|
||||
"""Shared FastAPI dependencies.
|
||||
|
||||
``get_current_user`` decodes the Bearer JWT and returns a ``UserProfile``.
|
||||
Step 9 will layer rate-limiting and sanitization middleware on top of this.
|
||||
Step 12 will add a DB look-up to fetch the live tier from PostgreSQL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
Raises ``HTTP 401`` on any invalid or expired token.
|
||||
The tier embedded in the JWT is used for feature-gating until Step 12
|
||||
adds a live DB lookup.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
tier: str = payload.get("tier", "free")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||
0
app/api/middleware/__init__.py
Normal file
0
app/api/middleware/__init__.py
Normal file
0
app/api/routes/__init__.py
Normal file
0
app/api/routes/__init__.py
Normal file
118
app/api/routes/auth.py
Normal file
118
app/api/routes/auth.py
Normal file
@@ -0,0 +1,118 @@
|
||||
"""Auth routes: register, login, refresh, me.
|
||||
|
||||
Users and refresh tokens are kept in an in-memory dict until Step 12
|
||||
migrates them to PostgreSQL.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
import bcrypt
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.schemas import AuthTokens, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
|
||||
_users: dict[str, dict[str, Any]] = {} # email → user record
|
||||
_refresh_tokens: dict[str, str] = {} # plain token → user_id
|
||||
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
|
||||
|
||||
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
|
||||
now = int(time.time())
|
||||
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
access_payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": access_exp,
|
||||
"iat": now,
|
||||
}
|
||||
access_token = jwt.encode(
|
||||
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
refresh_token = str(uuid.uuid4())
|
||||
_refresh_tokens[refresh_token] = user_id
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=refresh_token,
|
||||
expires_at=access_exp * 1000, # milliseconds for client
|
||||
)
|
||||
|
||||
|
||||
# ── Request bodies ────────────────────────────────────────────────────
|
||||
|
||||
class _RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class _LoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class _RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||
async def register(body: _RegisterRequest) -> AuthTokens:
|
||||
"""Create a new account and return JWT tokens."""
|
||||
if body.email in _users:
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||
user_id = str(uuid.uuid4())
|
||||
_users[body.email] = {
|
||||
"id": user_id,
|
||||
"email": body.email,
|
||||
"password_hash": _hash_password(body.password),
|
||||
"tier": "free",
|
||||
}
|
||||
return _make_tokens(user_id, body.email, "free")
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokens)
|
||||
async def login(body: _LoginRequest) -> AuthTokens:
|
||||
"""Validate credentials and return JWT tokens."""
|
||||
user = _users.get(body.email)
|
||||
if not user or not _verify_password(body.password, user["password_hash"]):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokens)
|
||||
async def refresh(body: _RefreshRequest) -> AuthTokens:
|
||||
"""Rotate a refresh token and return a new token pair."""
|
||||
user_id = _refresh_tokens.pop(body.refresh_token, None)
|
||||
if user_id is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||
user = next((u for u in _users.values() if u["id"] == user_id), None)
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||
"""Return the profile for the authenticated user."""
|
||||
return current_user
|
||||
158
app/api/routes/backup.py
Normal file
158
app/api/routes/backup.py
Normal file
@@ -0,0 +1,158 @@
|
||||
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
||||
|
||||
Blobs are stored in S3 via BlobStore. Backup metadata is kept in an
|
||||
in-memory dict until Step 12 migrates it to PostgreSQL (backup_metadata table).
|
||||
|
||||
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
||||
treating "history" as a ``{backup_id}`` path parameter.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from email.utils import parsedate_to_datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.schemas import BackupMetadata, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
|
||||
router = APIRouter(prefix="/backup", tags=["backup"])
|
||||
|
||||
_blob_store = BlobStore()
|
||||
|
||||
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
|
||||
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
|
||||
|
||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
||||
_TIER_BACKUP_LIMITS_GB: dict[str, int] = {
|
||||
"free": 0,
|
||||
"pro": 5,
|
||||
"power": 25,
|
||||
"team": -1, # unlimited
|
||||
}
|
||||
|
||||
|
||||
def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None:
|
||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||
limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0)
|
||||
if limit_gb == 0:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail="Backup is not available on the free tier",
|
||||
)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024**3
|
||||
used = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
||||
if used + size_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
|
||||
@router.put("")
|
||||
async def upload_backup(
|
||||
request: Request,
|
||||
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
||||
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
||||
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Upload an E2E-encrypted backup blob.
|
||||
|
||||
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
||||
"""
|
||||
blob = await request.body()
|
||||
reject_if_tampered(blob, x_backup_checksum)
|
||||
_check_backup_quota(current_user.id, current_user.tier, len(blob))
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||
)
|
||||
|
||||
backup_record: dict[str, Any] = {
|
||||
"id": str(x_backup_timestamp),
|
||||
"s3_key": s3_key,
|
||||
"version": x_backup_version,
|
||||
"timestamp": x_backup_timestamp,
|
||||
"checksum": x_backup_checksum,
|
||||
"size_bytes": len(blob),
|
||||
}
|
||||
|
||||
user_backups = _backups.setdefault(current_user.id, [])
|
||||
user_backups.append(backup_record)
|
||||
user_backups.sort(key=lambda b: b["timestamp"], reverse=True)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/history", response_model=list[BackupMetadata])
|
||||
async def backup_history(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> list[BackupMetadata]:
|
||||
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
||||
return [
|
||||
BackupMetadata(
|
||||
version=b["version"],
|
||||
timestamp=b["timestamp"],
|
||||
checksum=b["checksum"],
|
||||
chunk_count=1, # single-chunk uploads for now — TODO(Step12): track real count
|
||||
)
|
||||
for b in _backups.get(current_user.id, [])
|
||||
]
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def download_backup(
|
||||
request: Request,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> Response:
|
||||
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
||||
user_backups = _backups.get(current_user.id, [])
|
||||
if not user_backups:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
||||
|
||||
latest = user_backups[0]
|
||||
|
||||
ims_header = request.headers.get("If-Modified-Since")
|
||||
if ims_header:
|
||||
try:
|
||||
ims_dt = parsedate_to_datetime(ims_header)
|
||||
ims_ms = int(ims_dt.timestamp() * 1000)
|
||||
if latest["timestamp"] <= ims_ms:
|
||||
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
||||
except Exception:
|
||||
pass # malformed header — ignore and serve the blob
|
||||
|
||||
blob = await _blob_store.download(current_user.id, latest["s3_key"])
|
||||
return Response(
|
||||
content=blob,
|
||||
media_type="application/octet-stream",
|
||||
headers={
|
||||
"X-Backup-Version": str(latest["version"]),
|
||||
"X-Backup-Timestamp": str(latest["timestamp"]),
|
||||
"X-Checksum": latest["checksum"],
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{backup_id}", response_model=dict)
|
||||
async def delete_backup(
|
||||
backup_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a specific backup by ID."""
|
||||
user_backups = _backups.get(current_user.id, [])
|
||||
target = next((b for b in user_backups if b["id"] == backup_id), None)
|
||||
if target is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
||||
|
||||
await _blob_store.delete(current_user.id, target["s3_key"])
|
||||
_backups[current_user.id] = [b for b in user_backups if b["id"] != backup_id]
|
||||
|
||||
return {"ok": True}
|
||||
184
app/api/routes/billing.py
Normal file
184
app/api/routes/billing.py
Normal file
@@ -0,0 +1,184 @@
|
||||
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||
|
||||
Subscription records are kept in-memory until Step 12 migrates them to
|
||||
PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when
|
||||
STRIPE_SECRET_KEY is not configured, allowing local development without keys.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import stripe as stripe_lib
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
|
||||
# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12
|
||||
_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record
|
||||
|
||||
_TIER_PRICE_IDS: dict[str, str] = {
|
||||
"pro": "price_pro_monthly", # replace with real Stripe price IDs
|
||||
"power": "price_power_monthly",
|
||||
"team": "price_team_monthly",
|
||||
}
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _stripe_configured() -> bool:
|
||||
return bool(settings.STRIPE_SECRET_KEY)
|
||||
|
||||
|
||||
def _stripe() -> Any:
|
||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||
return stripe_lib
|
||||
|
||||
|
||||
# ── Request bodies ─────────────────────────────────────────────────────
|
||||
|
||||
class _CheckoutRequest(BaseModel):
|
||||
tier: BillingTier
|
||||
|
||||
|
||||
# ── Routes ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/checkout", response_model=dict)
|
||||
async def create_checkout(
|
||||
body: _CheckoutRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, str]:
|
||||
"""Create a Stripe checkout session for a tier upgrade.
|
||||
|
||||
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||
"""
|
||||
if body.tier == "free":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Cannot create a checkout session for the free tier",
|
||||
)
|
||||
|
||||
if _stripe_configured():
|
||||
price_id = _TIER_PRICE_IDS.get(body.tier)
|
||||
if not price_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Unknown tier: {body.tier}",
|
||||
)
|
||||
s = _stripe()
|
||||
session = s.checkout.Session.create(
|
||||
payment_method_types=["card"],
|
||||
mode="subscription",
|
||||
line_items=[{"price": price_id, "quantity": 1}],
|
||||
success_url=(
|
||||
"https://app.adiuva.app/billing/success"
|
||||
"?session_id={CHECKOUT_SESSION_ID}"
|
||||
),
|
||||
cancel_url="https://app.adiuva.app/billing/cancel",
|
||||
metadata={"user_id": current_user.id, "tier": body.tier},
|
||||
)
|
||||
return {"checkout_url": session.url}
|
||||
|
||||
return {"checkout_url": "https://stripe.com/stub-checkout"}
|
||||
|
||||
|
||||
@router.post("/webhook", response_model=dict)
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||
) -> dict[str, bool]:
|
||||
"""Handle Stripe webhook events.
|
||||
|
||||
No JWT auth — authenticated via Stripe signature verification instead.
|
||||
Returns 200 immediately when Stripe is not configured (local dev).
|
||||
"""
|
||||
payload = await request.body()
|
||||
|
||||
if not _stripe_configured():
|
||||
return {"ok": True}
|
||||
|
||||
try:
|
||||
s = _stripe()
|
||||
event = s.Webhook.construct_event(
|
||||
payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET
|
||||
)
|
||||
except stripe_lib.error.SignatureVerificationError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid Stripe signature",
|
||||
)
|
||||
|
||||
event_type: str = event["type"]
|
||||
data: dict[str, Any] = event["data"]["object"]
|
||||
|
||||
if event_type == "checkout.session.completed":
|
||||
user_id = data.get("metadata", {}).get("user_id")
|
||||
tier = data.get("metadata", {}).get("tier", "free")
|
||||
sub_id = data.get("subscription")
|
||||
if user_id:
|
||||
_subscriptions[user_id] = {
|
||||
"tier": tier,
|
||||
"stripe_subscription_id": sub_id,
|
||||
"status": "active",
|
||||
"current_period_end": None,
|
||||
}
|
||||
|
||||
elif event_type == "customer.subscription.updated":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier
|
||||
pass
|
||||
|
||||
elif event_type == "customer.subscription.deleted":
|
||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
||||
pass
|
||||
|
||||
elif event_type == "invoice.payment_failed":
|
||||
# TODO(Step12): flag subscription as past_due, notify user
|
||||
pass
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/subscription", response_model=dict)
|
||||
async def get_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Return the current subscription info for the authenticated user."""
|
||||
sub = _subscriptions.get(current_user.id)
|
||||
if sub is None:
|
||||
return {
|
||||
"tier": current_user.tier,
|
||||
"status": "free",
|
||||
"stripe_subscription_id": None,
|
||||
"current_period_end": None,
|
||||
}
|
||||
return sub
|
||||
|
||||
|
||||
@router.delete("/subscription", response_model=dict)
|
||||
async def cancel_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Cancel the active subscription."""
|
||||
sub = _subscriptions.get(current_user.id)
|
||||
if sub is None or not sub.get("stripe_subscription_id"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="No active subscription found",
|
||||
)
|
||||
|
||||
if _stripe_configured():
|
||||
s = _stripe()
|
||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
||||
|
||||
_subscriptions[current_user.id] = {
|
||||
**sub,
|
||||
"tier": "free",
|
||||
"status": "canceled",
|
||||
}
|
||||
|
||||
return {"ok": True}
|
||||
78
app/api/routes/chat.py
Normal file
78
app/api/routes/chat.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Chat routes: POST /chat and WebSocket /chat/stream."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
|
||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||
from fastapi.responses import JSONResponse
|
||||
from jose import JWTError, jwt
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.core.orchestrator import orchestrate, orchestrate_stream
|
||||
from app.schemas import ChatRequest, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def chat(
|
||||
body: ChatRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> JSONResponse:
|
||||
"""Route a chat message through the orchestrator.
|
||||
|
||||
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
||||
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
||||
"""
|
||||
result = await orchestrate(body)
|
||||
return JSONResponse(content=result.model_dump())
|
||||
|
||||
|
||||
@router.websocket("/stream")
|
||||
async def chat_stream(websocket: WebSocket) -> None:
|
||||
"""Streaming chat via WebSocket.
|
||||
|
||||
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
|
||||
|
||||
Protocol:
|
||||
1. Client sends ``ChatRequest`` as the first JSON text frame.
|
||||
2. Server streams response text chunks.
|
||||
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
|
||||
4. Server pings every 30 s to keep the connection alive.
|
||||
"""
|
||||
# Authenticate before accepting the connection
|
||||
token = websocket.query_params.get("token", "")
|
||||
try:
|
||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
||||
user_id: str | None = payload.get("sub")
|
||||
if not user_id:
|
||||
raise JWTError("missing sub")
|
||||
except JWTError:
|
||||
await websocket.close(code=1008) # 1008 = Policy Violation
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
try:
|
||||
raw = await websocket.receive_text()
|
||||
body = ChatRequest.model_validate_json(raw)
|
||||
|
||||
async def _heartbeat() -> None:
|
||||
while True:
|
||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||
await websocket.send_text(json.dumps({"ping": True}))
|
||||
|
||||
heartbeat_task = asyncio.create_task(_heartbeat())
|
||||
try:
|
||||
async for chunk in orchestrate_stream(body):
|
||||
await websocket.send_text(chunk)
|
||||
finally:
|
||||
heartbeat_task.cancel()
|
||||
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
37
app/api/routes/plans.py
Normal file
37
app/api/routes/plans.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.execution_plan import plan_cache
|
||||
from app.schemas import ExecutionPlan, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/plans", tags=["plans"])
|
||||
|
||||
|
||||
@router.get("/playbook", response_model=list[ExecutionPlan])
|
||||
async def list_playbooks(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> list[ExecutionPlan]:
|
||||
"""Return all cached execution plan playbooks for the authenticated user.
|
||||
|
||||
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
||||
"""
|
||||
return plan_cache.get_all_playbooks()
|
||||
|
||||
|
||||
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
||||
async def get_playbook(
|
||||
plan_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> ExecutionPlan:
|
||||
"""Return a specific execution plan playbook by ID."""
|
||||
plan = plan_cache.get_plan(plan_id)
|
||||
if plan is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Plan not found: {plan_id}",
|
||||
)
|
||||
return plan
|
||||
174
app/api/routes/plugins.py
Normal file
174
app/api/routes/plugins.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Plugins routes: browse and install plugins from the marketplace.
|
||||
|
||||
The catalog and installation records are kept in-memory as stubs.
|
||||
Step 10 replaces these with PluginRegistry, RevenueShare, and the plugins DB table.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.config.settings import settings
|
||||
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||
|
||||
# ── In-memory catalog (Step 10 replaces with PluginRegistry + DB) ─────
|
||||
|
||||
_plugin_catalog: list[PluginManifest] = [
|
||||
PluginManifest(
|
||||
id="plugin-github-sync",
|
||||
name="GitHub Sync",
|
||||
description="Sync tasks with GitHub Issues and pull requests.",
|
||||
version="1.0.0",
|
||||
author="Adiuva",
|
||||
permissions=["read:tasks", "write:tasks"],
|
||||
category="productivity",
|
||||
price_cents=0,
|
||||
),
|
||||
PluginManifest(
|
||||
id="plugin-slack-notify",
|
||||
name="Slack Notifier",
|
||||
description="Post task and checkpoint updates to Slack channels.",
|
||||
version="1.2.0",
|
||||
author="Adiuva",
|
||||
permissions=["read:tasks", "read:checkpoints"],
|
||||
category="communication",
|
||||
price_cents=499,
|
||||
),
|
||||
PluginManifest(
|
||||
id="plugin-time-tracker",
|
||||
name="Time Tracker",
|
||||
description="Track time spent on tasks with automatic reporting.",
|
||||
version="0.9.1",
|
||||
author="Third Party",
|
||||
permissions=["read:tasks", "write:tasks"],
|
||||
category="productivity",
|
||||
price_cents=999,
|
||||
),
|
||||
]
|
||||
|
||||
# plugin_id → set of user_ids who have installed it
|
||||
_installations: dict[str, set[str]] = {}
|
||||
|
||||
|
||||
# ── Tier gate ─────────────────────────────────────────────────────────
|
||||
|
||||
def _require_plugin_tier(user: UserProfile) -> None:
|
||||
"""Raise HTTP 403 for users below Power tier."""
|
||||
if user.tier not in ("power", "team"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="Plugin marketplace requires Power tier or above",
|
||||
)
|
||||
|
||||
|
||||
# ── Filter + sort helpers ──────────────────────────────────────────────
|
||||
|
||||
def _apply_filters(
|
||||
plugins: list[PluginManifest],
|
||||
category: str | None,
|
||||
q: str | None,
|
||||
) -> list[PluginManifest]:
|
||||
result = plugins
|
||||
if category:
|
||||
result = [p for p in result if p.category == category]
|
||||
if q:
|
||||
q_lower = q.lower()
|
||||
result = [
|
||||
p for p in result
|
||||
if q_lower in p.name.lower() or q_lower in p.description.lower()
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def _apply_sort(
|
||||
plugins: list[PluginManifest],
|
||||
sort: str,
|
||||
) -> list[PluginManifest]:
|
||||
if sort == "installs":
|
||||
return sorted(plugins, key=lambda p: len(_installations.get(p.id, set())), reverse=True)
|
||||
if sort == "rating":
|
||||
# Placeholder until Step 10 introduces avg_rating from DB
|
||||
return sorted(plugins, key=lambda p: -p.price_cents)
|
||||
return plugins # "newest" = catalog insertion order
|
||||
|
||||
|
||||
# ── Local detail schema ────────────────────────────────────────────────
|
||||
|
||||
class _PluginDetail(BaseModel):
|
||||
plugin: PluginManifest
|
||||
install_count: int
|
||||
ratings: list[Any] # Step 10 populates from plugin_reviews table
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("", response_model=PluginListResponse)
|
||||
async def list_plugins(
|
||||
category: str | None = Query(default=None),
|
||||
q: str | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> PluginListResponse:
|
||||
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||
_require_plugin_tier(current_user)
|
||||
filtered = _apply_filters(_plugin_catalog, category, q)
|
||||
sorted_plugins = _apply_sort(filtered, sort)
|
||||
return PluginListResponse(plugins=sorted_plugins, total=len(sorted_plugins), page=page)
|
||||
|
||||
|
||||
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||
async def get_plugin(
|
||||
plugin_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _PluginDetail:
|
||||
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||
_require_plugin_tier(current_user)
|
||||
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
|
||||
if plugin is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||
return _PluginDetail(
|
||||
plugin=plugin,
|
||||
install_count=len(_installations.get(plugin_id, set())),
|
||||
ratings=[], # Step 10 populates from plugin_reviews table
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{plugin_id}/install", response_model=dict)
|
||||
async def install_plugin(
|
||||
plugin_id: str,
|
||||
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, Any]:
|
||||
"""Install a plugin. Triggers Stripe Connect for paid plugins when configured.
|
||||
|
||||
Requires Power tier or above.
|
||||
"""
|
||||
_require_plugin_tier(current_user)
|
||||
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
|
||||
if plugin is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||
|
||||
if plugin.price_cents > 0 and settings.STRIPE_SECRET_KEY:
|
||||
# TODO(Step10): stripe.PaymentIntent.create with destination charge (70/30 split)
|
||||
pass
|
||||
|
||||
_installations.setdefault(plugin_id, set()).add(current_user.id)
|
||||
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
||||
return {"ok": True, "download_url": download_url}
|
||||
|
||||
|
||||
@router.delete("/{plugin_id}/install", response_model=dict)
|
||||
async def uninstall_plugin(
|
||||
plugin_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Unregister a plugin installation."""
|
||||
_installations.get(plugin_id, set()).discard(current_user.id)
|
||||
return {"ok": True}
|
||||
185
app/api/routes/storage.py
Normal file
185
app/api/routes/storage.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
||||
|
||||
Blobs are stored in S3 via BlobStore. Record metadata is kept in an
|
||||
in-memory dict until Step 12 migrates it to PostgreSQL (storage_records table).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
|
||||
router = APIRouter(prefix="/storage", tags=["storage"])
|
||||
|
||||
_blob_store = BlobStore()
|
||||
|
||||
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
|
||||
_records: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
||||
_TIER_STORAGE_LIMITS_GB: dict[str, int] = {
|
||||
"free": 0,
|
||||
"pro": 5,
|
||||
"power": 25,
|
||||
"team": -1, # unlimited
|
||||
}
|
||||
|
||||
|
||||
# ── Local response schemas ─────────────────────────────────────────────
|
||||
|
||||
class _CreateResponse(BaseModel):
|
||||
id: str
|
||||
created_at: int
|
||||
|
||||
|
||||
class _RecordMeta(BaseModel):
|
||||
id: str
|
||||
table: str
|
||||
checksum: str
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
# ── Helpers ────────────────────────────────────────────────────────────
|
||||
|
||||
def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None:
|
||||
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
|
||||
limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0)
|
||||
if limit_gb == -1:
|
||||
return # unlimited
|
||||
limit_bytes = limit_gb * 1024**3
|
||||
used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
||||
if used + additional_bytes > limit_bytes:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||
)
|
||||
|
||||
|
||||
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
||||
"""Look up a record and verify ownership. Always returns 404 on mismatch
|
||||
to prevent user enumeration attacks."""
|
||||
record = _records.get(record_id)
|
||||
if record is None or record["user_id"] != user_id:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
||||
return record
|
||||
|
||||
|
||||
# ── Routes ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_record(
|
||||
body: StorageRecordCreate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _CreateResponse:
|
||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||
reject_if_tampered(body.blob, body.checksum)
|
||||
_check_quota(current_user.id, current_user.tier, len(body.blob))
|
||||
|
||||
record_id = str(uuid.uuid4())
|
||||
now = int(time.time() * 1000)
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, body.table, record_id, body.blob, body.checksum
|
||||
)
|
||||
|
||||
_records[record_id] = {
|
||||
"id": record_id,
|
||||
"user_id": current_user.id,
|
||||
"table": body.table,
|
||||
"s3_key": s3_key,
|
||||
"checksum": body.checksum,
|
||||
"size_bytes": len(body.blob),
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
}
|
||||
|
||||
return _CreateResponse(id=record_id, created_at=now)
|
||||
|
||||
|
||||
@router.get("/records", response_model=list[_RecordMeta])
|
||||
async def list_records(
|
||||
table: str | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
limit: int = Query(default=50, ge=1, le=200),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> list[_RecordMeta]:
|
||||
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
||||
all_records = [
|
||||
r for r in _records.values()
|
||||
if r["user_id"] == current_user.id and (table is None or r["table"] == table)
|
||||
]
|
||||
start = (page - 1) * limit
|
||||
page_records = all_records[start : start + limit]
|
||||
return [
|
||||
_RecordMeta(
|
||||
id=r["id"],
|
||||
table=r["table"],
|
||||
checksum=r["checksum"],
|
||||
created_at=r["created_at"],
|
||||
updated_at=r["updated_at"],
|
||||
)
|
||||
for r in page_records
|
||||
]
|
||||
|
||||
|
||||
@router.get("/records/{record_id}")
|
||||
async def download_record(
|
||||
record_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> Response:
|
||||
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
||||
record = _get_record_for_user(record_id, current_user.id)
|
||||
blob = await _blob_store.download(current_user.id, record["s3_key"])
|
||||
return Response(
|
||||
content=blob,
|
||||
media_type="application/octet-stream",
|
||||
headers={"X-Checksum": record["checksum"]},
|
||||
)
|
||||
|
||||
|
||||
@router.put("/records/{record_id}", response_model=dict)
|
||||
async def update_record(
|
||||
record_id: str,
|
||||
body: StorageRecordUpdate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
||||
record = _get_record_for_user(record_id, current_user.id)
|
||||
reject_if_tampered(body.blob, body.checksum)
|
||||
|
||||
delta = len(body.blob) - record["size_bytes"]
|
||||
if delta > 0:
|
||||
_check_quota(current_user.id, current_user.tier, delta)
|
||||
|
||||
s3_key = await _blob_store.upload(
|
||||
current_user.id, record["table"], record_id, body.blob, body.checksum
|
||||
)
|
||||
|
||||
record["s3_key"] = s3_key
|
||||
record["checksum"] = body.checksum
|
||||
record["size_bytes"] = len(body.blob)
|
||||
record["updated_at"] = int(time.time() * 1000)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.delete("/records/{record_id}", response_model=dict)
|
||||
async def delete_record(
|
||||
record_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a record and its S3 blob."""
|
||||
record = _get_record_for_user(record_id, current_user.id)
|
||||
await _blob_store.delete(current_user.id, record["s3_key"])
|
||||
del _records[record_id]
|
||||
return {"ok": True}
|
||||
56
app/api/routes/vectors.py
Normal file
56
app/api/routes/vectors.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Vectors routes: upsert, search, and delete cloud vector store entries."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.schemas import (
|
||||
UserProfile,
|
||||
VectorSearchRequest,
|
||||
VectorSearchResponse,
|
||||
VectorUpsertRequest,
|
||||
)
|
||||
from app.storage.encryption import reject_if_tampered
|
||||
from app.storage.vector_store import VectorStore
|
||||
|
||||
router = APIRouter(prefix="/storage", tags=["vectors"])
|
||||
|
||||
_vector_store = VectorStore()
|
||||
|
||||
|
||||
class _VectorDeleteRequest(BaseModel):
|
||||
ids: list[str]
|
||||
|
||||
|
||||
@router.post("/vectors/upsert", response_model=dict)
|
||||
async def upsert_vectors(
|
||||
body: VectorUpsertRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, int]:
|
||||
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
||||
for item in body.vectors:
|
||||
reject_if_tampered(item.blob, item.checksum)
|
||||
await _vector_store.upsert(current_user.id, body.vectors)
|
||||
return {"upserted": len(body.vectors)}
|
||||
|
||||
|
||||
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
||||
async def search_vectors(
|
||||
body: VectorSearchRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> VectorSearchResponse:
|
||||
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
||||
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
||||
return VectorSearchResponse(results=results)
|
||||
|
||||
|
||||
@router.delete("/vectors", response_model=dict)
|
||||
async def delete_vectors(
|
||||
body: _VectorDeleteRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete vectors by ID, scoped to the authenticated user."""
|
||||
await _vector_store.delete(current_user.id, body.ids)
|
||||
return {"ok": True}
|
||||
0
app/billing/__init__.py
Normal file
0
app/billing/__init__.py
Normal file
0
app/config/__init__.py
Normal file
0
app/config/__init__.py
Normal file
36
app/config/settings.py
Normal file
36
app/config/settings.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from typing import Literal
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
||||
JWT_SECRET: str = "change-me-in-production"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||
|
||||
STRIPE_SECRET_KEY: str = ""
|
||||
STRIPE_WEBHOOK_SECRET: str = ""
|
||||
|
||||
S3_BUCKET: str = ""
|
||||
S3_REGION: str = "us-east-1"
|
||||
AWS_ACCESS_KEY_ID: str = ""
|
||||
AWS_SECRET_ACCESS_KEY: str = ""
|
||||
|
||||
PINECONE_API_KEY: str = ""
|
||||
PINECONE_INDEX: str = "adiuva"
|
||||
QDRANT_URL: str = ""
|
||||
QDRANT_API_KEY: str = ""
|
||||
|
||||
OPENAI_API_KEY: str = ""
|
||||
|
||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||
|
||||
ENV: Literal["dev", "prod"] = "dev"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
0
app/core/__init__.py
Normal file
0
app/core/__init__.py
Normal file
137
app/core/agent_registry.py
Normal file
137
app/core/agent_registry.py
Normal file
@@ -0,0 +1,137 @@
|
||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Common base for all agents."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str = "",
|
||||
shared_memory: dict[str, Any] | None = None,
|
||||
vector_store_context: list[str] | None = None,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||
self.vector_store_context: list[str] = vector_store_context or []
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def skills(self) -> list[str]:
|
||||
"""Override in subclasses to advertise capabilities."""
|
||||
return []
|
||||
|
||||
|
||||
class ChatAgent(BaseAgent):
|
||||
"""Base class for LLM-powered chat agents."""
|
||||
|
||||
@abstractmethod
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
"""Process a user query and return a text response."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
def get_tools(self) -> list[Any]:
|
||||
"""Return LangChain tool definitions available to this agent."""
|
||||
...
|
||||
|
||||
async def _tool_loop(
|
||||
self,
|
||||
llm: Any,
|
||||
messages: list[Any],
|
||||
tools: list[Any],
|
||||
max_iter: int = 5,
|
||||
) -> str:
|
||||
"""Shared tool-calling loop.
|
||||
|
||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
||||
requesting tool calls or *max_iter* is reached, and returns the
|
||||
final text response.
|
||||
"""
|
||||
from langchain_core.messages import AIMessage, ToolMessage
|
||||
|
||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
||||
|
||||
for _ in range(max_iter):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
return str(response.content)
|
||||
|
||||
# Execute each requested tool call
|
||||
tool_map = {t.name: t for t in tools}
|
||||
for call in response.tool_calls:
|
||||
tool_fn = tool_map.get(call["name"])
|
||||
if tool_fn is None:
|
||||
result = f"Unknown tool: {call['name']}"
|
||||
else:
|
||||
result = await tool_fn.ainvoke(call["args"])
|
||||
messages.append(
|
||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
||||
)
|
||||
|
||||
# Exhausted iterations — ask model for a final answer without tools
|
||||
response = await llm.ainvoke(messages)
|
||||
return str(response.content)
|
||||
|
||||
|
||||
class AgentRegistry:
|
||||
"""Singleton registry for ChatAgent subclasses."""
|
||||
|
||||
_instance: AgentRegistry | None = None
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._agents: dict[str, type[ChatAgent]] = {}
|
||||
|
||||
def __new__(cls) -> AgentRegistry:
|
||||
if cls._instance is None:
|
||||
cls._instance = super().__new__(cls)
|
||||
cls._instance._agents = {}
|
||||
return cls._instance
|
||||
|
||||
# ── public API ───────────────────────────────────────────────────
|
||||
|
||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
||||
"""Class decorator — registers an agent by its name."""
|
||||
instance = agent_class()
|
||||
name = instance.get_name()
|
||||
self._agents[name] = agent_class
|
||||
return agent_class
|
||||
|
||||
def get(self, name: str) -> ChatAgent:
|
||||
"""Return a fresh instance of the named agent."""
|
||||
cls = self._agents.get(name)
|
||||
if cls is None:
|
||||
raise KeyError(f"Agent not found: {name}")
|
||||
return cls()
|
||||
|
||||
def list_agents(self) -> list[dict[str, str]]:
|
||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
||||
result: list[dict[str, str]] = []
|
||||
for cls in self._agents.values():
|
||||
inst = cls()
|
||||
result.append(
|
||||
{"name": inst.get_name(), "description": inst.get_description()}
|
||||
)
|
||||
return result
|
||||
|
||||
async def call_agent(
|
||||
self, name: str, query: str, context: dict[str, Any]
|
||||
) -> str:
|
||||
"""Instantiate the named agent and call its ``handle`` method."""
|
||||
agent = self.get(name)
|
||||
return await agent.handle(query, context)
|
||||
|
||||
|
||||
# Module-level singleton
|
||||
registry = AgentRegistry()
|
||||
222
app/core/execution_plan.py
Normal file
222
app/core/execution_plan.py
Normal file
@@ -0,0 +1,222 @@
|
||||
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import OrderedDict
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import ExecutionPlan, PlanStep
|
||||
|
||||
|
||||
# ── Prompt Template Registry ──────────────────────────────────────────
|
||||
|
||||
|
||||
class PromptTemplateRegistry:
|
||||
"""Server-side store mapping template IDs to prompt text.
|
||||
|
||||
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
||||
The actual prompt text is resolved here on the server, keeping prompt IP
|
||||
out of API responses.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._templates: dict[str, str] = {}
|
||||
|
||||
def register(self, template_id: str, prompt_text: str) -> None:
|
||||
self._templates[template_id] = prompt_text
|
||||
|
||||
def get(self, template_id: str) -> str:
|
||||
"""Resolve a template ID to its prompt text.
|
||||
|
||||
Raises ``KeyError`` if the template is not registered.
|
||||
"""
|
||||
text = self._templates.get(template_id)
|
||||
if text is None:
|
||||
raise KeyError(f"Template not found: {template_id!r}")
|
||||
return text
|
||||
|
||||
def has(self, template_id: str) -> bool:
|
||||
return template_id in self._templates
|
||||
|
||||
def list_ids(self) -> list[str]:
|
||||
"""Return all registered template IDs (never the text)."""
|
||||
return list(self._templates.keys())
|
||||
|
||||
|
||||
# ── Execution Plan Builder ────────────────────────────────────────────
|
||||
|
||||
|
||||
class ExecutionPlanBuilder:
|
||||
"""Fluent builder for ``ExecutionPlan`` objects.
|
||||
|
||||
Example::
|
||||
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
||||
.add_data_step("create_record", data_from_step=0)
|
||||
.build()
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(self, agent: str) -> None:
|
||||
self._agent = agent
|
||||
self._steps: list[PlanStep] = []
|
||||
|
||||
# ── step adders ──────────────────────────────────────────────────
|
||||
|
||||
def add_step(
|
||||
self, action: str, params: dict[str, Any] | None = None
|
||||
) -> ExecutionPlanBuilder:
|
||||
"""Append a generic action step with optional parameters."""
|
||||
self._steps.append(PlanStep(action=action, variables=params))
|
||||
return self
|
||||
|
||||
def add_llm_step(
|
||||
self, template_id: str, variables: dict[str, Any] | None = None
|
||||
) -> ExecutionPlanBuilder:
|
||||
"""Append an LLM step referencing a server-side template by ID."""
|
||||
self._steps.append(
|
||||
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
||||
)
|
||||
return self
|
||||
|
||||
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
||||
"""Append a step whose input comes from the output of an earlier step."""
|
||||
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
||||
return self
|
||||
|
||||
# ── build ────────────────────────────────────────────────────────
|
||||
|
||||
def build(self) -> ExecutionPlan:
|
||||
"""Validate step references and return the ``ExecutionPlan``.
|
||||
|
||||
Raises ``ValueError`` if any ``data_from_step`` references a
|
||||
non-existent or future step index.
|
||||
"""
|
||||
for i, step in enumerate(self._steps):
|
||||
if step.data_from_step is not None:
|
||||
if not (0 <= step.data_from_step < i):
|
||||
raise ValueError(
|
||||
f"Step {i}: data_from_step={step.data_from_step} must "
|
||||
f"reference a preceding step index in range 0..{i - 1}"
|
||||
)
|
||||
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
||||
|
||||
|
||||
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class PlanCache:
|
||||
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
||||
|
||||
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
||||
The cache also serves as a runtime memoisation layer so that repeated
|
||||
identical intent classifications can skip re-building the plan.
|
||||
"""
|
||||
|
||||
def __init__(self, maxsize: int = 1000) -> None:
|
||||
self._maxsize = maxsize
|
||||
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
||||
|
||||
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
||||
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
||||
if key in self._cache:
|
||||
del self._cache[key] # remove so re-insertion places it at the end
|
||||
elif len(self._cache) >= self._maxsize:
|
||||
self._cache.popitem(last=False) # evict least-recently-used
|
||||
self._cache[key] = plan
|
||||
|
||||
def get_plan(self, key: str) -> ExecutionPlan | None:
|
||||
"""Return the cached plan for *key*, or ``None`` if not present.
|
||||
|
||||
Accessing a plan marks it as most-recently used.
|
||||
"""
|
||||
if key not in self._cache:
|
||||
return None
|
||||
self._cache.move_to_end(key)
|
||||
return self._cache[key]
|
||||
|
||||
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
||||
"""Return all cached plans (most-recently used last)."""
|
||||
return list(self._cache.values())
|
||||
|
||||
|
||||
# ── Module-level singletons ───────────────────────────────────────────
|
||||
|
||||
template_registry = PromptTemplateRegistry()
|
||||
plan_cache = PlanCache()
|
||||
|
||||
|
||||
def _register_builtin_templates() -> None:
|
||||
"""Register the built-in server-side prompt templates.
|
||||
|
||||
These strings never leave the server. Clients only receive the IDs.
|
||||
"""
|
||||
_tpls: dict[str, str] = {
|
||||
"tpl_task_agent_default": (
|
||||
"You are a task management assistant. Help the user create, update, "
|
||||
"list, and track tasks. Use correct status values (todo, in_progress, "
|
||||
"done) and priority values (high, medium, low) from the workspace model."
|
||||
),
|
||||
"tpl_checkpoint_agent_default": (
|
||||
"You are a project checkpoint assistant. Help the user create and manage "
|
||||
"milestone checkpoints on their projects. Every checkpoint requires a "
|
||||
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
||||
),
|
||||
"tpl_project_agent_default": (
|
||||
"You are a project management assistant. Help the user create, find, "
|
||||
"update, and archive projects. Projects have a name, an optional client, "
|
||||
"and a status of either active or archived."
|
||||
),
|
||||
"tpl_note_agent_default": (
|
||||
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
||||
"and delete Markdown notes. Notes can optionally be linked to a project."
|
||||
),
|
||||
"tpl_task_extract_from_project": (
|
||||
"Extract all actionable tasks from the provided project context. "
|
||||
"Return a structured list of tasks, each with a title, inferred priority "
|
||||
"(high, medium, or low), suggested status (todo), and a due_date in "
|
||||
"milliseconds where a deadline can be inferred."
|
||||
),
|
||||
"tpl_note_weekly_summary": (
|
||||
"Generate a weekly project summary note from the provided workspace data. "
|
||||
"Include: tasks completed this week, tasks due soon, active projects, "
|
||||
"and upcoming checkpoints. Format the output as clean Markdown."
|
||||
),
|
||||
}
|
||||
for tid, text in _tpls.items():
|
||||
template_registry.register(tid, text)
|
||||
|
||||
|
||||
def _load_playbooks() -> None:
|
||||
"""Pre-build and cache the built-in playbooks."""
|
||||
playbooks: list[tuple[str, ExecutionPlan]] = [
|
||||
(
|
||||
"create_tasks_from_project",
|
||||
ExecutionPlanBuilder("project_agent")
|
||||
.add_llm_step(
|
||||
"tpl_task_extract_from_project",
|
||||
{"source": "project_context"},
|
||||
)
|
||||
.add_data_step("create_record", data_from_step=0)
|
||||
.build(),
|
||||
),
|
||||
(
|
||||
"generate_weekly_note",
|
||||
ExecutionPlanBuilder("note_agent")
|
||||
.add_llm_step(
|
||||
"tpl_note_weekly_summary",
|
||||
{"period": "last_7_days"},
|
||||
)
|
||||
.add_data_step("create_record", data_from_step=0)
|
||||
.build(),
|
||||
),
|
||||
]
|
||||
for key, plan in playbooks:
|
||||
plan_cache.cache_plan(key, plan)
|
||||
|
||||
|
||||
# Initialise on module load
|
||||
_register_builtin_templates()
|
||||
_load_playbooks()
|
||||
169
app/core/orchestrator.py
Normal file
169
app/core/orchestrator.py
Normal file
@@ -0,0 +1,169 @@
|
||||
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any, AsyncGenerator
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from langchain_openai import ChatOpenAI
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_registry import AgentRegistry
|
||||
from app.core.agent_registry import registry as _default_registry
|
||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
||||
|
||||
_FALLBACK_AGENT = "task_agent"
|
||||
|
||||
_CLASSIFY_SYSTEM = (
|
||||
"You are an intent classifier. Given the user message and context, decide "
|
||||
"which agent to route to.\n"
|
||||
"Available agents: {agents}\n"
|
||||
"Respond with just the agent name, nothing else."
|
||||
)
|
||||
|
||||
_SYNTHESIZE_HUMAN = (
|
||||
"Combine the following agent results into one coherent response.\n\n"
|
||||
"Agent results:\n{results}\n\n"
|
||||
"Original message: {message}"
|
||||
)
|
||||
|
||||
|
||||
def _make_llm(model: str = "gpt-4o-mini") -> ChatOpenAI:
|
||||
return ChatOpenAI(model=model, temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||
|
||||
|
||||
async def classify_intent(
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry,
|
||||
) -> str:
|
||||
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
||||
|
||||
Falls back to ``task_agent`` when the registry is empty or the model
|
||||
returns a name that is not registered.
|
||||
"""
|
||||
agents = reg.list_agents()
|
||||
if not agents:
|
||||
return _FALLBACK_AGENT
|
||||
|
||||
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
||||
# Truncate context to keep the classification prompt short
|
||||
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
||||
|
||||
llm = _make_llm()
|
||||
response = await llm.ainvoke(
|
||||
[SystemMessage(content=system), HumanMessage(content=human)]
|
||||
)
|
||||
|
||||
agent_name = str(response.content).strip().lower()
|
||||
known = {a["name"] for a in agents}
|
||||
return agent_name if agent_name in known else _FALLBACK_AGENT
|
||||
|
||||
|
||||
async def route_single(
|
||||
agent_name: str,
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry,
|
||||
) -> ChatResponse:
|
||||
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
||||
response_text = await reg.call_agent(agent_name, message, context)
|
||||
return ChatResponse(response=response_text)
|
||||
|
||||
|
||||
async def route_pipeline(
|
||||
agent_names: list[str],
|
||||
message: str,
|
||||
context: dict[str, Any],
|
||||
reg: AgentRegistry,
|
||||
) -> ChatResponse:
|
||||
"""Execute agents sequentially; each agent receives previous results in context.
|
||||
|
||||
A final LLM synthesis call merges all results into one coherent response.
|
||||
"""
|
||||
previous_results: list[str] = []
|
||||
|
||||
for agent_name in agent_names:
|
||||
ctx = {**context, "previous_results": list(previous_results)}
|
||||
result = await reg.call_agent(agent_name, message, ctx)
|
||||
previous_results.append(result)
|
||||
|
||||
results_str = "\n\n".join(
|
||||
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
||||
)
|
||||
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
||||
llm = _make_llm()
|
||||
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
||||
return ChatResponse(response=str(synthesis.content))
|
||||
|
||||
|
||||
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
||||
"""Build an ``ExecutionPlan`` for the resolved agent.
|
||||
|
||||
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
||||
If a default template exists for the agent, an LLM step is emitted;
|
||||
otherwise a plain ``handle`` action step is used.
|
||||
"""
|
||||
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
||||
|
||||
template_id = f"tpl_{agent_name}_default"
|
||||
builder = ExecutionPlanBuilder(agent_name)
|
||||
if template_registry.has(template_id):
|
||||
builder.add_llm_step(template_id, {"message": message})
|
||||
else:
|
||||
builder.add_step("handle", {"message": message})
|
||||
return builder.build()
|
||||
|
||||
|
||||
async def orchestrate(
|
||||
request: ChatRequest,
|
||||
reg: AgentRegistry | None = None,
|
||||
) -> ChatResponse | ExecutionPlan:
|
||||
"""Main orchestration entry point.
|
||||
|
||||
* Classifies the user's intent to select an agent.
|
||||
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
||||
``ChatResponse``.
|
||||
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
||||
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
|
||||
context = request.context.model_dump()
|
||||
agent_name = await classify_intent(request.message, context, reg)
|
||||
|
||||
if request.execution_mode == "direct":
|
||||
return await route_single(agent_name, request.message, context, reg)
|
||||
|
||||
# plan mode — return plan, do not execute
|
||||
return _build_plan(agent_name, request.message)
|
||||
|
||||
|
||||
async def orchestrate_stream(
|
||||
request: ChatRequest,
|
||||
reg: AgentRegistry | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""Streaming orchestration — yields text chunks then a final JSON frame.
|
||||
|
||||
The final frame is a JSON object:
|
||||
``{"done": true, "response": "...", "actions": []}``.
|
||||
|
||||
Agents do not yet support token-level streaming; the full response is
|
||||
fetched first, then emitted in fixed-size chunks. Token-level streaming
|
||||
will be wired in Step 6 when agents expose ``astream()``.
|
||||
"""
|
||||
if reg is None:
|
||||
reg = _default_registry
|
||||
|
||||
context = request.context.model_dump()
|
||||
agent_name = await classify_intent(request.message, context, reg)
|
||||
response_text = await reg.call_agent(agent_name, request.message, context)
|
||||
|
||||
chunk_size = 50
|
||||
for i in range(0, len(response_text), chunk_size):
|
||||
yield response_text[i : i + chunk_size]
|
||||
|
||||
final = ChatResponse(response=response_text)
|
||||
yield json.dumps({"done": True, **final.model_dump()})
|
||||
55
app/main.py
Normal file
55
app/main.py
Normal file
@@ -0,0 +1,55 @@
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup: initialise DB connection pool and agent registry
|
||||
from app.core.agent_registry import registry # noqa: F401 — triggers module load
|
||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown: nothing to clean up for now
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="Adiuva Cloud API",
|
||||
version="0.1.0",
|
||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(plans.router, prefix="/api/v1")
|
||||
app.include_router(storage.router, prefix="/api/v1")
|
||||
app.include_router(vectors.router, prefix="/api/v1")
|
||||
app.include_router(backup.router, prefix="/api/v1")
|
||||
app.include_router(plugins.router, prefix="/api/v1")
|
||||
app.include_router(billing.router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "version": app.version}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
157
app/schemas.py
Normal file
157
app/schemas.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""Pydantic schemas — API request/response contracts.
|
||||
|
||||
Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
# ── Billing ──────────────────────────────────────────────────────────
|
||||
|
||||
BillingTier = Literal["free", "pro", "power", "team"]
|
||||
|
||||
|
||||
# ── Auth ─────────────────────────────────────────────────────────────
|
||||
|
||||
class AuthTokens(BaseModel):
|
||||
access_token: str
|
||||
refresh_token: str
|
||||
expires_at: int
|
||||
|
||||
|
||||
class UserProfile(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
tier: BillingTier
|
||||
|
||||
|
||||
# ── Chat ─────────────────────────────────────────────────────────────
|
||||
|
||||
class ChatContext(BaseModel):
|
||||
user_profile: dict[str, Any] = Field(default_factory=dict)
|
||||
relevant_documents: list[str] = Field(default_factory=list)
|
||||
recent_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PlanAction(BaseModel):
|
||||
type: Literal[
|
||||
"create_record",
|
||||
"update_record",
|
||||
"delete_record",
|
||||
"index_document",
|
||||
"send_notification",
|
||||
]
|
||||
table: str | None = None
|
||||
data: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
message: str
|
||||
context: ChatContext = Field(default_factory=ChatContext)
|
||||
execution_mode: Literal["direct", "plan"] = "direct"
|
||||
|
||||
|
||||
class ChatResponse(BaseModel):
|
||||
response: str
|
||||
actions: list[PlanAction] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Execution Plans ──────────────────────────────────────────────────
|
||||
|
||||
class PlanStep(BaseModel):
|
||||
action: str
|
||||
prompt_template: str | None = None
|
||||
variables: dict[str, Any] | None = None
|
||||
data_from_step: int | None = None
|
||||
|
||||
|
||||
class ExecutionPlan(BaseModel):
|
||||
agent: str
|
||||
steps: list[PlanStep] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Backup ───────────────────────────────────────────────────────────
|
||||
|
||||
class BackupMetadata(BaseModel):
|
||||
version: int
|
||||
timestamp: int
|
||||
checksum: str
|
||||
chunk_count: int
|
||||
|
||||
|
||||
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
||||
|
||||
class StorageRecord(BaseModel):
|
||||
id: str
|
||||
user_id: str
|
||||
table: str
|
||||
blob: bytes
|
||||
checksum: str
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
class StorageRecordCreate(BaseModel):
|
||||
table: str
|
||||
blob: bytes
|
||||
checksum: str
|
||||
|
||||
|
||||
class StorageRecordUpdate(BaseModel):
|
||||
blob: bytes
|
||||
checksum: str
|
||||
|
||||
|
||||
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
||||
|
||||
class VectorItem(BaseModel):
|
||||
id: str
|
||||
blob: bytes # encrypted vector + metadata — backend never decrypts
|
||||
checksum: str
|
||||
|
||||
|
||||
class VectorUpsertRequest(BaseModel):
|
||||
vectors: list[VectorItem]
|
||||
|
||||
|
||||
class VectorSearchRequest(BaseModel):
|
||||
query_blob: bytes # encrypted query — backend never decrypts
|
||||
top_k: int = 10
|
||||
|
||||
|
||||
class VectorSearchResult(BaseModel):
|
||||
id: str
|
||||
score: float
|
||||
blob: bytes
|
||||
|
||||
|
||||
class VectorSearchResponse(BaseModel):
|
||||
results: list[VectorSearchResult]
|
||||
|
||||
|
||||
# ── Plugin Marketplace ────────────────────────────────────────────────
|
||||
|
||||
class PluginManifest(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
version: str
|
||||
author: str
|
||||
permissions: list[str]
|
||||
category: str
|
||||
price_cents: int = 0
|
||||
|
||||
|
||||
class PluginListResponse(BaseModel):
|
||||
plugins: list[PluginManifest]
|
||||
total: int
|
||||
page: int
|
||||
|
||||
|
||||
class PluginInstallRequest(BaseModel):
|
||||
plugin_id: str
|
||||
1
app/storage/__init__.py
Normal file
1
app/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
||||
105
app/storage/blob_store.py
Normal file
105
app/storage/blob_store.py
Normal file
@@ -0,0 +1,105 @@
|
||||
"""S3-backed store for E2E-encrypted blobs.
|
||||
|
||||
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
||||
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
|
||||
class BlobStore:
|
||||
"""Thin wrapper around boto3 S3.
|
||||
|
||||
All blobs must be E2E encrypted by the client before upload.
|
||||
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
||||
but cannot decrypt the inner client-side payload.
|
||||
"""
|
||||
|
||||
def _client(self) -> Any:
|
||||
return boto3.client(
|
||||
"s3",
|
||||
region_name=settings.S3_REGION,
|
||||
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
|
||||
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _key(user_id: str, table: str, record_id: str) -> str:
|
||||
return f"{user_id}/{table}/{record_id}"
|
||||
|
||||
async def upload(
|
||||
self,
|
||||
user_id: str,
|
||||
table: str,
|
||||
record_id: str,
|
||||
blob: bytes,
|
||||
checksum: str,
|
||||
) -> str:
|
||||
"""Store *blob* in S3 and return the S3 key.
|
||||
|
||||
Args:
|
||||
user_id: Owner of the blob (used as key prefix).
|
||||
table: Logical table name (e.g. ``"tasks"``).
|
||||
record_id: Record UUID.
|
||||
blob: Raw bytes (pre-encrypted by client).
|
||||
checksum: SHA-256 hex digest supplied by the client; stored as
|
||||
object metadata for download-time verification.
|
||||
|
||||
Returns:
|
||||
The S3 key under which the blob was stored.
|
||||
"""
|
||||
key = self._key(user_id, table, record_id)
|
||||
self._client().put_object(
|
||||
Bucket=settings.S3_BUCKET,
|
||||
Key=key,
|
||||
Body=blob,
|
||||
ServerSideEncryption="AES256", # SSE-S3 at rest
|
||||
Metadata={"checksum": checksum},
|
||||
)
|
||||
return key
|
||||
|
||||
async def download(self, user_id: str, s3_key: str) -> bytes:
|
||||
"""Retrieve the blob stored at *s3_key*.
|
||||
|
||||
*user_id* is retained in the signature so higher-level code can
|
||||
enforce ownership without re-parsing the key.
|
||||
|
||||
Raises:
|
||||
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
||||
object does not exist.
|
||||
"""
|
||||
response = self._client().get_object(
|
||||
Bucket=settings.S3_BUCKET,
|
||||
Key=s3_key,
|
||||
)
|
||||
return response["Body"].read()
|
||||
|
||||
async def delete(self, user_id: str, s3_key: str) -> None:
|
||||
"""Delete the object at *s3_key*.
|
||||
|
||||
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
||||
not exist.
|
||||
"""
|
||||
self._client().delete_object(
|
||||
Bucket=settings.S3_BUCKET,
|
||||
Key=s3_key,
|
||||
)
|
||||
|
||||
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
||||
"""Return all S3 keys for a given user + table combination.
|
||||
|
||||
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
||||
"""
|
||||
prefix = f"{user_id}/{table}/"
|
||||
response = self._client().list_objects_v2(
|
||||
Bucket=settings.S3_BUCKET,
|
||||
Prefix=prefix,
|
||||
)
|
||||
return [obj["Key"] for obj in response.get("Contents", [])]
|
||||
32
app/storage/encryption.py
Normal file
32
app/storage/encryption.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Integrity verification only — the backend NEVER decrypts user data."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
||||
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
||||
|
||||
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
||||
timing-based side-channel attacks.
|
||||
"""
|
||||
computed = hashlib.sha256(blob).hexdigest()
|
||||
return hmac.compare_digest(computed, checksum)
|
||||
|
||||
|
||||
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
||||
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
||||
|
||||
Call this before storing or forwarding any client-provided blob.
|
||||
The backend never holds decryption keys — this check only verifies
|
||||
that the opaque bytes arrived intact.
|
||||
"""
|
||||
if not verify_checksum(blob, checksum):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Checksum mismatch: blob integrity check failed",
|
||||
)
|
||||
205
app/storage/vector_store.py
Normal file
205
app/storage/vector_store.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
||||
|
||||
Vectors are pre-encrypted blobs from the client. The backend stores them
|
||||
alongside a deterministic 32-dim float representation derived from the blob's
|
||||
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
||||
is a known trade-off documented in the backend plan.
|
||||
|
||||
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
||||
``user_id`` payload field on a shared collection.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
from typing import Any
|
||||
|
||||
from pinecone import Pinecone
|
||||
from qdrant_client import QdrantClient
|
||||
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.schemas import VectorItem, VectorSearchResult
|
||||
|
||||
_QDRANT_COLLECTION = "adiuva_vectors"
|
||||
|
||||
|
||||
def _blob_to_vector(blob: bytes) -> list[float]:
|
||||
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
||||
|
||||
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
||||
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
||||
semantic meaning on encrypted data.
|
||||
"""
|
||||
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
||||
|
||||
|
||||
class VectorStore:
|
||||
"""Thin wrapper around Pinecone or Qdrant.
|
||||
|
||||
The backend to use is selected at runtime:
|
||||
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
||||
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
||||
"""
|
||||
|
||||
def _use_pinecone(self) -> bool:
|
||||
return bool(settings.PINECONE_API_KEY)
|
||||
|
||||
# ── Pinecone helpers ──────────────────────────────────────────────
|
||||
|
||||
def _pinecone_index(self) -> Any:
|
||||
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
||||
return pc.Index(settings.PINECONE_INDEX)
|
||||
|
||||
# ── Qdrant helpers ────────────────────────────────────────────────
|
||||
|
||||
def _qdrant_client(self) -> Any:
|
||||
return QdrantClient(
|
||||
url=settings.QDRANT_URL,
|
||||
api_key=settings.QDRANT_API_KEY or None,
|
||||
)
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────
|
||||
|
||||
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||
"""Store encrypted vectors in the backend.
|
||||
|
||||
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
||||
so it can be returned verbatim during search.
|
||||
|
||||
Args:
|
||||
user_id: Used as Pinecone namespace or Qdrant payload field.
|
||||
vectors: List of encrypted vector items from the client.
|
||||
"""
|
||||
if self._use_pinecone():
|
||||
await self._pinecone_upsert(user_id, vectors)
|
||||
else:
|
||||
await self._qdrant_upsert(user_id, vectors)
|
||||
|
||||
async def search(
|
||||
self,
|
||||
user_id: str,
|
||||
query_blob: bytes,
|
||||
top_k: int,
|
||||
) -> list[VectorSearchResult]:
|
||||
"""Query the vector store and return encrypted result blobs.
|
||||
|
||||
The query vector is derived from *query_blob* using the same
|
||||
deterministic mapping as upsert.
|
||||
|
||||
Args:
|
||||
user_id: Scopes the search to this user's namespace.
|
||||
query_blob: Encrypted query from the client.
|
||||
top_k: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
||||
"""
|
||||
if self._use_pinecone():
|
||||
return await self._pinecone_search(user_id, query_blob, top_k)
|
||||
return await self._qdrant_search(user_id, query_blob, top_k)
|
||||
|
||||
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||
"""Remove vectors by ID, scoped to *user_id*.
|
||||
|
||||
Args:
|
||||
user_id: Namespace / payload filter to prevent cross-user deletion.
|
||||
vector_ids: List of vector IDs to remove.
|
||||
"""
|
||||
if self._use_pinecone():
|
||||
await self._pinecone_delete(user_id, vector_ids)
|
||||
else:
|
||||
await self._qdrant_delete(user_id, vector_ids)
|
||||
|
||||
# ── Pinecone implementation ───────────────────────────────────────
|
||||
|
||||
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||
index = self._pinecone_index()
|
||||
records = [
|
||||
{
|
||||
"id": v.id,
|
||||
"values": _blob_to_vector(v.blob),
|
||||
"metadata": {
|
||||
"blob": base64.b64encode(v.blob).decode(),
|
||||
"checksum": v.checksum,
|
||||
"user_id": user_id,
|
||||
},
|
||||
}
|
||||
for v in vectors
|
||||
]
|
||||
index.upsert(vectors=records, namespace=user_id)
|
||||
|
||||
async def _pinecone_search(
|
||||
self, user_id: str, query_blob: bytes, top_k: int
|
||||
) -> list[VectorSearchResult]:
|
||||
index = self._pinecone_index()
|
||||
query_vector = _blob_to_vector(query_blob)
|
||||
response = index.query(
|
||||
vector=query_vector,
|
||||
top_k=top_k,
|
||||
namespace=user_id,
|
||||
include_metadata=True,
|
||||
)
|
||||
results: list[VectorSearchResult] = []
|
||||
for match in response.get("matches", []):
|
||||
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
||||
results.append(
|
||||
VectorSearchResult(
|
||||
id=match["id"],
|
||||
score=match["score"],
|
||||
blob=blob_bytes,
|
||||
)
|
||||
)
|
||||
return results
|
||||
|
||||
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||
index = self._pinecone_index()
|
||||
index.delete(ids=vector_ids, namespace=user_id)
|
||||
|
||||
# ── Qdrant implementation ─────────────────────────────────────────
|
||||
|
||||
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||
client = self._qdrant_client()
|
||||
points = [
|
||||
PointStruct(
|
||||
id=v.id,
|
||||
vector=_blob_to_vector(v.blob),
|
||||
payload={
|
||||
"blob": base64.b64encode(v.blob).decode(),
|
||||
"checksum": v.checksum,
|
||||
"user_id": user_id,
|
||||
},
|
||||
)
|
||||
for v in vectors
|
||||
]
|
||||
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
||||
|
||||
async def _qdrant_search(
|
||||
self, user_id: str, query_blob: bytes, top_k: int
|
||||
) -> list[VectorSearchResult]:
|
||||
client = self._qdrant_client()
|
||||
query_vector = _blob_to_vector(query_blob)
|
||||
hits = client.search(
|
||||
collection_name=_QDRANT_COLLECTION,
|
||||
query_vector=query_vector,
|
||||
query_filter=Filter(
|
||||
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
||||
),
|
||||
limit=top_k,
|
||||
)
|
||||
return [
|
||||
VectorSearchResult(
|
||||
id=str(hit.id),
|
||||
score=hit.score,
|
||||
blob=base64.b64decode(hit.payload["blob"]),
|
||||
)
|
||||
for hit in hits
|
||||
]
|
||||
|
||||
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||
client = self._qdrant_client()
|
||||
client.delete(
|
||||
collection_name=_QDRANT_COLLECTION,
|
||||
points_selector=PointIdsList(points=vector_ids),
|
||||
)
|
||||
38
docker-compose.yml
Normal file
38
docker-compose.yml
Normal file
@@ -0,0 +1,38 @@
|
||||
version: "3.9"
|
||||
|
||||
services:
|
||||
app:
|
||||
build: .
|
||||
ports:
|
||||
- "8000:8000"
|
||||
env_file:
|
||||
- .env
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
restart: unless-stopped
|
||||
|
||||
db:
|
||||
image: postgres:16-alpine
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: adiuva
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
|
||||
# Optional Redis for future rate-limit or caching needs
|
||||
# redis:
|
||||
# image: redis:7-alpine
|
||||
# restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
22
requirements.txt
Normal file
22
requirements.txt
Normal file
@@ -0,0 +1,22 @@
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.34.0
|
||||
langchain>=0.3.0
|
||||
langchain-openai>=0.3.0
|
||||
pydantic>=2.10.0
|
||||
pydantic-settings>=2.7.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
stripe>=11.0.0
|
||||
boto3>=1.35.0
|
||||
slowapi>=0.1.9
|
||||
sqlalchemy>=2.0.0
|
||||
asyncpg>=0.30.0
|
||||
alembic>=1.14.0
|
||||
bcrypt>=4.2.0
|
||||
python-dotenv>=1.0.0
|
||||
httpx>=0.28.0
|
||||
websockets>=14.0
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.24.0
|
||||
moto[s3]>=5.0.0
|
||||
pinecone>=5.0.0
|
||||
qdrant-client>=1.7.0
|
||||
0
tests/__init__.py
Normal file
0
tests/__init__.py
Normal file
214
tests/test_agent_registry.py
Normal file
214
tests/test_agent_registry.py
Normal file
@@ -0,0 +1,214 @@
|
||||
"""Unit tests for the agent registry, base classes, and tool loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
class _StubAgent(ChatAgent):
|
||||
"""Minimal concrete agent for testing."""
|
||||
|
||||
def get_name(self) -> str:
|
||||
return "stub"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "A stub agent for tests"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return f"echo: {query}"
|
||||
|
||||
|
||||
class _AnotherAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "another"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Another stub"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return "another"
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fresh_registry():
|
||||
"""Reset the singleton between tests."""
|
||||
AgentRegistry._instance = None
|
||||
yield
|
||||
AgentRegistry._instance = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def reg() -> AgentRegistry:
|
||||
return AgentRegistry()
|
||||
|
||||
|
||||
# ── Tests ────────────────────────────────────────────────────────────
|
||||
|
||||
class TestRegisterAndGet:
|
||||
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
agent = reg.get("stub")
|
||||
assert isinstance(agent, _StubAgent)
|
||||
|
||||
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError, match="not found"):
|
||||
reg.get("nonexistent")
|
||||
|
||||
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
reg.register(_AnotherAgent)
|
||||
assert reg.get("stub").get_name() == "stub"
|
||||
assert reg.get("another").get_name() == "another"
|
||||
|
||||
|
||||
class TestListAgents:
|
||||
def test_empty(self, reg: AgentRegistry) -> None:
|
||||
assert reg.list_agents() == []
|
||||
|
||||
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
agents = reg.list_agents()
|
||||
assert len(agents) == 1
|
||||
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
||||
|
||||
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
reg.register(_AnotherAgent)
|
||||
names = {a["name"] for a in reg.list_agents()}
|
||||
assert names == {"stub", "another"}
|
||||
|
||||
|
||||
class TestCallAgent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
||||
reg.register(_StubAgent)
|
||||
result = await reg.call_agent("stub", "hello", {})
|
||||
assert result == "echo: hello"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
await reg.call_agent("nope", "hi", {})
|
||||
|
||||
|
||||
class TestSingleton:
|
||||
def test_singleton_identity(self) -> None:
|
||||
a = AgentRegistry()
|
||||
b = AgentRegistry()
|
||||
assert a is b
|
||||
|
||||
|
||||
class TestToolLoop:
|
||||
@pytest.mark.asyncio
|
||||
async def test_no_tool_calls(self) -> None:
|
||||
"""When the LLM responds without tool calls, return content directly."""
|
||||
agent = _StubAgent()
|
||||
|
||||
ai_msg = MagicMock()
|
||||
ai_msg.content = "final answer"
|
||||
ai_msg.tool_calls = []
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm)
|
||||
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
||||
|
||||
result = await agent._tool_loop(llm, [], [])
|
||||
assert result == "final answer"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tool_call_then_answer(self) -> None:
|
||||
"""LLM requests one tool call, gets result, then answers."""
|
||||
agent = _StubAgent()
|
||||
|
||||
# First response: tool call
|
||||
tool_call_msg = MagicMock()
|
||||
tool_call_msg.content = ""
|
||||
tool_call_msg.tool_calls = [
|
||||
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
||||
]
|
||||
|
||||
# Second response: final answer
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = "done"
|
||||
final_msg.tool_calls = []
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm)
|
||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||
|
||||
# Mock tool
|
||||
tool = AsyncMock()
|
||||
tool.name = "my_tool"
|
||||
tool.ainvoke = AsyncMock(return_value="tool_result")
|
||||
|
||||
result = await agent._tool_loop(llm, [], [tool])
|
||||
assert result == "done"
|
||||
tool.ainvoke.assert_called_once_with({"x": 1})
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_tool_handled(self) -> None:
|
||||
"""Unknown tool names produce an error message instead of crashing."""
|
||||
agent = _StubAgent()
|
||||
|
||||
tool_call_msg = MagicMock()
|
||||
tool_call_msg.content = ""
|
||||
tool_call_msg.tool_calls = [
|
||||
{"id": "call_1", "name": "missing", "args": {}}
|
||||
]
|
||||
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = "recovered"
|
||||
final_msg.tool_calls = []
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm)
|
||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
||||
|
||||
result = await agent._tool_loop(llm, [], [])
|
||||
assert result == "recovered"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_max_iter_reached(self) -> None:
|
||||
"""When max iterations are exhausted, a final no-tools call is made."""
|
||||
agent = _StubAgent()
|
||||
|
||||
# Every response requests a tool call
|
||||
loop_msg = MagicMock()
|
||||
loop_msg.content = ""
|
||||
loop_msg.tool_calls = [
|
||||
{"id": "call_x", "name": "t", "args": {}}
|
||||
]
|
||||
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = "gave up"
|
||||
final_msg.tool_calls = []
|
||||
|
||||
tool = AsyncMock()
|
||||
tool.name = "t"
|
||||
tool.ainvoke = AsyncMock(return_value="ok")
|
||||
|
||||
llm_with_tools = AsyncMock()
|
||||
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
||||
|
||||
llm = AsyncMock()
|
||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||
|
||||
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
||||
assert result == "gave up"
|
||||
assert llm_with_tools.ainvoke.call_count == 2
|
||||
620
tests/test_agents.py
Normal file
620
tests/test_agents.py
Normal file
@@ -0,0 +1,620 @@
|
||||
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
||||
from app.agents.checkpoint_agent import CheckpointAgent
|
||||
from app.agents.note_agent import NoteAgent
|
||||
from app.agents.project_agent import ProjectAgent
|
||||
from app.agents.task_agent import TaskAgent
|
||||
from app.core.agent_registry import registry
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_llm(response_text: str) -> MagicMock:
|
||||
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
|
||||
msg = MagicMock()
|
||||
msg.content = response_text
|
||||
msg.tool_calls = []
|
||||
llm = MagicMock()
|
||||
bound = MagicMock()
|
||||
bound.ainvoke = AsyncMock(return_value=msg)
|
||||
llm.bind_tools = MagicMock(return_value=bound)
|
||||
llm.ainvoke = AsyncMock(return_value=msg)
|
||||
return llm
|
||||
|
||||
|
||||
def _mock_llm_with_tool_call(
|
||||
tool_name: str, tool_args: dict[str, Any], final_text: str
|
||||
) -> MagicMock:
|
||||
"""Mock LLM that fires one tool call then returns *final_text*."""
|
||||
tool_msg = MagicMock()
|
||||
tool_msg.content = ""
|
||||
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
|
||||
|
||||
final_msg = MagicMock()
|
||||
final_msg.content = final_text
|
||||
final_msg.tool_calls = []
|
||||
|
||||
bound = MagicMock()
|
||||
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
|
||||
|
||||
llm = MagicMock()
|
||||
llm.bind_tools = MagicMock(return_value=bound)
|
||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||
return llm
|
||||
|
||||
|
||||
# ── Registration ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestAgentRegistration:
|
||||
def test_all_agents_registered(self) -> None:
|
||||
names = {a["name"] for a in registry.list_agents()}
|
||||
assert {
|
||||
"task_agent", "checkpoint_agent", "project_agent", "note_agent"
|
||||
}.issubset(names)
|
||||
|
||||
def test_registry_returns_correct_types(self) -> None:
|
||||
assert isinstance(registry.get("task_agent"), TaskAgent)
|
||||
assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent)
|
||||
assert isinstance(registry.get("project_agent"), ProjectAgent)
|
||||
assert isinstance(registry.get("note_agent"), NoteAgent)
|
||||
|
||||
def test_descriptions_present(self) -> None:
|
||||
for agent_info in registry.list_agents():
|
||||
assert agent_info["description"], f"Empty description: {agent_info['name']}"
|
||||
|
||||
|
||||
# ── TaskAgent ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestTaskAgent:
|
||||
def test_name(self) -> None:
|
||||
assert TaskAgent().get_name() == "task_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(TaskAgent().get_tools()) == 8
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in TaskAgent().get_tools()}
|
||||
assert names == {
|
||||
"list_tasks",
|
||||
"create_task",
|
||||
"update_task",
|
||||
"delete_task",
|
||||
"list_tasks_due_today",
|
||||
"list_task_comments",
|
||||
"add_task_comment",
|
||||
"delete_task_comment",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_returns_string(self) -> None:
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Task created.")
|
||||
result = await TaskAgent().handle("create a task", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Here are your tasks.")
|
||||
result = await TaskAgent().handle("list my tasks", {})
|
||||
assert result == "Here are your tasks."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_task_tool_call(self) -> None:
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_task",
|
||||
{"title": "Buy groceries", "priority": "low"},
|
||||
"Task 'Buy groceries' created.",
|
||||
)
|
||||
result = await TaskAgent().handle("add a grocery task", {})
|
||||
assert result == "Task 'Buy groceries' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await TaskAgent().handle("help", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_rich_context(self) -> None:
|
||||
context = {
|
||||
"user_profile": {"id": "u1", "tier": "pro"},
|
||||
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
||||
}
|
||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Tasks listed.")
|
||||
result = await TaskAgent().handle("show tasks", context)
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestTaskAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_defaults(self) -> None:
|
||||
from app.agents.task_agent import list_tasks
|
||||
result = await list_tasks.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "tasks"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_with_status_filter(self) -> None:
|
||||
from app.agents.task_agent import list_tasks
|
||||
result = await list_tasks.ainvoke({"status": "done"})
|
||||
data = json.loads(result)
|
||||
assert data["filters"]["status"] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_defaults(self) -> None:
|
||||
from app.agents.task_agent import create_task
|
||||
result = await create_task.ainvoke({"title": "Test task"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "create_record"
|
||||
assert data["table"] == "tasks"
|
||||
assert data["data"]["title"] == "Test task"
|
||||
assert data["data"]["status"] == "todo"
|
||||
assert data["data"]["priority"] == "medium"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_task_with_all_fields(self) -> None:
|
||||
from app.agents.task_agent import create_task
|
||||
result = await create_task.ainvoke({
|
||||
"title": "Deploy",
|
||||
"priority": "high",
|
||||
"status": "in_progress",
|
||||
"project_id": "p1",
|
||||
"is_ai_suggested": 1,
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["priority"] == "high"
|
||||
assert data["data"]["status"] == "in_progress"
|
||||
assert data["data"]["projectId"] == "p1"
|
||||
assert data["data"]["isAiSuggested"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_with_status(self) -> None:
|
||||
from app.agents.task_agent import update_task
|
||||
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "update_record"
|
||||
assert data["data"]["id"] == "t1"
|
||||
assert data["data"]["updates"]["status"] == "done"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_task_empty_updates(self) -> None:
|
||||
from app.agents.task_agent import update_task
|
||||
result = await update_task.ainvoke({"task_id": "t1"})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_task(self) -> None:
|
||||
from app.agents.task_agent import delete_task
|
||||
result = await delete_task.ainvoke({"task_id": "t1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["table"] == "tasks"
|
||||
assert data["data"]["id"] == "t1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_tasks_due_today(self) -> None:
|
||||
from app.agents.task_agent import list_tasks_due_today
|
||||
result = await list_tasks_due_today.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list_due_today"
|
||||
assert data["table"] == "tasks"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_task_comments(self) -> None:
|
||||
from app.agents.task_agent import list_task_comments
|
||||
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "taskComments"
|
||||
assert data["filters"]["taskId"] == "t1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_add_task_comment(self) -> None:
|
||||
from app.agents.task_agent import add_task_comment
|
||||
result = await add_task_comment.ainvoke({
|
||||
"task_id": "t1",
|
||||
"author": "Alice",
|
||||
"content": "Looks good!",
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "create_record"
|
||||
assert data["table"] == "taskComments"
|
||||
assert data["data"]["taskId"] == "t1"
|
||||
assert data["data"]["author"] == "Alice"
|
||||
assert data["data"]["content"] == "Looks good!"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_task_comment(self) -> None:
|
||||
from app.agents.task_agent import delete_task_comment
|
||||
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["table"] == "taskComments"
|
||||
assert data["data"]["id"] == "c1"
|
||||
|
||||
|
||||
# ── CheckpointAgent ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestCheckpointAgent:
|
||||
def test_name(self) -> None:
|
||||
assert CheckpointAgent().get_name() == "checkpoint_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(CheckpointAgent().get_tools()) == 4
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in CheckpointAgent().get_tools()}
|
||||
assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("No checkpoints found.")
|
||||
result = await CheckpointAgent().handle("list checkpoints", {})
|
||||
assert result == "No checkpoints found."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_tool_call(self) -> None:
|
||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_checkpoint",
|
||||
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
||||
"Checkpoint 'MVP Launch' created.",
|
||||
)
|
||||
result = await CheckpointAgent().handle("add MVP checkpoint", {})
|
||||
assert result == "Checkpoint 'MVP Launch' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await CheckpointAgent().handle("show milestones", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestCheckpointAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_checkpoints_no_project(self) -> None:
|
||||
from app.agents.checkpoint_agent import list_checkpoints
|
||||
result = await list_checkpoints.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "checkpoints"
|
||||
assert data["filters"]["projectId"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_checkpoints_with_project(self) -> None:
|
||||
from app.agents.checkpoint_agent import list_checkpoints
|
||||
result = await list_checkpoints.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["filters"]["projectId"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkpoint(self) -> None:
|
||||
from app.agents.checkpoint_agent import create_checkpoint
|
||||
result = await create_checkpoint.ainvoke({
|
||||
"project_id": "p1",
|
||||
"title": "Beta release",
|
||||
"date": 1700000000000,
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "create_record"
|
||||
assert data["table"] == "checkpoints"
|
||||
assert data["data"]["projectId"] == "p1"
|
||||
assert data["data"]["title"] == "Beta release"
|
||||
assert data["data"]["date"] == 1700000000000
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_checkpoint_ai_suggested(self) -> None:
|
||||
from app.agents.checkpoint_agent import create_checkpoint
|
||||
result = await create_checkpoint.ainvoke({
|
||||
"project_id": "p1",
|
||||
"title": "Review",
|
||||
"date": 1700000000000,
|
||||
"is_ai_suggested": 1,
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["isAiSuggested"] == 1
|
||||
assert data["data"]["isApproved"] == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_checkpoint_approve(self) -> None:
|
||||
from app.agents.checkpoint_agent import update_checkpoint
|
||||
result = await update_checkpoint.ainvoke({
|
||||
"checkpoint_id": "c1",
|
||||
"is_approved": 1,
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "update_record"
|
||||
assert data["data"]["id"] == "c1"
|
||||
assert data["data"]["updates"]["isApproved"] == 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_checkpoint_empty_updates(self) -> None:
|
||||
from app.agents.checkpoint_agent import update_checkpoint
|
||||
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_checkpoint(self) -> None:
|
||||
from app.agents.checkpoint_agent import delete_checkpoint
|
||||
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["table"] == "checkpoints"
|
||||
assert data["data"]["id"] == "c1"
|
||||
|
||||
|
||||
# ── ProjectAgent ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestProjectAgent:
|
||||
def test_name(self) -> None:
|
||||
assert ProjectAgent().get_name() == "project_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(ProjectAgent().get_tools()) == 6
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in ProjectAgent().get_tools()}
|
||||
assert names == {
|
||||
"list_projects",
|
||||
"list_all_projects",
|
||||
"get_project",
|
||||
"create_project",
|
||||
"update_project",
|
||||
"delete_project",
|
||||
}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
||||
result = await ProjectAgent().handle("show my projects", {})
|
||||
assert result == "Project Alpha is active."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_project_tool_call(self) -> None:
|
||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_project",
|
||||
{"name": "Pippo"},
|
||||
"Project 'Pippo' created.",
|
||||
)
|
||||
result = await ProjectAgent().handle("create project Pippo", {})
|
||||
assert result == "Project 'Pippo' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await ProjectAgent().handle("archive old project", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestProjectAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_projects_defaults(self) -> None:
|
||||
from app.agents.project_agent import list_projects
|
||||
result = await list_projects.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "projects"
|
||||
assert data["filters"]["includeArchived"] is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_projects_include_archived(self) -> None:
|
||||
from app.agents.project_agent import list_projects
|
||||
result = await list_projects.ainvoke({"include_archived": 1})
|
||||
data = json.loads(result)
|
||||
assert data["filters"]["includeArchived"] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_all_projects(self) -> None:
|
||||
from app.agents.project_agent import list_all_projects
|
||||
result = await list_all_projects.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list_all"
|
||||
assert data["table"] == "projects"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_project(self) -> None:
|
||||
from app.agents.project_agent import get_project
|
||||
result = await get_project.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "get"
|
||||
assert data["table"] == "projects"
|
||||
assert data["data"]["id"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project_name_only(self) -> None:
|
||||
from app.agents.project_agent import create_project
|
||||
result = await create_project.ainvoke({"name": "Alpha"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "create_record"
|
||||
assert data["data"]["name"] == "Alpha"
|
||||
assert data["data"]["clientId"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_project_with_client(self) -> None:
|
||||
from app.agents.project_agent import create_project
|
||||
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["clientId"] == "cl1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_archive(self) -> None:
|
||||
from app.agents.project_agent import update_project
|
||||
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "update_record"
|
||||
assert data["data"]["id"] == "p1"
|
||||
assert data["data"]["updates"]["status"] == "archived"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_project_empty_updates(self) -> None:
|
||||
from app.agents.project_agent import update_project
|
||||
result = await update_project.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_project(self) -> None:
|
||||
from app.agents.project_agent import delete_project
|
||||
result = await delete_project.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["data"]["id"] == "p1"
|
||||
|
||||
|
||||
# ── NoteAgent ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestNoteAgent:
|
||||
def test_name(self) -> None:
|
||||
assert NoteAgent().get_name() == "note_agent"
|
||||
|
||||
def test_description(self) -> None:
|
||||
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
|
||||
|
||||
def test_get_tools_count(self) -> None:
|
||||
assert len(NoteAgent().get_tools()) == 5
|
||||
|
||||
def test_tool_names(self) -> None:
|
||||
names = {t.name for t in NoteAgent().get_tools()}
|
||||
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_no_tool_calls(self) -> None:
|
||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Note created.")
|
||||
result = await NoteAgent().handle("create a note", {})
|
||||
assert result == "Note created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_with_create_note_tool_call(self) -> None:
|
||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||
"create_note",
|
||||
{"title": "Daily log", "content": "# Today\nAll good."},
|
||||
"Note 'Daily log' created.",
|
||||
)
|
||||
result = await NoteAgent().handle("log today's progress", {})
|
||||
assert result == "Note 'Daily log' created."
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_accepts_empty_context(self) -> None:
|
||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("Done.")
|
||||
result = await NoteAgent().handle("show notes", {})
|
||||
assert isinstance(result, str)
|
||||
|
||||
|
||||
class TestNoteAgentTools:
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_notes_no_project(self) -> None:
|
||||
from app.agents.note_agent import list_notes
|
||||
result = await list_notes.ainvoke({})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "list"
|
||||
assert data["table"] == "notes"
|
||||
assert data["filters"]["projectId"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_notes_with_project(self) -> None:
|
||||
from app.agents.note_agent import list_notes
|
||||
result = await list_notes.ainvoke({"project_id": "p1"})
|
||||
data = json.loads(result)
|
||||
assert data["filters"]["projectId"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_note(self) -> None:
|
||||
from app.agents.note_agent import get_note
|
||||
result = await get_note.ainvoke({"note_id": "n1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "get"
|
||||
assert data["table"] == "notes"
|
||||
assert data["data"]["id"] == "n1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_note_minimal(self) -> None:
|
||||
from app.agents.note_agent import create_note
|
||||
result = await create_note.ainvoke({
|
||||
"title": "Daily log",
|
||||
"content": "# Today\nAll good.",
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "create_record"
|
||||
assert data["table"] == "notes"
|
||||
assert data["data"]["title"] == "Daily log"
|
||||
assert data["data"]["content"] == "# Today\nAll good."
|
||||
assert data["data"]["projectId"] is None
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_note_with_project(self) -> None:
|
||||
from app.agents.note_agent import create_note
|
||||
result = await create_note.ainvoke({
|
||||
"title": "Sprint notes",
|
||||
"content": "## Sprint 1",
|
||||
"project_id": "p1",
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["projectId"] == "p1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_note_content_only(self) -> None:
|
||||
from app.agents.note_agent import update_note
|
||||
result = await update_note.ainvoke({
|
||||
"note_id": "n1",
|
||||
"content": "# Updated content",
|
||||
})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "update_record"
|
||||
assert data["data"]["id"] == "n1"
|
||||
assert data["data"]["updates"]["content"] == "# Updated content"
|
||||
assert "title" not in data["data"]["updates"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_note_empty_updates(self) -> None:
|
||||
from app.agents.note_agent import update_note
|
||||
result = await update_note.ainvoke({"note_id": "n1"})
|
||||
data = json.loads(result)
|
||||
assert data["data"]["updates"] == {}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_note(self) -> None:
|
||||
from app.agents.note_agent import delete_note
|
||||
result = await delete_note.ainvoke({"note_id": "n1"})
|
||||
data = json.loads(result)
|
||||
assert data["action"] == "delete_record"
|
||||
assert data["table"] == "notes"
|
||||
assert data["data"]["id"] == "n1"
|
||||
286
tests/test_execution_plan.py
Normal file
286
tests/test_execution_plan.py
Normal file
@@ -0,0 +1,286 @@
|
||||
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.execution_plan import (
|
||||
ExecutionPlanBuilder,
|
||||
PlanCache,
|
||||
PromptTemplateRegistry,
|
||||
plan_cache,
|
||||
template_registry,
|
||||
)
|
||||
from app.schemas import ExecutionPlan
|
||||
|
||||
|
||||
# ── PromptTemplateRegistry ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPromptTemplateRegistry:
|
||||
def test_register_and_get(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_foo", "You are a foo agent.")
|
||||
assert reg.get("tpl_foo") == "You are a foo agent."
|
||||
|
||||
def test_get_unknown_raises_key_error(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
with pytest.raises(KeyError, match="tpl_missing"):
|
||||
reg.get("tpl_missing")
|
||||
|
||||
def test_has_returns_true_for_registered(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_x", "prompt text")
|
||||
assert reg.has("tpl_x") is True
|
||||
|
||||
def test_has_returns_false_for_unregistered(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
assert reg.has("tpl_missing") is False
|
||||
|
||||
def test_list_ids_returns_all_registered_ids(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_a", "a")
|
||||
reg.register("tpl_b", "b")
|
||||
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
|
||||
|
||||
def test_list_ids_does_not_return_prompt_text(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_secret", "top secret prompt")
|
||||
ids = reg.list_ids()
|
||||
assert "top secret prompt" not in ids
|
||||
|
||||
def test_overwrite_existing_template(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
reg.register("tpl_x", "v1")
|
||||
reg.register("tpl_x", "v2")
|
||||
assert reg.get("tpl_x") == "v2"
|
||||
|
||||
def test_empty_registry_has_no_ids(self) -> None:
|
||||
reg = PromptTemplateRegistry()
|
||||
assert reg.list_ids() == []
|
||||
|
||||
|
||||
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestExecutionPlanBuilder:
|
||||
def test_builds_empty_plan(self) -> None:
|
||||
plan = ExecutionPlanBuilder("task_agent").build()
|
||||
assert plan.agent == "task_agent"
|
||||
assert plan.steps == []
|
||||
|
||||
def test_add_step_basic(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_step("create_task", {"priority": "high"})
|
||||
.build()
|
||||
)
|
||||
assert len(plan.steps) == 1
|
||||
assert plan.steps[0].action == "create_task"
|
||||
assert plan.steps[0].variables == {"priority": "high"}
|
||||
assert plan.steps[0].prompt_template is None
|
||||
assert plan.steps[0].data_from_step is None
|
||||
|
||||
def test_add_step_no_params(self) -> None:
|
||||
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
|
||||
assert plan.steps[0].variables is None
|
||||
|
||||
def test_add_llm_step(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_llm_step("tpl_task_default", {"message": "hi"})
|
||||
.build()
|
||||
)
|
||||
assert plan.steps[0].action == "llm"
|
||||
assert plan.steps[0].prompt_template == "tpl_task_default"
|
||||
assert plan.steps[0].variables == {"message": "hi"}
|
||||
|
||||
def test_add_llm_step_no_variables(self) -> None:
|
||||
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
|
||||
assert plan.steps[0].variables is None
|
||||
|
||||
def test_add_data_step(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_step("fetch_data")
|
||||
.add_data_step("transform", data_from_step=0)
|
||||
.build()
|
||||
)
|
||||
assert plan.steps[1].action == "transform"
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_fluent_chaining_returns_builder(self) -> None:
|
||||
builder = ExecutionPlanBuilder("analytics_agent")
|
||||
result = builder.add_step("a")
|
||||
assert result is builder
|
||||
|
||||
def test_fluent_chain_multiple_steps(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("analytics_agent")
|
||||
.add_llm_step("tpl_analytics_default")
|
||||
.add_step("format_output")
|
||||
.add_data_step("store", data_from_step=0)
|
||||
.build()
|
||||
)
|
||||
assert len(plan.steps) == 3
|
||||
|
||||
def test_build_validates_data_from_step_out_of_range(self) -> None:
|
||||
with pytest.raises(ValueError, match="data_from_step"):
|
||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
|
||||
|
||||
def test_build_validates_data_from_step_self_reference(self) -> None:
|
||||
"""data_from_step=0 on the first step (index 0) is invalid."""
|
||||
with pytest.raises(ValueError, match="data_from_step"):
|
||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
|
||||
|
||||
def test_build_validates_data_from_step_negative(self) -> None:
|
||||
with pytest.raises(ValueError, match="data_from_step"):
|
||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
|
||||
|
||||
def test_valid_data_from_step_at_index_two(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_step("step0")
|
||||
.add_step("step1")
|
||||
.add_data_step("step2", data_from_step=1)
|
||||
.build()
|
||||
)
|
||||
assert plan.steps[2].data_from_step == 1
|
||||
|
||||
def test_data_from_step_zero_valid_at_index_one(self) -> None:
|
||||
plan = (
|
||||
ExecutionPlanBuilder("task_agent")
|
||||
.add_step("step0")
|
||||
.add_data_step("step1", data_from_step=0)
|
||||
.build()
|
||||
)
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_build_returns_new_plan_each_call(self) -> None:
|
||||
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
|
||||
plan1 = builder.build()
|
||||
plan2 = builder.build()
|
||||
assert plan1 is not plan2
|
||||
assert plan1.steps == plan2.steps
|
||||
|
||||
def test_plan_is_execution_plan_instance(self) -> None:
|
||||
plan = ExecutionPlanBuilder("task_agent").build()
|
||||
assert isinstance(plan, ExecutionPlan)
|
||||
|
||||
|
||||
# ── PlanCache ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestPlanCache:
|
||||
def _plan(self, agent: str = "a") -> ExecutionPlan:
|
||||
return ExecutionPlanBuilder(agent).build()
|
||||
|
||||
def test_cache_and_get(self) -> None:
|
||||
cache = PlanCache()
|
||||
plan = self._plan()
|
||||
cache.cache_plan("key1", plan)
|
||||
assert cache.get_plan("key1") is plan
|
||||
|
||||
def test_get_missing_returns_none(self) -> None:
|
||||
cache = PlanCache()
|
||||
assert cache.get_plan("nonexistent") is None
|
||||
|
||||
def test_get_all_playbooks_empty(self) -> None:
|
||||
cache = PlanCache()
|
||||
assert cache.get_all_playbooks() == []
|
||||
|
||||
def test_get_all_playbooks_returns_all_stored(self) -> None:
|
||||
cache = PlanCache()
|
||||
p1, p2 = self._plan("a"), self._plan("b")
|
||||
cache.cache_plan("k1", p1)
|
||||
cache.cache_plan("k2", p2)
|
||||
playbooks = cache.get_all_playbooks()
|
||||
assert len(playbooks) == 2
|
||||
assert p1 in playbooks
|
||||
assert p2 in playbooks
|
||||
|
||||
def test_lru_evicts_oldest_entry(self) -> None:
|
||||
cache = PlanCache(maxsize=2)
|
||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
||||
cache.cache_plan("k1", p1)
|
||||
cache.cache_plan("k2", p2)
|
||||
cache.cache_plan("k3", p3) # k1 should be evicted
|
||||
assert cache.get_plan("k1") is None
|
||||
assert cache.get_plan("k2") is p2
|
||||
assert cache.get_plan("k3") is p3
|
||||
|
||||
def test_lru_access_updates_recency(self) -> None:
|
||||
cache = PlanCache(maxsize=2)
|
||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
||||
cache.cache_plan("k1", p1)
|
||||
cache.cache_plan("k2", p2)
|
||||
cache.get_plan("k1") # k1 is now most-recently used
|
||||
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
|
||||
assert cache.get_plan("k1") is p1
|
||||
assert cache.get_plan("k2") is None
|
||||
assert cache.get_plan("k3") is p3
|
||||
|
||||
def test_overwrite_existing_key(self) -> None:
|
||||
cache = PlanCache()
|
||||
p1, p2 = self._plan("a"), self._plan("b")
|
||||
cache.cache_plan("same_key", p1)
|
||||
cache.cache_plan("same_key", p2)
|
||||
assert cache.get_plan("same_key") is p2
|
||||
assert len(cache.get_all_playbooks()) == 1
|
||||
|
||||
def test_overwrite_does_not_consume_capacity(self) -> None:
|
||||
cache = PlanCache(maxsize=2)
|
||||
p1, p2 = self._plan("a"), self._plan("b")
|
||||
cache.cache_plan("k1", p1)
|
||||
cache.cache_plan("k1", p2) # overwrite, not a new slot
|
||||
cache.cache_plan("k2", p1) # should fit without eviction
|
||||
assert cache.get_plan("k1") is p2
|
||||
assert cache.get_plan("k2") is p1
|
||||
|
||||
|
||||
# ── Module-level singletons ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestModuleSingletons:
|
||||
def test_template_registry_has_all_agent_defaults(self) -> None:
|
||||
for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"):
|
||||
assert template_registry.has(f"tpl_{agent}_default"), (
|
||||
f"Missing template: tpl_{agent}_default"
|
||||
)
|
||||
|
||||
def test_template_registry_has_operation_templates(self) -> None:
|
||||
assert template_registry.has("tpl_task_extract_from_project")
|
||||
assert template_registry.has("tpl_note_weekly_summary")
|
||||
|
||||
def test_template_registry_get_returns_non_empty_string(self) -> None:
|
||||
text = template_registry.get("tpl_task_agent_default")
|
||||
assert isinstance(text, str)
|
||||
assert len(text) > 0
|
||||
|
||||
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
|
||||
assert len(plan_cache.get_all_playbooks()) >= 2
|
||||
|
||||
def test_playbook_create_tasks_from_project(self) -> None:
|
||||
plan = plan_cache.get_plan("create_tasks_from_project")
|
||||
assert plan is not None
|
||||
assert plan.agent == "project_agent"
|
||||
assert len(plan.steps) == 2
|
||||
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_playbook_generate_weekly_note(self) -> None:
|
||||
plan = plan_cache.get_plan("generate_weekly_note")
|
||||
assert plan is not None
|
||||
assert plan.agent == "note_agent"
|
||||
assert len(plan.steps) == 2
|
||||
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
|
||||
assert plan.steps[1].data_from_step == 0
|
||||
|
||||
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
|
||||
"""Plans must not embed prompt text — only template IDs."""
|
||||
for plan in plan_cache.get_all_playbooks():
|
||||
for step in plan.steps:
|
||||
if step.prompt_template is not None:
|
||||
assert step.prompt_template.startswith("tpl_"), (
|
||||
f"prompt_template looks like raw text: {step.prompt_template!r}"
|
||||
)
|
||||
348
tests/test_orchestrator.py
Normal file
348
tests/test_orchestrator.py
Normal file
@@ -0,0 +1,348 @@
|
||||
"""Integration tests for the orchestrator module."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||
from app.core.orchestrator import (
|
||||
classify_intent,
|
||||
orchestrate,
|
||||
orchestrate_stream,
|
||||
route_pipeline,
|
||||
route_single,
|
||||
)
|
||||
from app.schemas import ChatContext, ChatRequest, ChatResponse, ExecutionPlan
|
||||
|
||||
|
||||
# ── Stub agents ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _TaskAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "task_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Manages tasks: create, update, list, suggest"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return f"task: {query}"
|
||||
|
||||
|
||||
class _CalendarAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "calendar_agent"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "Calendar management: events, conflicts, scheduling"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
return f"calendar: {query}"
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _mock_llm(response_text: str) -> MagicMock:
|
||||
"""Return a mock LLM that always produces *response_text*."""
|
||||
msg = MagicMock()
|
||||
msg.content = response_text
|
||||
llm = MagicMock()
|
||||
llm.ainvoke = AsyncMock(return_value=msg)
|
||||
return llm
|
||||
|
||||
|
||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _fresh_registry():
|
||||
"""Reset the AgentRegistry singleton between tests."""
|
||||
AgentRegistry._instance = None
|
||||
yield
|
||||
AgentRegistry._instance = None
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def reg() -> AgentRegistry:
|
||||
r = AgentRegistry()
|
||||
r.register(_TaskAgent)
|
||||
r.register(_CalendarAgent)
|
||||
return r
|
||||
|
||||
|
||||
# ── classify_intent ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestClassifyIntent:
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
result = await classify_intent("add a task", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||
result = await classify_intent("schedule a meeting", {}, reg)
|
||||
assert result == "calendar_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("nonexistent_agent")
|
||||
result = await classify_intent("do something", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
|
||||
empty_reg = AgentRegistry()
|
||||
# No LLM should be instantiated — early return path
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
result = await classify_intent("anything", {}, empty_reg)
|
||||
mock_cls.assert_not_called()
|
||||
assert result == "task_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm(" task_agent \n")
|
||||
result = await classify_intent("create task", {}, reg)
|
||||
assert result == "task_agent"
|
||||
|
||||
|
||||
# ── route_single ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRouteSingle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||
result = await route_single("task_agent", "create a task", {}, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
|
||||
result = await route_single("task_agent", "create a task", {}, reg)
|
||||
assert result.response == "task: create a task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
|
||||
with pytest.raises(KeyError):
|
||||
await route_single("nonexistent", "hello", {}, reg)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
|
||||
result = await route_single("task_agent", "hi", {}, reg)
|
||||
assert result.actions == []
|
||||
|
||||
|
||||
# ── route_pipeline ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestRoutePipeline:
|
||||
@pytest.mark.asyncio
|
||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("synthesized result")
|
||||
result = await route_pipeline(
|
||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||
)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("synthesized result")
|
||||
result = await route_pipeline(
|
||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||
)
|
||||
assert result.response == "synthesized result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_passes_previous_results_to_subsequent_agents(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
"""Each agent after the first should receive prior outputs in context."""
|
||||
received_contexts: list[dict[str, Any]] = []
|
||||
|
||||
class _CapturingAgent(ChatAgent):
|
||||
def get_name(self) -> str:
|
||||
return "capture"
|
||||
|
||||
def get_description(self) -> str:
|
||||
return "captures context for testing"
|
||||
|
||||
def get_tools(self) -> list[Any]:
|
||||
return []
|
||||
|
||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||
received_contexts.append(dict(context))
|
||||
return "captured"
|
||||
|
||||
reg.register(_CapturingAgent)
|
||||
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("done")
|
||||
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
|
||||
|
||||
# The second agent (capture) must have received previous results
|
||||
assert len(received_contexts) == 1
|
||||
assert "previous_results" in received_contexts[0]
|
||||
assert received_contexts[0]["previous_results"] == ["task: hi"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("single result")
|
||||
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
|
||||
assert result.response == "single result"
|
||||
|
||||
|
||||
# ── orchestrate ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestOrchestrate:
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_mode_returns_chat_response(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
assert result.response == "task: add a task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_returns_execution_plan(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan my tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_agent_matches_classified(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||
request = ChatRequest(
|
||||
message="schedule something", execution_mode="plan"
|
||||
)
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
assert result.agent == "calendar_agent"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
assert len(result.steps) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_plan_mode_template_id_contains_agent_name(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ExecutionPlan)
|
||||
assert result.steps[0].prompt_template is not None
|
||||
assert "task_agent" in result.steps[0].prompt_template
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_default_execution_mode_is_direct(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
# execution_mode defaults to "direct"
|
||||
request = ChatRequest(message="help me")
|
||||
result = await orchestrate(request, reg)
|
||||
assert isinstance(result, ChatResponse)
|
||||
|
||||
|
||||
# ── orchestrate_stream ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestOrchestrateStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
assert len(chunks) >= 1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_last_chunk_is_final_json_frame(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
|
||||
last = json.loads(chunks[-1])
|
||||
assert last["done"] is True
|
||||
assert "response" in last
|
||||
assert "actions" in last
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_final_frame_response_matches_agent_output(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(message="create a task", execution_mode="direct")
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
|
||||
final = json.loads(chunks[-1])
|
||||
assert final["response"] == "task: create a task"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_text_chunks_before_final_frame(
|
||||
self, reg: AgentRegistry
|
||||
) -> None:
|
||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||
mock_cls.return_value = _mock_llm("task_agent")
|
||||
request = ChatRequest(
|
||||
message="x" * 200, execution_mode="direct"
|
||||
) # long enough to produce multiple chunks
|
||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||
|
||||
# All but the last chunk should be plain text (not valid final JSON)
|
||||
non_final = chunks[:-1]
|
||||
for chunk in non_final:
|
||||
try:
|
||||
parsed = json.loads(chunk)
|
||||
assert parsed.get("done") is not True
|
||||
except json.JSONDecodeError:
|
||||
pass # plain text chunk — expected
|
||||
385
tests/test_storage.py
Normal file
385
tests/test_storage.py
Normal file
@@ -0,0 +1,385 @@
|
||||
"""Tests for the storage layer: encryption, BlobStore, and VectorStore."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import boto3
|
||||
import pytest
|
||||
from botocore.exceptions import ClientError
|
||||
from moto import mock_aws
|
||||
|
||||
from app.storage.encryption import reject_if_tampered, verify_checksum
|
||||
from app.storage.blob_store import BlobStore
|
||||
from app.storage.vector_store import VectorStore, _blob_to_vector
|
||||
from app.schemas import VectorItem, VectorSearchResult
|
||||
|
||||
|
||||
# ── Helpers ───────────────────────────────────────────────────────────
|
||||
|
||||
_BLOB = b"encrypted-payload-opaque-to-server"
|
||||
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||
_BUCKET = "test-bucket"
|
||||
_REGION = "us-east-1"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def s3_bucket():
|
||||
"""Create a mocked S3 bucket and expose its name."""
|
||||
with mock_aws():
|
||||
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
||||
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
||||
os.environ.setdefault("AWS_DEFAULT_REGION", _REGION)
|
||||
client = boto3.client("s3", region_name=_REGION)
|
||||
client.create_bucket(Bucket=_BUCKET)
|
||||
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||
mock_settings.S3_BUCKET = _BUCKET
|
||||
mock_settings.S3_REGION = _REGION
|
||||
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||
yield _BUCKET
|
||||
|
||||
|
||||
def _pinecone_mock():
|
||||
"""Return a mock Pinecone index with realistic return shapes."""
|
||||
mock_index = MagicMock()
|
||||
mock_index.query.return_value = {
|
||||
"matches": [
|
||||
{
|
||||
"id": "v1",
|
||||
"score": 0.95,
|
||||
"metadata": {
|
||||
"blob": base64.b64encode(b"result-blob").decode(),
|
||||
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
|
||||
"user_id": "u1",
|
||||
},
|
||||
}
|
||||
]
|
||||
}
|
||||
mock_pc = MagicMock()
|
||||
mock_pc.return_value.Index.return_value = mock_index
|
||||
return mock_pc, mock_index
|
||||
|
||||
|
||||
# ── TestEncryption ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestEncryption:
|
||||
def test_verify_checksum_correct(self) -> None:
|
||||
assert verify_checksum(_BLOB, _CHECKSUM) is True
|
||||
|
||||
def test_verify_checksum_wrong(self) -> None:
|
||||
assert verify_checksum(_BLOB, "0" * 64) is False
|
||||
|
||||
def test_verify_checksum_empty_checksum(self) -> None:
|
||||
assert verify_checksum(_BLOB, "") is False
|
||||
|
||||
def test_verify_checksum_empty_blob(self) -> None:
|
||||
expected = hashlib.sha256(b"").hexdigest()
|
||||
assert verify_checksum(b"", expected) is True
|
||||
|
||||
def test_verify_checksum_tampered_blob(self) -> None:
|
||||
tampered = _BLOB + b"\x00"
|
||||
assert verify_checksum(tampered, _CHECKSUM) is False
|
||||
|
||||
def test_reject_if_tampered_passes_when_valid(self) -> None:
|
||||
# Should not raise
|
||||
reject_if_tampered(_BLOB, _CHECKSUM)
|
||||
|
||||
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
reject_if_tampered(_BLOB, "bad" * 20)
|
||||
assert exc_info.value.status_code == 400
|
||||
|
||||
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
reject_if_tampered(_BLOB, "bad" * 20)
|
||||
assert "checksum" in exc_info.value.detail.lower()
|
||||
|
||||
def test_checksum_is_sha256_hex(self) -> None:
|
||||
cs = hashlib.sha256(_BLOB).hexdigest()
|
||||
assert len(cs) == 64
|
||||
assert all(c in "0123456789abcdef" for c in cs)
|
||||
|
||||
|
||||
# ── TestBlobStore ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBlobStore:
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
assert key == "u1/tasks/r1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
# Verify by downloading — no exception means object exists
|
||||
retrieved = await store.download("u1", "u1/tasks/r1")
|
||||
assert retrieved == _BLOB
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
|
||||
result = await store.download("u1", "u1/notes/n1")
|
||||
assert result == b"note-data"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_removes_object(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.delete("u1", "u1/tasks/r1")
|
||||
with pytest.raises(ClientError) as exc_info:
|
||||
await store.download("u1", "u1/tasks/r1")
|
||||
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
# Delete a key that never existed — should not raise
|
||||
await store.delete("u1", "u1/tasks/nonexistent")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
|
||||
keys = await store.list_keys("u1", "tasks")
|
||||
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
|
||||
keys = await store.list_keys("u1", "tasks")
|
||||
assert "u1/notes/n1" not in keys
|
||||
assert "u1/tasks/r1" in keys
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
keys_u1 = await store.list_keys("u1", "tasks")
|
||||
assert "u2/tasks/r1" not in keys_u1
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
keys = await store.list_keys("u1", "tasks")
|
||||
assert keys == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
# Verify S3 metadata was set — check via head_object
|
||||
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||
mock_settings.S3_BUCKET = _BUCKET
|
||||
mock_settings.S3_REGION = _REGION
|
||||
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||
client = boto3.client("s3", region_name=_REGION)
|
||||
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||
assert response.get("ServerSideEncryption") == "AES256"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
|
||||
store = BlobStore()
|
||||
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||
client = boto3.client("s3", region_name=_REGION)
|
||||
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||
assert response["Metadata"]["checksum"] == _CHECKSUM
|
||||
|
||||
|
||||
# ── _blob_to_vector helper ────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestBlobToVector:
|
||||
def test_returns_32_floats(self) -> None:
|
||||
v = _blob_to_vector(b"test")
|
||||
assert len(v) == 32
|
||||
|
||||
def test_all_values_in_range(self) -> None:
|
||||
v = _blob_to_vector(b"test")
|
||||
assert all(-1.0 <= x <= 1.0 for x in v)
|
||||
|
||||
def test_deterministic(self) -> None:
|
||||
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
|
||||
|
||||
def test_different_blobs_different_vectors(self) -> None:
|
||||
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
|
||||
|
||||
|
||||
# ── TestVectorStorePinecone ───────────────────────────────────────────
|
||||
|
||||
|
||||
class TestVectorStorePinecone:
|
||||
def _store(self) -> VectorStore:
|
||||
store = VectorStore()
|
||||
store._use_pinecone = lambda: True # type: ignore[method-assign]
|
||||
return store
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_calls_index_upsert(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
mock_index.upsert.assert_called_once()
|
||||
call_kwargs = mock_index.upsert.call_args[1]
|
||||
assert call_kwargs.get("namespace") == "u1"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
|
||||
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_calls_index_query(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query-blob", top_k=5)
|
||||
mock_index.query.assert_called_once()
|
||||
query_kwargs = mock_index.query.call_args[1]
|
||||
assert query_kwargs.get("namespace") == "u1"
|
||||
assert query_kwargs.get("top_k") == 5
|
||||
assert query_kwargs.get("include_metadata") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_vector_search_results(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
results = await store.search("u1", b"query", top_k=10)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], VectorSearchResult)
|
||||
assert results[0].id == "v1"
|
||||
assert results[0].score == 0.95
|
||||
assert results[0].blob == b"result-blob"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_uses_derived_query_vector(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query-blob", top_k=3)
|
||||
expected_vector = _blob_to_vector(b"query-blob")
|
||||
actual_vector = mock_index.query.call_args[1].get("vector")
|
||||
assert actual_vector == expected_vector
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_calls_index_delete(self) -> None:
|
||||
mock_pc, mock_index = _pinecone_mock()
|
||||
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||
store = self._store()
|
||||
await store.delete("u1", ["v1", "v2"])
|
||||
mock_index.delete.assert_called_once()
|
||||
delete_kwargs = mock_index.delete.call_args[1]
|
||||
assert delete_kwargs.get("namespace") == "u1"
|
||||
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
|
||||
|
||||
|
||||
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
|
||||
|
||||
|
||||
class TestVectorStoreQdrant:
|
||||
def _store(self) -> VectorStore:
|
||||
store = VectorStore()
|
||||
store._use_pinecone = lambda: False # type: ignore[method-assign]
|
||||
return store
|
||||
|
||||
def _qdrant_mock(self) -> MagicMock:
|
||||
mock_hit = MagicMock()
|
||||
mock_hit.id = "v1"
|
||||
mock_hit.score = 0.88
|
||||
mock_hit.payload = {
|
||||
"blob": base64.b64encode(b"qdrant-result").decode(),
|
||||
"user_id": "u1",
|
||||
}
|
||||
mock_client = MagicMock()
|
||||
mock_client.search.return_value = [mock_hit]
|
||||
return mock_client
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_calls_client_upsert(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
mock_client.upsert.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upsert_uses_correct_collection(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||
await store.upsert("u1", items)
|
||||
call_kwargs = mock_client.upsert.call_args[1]
|
||||
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_calls_client_search(self) -> None:
|
||||
mock_client = self._qdrant_mock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query", top_k=5)
|
||||
mock_client.search.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_passes_limit(self) -> None:
|
||||
mock_client = self._qdrant_mock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.search("u1", b"query", top_k=7)
|
||||
call_kwargs = mock_client.search.call_args[1]
|
||||
assert call_kwargs.get("limit") == 7
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_search_returns_vector_search_results(self) -> None:
|
||||
mock_client = self._qdrant_mock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
results = await store.search("u1", b"query", top_k=5)
|
||||
assert len(results) == 1
|
||||
assert isinstance(results[0], VectorSearchResult)
|
||||
assert results[0].id == "v1"
|
||||
assert results[0].score == 0.88
|
||||
assert results[0].blob == b"qdrant-result"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_calls_client_delete(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.delete("u1", ["v1", "v2"])
|
||||
mock_client.delete.assert_called_once()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_delete_uses_correct_collection(self) -> None:
|
||||
mock_client = MagicMock()
|
||||
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||
store = self._store()
|
||||
await store.delete("u1", ["v1"])
|
||||
call_kwargs = mock_client.delete.call_args[1]
|
||||
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||
Reference in New Issue
Block a user