Compare commits
5 Commits
864dfdc4e6
...
c8ef7b119b
| Author | SHA1 | Date | |
|---|---|---|---|
| c8ef7b119b | |||
| 35dd9ac86f | |||
| e72d72f4f6 | |||
| 14d1a7351d | |||
| 68955d2fc2 |
311
BACKEND_PLAN.md
311
BACKEND_PLAN.md
@@ -2,8 +2,8 @@
|
|||||||
|
|
||||||
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
||||||
>
|
>
|
||||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, and backup blob storage.
|
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
|
||||||
> The backend NEVER persists user data. It receives context in requests, uses it for orchestration, and discards it.
|
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
@@ -20,7 +20,7 @@ adiuva-api/
|
|||||||
│ │ ├── orchestrator.py # LLM-based intent router
|
│ │ ├── orchestrator.py # LLM-based intent router
|
||||||
│ │ ├── execution_plan.py # Plan builder + cache
|
│ │ ├── execution_plan.py # Plan builder + cache
|
||||||
│ │ └── plugin_loader.py # Dynamic agent loading
|
│ │ └── plugin_loader.py # Dynamic agent loading
|
||||||
│ ├── agents/
|
│ ├── agents/ # Chat agents (proprietary logic + prompts)
|
||||||
│ │ ├── __init__.py # Auto-registers all agents
|
│ │ ├── __init__.py # Auto-registers all agents
|
||||||
│ │ ├── task_agent.py
|
│ │ ├── task_agent.py
|
||||||
│ │ ├── calendar_agent.py
|
│ │ ├── calendar_agent.py
|
||||||
@@ -32,7 +32,10 @@ adiuva-api/
|
|||||||
│ │ │ ├── __init__.py
|
│ │ │ ├── __init__.py
|
||||||
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
||||||
│ │ │ ├── plans.py # GET /plans/playbook
|
│ │ │ ├── plans.py # GET /plans/playbook
|
||||||
|
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
|
||||||
|
│ │ │ ├── vectors.py # Upsert/search cloud vector store
|
||||||
│ │ │ ├── backup.py # PUT/GET /backup
|
│ │ │ ├── backup.py # PUT/GET /backup
|
||||||
|
│ │ │ ├── plugins.py # Plugin marketplace
|
||||||
│ │ │ ├── auth.py # Register/login/refresh
|
│ │ │ ├── auth.py # Register/login/refresh
|
||||||
│ │ │ └── billing.py # Checkout/webhook/subscription
|
│ │ │ └── billing.py # Checkout/webhook/subscription
|
||||||
│ │ └── middleware/
|
│ │ └── middleware/
|
||||||
@@ -40,6 +43,16 @@ adiuva-api/
|
|||||||
│ │ ├── auth.py # JWT validation
|
│ │ ├── auth.py # JWT validation
|
||||||
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
||||||
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
||||||
|
│ ├── storage/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
|
||||||
|
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
|
||||||
|
│ │ └── encryption.py # Integrity verification only — NO decryption
|
||||||
|
│ ├── marketplace/
|
||||||
|
│ │ ├── __init__.py
|
||||||
|
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
|
||||||
|
│ │ ├── plugin_review.py # Review queue + approval workflow
|
||||||
|
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
|
||||||
│ ├── billing/
|
│ ├── billing/
|
||||||
│ │ ├── __init__.py
|
│ │ ├── __init__.py
|
||||||
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
||||||
@@ -53,8 +66,10 @@ adiuva-api/
|
|||||||
│ ├── test_orchestrator.py
|
│ ├── test_orchestrator.py
|
||||||
│ ├── test_agents.py
|
│ ├── test_agents.py
|
||||||
│ ├── test_auth.py
|
│ ├── test_auth.py
|
||||||
│ └── test_backup.py
|
│ ├── test_backup.py
|
||||||
├── alembic/ # DB migrations (auth/billing tables only)
|
│ ├── test_storage.py
|
||||||
|
│ └── test_plugins.py
|
||||||
|
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
|
||||||
│ ├── alembic.ini
|
│ ├── alembic.ini
|
||||||
│ └── versions/
|
│ └── versions/
|
||||||
├── requirements.txt
|
├── requirements.txt
|
||||||
@@ -92,7 +107,7 @@ adiuva-api/
|
|||||||
pytest-asyncio>=0.24.0
|
pytest-asyncio>=0.24.0
|
||||||
```
|
```
|
||||||
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
||||||
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod)
|
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
||||||
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
||||||
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
||||||
- [x] Write `.env.example`
|
- [x] Write `.env.example`
|
||||||
@@ -103,13 +118,24 @@ adiuva-api/
|
|||||||
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
||||||
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
||||||
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
||||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification']`, `table: str | None`, `data: dict | None`
|
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
|
||||||
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
||||||
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
||||||
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
||||||
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
||||||
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
||||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
||||||
|
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
|
||||||
|
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
|
||||||
|
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
|
||||||
|
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
|
||||||
|
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
|
||||||
|
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
|
||||||
|
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
|
||||||
|
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
|
||||||
|
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
|
||||||
|
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
|
||||||
|
- `PluginInstallRequest`: `plugin_id: str`
|
||||||
- **Outcome:** All request/response models defined and validated.
|
- **Outcome:** All request/response models defined and validated.
|
||||||
|
|
||||||
### Step 3 — Agent Registry + base classes ✅
|
### Step 3 — Agent Registry + base classes ✅
|
||||||
@@ -130,8 +156,8 @@ adiuva-api/
|
|||||||
- [x] Unit tests: register, get, list, call_agent with mock
|
- [x] Unit tests: register, get, list, call_agent with mock
|
||||||
- **Outcome:** Pluggable agent framework.
|
- **Outcome:** Pluggable agent framework.
|
||||||
|
|
||||||
### Step 4 — Orchestrator
|
### Step 4 — Orchestrator ✅
|
||||||
- [ ] `app/core/orchestrator.py`:
|
- [x] `app/core/orchestrator.py`:
|
||||||
- `async classify_intent(message, context, registry) -> str`:
|
- `async classify_intent(message, context, registry) -> str`:
|
||||||
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
||||||
- Uses gpt-4o-mini via LangChain for low latency
|
- Uses gpt-4o-mini via LangChain for low latency
|
||||||
@@ -146,16 +172,17 @@ adiuva-api/
|
|||||||
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
||||||
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
||||||
- Main entry point
|
- Main entry point
|
||||||
|
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
|
||||||
- Classifies intent
|
- Classifies intent
|
||||||
- If `execution_mode == 'direct'`: route + return response
|
- If `execution_mode == 'direct'`: route + return response
|
||||||
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
||||||
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
||||||
- Same as orchestrate but yields tokens for WebSocket streaming
|
- Same as orchestrate but yields tokens for WebSocket streaming
|
||||||
- [ ] Integration tests with mocked LLM and mocked agents
|
- [x] Integration tests with mocked LLM and mocked agents
|
||||||
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
||||||
|
|
||||||
### Step 5 — Execution Plan generator
|
### Step 5 — Execution Plan generator ✅
|
||||||
- [ ] `app/core/execution_plan.py`:
|
- [x] `app/core/execution_plan.py`:
|
||||||
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
||||||
- `ExecutionPlanBuilder`:
|
- `ExecutionPlanBuilder`:
|
||||||
- `add_step(action, params) -> self`
|
- `add_step(action, params) -> self`
|
||||||
@@ -168,32 +195,52 @@ adiuva-api/
|
|||||||
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
||||||
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
||||||
|
|
||||||
### Step 6 — Chat Agents
|
### Step 6 — Chat Agents ✅
|
||||||
- [ ] `app/agents/task_agent.py` — `@registry.register`:
|
- [x] `app/agents/task_agent.py` — `@registry.register`:
|
||||||
- Description: "Manages tasks: create, update, list, suggest"
|
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||||
- Tools: `create_task(title, description, priority, due_date)`, `update_task(id, updates)`, `list_tasks(filters)`, `suggest_tasks(notes_context)`
|
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
|
||||||
- System prompt: PM-oriented, validates task structure, infers priority from context
|
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
|
||||||
- `handle()`: LLM + tool loop via `_tool_loop()`, returns response text + list of actions performed
|
- Accepts flexible context; sentinel `-1` for optional integer update fields
|
||||||
- [ ] `app/agents/calendar_agent.py` — `@registry.register`:
|
- [x] `app/agents/checkpoint_agent.py` — `@registry.register`:
|
||||||
- Description: "Calendar management: events, conflicts, scheduling"
|
- Description: "Manages project checkpoints (milestones): list, create, update, delete"
|
||||||
- Tools: `list_events(date_range)`, `detect_conflicts(events)`, `suggest_reschedule(conflict)`
|
- Tools (4): `list_checkpoints(project_id)`, `create_checkpoint(project_id, title, date, is_ai_suggested, is_approved)`, `update_checkpoint(checkpoint_id, ...)`, `delete_checkpoint(checkpoint_id)`
|
||||||
- Works with event metadata passed in context (never raw calendar data stored)
|
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
|
||||||
- [ ] `app/agents/email_agent.py` — `@registry.register`:
|
- [x] `app/agents/project_agent.py` — `@registry.register`:
|
||||||
- Description: "Email analysis: classify, extract actions, draft responses"
|
- Description: "Manages projects: list, get, create, update, archive, delete"
|
||||||
- Tools: `classify_email(metadata)`, `extract_action_items(metadata)`, `draft_response(thread_context)`
|
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
|
||||||
- Only processes metadata sent by client — never raw email bodies
|
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
|
||||||
- [ ] `app/agents/analytics_agent.py` — `@registry.register`:
|
- [x] `app/agents/note_agent.py` — `@registry.register`:
|
||||||
- Description: "Workspace analytics: metrics, reports, trends"
|
- Description: "Manages notes: list, get, create, update, delete"
|
||||||
- Tools: `calculate_metrics(task_data)`, `generate_report(period, data)`, `trend_analysis(data_points)`
|
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
|
||||||
- Crunches numbers from context, returns structured insights
|
- content is Markdown; `get_note` should be called before update to preserve existing content
|
||||||
- [ ] `app/agents/__init__.py`: imports all agent modules to trigger `@registry.register` decorators
|
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
|
||||||
- [ ] Unit tests per agent with mocked LLM
|
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
|
||||||
- **Outcome:** Four specialized agents, all registered and tested.
|
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Checkpoints, Projects, Notes), all registered and tested.
|
||||||
|
|
||||||
### Step 7 — API Routes
|
### Step 7 — Storage Layer ✅
|
||||||
|
- [x] `app/storage/blob_store.py`:
|
||||||
|
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
|
||||||
|
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
|
||||||
|
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
|
||||||
|
- [x] `app/storage/vector_store.py`:
|
||||||
|
- `VectorStore`: `async upsert`, `async search`, `async delete`
|
||||||
|
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
|
||||||
|
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
|
||||||
|
- ANN on encrypted data: known accuracy trade-off, documented
|
||||||
|
- [x] `app/storage/encryption.py`:
|
||||||
|
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
|
||||||
|
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
|
||||||
|
- Backend NEVER holds decryption keys
|
||||||
|
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
|
||||||
|
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
||||||
|
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
|
||||||
|
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
|
||||||
|
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
|
||||||
|
|
||||||
#### 7a — Chat endpoint
|
### Step 8 — API Routes ✅
|
||||||
- [ ] `app/api/routes/chat.py`:
|
|
||||||
|
#### 8a — Chat endpoint
|
||||||
|
- [x] `app/api/routes/chat.py`:
|
||||||
- `POST /api/v1/chat`:
|
- `POST /api/v1/chat`:
|
||||||
- Request: `ChatRequest`
|
- Request: `ChatRequest`
|
||||||
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
||||||
@@ -204,48 +251,93 @@ adiuva-api/
|
|||||||
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
||||||
- Heartbeat ping every 30s to keep connection alive
|
- Heartbeat ping every 30s to keep connection alive
|
||||||
|
|
||||||
#### 7b — Plans endpoint
|
#### 8b — Plans endpoint
|
||||||
- [ ] `app/api/routes/plans.py`:
|
- [x] `app/api/routes/plans.py`:
|
||||||
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
||||||
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
||||||
|
|
||||||
#### 7c — Backup endpoint
|
#### 8c — Storage endpoint (cloud records)
|
||||||
- [ ] `app/api/routes/backup.py`:
|
- [x] `app/api/routes/storage.py`:
|
||||||
|
- `POST /api/v1/storage/records`: Create encrypted record
|
||||||
|
- Request: `StorageRecordCreate`
|
||||||
|
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
|
||||||
|
- Response: `{id: str, created_at: int}`
|
||||||
|
- `GET /api/v1/storage/records`: List record metadata (no blobs)
|
||||||
|
- Query params: `table: str`, `page: int`, `limit: int`
|
||||||
|
- Response: `list[{id, table, checksum, created_at, updated_at}]`
|
||||||
|
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
|
||||||
|
- Response: blob bytes + `X-Checksum` header
|
||||||
|
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
|
||||||
|
- Request: `StorageRecordUpdate`
|
||||||
|
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
|
||||||
|
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
|
||||||
|
|
||||||
|
#### 8d — Vectors endpoint (cloud vector store)
|
||||||
|
- [x] `app/api/routes/vectors.py`:
|
||||||
|
- `POST /api/v1/storage/vectors/upsert`:
|
||||||
|
- Request: `VectorUpsertRequest`
|
||||||
|
- Verifies checksums, delegates to `VectorStore.upsert()`
|
||||||
|
- Response: `{upserted: int}`
|
||||||
|
- `POST /api/v1/storage/vectors/search`:
|
||||||
|
- Request: `VectorSearchRequest`
|
||||||
|
- Delegates to `VectorStore.search()`
|
||||||
|
- Response: `VectorSearchResponse`
|
||||||
|
- `DELETE /api/v1/storage/vectors`:
|
||||||
|
- Request: `{ids: list[str]}`
|
||||||
|
|
||||||
|
#### 8e — Backup endpoint
|
||||||
|
- [x] `app/api/routes/backup.py`:
|
||||||
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
||||||
- Free: 0 (no backup)
|
- Free: 0 (no backup)
|
||||||
- Pro: 5 GB
|
- Pro: 5 GB
|
||||||
- Power: 50 GB
|
- Power: 25 GB
|
||||||
- Team: unlimited
|
- Team: unlimited
|
||||||
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
||||||
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
||||||
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
||||||
|
|
||||||
#### 7d — Auth endpoint
|
#### 8f — Plugins endpoint
|
||||||
- [ ] `app/api/routes/auth.py`:
|
- [x] `app/api/routes/plugins.py`:
|
||||||
|
- `GET /api/v1/plugins`:
|
||||||
|
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
|
||||||
|
- Response: `PluginListResponse`
|
||||||
|
- Available from Power tier and above
|
||||||
|
- `GET /api/v1/plugins/{id}`:
|
||||||
|
- Response: `PluginManifest` + ratings + install count
|
||||||
|
- `POST /api/v1/plugins/{id}/install`:
|
||||||
|
- Request: `PluginInstallRequest`
|
||||||
|
- Records installation for the user (billing tracking, analytics)
|
||||||
|
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
|
||||||
|
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
|
||||||
|
- `DELETE /api/v1/plugins/{id}/install`:
|
||||||
|
- Unregisters installation
|
||||||
|
|
||||||
|
#### 8g — Auth endpoint
|
||||||
|
- [x] `app/api/routes/auth.py`:
|
||||||
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
||||||
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
||||||
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
||||||
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
||||||
|
|
||||||
#### 7e — Billing endpoint
|
#### 8h — Billing endpoint
|
||||||
- [ ] `app/api/routes/billing.py`:
|
- [x] `app/api/routes/billing.py`:
|
||||||
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
||||||
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
||||||
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
||||||
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
||||||
|
|
||||||
- **Outcome:** Complete REST + WebSocket API.
|
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
|
||||||
|
|
||||||
### Step 8 — Middleware
|
### Step 9 — Middleware
|
||||||
|
|
||||||
#### 8a — Auth middleware
|
#### 9a — Auth middleware
|
||||||
- [ ] `app/api/middleware/auth.py`:
|
- [ ] `app/api/middleware/auth.py`:
|
||||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
||||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
||||||
- Raises `401` on invalid/expired token
|
- Raises `401` on invalid/expired token
|
||||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||||
|
|
||||||
#### 8b — Rate limiter
|
#### 9b — Rate limiter
|
||||||
- [ ] `app/api/middleware/rate_limit.py`:
|
- [ ] `app/api/middleware/rate_limit.py`:
|
||||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
||||||
- Tier-based limits:
|
- Tier-based limits:
|
||||||
@@ -255,7 +347,7 @@ adiuva-api/
|
|||||||
- Team: 200 req/seat/min
|
- Team: 200 req/seat/min
|
||||||
- Custom 429 response with `Retry-After` header
|
- Custom 429 response with `Retry-After` header
|
||||||
|
|
||||||
#### 8c — Sanitizer
|
#### 9c — Sanitizer
|
||||||
- [ ] `app/api/middleware/sanitizer.py`:
|
- [ ] `app/api/middleware/sanitizer.py`:
|
||||||
- Response middleware that scans response bodies
|
- Response middleware that scans response bodies
|
||||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
||||||
@@ -264,7 +356,27 @@ adiuva-api/
|
|||||||
|
|
||||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
||||||
|
|
||||||
### Step 9 — Billing & Tier management
|
### Step 10 — Plugin Marketplace
|
||||||
|
- [ ] `app/marketplace/plugin_registry.py`:
|
||||||
|
- `PluginRegistry`:
|
||||||
|
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
||||||
|
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
||||||
|
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
||||||
|
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
||||||
|
- `async reject_plugin(plugin_id, reason: str) -> None`
|
||||||
|
- [ ] `app/marketplace/plugin_review.py`:
|
||||||
|
- `ReviewQueue`:
|
||||||
|
- `async get_pending() -> list[dict]`
|
||||||
|
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
||||||
|
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
||||||
|
- [ ] `app/marketplace/revenue_share.py`:
|
||||||
|
- `RevenueShare`:
|
||||||
|
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
||||||
|
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
||||||
|
- `async get_earnings(developer_id, period) -> dict`
|
||||||
|
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
||||||
|
|
||||||
|
### Step 11 — Billing & Tier management
|
||||||
- [ ] `app/billing/stripe_service.py`:
|
- [ ] `app/billing/stripe_service.py`:
|
||||||
- `create_checkout_session(user_id, tier) -> str`
|
- `create_checkout_session(user_id, tier) -> str`
|
||||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
||||||
@@ -275,33 +387,77 @@ adiuva-api/
|
|||||||
- Feature matrix:
|
- Feature matrix:
|
||||||
```python
|
```python
|
||||||
FEATURES = {
|
FEATURES = {
|
||||||
'free': {'agents': 3, 'batch': False, 'providers': 1, 'backup_gb': 0},
|
'free': {
|
||||||
'pro': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': 5},
|
'agents': 3,
|
||||||
'power': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': 50, 'byok': True},
|
'batch_active': 2,
|
||||||
'team': {'agents': -1, 'batch': True, 'providers': -1, 'backup_gb': -1, 'sso': True},
|
'cloud_storage_gb': 0,
|
||||||
|
'backup_gb': 0,
|
||||||
|
'providers': 1,
|
||||||
|
'batch_builder': False,
|
||||||
|
'plugin_marketplace': False,
|
||||||
|
'sso': False,
|
||||||
|
},
|
||||||
|
'pro': {
|
||||||
|
'agents': -1, # unlimited
|
||||||
|
'batch_active': 10,
|
||||||
|
'cloud_storage_gb': 5,
|
||||||
|
'backup_gb': 5,
|
||||||
|
'providers': -1,
|
||||||
|
'batch_builder': False,
|
||||||
|
'plugin_marketplace': False,
|
||||||
|
'sso': False,
|
||||||
|
},
|
||||||
|
'power': {
|
||||||
|
'agents': -1,
|
||||||
|
'batch_active': -1, # unlimited
|
||||||
|
'cloud_storage_gb': 25,
|
||||||
|
'backup_gb': 25,
|
||||||
|
'providers': -1,
|
||||||
|
'batch_builder': True,
|
||||||
|
'plugin_marketplace': True,
|
||||||
|
'sso': False,
|
||||||
|
},
|
||||||
|
'team': {
|
||||||
|
'agents': -1,
|
||||||
|
'batch_active': -1,
|
||||||
|
'cloud_storage_gb': -1,
|
||||||
|
'backup_gb': -1,
|
||||||
|
'providers': -1,
|
||||||
|
'batch_builder': True,
|
||||||
|
'plugin_marketplace': True,
|
||||||
|
'sso': True,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
- `get_tier(user_id) -> BillingTier`
|
- `get_tier(user_id) -> BillingTier`
|
||||||
- `check_feature(user_id, feature) -> bool`
|
- `check_feature(user_id, feature) -> bool`
|
||||||
- `get_rate_limit(tier) -> int`
|
- `get_rate_limit(tier) -> int`
|
||||||
- **Outcome:** Stripe integration with tier-based feature gating.
|
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
||||||
|
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
||||||
|
|
||||||
### Step 10 — Database (auth/billing only)
|
### Step 12 — Database (auth/billing/marketplace only)
|
||||||
- [ ] PostgreSQL schema via Alembic:
|
- [ ] PostgreSQL schema via Alembic:
|
||||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
||||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
||||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
||||||
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
||||||
|
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
|
||||||
|
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
|
||||||
|
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
||||||
|
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
||||||
|
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
||||||
- [ ] Initial Alembic migration
|
- [ ] Initial Alembic migration
|
||||||
- [ ] SQLAlchemy models in `app/models.py`
|
- [ ] SQLAlchemy models in `app/models.py`
|
||||||
- **Outcome:** Auth and billing persistence. Zero user data stored.
|
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
||||||
|
|
||||||
### Step 11 — Testing & deployment
|
### Step 13 — Testing & deployment
|
||||||
- [ ] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed)
|
- [ ] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
|
||||||
- [ ] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
- [ ] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
||||||
- [ ] `tests/test_agents.py`: each agent with mocked tools
|
- [ ] `tests/test_agents.py`: each agent with mocked tools
|
||||||
- [ ] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
- [ ] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
||||||
- [ ] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
- [ ] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
||||||
|
- [ ] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
|
||||||
|
- [ ] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
|
||||||
- [ ] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
- [ ] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
||||||
- [ ] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
- [ ] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
||||||
- **Outcome:** Fully tested, deployable backend.
|
- **Outcome:** Fully tested, deployable backend.
|
||||||
@@ -320,10 +476,22 @@ adiuva-api/
|
|||||||
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
||||||
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
||||||
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
||||||
|
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
|
||||||
|
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
|
||||||
|
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
|
||||||
|
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
|
||||||
|
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
|
||||||
|
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
|
||||||
|
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
|
||||||
|
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
|
||||||
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
||||||
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
||||||
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
||||||
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
||||||
|
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
|
||||||
|
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
|
||||||
|
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
|
||||||
|
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
|
||||||
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
||||||
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
||||||
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
||||||
@@ -339,21 +507,24 @@ adiuva-api/
|
|||||||
| Framework | FastAPI + Uvicorn |
|
| Framework | FastAPI + Uvicorn |
|
||||||
| LLM | LangChain + langchain-openai |
|
| LLM | LangChain + langchain-openai |
|
||||||
| Auth | PyJWT + bcrypt + OAuth2 |
|
| Auth | PyJWT + bcrypt + OAuth2 |
|
||||||
| Billing | stripe-python |
|
| Billing | stripe-python + Stripe Connect |
|
||||||
| Storage | boto3 (S3) |
|
| Blob storage | boto3 (S3) |
|
||||||
|
| Vector store | Pinecone or Qdrant (configurable) |
|
||||||
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
||||||
| Rate limiting | slowapi |
|
| Rate limiting | slowapi |
|
||||||
| Testing | pytest + pytest-asyncio + httpx |
|
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
|
||||||
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
||||||
|
|
||||||
---
|
---
|
||||||
|
|
||||||
## Development Rules
|
## Development Rules
|
||||||
|
|
||||||
1. **NEVER persist user data.** The DB stores only auth, billing, and backup metadata. User context arrives in requests and is discarded after processing.
|
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
|
||||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending.
|
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
|
||||||
3. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
|
||||||
4. **Type hints everywhere.** All functions have full type annotations.
|
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
||||||
5. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
5. **Type hints everywhere.** All functions have full type annotations.
|
||||||
6. **Structured logging.** JSON logs with request ID correlation.
|
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
||||||
7. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
7. **Structured logging.** JSON logs with request ID correlation.
|
||||||
|
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
|
||||||
|
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
||||||
|
|||||||
@@ -0,0 +1,5 @@
|
|||||||
|
"""Import all agent modules to trigger @registry.register decorators."""
|
||||||
|
|
||||||
|
from app.agents import checkpoint_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
|
__all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
122
app/agents/checkpoint_agent.py
Normal file
122
app/agents/checkpoint_agent.py
Normal file
@@ -0,0 +1,122 @@
|
|||||||
|
"""Checkpoint agent — project milestone management (list, create, update, delete)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
|
||||||
|
"track progress on a project — they are not calendar events.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
||||||
|
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
||||||
|
" - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n"
|
||||||
|
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
||||||
|
" - For update_checkpoint, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Listing without a project_id returns all checkpoints across projects\n"
|
||||||
|
" - Always echo the title and formatted date in your confirmation."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_checkpoints(project_id: str = "") -> str:
|
||||||
|
"""List checkpoints. Provide project_id to scope to a specific project."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "checkpoints",
|
||||||
|
"filters": {"projectId": project_id or None},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_checkpoint(
|
||||||
|
project_id: str,
|
||||||
|
title: str,
|
||||||
|
date: int,
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a project checkpoint (milestone).
|
||||||
|
project_id: REQUIRED UUID of the parent project
|
||||||
|
title: descriptive name for the milestone
|
||||||
|
date: Unix timestamp in milliseconds
|
||||||
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "checkpoints",
|
||||||
|
"data": {
|
||||||
|
"projectId": project_id,
|
||||||
|
"title": title,
|
||||||
|
"date": date,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_checkpoint(
|
||||||
|
checkpoint_id: str,
|
||||||
|
title: str = "",
|
||||||
|
date: int = -1,
|
||||||
|
is_approved: int = -1,
|
||||||
|
) -> str:
|
||||||
|
"""Update a checkpoint. Only pass fields that should change.
|
||||||
|
checkpoint_id: UUID of the checkpoint (required)
|
||||||
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if date != -1:
|
||||||
|
updates["date"] = date
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
|
return json.dumps({
|
||||||
|
"action": "update_record",
|
||||||
|
"table": "checkpoints",
|
||||||
|
"data": {"id": checkpoint_id, "updates": updates},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_checkpoint(checkpoint_id: str) -> str:
|
||||||
|
"""Delete a checkpoint permanently by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "checkpoints",
|
||||||
|
"data": {"id": checkpoint_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class CheckpointAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "checkpoint_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages project checkpoints (milestones): list, create, update, delete"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(
|
||||||
|
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return await self._tool_loop(llm, messages, self.get_tools())
|
||||||
123
app/agents/note_agent.py
Normal file
123
app/agents/note_agent.py
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
||||||
|
"and delete Markdown notes in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - content is always Markdown; preserve formatting when updating\n"
|
||||||
|
" - project_id is optional; link a note to a project when mentioned\n"
|
||||||
|
" - When updating, call get_note first if you need to read existing content\n"
|
||||||
|
" before appending or replacing sections\n"
|
||||||
|
" - list_notes without project_id returns all notes; scope with project_id\n"
|
||||||
|
" when the user is working within a specific project\n"
|
||||||
|
" - Do not fabricate note content — reflect what the user provides or what\n"
|
||||||
|
" is already in the note (retrieved via get_note)."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_notes(project_id: str = "") -> str:
|
||||||
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "notes",
|
||||||
|
"filters": {"projectId": project_id or None},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_note(note_id: str) -> str:
|
||||||
|
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "get",
|
||||||
|
"table": "notes",
|
||||||
|
"data": {"id": note_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_note(
|
||||||
|
title: str,
|
||||||
|
content: str,
|
||||||
|
project_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Create a new note.
|
||||||
|
title: note heading (required)
|
||||||
|
content: Markdown body text (required)
|
||||||
|
project_id: optional UUID linking this note to a project
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "notes",
|
||||||
|
"data": {
|
||||||
|
"title": title,
|
||||||
|
"content": content,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_note(
|
||||||
|
note_id: str,
|
||||||
|
title: str = "",
|
||||||
|
content: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update an existing note. Only pass fields that should change.
|
||||||
|
note_id: UUID of the note (required)
|
||||||
|
If you need to preserve existing content, call get_note first.
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if content:
|
||||||
|
updates["content"] = content
|
||||||
|
return json.dumps({
|
||||||
|
"action": "update_record",
|
||||||
|
"table": "notes",
|
||||||
|
"data": {"id": note_id, "updates": updates},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_note(note_id: str) -> str:
|
||||||
|
"""Delete a note permanently by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "notes",
|
||||||
|
"data": {"id": note_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class NoteAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "note_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages notes: list, get, create, update, delete"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return [list_notes, get_note, create_note, update_note, delete_note]
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(
|
||||||
|
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return await self._tool_loop(llm, messages, self.get_tools())
|
||||||
158
app/agents/project_agent.py
Normal file
158
app/agents/project_agent.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a project management assistant. You help users create, find,\n"
|
||||||
|
"update, and archive projects in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: active, archived\n"
|
||||||
|
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
||||||
|
" - ai_summary is populated only when the user asks for a project summary;\n"
|
||||||
|
" derive it from context data — do not fabricate content\n"
|
||||||
|
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
||||||
|
" user wants a complete cross-client view including archived projects\n"
|
||||||
|
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
||||||
|
" list_projects if you only have a project name\n"
|
||||||
|
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
||||||
|
" only call delete_project when the user explicitly confirms deletion."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_projects(
|
||||||
|
client_id: str = "",
|
||||||
|
include_archived: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""List projects, optionally filtered by client_id.
|
||||||
|
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "projects",
|
||||||
|
"filters": {
|
||||||
|
"clientId": client_id or None,
|
||||||
|
"includeArchived": bool(include_archived),
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_all_projects() -> str:
|
||||||
|
"""List every project regardless of client or status.
|
||||||
|
Use only when the user wants a complete cross-client overview.
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list_all",
|
||||||
|
"table": "projects",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_project(project_id: str) -> str:
|
||||||
|
"""Fetch a single project by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "get",
|
||||||
|
"table": "projects",
|
||||||
|
"data": {"id": project_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_project(
|
||||||
|
name: str,
|
||||||
|
client_id: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Create a new project.
|
||||||
|
name: human-readable project name (required)
|
||||||
|
client_id: optional UUID of the owning client
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "projects",
|
||||||
|
"data": {
|
||||||
|
"name": name,
|
||||||
|
"clientId": client_id or None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_project(
|
||||||
|
project_id: str,
|
||||||
|
name: str = "",
|
||||||
|
client_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
ai_summary: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""Update a project. Only pass fields that should change.
|
||||||
|
project_id: UUID of the project (required)
|
||||||
|
status: active | archived
|
||||||
|
ai_summary: AI-generated summary text (populate only when explicitly requested)
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if name:
|
||||||
|
updates["name"] = name
|
||||||
|
if client_id:
|
||||||
|
updates["clientId"] = client_id
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if ai_summary:
|
||||||
|
updates["aiSummary"] = ai_summary
|
||||||
|
return json.dumps({
|
||||||
|
"action": "update_record",
|
||||||
|
"table": "projects",
|
||||||
|
"data": {"id": project_id, "updates": updates},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_project(project_id: str) -> str:
|
||||||
|
"""Permanently delete a project and orphan its tasks.
|
||||||
|
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||||
|
has explicitly confirmed they want permanent deletion.
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "projects",
|
||||||
|
"data": {"id": project_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class ProjectAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "project_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages projects: list, get, create, update, archive, delete"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(
|
||||||
|
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return await self._tool_loop(llm, messages, self.get_tools())
|
||||||
229
app/agents/task_agent.py
Normal file
229
app/agents/task_agent.py
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
"""Task agent — full CRUD for tasks and task comments."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.agent_registry import ChatAgent, registry
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT = (
|
||||||
|
"You are a task management assistant for a project workspace.\n"
|
||||||
|
"You create, update, list, and track tasks and their comments.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: todo, in_progress, done\n"
|
||||||
|
" - priority must be one of: high, medium, low\n"
|
||||||
|
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
||||||
|
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
||||||
|
" - project_id is optional; link to a project when the user mentions one\n"
|
||||||
|
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
||||||
|
" did not explicitly request; 0 otherwise\n"
|
||||||
|
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
||||||
|
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
||||||
|
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Always confirm the action in plain, user-friendly language."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks(
|
||||||
|
project_id: str = "",
|
||||||
|
status: str = "",
|
||||||
|
search: str = "",
|
||||||
|
order_by: str = "",
|
||||||
|
) -> str:
|
||||||
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "tasks",
|
||||||
|
"filters": {
|
||||||
|
"projectId": project_id or None,
|
||||||
|
"status": status or None,
|
||||||
|
"search": search or None,
|
||||||
|
"orderBy": order_by or None,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_task(
|
||||||
|
title: str,
|
||||||
|
description: str = "",
|
||||||
|
status: str = "todo",
|
||||||
|
priority: str = "medium",
|
||||||
|
assignees: str = "[]",
|
||||||
|
due_date: int = 0,
|
||||||
|
project_id: str = "",
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a new task.
|
||||||
|
title: task title (required)
|
||||||
|
description: optional details
|
||||||
|
status: todo | in_progress | done (default: todo)
|
||||||
|
priority: high | medium | low (default: medium)
|
||||||
|
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
||||||
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
|
project_id: optional UUID of the parent project
|
||||||
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms; 1 when confirmed
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "tasks",
|
||||||
|
"data": {
|
||||||
|
"title": title,
|
||||||
|
"description": description or None,
|
||||||
|
"status": status,
|
||||||
|
"priority": priority,
|
||||||
|
"assignee": assignees,
|
||||||
|
"dueDate": due_date or None,
|
||||||
|
"projectId": project_id or None,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_task(
|
||||||
|
task_id: str,
|
||||||
|
title: str = "",
|
||||||
|
description: str = "",
|
||||||
|
status: str = "",
|
||||||
|
priority: str = "",
|
||||||
|
assignees: str = "",
|
||||||
|
due_date: int = -1,
|
||||||
|
project_id: str = "",
|
||||||
|
is_approved: int = -1,
|
||||||
|
) -> str:
|
||||||
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
|
task_id: the task's UUID (required)
|
||||||
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the value
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if description:
|
||||||
|
updates["description"] = description
|
||||||
|
if status:
|
||||||
|
updates["status"] = status
|
||||||
|
if priority:
|
||||||
|
updates["priority"] = priority
|
||||||
|
if assignees:
|
||||||
|
updates["assignee"] = assignees
|
||||||
|
if due_date != -1:
|
||||||
|
updates["dueDate"] = due_date or None
|
||||||
|
if project_id:
|
||||||
|
updates["projectId"] = project_id
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
|
return json.dumps({
|
||||||
|
"action": "update_record",
|
||||||
|
"table": "tasks",
|
||||||
|
"data": {"id": task_id, "updates": updates},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task(task_id: str) -> str:
|
||||||
|
"""Delete a task permanently by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "tasks",
|
||||||
|
"data": {"id": task_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_tasks_due_today() -> str:
|
||||||
|
"""List all tasks whose due date falls on today's date."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list_due_today",
|
||||||
|
"table": "tasks",
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Task comment tools ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_task_comments(task_id: str) -> str:
|
||||||
|
"""List all comments on a task by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "list",
|
||||||
|
"table": "taskComments",
|
||||||
|
"filters": {"taskId": task_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||||
|
"""Add a comment to a task.
|
||||||
|
task_id: UUID of the task to comment on
|
||||||
|
author: name or ID of the comment author
|
||||||
|
content: comment text
|
||||||
|
"""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "create_record",
|
||||||
|
"table": "taskComments",
|
||||||
|
"data": {
|
||||||
|
"taskId": task_id,
|
||||||
|
"author": author,
|
||||||
|
"content": content,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_task_comment(comment_id: str) -> str:
|
||||||
|
"""Delete a task comment by its UUID."""
|
||||||
|
return json.dumps({
|
||||||
|
"action": "delete_record",
|
||||||
|
"table": "taskComments",
|
||||||
|
"data": {"id": comment_id},
|
||||||
|
})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@registry.register
|
||||||
|
class TaskAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "task_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content=_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(
|
||||||
|
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
||||||
|
),
|
||||||
|
]
|
||||||
|
return await self._tool_loop(llm, messages, self.get_tools())
|
||||||
46
app/api/deps.py
Normal file
46
app/api/deps.py
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
"""Shared FastAPI dependencies.
|
||||||
|
|
||||||
|
``get_current_user`` decodes the Bearer JWT and returns a ``UserProfile``.
|
||||||
|
Step 9 will layer rate-limiting and sanitization middleware on top of this.
|
||||||
|
Step 12 will add a DB look-up to fetch the live tier from PostgreSQL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.schemas import BillingTier, UserProfile
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: str = Depends(oauth2_scheme),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Validate a Bearer JWT and return the authenticated user.
|
||||||
|
|
||||||
|
Raises ``HTTP 401`` on any invalid or expired token.
|
||||||
|
The tier embedded in the JWT is used for feature-gating until Step 12
|
||||||
|
adds a live DB lookup.
|
||||||
|
"""
|
||||||
|
credentials_exc = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
tier: str = payload.get("tier", "free")
|
||||||
|
if not user_id or not email:
|
||||||
|
raise credentials_exc
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exc
|
||||||
|
|
||||||
|
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
||||||
118
app/api/routes/auth.py
Normal file
118
app/api/routes/auth.py
Normal file
@@ -0,0 +1,118 @@
|
|||||||
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
|
Users and refresh tokens are kept in an in-memory dict until Step 12
|
||||||
|
migrates them to PostgreSQL.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from jose import jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
|
||||||
|
_users: dict[str, dict[str, Any]] = {} # email → user record
|
||||||
|
_refresh_tokens: dict[str, str] = {} # plain token → user_id
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _hash_password(password: str) -> str:
|
||||||
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_password(password: str, hashed: str) -> bool:
|
||||||
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
|
||||||
|
now = int(time.time())
|
||||||
|
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||||
|
access_payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"email": email,
|
||||||
|
"tier": tier,
|
||||||
|
"exp": access_exp,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
access_token = jwt.encode(
|
||||||
|
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
||||||
|
)
|
||||||
|
refresh_token = str(uuid.uuid4())
|
||||||
|
_refresh_tokens[refresh_token] = user_id
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=refresh_token,
|
||||||
|
expires_at=access_exp * 1000, # milliseconds for client
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _RegisterRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class _LoginRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class _RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(body: _RegisterRequest) -> AuthTokens:
|
||||||
|
"""Create a new account and return JWT tokens."""
|
||||||
|
if body.email in _users:
|
||||||
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||||
|
user_id = str(uuid.uuid4())
|
||||||
|
_users[body.email] = {
|
||||||
|
"id": user_id,
|
||||||
|
"email": body.email,
|
||||||
|
"password_hash": _hash_password(body.password),
|
||||||
|
"tier": "free",
|
||||||
|
}
|
||||||
|
return _make_tokens(user_id, body.email, "free")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=AuthTokens)
|
||||||
|
async def login(body: _LoginRequest) -> AuthTokens:
|
||||||
|
"""Validate credentials and return JWT tokens."""
|
||||||
|
user = _users.get(body.email)
|
||||||
|
if not user or not _verify_password(body.password, user["password_hash"]):
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||||
|
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=AuthTokens)
|
||||||
|
async def refresh(body: _RefreshRequest) -> AuthTokens:
|
||||||
|
"""Rotate a refresh token and return a new token pair."""
|
||||||
|
user_id = _refresh_tokens.pop(body.refresh_token, None)
|
||||||
|
if user_id is None:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||||
|
user = next((u for u in _users.values() if u["id"] == user_id), None)
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||||
|
return _make_tokens(user["id"], user["email"], user["tier"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserProfile)
|
||||||
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||||
|
"""Return the profile for the authenticated user."""
|
||||||
|
return current_user
|
||||||
158
app/api/routes/backup.py
Normal file
158
app/api/routes/backup.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Backup metadata is kept in an
|
||||||
|
in-memory dict until Step 12 migrates it to PostgreSQL (backup_metadata table).
|
||||||
|
|
||||||
|
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
||||||
|
treating "history" as a ``{backup_id}`` path parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
from email.utils import parsedate_to_datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.schemas import BackupMetadata, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/backup", tags=["backup"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
|
||||||
|
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
|
||||||
|
|
||||||
|
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
||||||
|
_TIER_BACKUP_LIMITS_GB: dict[str, int] = {
|
||||||
|
"free": 0,
|
||||||
|
"pro": 5,
|
||||||
|
"power": 25,
|
||||||
|
"team": -1, # unlimited
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None:
|
||||||
|
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||||
|
limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0)
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail="Backup is not available on the free tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024**3
|
||||||
|
used = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
||||||
|
if used + size_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("")
|
||||||
|
async def upload_backup(
|
||||||
|
request: Request,
|
||||||
|
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
||||||
|
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
||||||
|
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Upload an E2E-encrypted backup blob.
|
||||||
|
|
||||||
|
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
||||||
|
"""
|
||||||
|
blob = await request.body()
|
||||||
|
reject_if_tampered(blob, x_backup_checksum)
|
||||||
|
_check_backup_quota(current_user.id, current_user.tier, len(blob))
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
backup_record: dict[str, Any] = {
|
||||||
|
"id": str(x_backup_timestamp),
|
||||||
|
"s3_key": s3_key,
|
||||||
|
"version": x_backup_version,
|
||||||
|
"timestamp": x_backup_timestamp,
|
||||||
|
"checksum": x_backup_checksum,
|
||||||
|
"size_bytes": len(blob),
|
||||||
|
}
|
||||||
|
|
||||||
|
user_backups = _backups.setdefault(current_user.id, [])
|
||||||
|
user_backups.append(backup_record)
|
||||||
|
user_backups.sort(key=lambda b: b["timestamp"], reverse=True)
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history", response_model=list[BackupMetadata])
|
||||||
|
async def backup_history(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> list[BackupMetadata]:
|
||||||
|
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
||||||
|
return [
|
||||||
|
BackupMetadata(
|
||||||
|
version=b["version"],
|
||||||
|
timestamp=b["timestamp"],
|
||||||
|
checksum=b["checksum"],
|
||||||
|
chunk_count=1, # single-chunk uploads for now — TODO(Step12): track real count
|
||||||
|
)
|
||||||
|
for b in _backups.get(current_user.id, [])
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def download_backup(
|
||||||
|
request: Request,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> Response:
|
||||||
|
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
||||||
|
user_backups = _backups.get(current_user.id, [])
|
||||||
|
if not user_backups:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
||||||
|
|
||||||
|
latest = user_backups[0]
|
||||||
|
|
||||||
|
ims_header = request.headers.get("If-Modified-Since")
|
||||||
|
if ims_header:
|
||||||
|
try:
|
||||||
|
ims_dt = parsedate_to_datetime(ims_header)
|
||||||
|
ims_ms = int(ims_dt.timestamp() * 1000)
|
||||||
|
if latest["timestamp"] <= ims_ms:
|
||||||
|
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
||||||
|
except Exception:
|
||||||
|
pass # malformed header — ignore and serve the blob
|
||||||
|
|
||||||
|
blob = await _blob_store.download(current_user.id, latest["s3_key"])
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={
|
||||||
|
"X-Backup-Version": str(latest["version"]),
|
||||||
|
"X-Backup-Timestamp": str(latest["timestamp"]),
|
||||||
|
"X-Checksum": latest["checksum"],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{backup_id}", response_model=dict)
|
||||||
|
async def delete_backup(
|
||||||
|
backup_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a specific backup by ID."""
|
||||||
|
user_backups = _backups.get(current_user.id, [])
|
||||||
|
target = next((b for b in user_backups if b["id"] == backup_id), None)
|
||||||
|
if target is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
||||||
|
|
||||||
|
await _blob_store.delete(current_user.id, target["s3_key"])
|
||||||
|
_backups[current_user.id] = [b for b in user_backups if b["id"] != backup_id]
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
184
app/api/routes/billing.py
Normal file
184
app/api/routes/billing.py
Normal file
@@ -0,0 +1,184 @@
|
|||||||
|
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||||
|
|
||||||
|
Subscription records are kept in-memory until Step 12 migrates them to
|
||||||
|
PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when
|
||||||
|
STRIPE_SECRET_KEY is not configured, allowing local development without keys.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.schemas import BillingTier, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
|
|
||||||
|
# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12
|
||||||
|
_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record
|
||||||
|
|
||||||
|
_TIER_PRICE_IDS: dict[str, str] = {
|
||||||
|
"pro": "price_pro_monthly", # replace with real Stripe price IDs
|
||||||
|
"power": "price_power_monthly",
|
||||||
|
"team": "price_team_monthly",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _stripe_configured() -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
def _stripe() -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CheckoutRequest(BaseModel):
|
||||||
|
tier: BillingTier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/checkout", response_model=dict)
|
||||||
|
async def create_checkout(
|
||||||
|
body: _CheckoutRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Create a Stripe checkout session for a tier upgrade.
|
||||||
|
|
||||||
|
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||||
|
"""
|
||||||
|
if body.tier == "free":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Cannot create a checkout session for the free tier",
|
||||||
|
)
|
||||||
|
|
||||||
|
if _stripe_configured():
|
||||||
|
price_id = _TIER_PRICE_IDS.get(body.tier)
|
||||||
|
if not price_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unknown tier: {body.tier}",
|
||||||
|
)
|
||||||
|
s = _stripe()
|
||||||
|
session = s.checkout.Session.create(
|
||||||
|
payment_method_types=["card"],
|
||||||
|
mode="subscription",
|
||||||
|
line_items=[{"price": price_id, "quantity": 1}],
|
||||||
|
success_url=(
|
||||||
|
"https://app.adiuva.app/billing/success"
|
||||||
|
"?session_id={CHECKOUT_SESSION_ID}"
|
||||||
|
),
|
||||||
|
cancel_url="https://app.adiuva.app/billing/cancel",
|
||||||
|
metadata={"user_id": current_user.id, "tier": body.tier},
|
||||||
|
)
|
||||||
|
return {"checkout_url": session.url}
|
||||||
|
|
||||||
|
return {"checkout_url": "https://stripe.com/stub-checkout"}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/webhook", response_model=dict)
|
||||||
|
async def stripe_webhook(
|
||||||
|
request: Request,
|
||||||
|
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Handle Stripe webhook events.
|
||||||
|
|
||||||
|
No JWT auth — authenticated via Stripe signature verification instead.
|
||||||
|
Returns 200 immediately when Stripe is not configured (local dev).
|
||||||
|
"""
|
||||||
|
payload = await request.body()
|
||||||
|
|
||||||
|
if not _stripe_configured():
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = _stripe()
|
||||||
|
event = s.Webhook.construct_event(
|
||||||
|
payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET
|
||||||
|
)
|
||||||
|
except stripe_lib.error.SignatureVerificationError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid Stripe signature",
|
||||||
|
)
|
||||||
|
|
||||||
|
event_type: str = event["type"]
|
||||||
|
data: dict[str, Any] = event["data"]["object"]
|
||||||
|
|
||||||
|
if event_type == "checkout.session.completed":
|
||||||
|
user_id = data.get("metadata", {}).get("user_id")
|
||||||
|
tier = data.get("metadata", {}).get("tier", "free")
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
if user_id:
|
||||||
|
_subscriptions[user_id] = {
|
||||||
|
"tier": tier,
|
||||||
|
"stripe_subscription_id": sub_id,
|
||||||
|
"status": "active",
|
||||||
|
"current_period_end": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.updated":
|
||||||
|
# TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.deleted":
|
||||||
|
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
||||||
|
pass
|
||||||
|
|
||||||
|
elif event_type == "invoice.payment_failed":
|
||||||
|
# TODO(Step12): flag subscription as past_due, notify user
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/subscription", response_model=dict)
|
||||||
|
async def get_subscription(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return the current subscription info for the authenticated user."""
|
||||||
|
sub = _subscriptions.get(current_user.id)
|
||||||
|
if sub is None:
|
||||||
|
return {
|
||||||
|
"tier": current_user.tier,
|
||||||
|
"status": "free",
|
||||||
|
"stripe_subscription_id": None,
|
||||||
|
"current_period_end": None,
|
||||||
|
}
|
||||||
|
return sub
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/subscription", response_model=dict)
|
||||||
|
async def cancel_subscription(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Cancel the active subscription."""
|
||||||
|
sub = _subscriptions.get(current_user.id)
|
||||||
|
if sub is None or not sub.get("stripe_subscription_id"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="No active subscription found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if _stripe_configured():
|
||||||
|
s = _stripe()
|
||||||
|
s.Subscription.cancel(sub["stripe_subscription_id"])
|
||||||
|
|
||||||
|
_subscriptions[current_user.id] = {
|
||||||
|
**sub,
|
||||||
|
"tier": "free",
|
||||||
|
"status": "canceled",
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
78
app/api/routes/chat.py
Normal file
78
app/api/routes/chat.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
"""Chat routes: POST /chat and WebSocket /chat/stream."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.orchestrator import orchestrate, orchestrate_stream
|
||||||
|
from app.schemas import ChatRequest, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def chat(
|
||||||
|
body: ChatRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Route a chat message through the orchestrator.
|
||||||
|
|
||||||
|
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
||||||
|
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
||||||
|
"""
|
||||||
|
result = await orchestrate(body)
|
||||||
|
return JSONResponse(content=result.model_dump())
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/stream")
|
||||||
|
async def chat_stream(websocket: WebSocket) -> None:
|
||||||
|
"""Streaming chat via WebSocket.
|
||||||
|
|
||||||
|
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
1. Client sends ``ChatRequest`` as the first JSON text frame.
|
||||||
|
2. Server streams response text chunks.
|
||||||
|
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
|
||||||
|
4. Server pings every 30 s to keep the connection alive.
|
||||||
|
"""
|
||||||
|
# Authenticate before accepting the connection
|
||||||
|
token = websocket.query_params.get("token", "")
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
if not user_id:
|
||||||
|
raise JWTError("missing sub")
|
||||||
|
except JWTError:
|
||||||
|
await websocket.close(code=1008) # 1008 = Policy Violation
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
try:
|
||||||
|
raw = await websocket.receive_text()
|
||||||
|
body = ChatRequest.model_validate_json(raw)
|
||||||
|
|
||||||
|
async def _heartbeat() -> None:
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||||
|
await websocket.send_text(json.dumps({"ping": True}))
|
||||||
|
|
||||||
|
heartbeat_task = asyncio.create_task(_heartbeat())
|
||||||
|
try:
|
||||||
|
async for chunk in orchestrate_stream(body):
|
||||||
|
await websocket.send_text(chunk)
|
||||||
|
finally:
|
||||||
|
heartbeat_task.cancel()
|
||||||
|
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
37
app/api/routes/plans.py
Normal file
37
app/api/routes/plans.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.execution_plan import plan_cache
|
||||||
|
from app.schemas import ExecutionPlan, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/plans", tags=["plans"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/playbook", response_model=list[ExecutionPlan])
|
||||||
|
async def list_playbooks(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> list[ExecutionPlan]:
|
||||||
|
"""Return all cached execution plan playbooks for the authenticated user.
|
||||||
|
|
||||||
|
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
||||||
|
"""
|
||||||
|
return plan_cache.get_all_playbooks()
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
||||||
|
async def get_playbook(
|
||||||
|
plan_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> ExecutionPlan:
|
||||||
|
"""Return a specific execution plan playbook by ID."""
|
||||||
|
plan = plan_cache.get_plan(plan_id)
|
||||||
|
if plan is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Plan not found: {plan_id}",
|
||||||
|
)
|
||||||
|
return plan
|
||||||
174
app/api/routes/plugins.py
Normal file
174
app/api/routes/plugins.py
Normal file
@@ -0,0 +1,174 @@
|
|||||||
|
"""Plugins routes: browse and install plugins from the marketplace.
|
||||||
|
|
||||||
|
The catalog and installation records are kept in-memory as stubs.
|
||||||
|
Step 10 replaces these with PluginRegistry, RevenueShare, and the plugins DB table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||||
|
|
||||||
|
# ── In-memory catalog (Step 10 replaces with PluginRegistry + DB) ─────
|
||||||
|
|
||||||
|
_plugin_catalog: list[PluginManifest] = [
|
||||||
|
PluginManifest(
|
||||||
|
id="plugin-github-sync",
|
||||||
|
name="GitHub Sync",
|
||||||
|
description="Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
version="1.0.0",
|
||||||
|
author="Adiuva",
|
||||||
|
permissions=["read:tasks", "write:tasks"],
|
||||||
|
category="productivity",
|
||||||
|
price_cents=0,
|
||||||
|
),
|
||||||
|
PluginManifest(
|
||||||
|
id="plugin-slack-notify",
|
||||||
|
name="Slack Notifier",
|
||||||
|
description="Post task and checkpoint updates to Slack channels.",
|
||||||
|
version="1.2.0",
|
||||||
|
author="Adiuva",
|
||||||
|
permissions=["read:tasks", "read:checkpoints"],
|
||||||
|
category="communication",
|
||||||
|
price_cents=499,
|
||||||
|
),
|
||||||
|
PluginManifest(
|
||||||
|
id="plugin-time-tracker",
|
||||||
|
name="Time Tracker",
|
||||||
|
description="Track time spent on tasks with automatic reporting.",
|
||||||
|
version="0.9.1",
|
||||||
|
author="Third Party",
|
||||||
|
permissions=["read:tasks", "write:tasks"],
|
||||||
|
category="productivity",
|
||||||
|
price_cents=999,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
# plugin_id → set of user_ids who have installed it
|
||||||
|
_installations: dict[str, set[str]] = {}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier gate ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _require_plugin_tier(user: UserProfile) -> None:
|
||||||
|
"""Raise HTTP 403 for users below Power tier."""
|
||||||
|
if user.tier not in ("power", "team"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Plugin marketplace requires Power tier or above",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Filter + sort helpers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _apply_filters(
|
||||||
|
plugins: list[PluginManifest],
|
||||||
|
category: str | None,
|
||||||
|
q: str | None,
|
||||||
|
) -> list[PluginManifest]:
|
||||||
|
result = plugins
|
||||||
|
if category:
|
||||||
|
result = [p for p in result if p.category == category]
|
||||||
|
if q:
|
||||||
|
q_lower = q.lower()
|
||||||
|
result = [
|
||||||
|
p for p in result
|
||||||
|
if q_lower in p.name.lower() or q_lower in p.description.lower()
|
||||||
|
]
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_sort(
|
||||||
|
plugins: list[PluginManifest],
|
||||||
|
sort: str,
|
||||||
|
) -> list[PluginManifest]:
|
||||||
|
if sort == "installs":
|
||||||
|
return sorted(plugins, key=lambda p: len(_installations.get(p.id, set())), reverse=True)
|
||||||
|
if sort == "rating":
|
||||||
|
# Placeholder until Step 10 introduces avg_rating from DB
|
||||||
|
return sorted(plugins, key=lambda p: -p.price_cents)
|
||||||
|
return plugins # "newest" = catalog insertion order
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local detail schema ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _PluginDetail(BaseModel):
|
||||||
|
plugin: PluginManifest
|
||||||
|
install_count: int
|
||||||
|
ratings: list[Any] # Step 10 populates from plugin_reviews table
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("", response_model=PluginListResponse)
|
||||||
|
async def list_plugins(
|
||||||
|
category: str | None = Query(default=None),
|
||||||
|
q: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
filtered = _apply_filters(_plugin_catalog, category, q)
|
||||||
|
sorted_plugins = _apply_sort(filtered, sort)
|
||||||
|
return PluginListResponse(plugins=sorted_plugins, total=len(sorted_plugins), page=page)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||||
|
async def get_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _PluginDetail:
|
||||||
|
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
|
||||||
|
if plugin is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
return _PluginDetail(
|
||||||
|
plugin=plugin,
|
||||||
|
install_count=len(_installations.get(plugin_id, set())),
|
||||||
|
ratings=[], # Step 10 populates from plugin_reviews table
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def install_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Install a plugin. Triggers Stripe Connect for paid plugins when configured.
|
||||||
|
|
||||||
|
Requires Power tier or above.
|
||||||
|
"""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
plugin = next((p for p in _plugin_catalog if p.id == plugin_id), None)
|
||||||
|
if plugin is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
if plugin.price_cents > 0 and settings.STRIPE_SECRET_KEY:
|
||||||
|
# TODO(Step10): stripe.PaymentIntent.create with destination charge (70/30 split)
|
||||||
|
pass
|
||||||
|
|
||||||
|
_installations.setdefault(plugin_id, set()).add(current_user.id)
|
||||||
|
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
||||||
|
return {"ok": True, "download_url": download_url}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def uninstall_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Unregister a plugin installation."""
|
||||||
|
_installations.get(plugin_id, set()).discard(current_user.id)
|
||||||
|
return {"ok": True}
|
||||||
185
app/api/routes/storage.py
Normal file
185
app/api/routes/storage.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Record metadata is kept in an
|
||||||
|
in-memory dict until Step 12 migrates it to PostgreSQL (storage_records table).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["storage"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
|
||||||
|
_records: dict[str, dict[str, Any]] = {}
|
||||||
|
|
||||||
|
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
||||||
|
_TIER_STORAGE_LIMITS_GB: dict[str, int] = {
|
||||||
|
"free": 0,
|
||||||
|
"pro": 5,
|
||||||
|
"power": 25,
|
||||||
|
"team": -1, # unlimited
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local response schemas ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CreateResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordMeta(BaseModel):
|
||||||
|
id: str
|
||||||
|
table: str
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None:
|
||||||
|
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
|
||||||
|
limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024**3
|
||||||
|
used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
||||||
|
if used + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
||||||
|
"""Look up a record and verify ownership. Always returns 404 on mismatch
|
||||||
|
to prevent user enumeration attacks."""
|
||||||
|
record = _records.get(record_id)
|
||||||
|
if record is None or record["user_id"] != user_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_record(
|
||||||
|
body: StorageRecordCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _CreateResponse:
|
||||||
|
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
_check_quota(current_user.id, current_user.tier, len(body.blob))
|
||||||
|
|
||||||
|
record_id = str(uuid.uuid4())
|
||||||
|
now = int(time.time() * 1000)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, body.table, record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
_records[record_id] = {
|
||||||
|
"id": record_id,
|
||||||
|
"user_id": current_user.id,
|
||||||
|
"table": body.table,
|
||||||
|
"s3_key": s3_key,
|
||||||
|
"checksum": body.checksum,
|
||||||
|
"size_bytes": len(body.blob),
|
||||||
|
"created_at": now,
|
||||||
|
"updated_at": now,
|
||||||
|
}
|
||||||
|
|
||||||
|
return _CreateResponse(id=record_id, created_at=now)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records", response_model=list[_RecordMeta])
|
||||||
|
async def list_records(
|
||||||
|
table: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> list[_RecordMeta]:
|
||||||
|
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
||||||
|
all_records = [
|
||||||
|
r for r in _records.values()
|
||||||
|
if r["user_id"] == current_user.id and (table is None or r["table"] == table)
|
||||||
|
]
|
||||||
|
start = (page - 1) * limit
|
||||||
|
page_records = all_records[start : start + limit]
|
||||||
|
return [
|
||||||
|
_RecordMeta(
|
||||||
|
id=r["id"],
|
||||||
|
table=r["table"],
|
||||||
|
checksum=r["checksum"],
|
||||||
|
created_at=r["created_at"],
|
||||||
|
updated_at=r["updated_at"],
|
||||||
|
)
|
||||||
|
for r in page_records
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records/{record_id}")
|
||||||
|
async def download_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> Response:
|
||||||
|
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
||||||
|
record = _get_record_for_user(record_id, current_user.id)
|
||||||
|
blob = await _blob_store.download(current_user.id, record["s3_key"])
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"X-Checksum": record["checksum"]},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/records/{record_id}", response_model=dict)
|
||||||
|
async def update_record(
|
||||||
|
record_id: str,
|
||||||
|
body: StorageRecordUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
||||||
|
record = _get_record_for_user(record_id, current_user.id)
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
|
||||||
|
delta = len(body.blob) - record["size_bytes"]
|
||||||
|
if delta > 0:
|
||||||
|
_check_quota(current_user.id, current_user.tier, delta)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, record["table"], record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
record["s3_key"] = s3_key
|
||||||
|
record["checksum"] = body.checksum
|
||||||
|
record["size_bytes"] = len(body.blob)
|
||||||
|
record["updated_at"] = int(time.time() * 1000)
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/records/{record_id}", response_model=dict)
|
||||||
|
async def delete_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a record and its S3 blob."""
|
||||||
|
record = _get_record_for_user(record_id, current_user.id)
|
||||||
|
await _blob_store.delete(current_user.id, record["s3_key"])
|
||||||
|
del _records[record_id]
|
||||||
|
return {"ok": True}
|
||||||
56
app/api/routes/vectors.py
Normal file
56
app/api/routes/vectors.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Vectors routes: upsert, search, and delete cloud vector store entries."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.schemas import (
|
||||||
|
UserProfile,
|
||||||
|
VectorSearchRequest,
|
||||||
|
VectorSearchResponse,
|
||||||
|
VectorUpsertRequest,
|
||||||
|
)
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
from app.storage.vector_store import VectorStore
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["vectors"])
|
||||||
|
|
||||||
|
_vector_store = VectorStore()
|
||||||
|
|
||||||
|
|
||||||
|
class _VectorDeleteRequest(BaseModel):
|
||||||
|
ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/upsert", response_model=dict)
|
||||||
|
async def upsert_vectors(
|
||||||
|
body: VectorUpsertRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
||||||
|
for item in body.vectors:
|
||||||
|
reject_if_tampered(item.blob, item.checksum)
|
||||||
|
await _vector_store.upsert(current_user.id, body.vectors)
|
||||||
|
return {"upserted": len(body.vectors)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
||||||
|
async def search_vectors(
|
||||||
|
body: VectorSearchRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> VectorSearchResponse:
|
||||||
|
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
||||||
|
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
||||||
|
return VectorSearchResponse(results=results)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/vectors", response_model=dict)
|
||||||
|
async def delete_vectors(
|
||||||
|
body: _VectorDeleteRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete vectors by ID, scoped to the authenticated user."""
|
||||||
|
await _vector_store.delete(current_user.id, body.ids)
|
||||||
|
return {"ok": True}
|
||||||
@@ -17,6 +17,11 @@ class Settings(BaseSettings):
|
|||||||
AWS_ACCESS_KEY_ID: str = ""
|
AWS_ACCESS_KEY_ID: str = ""
|
||||||
AWS_SECRET_ACCESS_KEY: str = ""
|
AWS_SECRET_ACCESS_KEY: str = ""
|
||||||
|
|
||||||
|
PINECONE_API_KEY: str = ""
|
||||||
|
PINECONE_INDEX: str = "adiuva"
|
||||||
|
QDRANT_URL: str = ""
|
||||||
|
QDRANT_API_KEY: str = ""
|
||||||
|
|
||||||
OPENAI_API_KEY: str = ""
|
OPENAI_API_KEY: str = ""
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
|
|||||||
222
app/core/execution_plan.py
Normal file
222
app/core/execution_plan.py
Normal file
@@ -0,0 +1,222 @@
|
|||||||
|
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections import OrderedDict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.schemas import ExecutionPlan, PlanStep
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt Template Registry ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class PromptTemplateRegistry:
|
||||||
|
"""Server-side store mapping template IDs to prompt text.
|
||||||
|
|
||||||
|
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
||||||
|
The actual prompt text is resolved here on the server, keeping prompt IP
|
||||||
|
out of API responses.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._templates: dict[str, str] = {}
|
||||||
|
|
||||||
|
def register(self, template_id: str, prompt_text: str) -> None:
|
||||||
|
self._templates[template_id] = prompt_text
|
||||||
|
|
||||||
|
def get(self, template_id: str) -> str:
|
||||||
|
"""Resolve a template ID to its prompt text.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the template is not registered.
|
||||||
|
"""
|
||||||
|
text = self._templates.get(template_id)
|
||||||
|
if text is None:
|
||||||
|
raise KeyError(f"Template not found: {template_id!r}")
|
||||||
|
return text
|
||||||
|
|
||||||
|
def has(self, template_id: str) -> bool:
|
||||||
|
return template_id in self._templates
|
||||||
|
|
||||||
|
def list_ids(self) -> list[str]:
|
||||||
|
"""Return all registered template IDs (never the text)."""
|
||||||
|
return list(self._templates.keys())
|
||||||
|
|
||||||
|
|
||||||
|
# ── Execution Plan Builder ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class ExecutionPlanBuilder:
|
||||||
|
"""Fluent builder for ``ExecutionPlan`` objects.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
||||||
|
.add_data_step("create_record", data_from_step=0)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, agent: str) -> None:
|
||||||
|
self._agent = agent
|
||||||
|
self._steps: list[PlanStep] = []
|
||||||
|
|
||||||
|
# ── step adders ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def add_step(
|
||||||
|
self, action: str, params: dict[str, Any] | None = None
|
||||||
|
) -> ExecutionPlanBuilder:
|
||||||
|
"""Append a generic action step with optional parameters."""
|
||||||
|
self._steps.append(PlanStep(action=action, variables=params))
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_llm_step(
|
||||||
|
self, template_id: str, variables: dict[str, Any] | None = None
|
||||||
|
) -> ExecutionPlanBuilder:
|
||||||
|
"""Append an LLM step referencing a server-side template by ID."""
|
||||||
|
self._steps.append(
|
||||||
|
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
||||||
|
)
|
||||||
|
return self
|
||||||
|
|
||||||
|
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
||||||
|
"""Append a step whose input comes from the output of an earlier step."""
|
||||||
|
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
||||||
|
return self
|
||||||
|
|
||||||
|
# ── build ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def build(self) -> ExecutionPlan:
|
||||||
|
"""Validate step references and return the ``ExecutionPlan``.
|
||||||
|
|
||||||
|
Raises ``ValueError`` if any ``data_from_step`` references a
|
||||||
|
non-existent or future step index.
|
||||||
|
"""
|
||||||
|
for i, step in enumerate(self._steps):
|
||||||
|
if step.data_from_step is not None:
|
||||||
|
if not (0 <= step.data_from_step < i):
|
||||||
|
raise ValueError(
|
||||||
|
f"Step {i}: data_from_step={step.data_from_step} must "
|
||||||
|
f"reference a preceding step index in range 0..{i - 1}"
|
||||||
|
)
|
||||||
|
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class PlanCache:
|
||||||
|
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
||||||
|
|
||||||
|
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
||||||
|
The cache also serves as a runtime memoisation layer so that repeated
|
||||||
|
identical intent classifications can skip re-building the plan.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, maxsize: int = 1000) -> None:
|
||||||
|
self._maxsize = maxsize
|
||||||
|
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
||||||
|
|
||||||
|
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
||||||
|
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
||||||
|
if key in self._cache:
|
||||||
|
del self._cache[key] # remove so re-insertion places it at the end
|
||||||
|
elif len(self._cache) >= self._maxsize:
|
||||||
|
self._cache.popitem(last=False) # evict least-recently-used
|
||||||
|
self._cache[key] = plan
|
||||||
|
|
||||||
|
def get_plan(self, key: str) -> ExecutionPlan | None:
|
||||||
|
"""Return the cached plan for *key*, or ``None`` if not present.
|
||||||
|
|
||||||
|
Accessing a plan marks it as most-recently used.
|
||||||
|
"""
|
||||||
|
if key not in self._cache:
|
||||||
|
return None
|
||||||
|
self._cache.move_to_end(key)
|
||||||
|
return self._cache[key]
|
||||||
|
|
||||||
|
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
||||||
|
"""Return all cached plans (most-recently used last)."""
|
||||||
|
return list(self._cache.values())
|
||||||
|
|
||||||
|
|
||||||
|
# ── Module-level singletons ───────────────────────────────────────────
|
||||||
|
|
||||||
|
template_registry = PromptTemplateRegistry()
|
||||||
|
plan_cache = PlanCache()
|
||||||
|
|
||||||
|
|
||||||
|
def _register_builtin_templates() -> None:
|
||||||
|
"""Register the built-in server-side prompt templates.
|
||||||
|
|
||||||
|
These strings never leave the server. Clients only receive the IDs.
|
||||||
|
"""
|
||||||
|
_tpls: dict[str, str] = {
|
||||||
|
"tpl_task_agent_default": (
|
||||||
|
"You are a task management assistant. Help the user create, update, "
|
||||||
|
"list, and track tasks. Use correct status values (todo, in_progress, "
|
||||||
|
"done) and priority values (high, medium, low) from the workspace model."
|
||||||
|
),
|
||||||
|
"tpl_checkpoint_agent_default": (
|
||||||
|
"You are a project checkpoint assistant. Help the user create and manage "
|
||||||
|
"milestone checkpoints on their projects. Every checkpoint requires a "
|
||||||
|
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
||||||
|
),
|
||||||
|
"tpl_project_agent_default": (
|
||||||
|
"You are a project management assistant. Help the user create, find, "
|
||||||
|
"update, and archive projects. Projects have a name, an optional client, "
|
||||||
|
"and a status of either active or archived."
|
||||||
|
),
|
||||||
|
"tpl_note_agent_default": (
|
||||||
|
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
||||||
|
"and delete Markdown notes. Notes can optionally be linked to a project."
|
||||||
|
),
|
||||||
|
"tpl_task_extract_from_project": (
|
||||||
|
"Extract all actionable tasks from the provided project context. "
|
||||||
|
"Return a structured list of tasks, each with a title, inferred priority "
|
||||||
|
"(high, medium, or low), suggested status (todo), and a due_date in "
|
||||||
|
"milliseconds where a deadline can be inferred."
|
||||||
|
),
|
||||||
|
"tpl_note_weekly_summary": (
|
||||||
|
"Generate a weekly project summary note from the provided workspace data. "
|
||||||
|
"Include: tasks completed this week, tasks due soon, active projects, "
|
||||||
|
"and upcoming checkpoints. Format the output as clean Markdown."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
for tid, text in _tpls.items():
|
||||||
|
template_registry.register(tid, text)
|
||||||
|
|
||||||
|
|
||||||
|
def _load_playbooks() -> None:
|
||||||
|
"""Pre-build and cache the built-in playbooks."""
|
||||||
|
playbooks: list[tuple[str, ExecutionPlan]] = [
|
||||||
|
(
|
||||||
|
"create_tasks_from_project",
|
||||||
|
ExecutionPlanBuilder("project_agent")
|
||||||
|
.add_llm_step(
|
||||||
|
"tpl_task_extract_from_project",
|
||||||
|
{"source": "project_context"},
|
||||||
|
)
|
||||||
|
.add_data_step("create_record", data_from_step=0)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"generate_weekly_note",
|
||||||
|
ExecutionPlanBuilder("note_agent")
|
||||||
|
.add_llm_step(
|
||||||
|
"tpl_note_weekly_summary",
|
||||||
|
{"period": "last_7_days"},
|
||||||
|
)
|
||||||
|
.add_data_step("create_record", data_from_step=0)
|
||||||
|
.build(),
|
||||||
|
),
|
||||||
|
]
|
||||||
|
for key, plan in playbooks:
|
||||||
|
plan_cache.cache_plan(key, plan)
|
||||||
|
|
||||||
|
|
||||||
|
# Initialise on module load
|
||||||
|
_register_builtin_templates()
|
||||||
|
_load_playbooks()
|
||||||
169
app/core/orchestrator.py
Normal file
169
app/core/orchestrator.py
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.agent_registry import AgentRegistry
|
||||||
|
from app.core.agent_registry import registry as _default_registry
|
||||||
|
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
||||||
|
|
||||||
|
_FALLBACK_AGENT = "task_agent"
|
||||||
|
|
||||||
|
_CLASSIFY_SYSTEM = (
|
||||||
|
"You are an intent classifier. Given the user message and context, decide "
|
||||||
|
"which agent to route to.\n"
|
||||||
|
"Available agents: {agents}\n"
|
||||||
|
"Respond with just the agent name, nothing else."
|
||||||
|
)
|
||||||
|
|
||||||
|
_SYNTHESIZE_HUMAN = (
|
||||||
|
"Combine the following agent results into one coherent response.\n\n"
|
||||||
|
"Agent results:\n{results}\n\n"
|
||||||
|
"Original message: {message}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_llm(model: str = "gpt-4o-mini") -> ChatOpenAI:
|
||||||
|
return ChatOpenAI(model=model, temperature=0, api_key=settings.OPENAI_API_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
async def classify_intent(
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry,
|
||||||
|
) -> str:
|
||||||
|
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
||||||
|
|
||||||
|
Falls back to ``task_agent`` when the registry is empty or the model
|
||||||
|
returns a name that is not registered.
|
||||||
|
"""
|
||||||
|
agents = reg.list_agents()
|
||||||
|
if not agents:
|
||||||
|
return _FALLBACK_AGENT
|
||||||
|
|
||||||
|
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
||||||
|
# Truncate context to keep the classification prompt short
|
||||||
|
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
||||||
|
|
||||||
|
llm = _make_llm()
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[SystemMessage(content=system), HumanMessage(content=human)]
|
||||||
|
)
|
||||||
|
|
||||||
|
agent_name = str(response.content).strip().lower()
|
||||||
|
known = {a["name"] for a in agents}
|
||||||
|
return agent_name if agent_name in known else _FALLBACK_AGENT
|
||||||
|
|
||||||
|
|
||||||
|
async def route_single(
|
||||||
|
agent_name: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry,
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
||||||
|
response_text = await reg.call_agent(agent_name, message, context)
|
||||||
|
return ChatResponse(response=response_text)
|
||||||
|
|
||||||
|
|
||||||
|
async def route_pipeline(
|
||||||
|
agent_names: list[str],
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
reg: AgentRegistry,
|
||||||
|
) -> ChatResponse:
|
||||||
|
"""Execute agents sequentially; each agent receives previous results in context.
|
||||||
|
|
||||||
|
A final LLM synthesis call merges all results into one coherent response.
|
||||||
|
"""
|
||||||
|
previous_results: list[str] = []
|
||||||
|
|
||||||
|
for agent_name in agent_names:
|
||||||
|
ctx = {**context, "previous_results": list(previous_results)}
|
||||||
|
result = await reg.call_agent(agent_name, message, ctx)
|
||||||
|
previous_results.append(result)
|
||||||
|
|
||||||
|
results_str = "\n\n".join(
|
||||||
|
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
||||||
|
)
|
||||||
|
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
||||||
|
llm = _make_llm()
|
||||||
|
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
||||||
|
return ChatResponse(response=str(synthesis.content))
|
||||||
|
|
||||||
|
|
||||||
|
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
||||||
|
"""Build an ``ExecutionPlan`` for the resolved agent.
|
||||||
|
|
||||||
|
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
||||||
|
If a default template exists for the agent, an LLM step is emitted;
|
||||||
|
otherwise a plain ``handle`` action step is used.
|
||||||
|
"""
|
||||||
|
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
||||||
|
|
||||||
|
template_id = f"tpl_{agent_name}_default"
|
||||||
|
builder = ExecutionPlanBuilder(agent_name)
|
||||||
|
if template_registry.has(template_id):
|
||||||
|
builder.add_llm_step(template_id, {"message": message})
|
||||||
|
else:
|
||||||
|
builder.add_step("handle", {"message": message})
|
||||||
|
return builder.build()
|
||||||
|
|
||||||
|
|
||||||
|
async def orchestrate(
|
||||||
|
request: ChatRequest,
|
||||||
|
reg: AgentRegistry | None = None,
|
||||||
|
) -> ChatResponse | ExecutionPlan:
|
||||||
|
"""Main orchestration entry point.
|
||||||
|
|
||||||
|
* Classifies the user's intent to select an agent.
|
||||||
|
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
||||||
|
``ChatResponse``.
|
||||||
|
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
||||||
|
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
||||||
|
"""
|
||||||
|
if reg is None:
|
||||||
|
reg = _default_registry
|
||||||
|
|
||||||
|
context = request.context.model_dump()
|
||||||
|
agent_name = await classify_intent(request.message, context, reg)
|
||||||
|
|
||||||
|
if request.execution_mode == "direct":
|
||||||
|
return await route_single(agent_name, request.message, context, reg)
|
||||||
|
|
||||||
|
# plan mode — return plan, do not execute
|
||||||
|
return _build_plan(agent_name, request.message)
|
||||||
|
|
||||||
|
|
||||||
|
async def orchestrate_stream(
|
||||||
|
request: ChatRequest,
|
||||||
|
reg: AgentRegistry | None = None,
|
||||||
|
) -> AsyncGenerator[str, None]:
|
||||||
|
"""Streaming orchestration — yields text chunks then a final JSON frame.
|
||||||
|
|
||||||
|
The final frame is a JSON object:
|
||||||
|
``{"done": true, "response": "...", "actions": []}``.
|
||||||
|
|
||||||
|
Agents do not yet support token-level streaming; the full response is
|
||||||
|
fetched first, then emitted in fixed-size chunks. Token-level streaming
|
||||||
|
will be wired in Step 6 when agents expose ``astream()``.
|
||||||
|
"""
|
||||||
|
if reg is None:
|
||||||
|
reg = _default_registry
|
||||||
|
|
||||||
|
context = request.context.model_dump()
|
||||||
|
agent_name = await classify_intent(request.message, context, reg)
|
||||||
|
response_text = await reg.call_agent(agent_name, request.message, context)
|
||||||
|
|
||||||
|
chunk_size = 50
|
||||||
|
for i in range(0, len(response_text), chunk_size):
|
||||||
|
yield response_text[i : i + chunk_size]
|
||||||
|
|
||||||
|
final = ChatResponse(response=response_text)
|
||||||
|
yield json.dumps({"done": True, **final.model_dump()})
|
||||||
17
app/main.py
17
app/main.py
@@ -34,13 +34,16 @@ def create_app() -> FastAPI:
|
|||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
|
|
||||||
# Routers (registered when implemented)
|
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
|
||||||
# from app.api.routes import auth, chat, plans, backup, billing
|
|
||||||
# app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
# app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
# app.include_router(plans.router, prefix="/api/v1")
|
app.include_router(plans.router, prefix="/api/v1")
|
||||||
# app.include_router(backup.router, prefix="/api/v1")
|
app.include_router(storage.router, prefix="/api/v1")
|
||||||
# app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(vectors.router, prefix="/api/v1")
|
||||||
|
app.include_router(backup.router, prefix="/api/v1")
|
||||||
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
async def health() -> dict:
|
async def health() -> dict:
|
||||||
|
|||||||
@@ -82,3 +82,76 @@ class BackupMetadata(BaseModel):
|
|||||||
timestamp: int
|
timestamp: int
|
||||||
checksum: str
|
checksum: str
|
||||||
chunk_count: int
|
chunk_count: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
||||||
|
|
||||||
|
class StorageRecord(BaseModel):
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordCreate(BaseModel):
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordUpdate(BaseModel):
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
||||||
|
|
||||||
|
class VectorItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
blob: bytes # encrypted vector + metadata — backend never decrypts
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorUpsertRequest(BaseModel):
|
||||||
|
vectors: list[VectorItem]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchRequest(BaseModel):
|
||||||
|
query_blob: bytes # encrypted query — backend never decrypts
|
||||||
|
top_k: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResult(BaseModel):
|
||||||
|
id: str
|
||||||
|
score: float
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResponse(BaseModel):
|
||||||
|
results: list[VectorSearchResult]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin Marketplace ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PluginManifest(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
version: str
|
||||||
|
author: str
|
||||||
|
permissions: list[str]
|
||||||
|
category: str
|
||||||
|
price_cents: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class PluginListResponse(BaseModel):
|
||||||
|
plugins: list[PluginManifest]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallRequest(BaseModel):
|
||||||
|
plugin_id: str
|
||||||
|
|||||||
1
app/storage/__init__.py
Normal file
1
app/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
||||||
105
app/storage/blob_store.py
Normal file
105
app/storage/blob_store.py
Normal file
@@ -0,0 +1,105 @@
|
|||||||
|
"""S3-backed store for E2E-encrypted blobs.
|
||||||
|
|
||||||
|
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
||||||
|
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class BlobStore:
|
||||||
|
"""Thin wrapper around boto3 S3.
|
||||||
|
|
||||||
|
All blobs must be E2E encrypted by the client before upload.
|
||||||
|
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
||||||
|
but cannot decrypt the inner client-side payload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
return boto3.client(
|
||||||
|
"s3",
|
||||||
|
region_name=settings.S3_REGION,
|
||||||
|
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
|
||||||
|
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _key(user_id: str, table: str, record_id: str) -> str:
|
||||||
|
return f"{user_id}/{table}/{record_id}"
|
||||||
|
|
||||||
|
async def upload(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
table: str,
|
||||||
|
record_id: str,
|
||||||
|
blob: bytes,
|
||||||
|
checksum: str,
|
||||||
|
) -> str:
|
||||||
|
"""Store *blob* in S3 and return the S3 key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Owner of the blob (used as key prefix).
|
||||||
|
table: Logical table name (e.g. ``"tasks"``).
|
||||||
|
record_id: Record UUID.
|
||||||
|
blob: Raw bytes (pre-encrypted by client).
|
||||||
|
checksum: SHA-256 hex digest supplied by the client; stored as
|
||||||
|
object metadata for download-time verification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The S3 key under which the blob was stored.
|
||||||
|
"""
|
||||||
|
key = self._key(user_id, table, record_id)
|
||||||
|
self._client().put_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=key,
|
||||||
|
Body=blob,
|
||||||
|
ServerSideEncryption="AES256", # SSE-S3 at rest
|
||||||
|
Metadata={"checksum": checksum},
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
async def download(self, user_id: str, s3_key: str) -> bytes:
|
||||||
|
"""Retrieve the blob stored at *s3_key*.
|
||||||
|
|
||||||
|
*user_id* is retained in the signature so higher-level code can
|
||||||
|
enforce ownership without re-parsing the key.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
||||||
|
object does not exist.
|
||||||
|
"""
|
||||||
|
response = self._client().get_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
return response["Body"].read()
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, s3_key: str) -> None:
|
||||||
|
"""Delete the object at *s3_key*.
|
||||||
|
|
||||||
|
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
||||||
|
not exist.
|
||||||
|
"""
|
||||||
|
self._client().delete_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
||||||
|
"""Return all S3 keys for a given user + table combination.
|
||||||
|
|
||||||
|
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
||||||
|
"""
|
||||||
|
prefix = f"{user_id}/{table}/"
|
||||||
|
response = self._client().list_objects_v2(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Prefix=prefix,
|
||||||
|
)
|
||||||
|
return [obj["Key"] for obj in response.get("Contents", [])]
|
||||||
32
app/storage/encryption.py
Normal file
32
app/storage/encryption.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Integrity verification only — the backend NEVER decrypts user data."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
||||||
|
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
||||||
|
|
||||||
|
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
||||||
|
timing-based side-channel attacks.
|
||||||
|
"""
|
||||||
|
computed = hashlib.sha256(blob).hexdigest()
|
||||||
|
return hmac.compare_digest(computed, checksum)
|
||||||
|
|
||||||
|
|
||||||
|
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
||||||
|
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
||||||
|
|
||||||
|
Call this before storing or forwarding any client-provided blob.
|
||||||
|
The backend never holds decryption keys — this check only verifies
|
||||||
|
that the opaque bytes arrived intact.
|
||||||
|
"""
|
||||||
|
if not verify_checksum(blob, checksum):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Checksum mismatch: blob integrity check failed",
|
||||||
|
)
|
||||||
205
app/storage/vector_store.py
Normal file
205
app/storage/vector_store.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
||||||
|
|
||||||
|
Vectors are pre-encrypted blobs from the client. The backend stores them
|
||||||
|
alongside a deterministic 32-dim float representation derived from the blob's
|
||||||
|
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
||||||
|
is a known trade-off documented in the backend plan.
|
||||||
|
|
||||||
|
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
||||||
|
``user_id`` payload field on a shared collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pinecone import Pinecone
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.schemas import VectorItem, VectorSearchResult
|
||||||
|
|
||||||
|
_QDRANT_COLLECTION = "adiuva_vectors"
|
||||||
|
|
||||||
|
|
||||||
|
def _blob_to_vector(blob: bytes) -> list[float]:
|
||||||
|
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
||||||
|
|
||||||
|
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
||||||
|
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
||||||
|
semantic meaning on encrypted data.
|
||||||
|
"""
|
||||||
|
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStore:
|
||||||
|
"""Thin wrapper around Pinecone or Qdrant.
|
||||||
|
|
||||||
|
The backend to use is selected at runtime:
|
||||||
|
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
||||||
|
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _use_pinecone(self) -> bool:
|
||||||
|
return bool(settings.PINECONE_API_KEY)
|
||||||
|
|
||||||
|
# ── Pinecone helpers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _pinecone_index(self) -> Any:
|
||||||
|
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
||||||
|
return pc.Index(settings.PINECONE_INDEX)
|
||||||
|
|
||||||
|
# ── Qdrant helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _qdrant_client(self) -> Any:
|
||||||
|
return QdrantClient(
|
||||||
|
url=settings.QDRANT_URL,
|
||||||
|
api_key=settings.QDRANT_API_KEY or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
"""Store encrypted vectors in the backend.
|
||||||
|
|
||||||
|
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
||||||
|
so it can be returned verbatim during search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Used as Pinecone namespace or Qdrant payload field.
|
||||||
|
vectors: List of encrypted vector items from the client.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_upsert(user_id, vectors)
|
||||||
|
else:
|
||||||
|
await self._qdrant_upsert(user_id, vectors)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
query_blob: bytes,
|
||||||
|
top_k: int,
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
"""Query the vector store and return encrypted result blobs.
|
||||||
|
|
||||||
|
The query vector is derived from *query_blob* using the same
|
||||||
|
deterministic mapping as upsert.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Scopes the search to this user's namespace.
|
||||||
|
query_blob: Encrypted query from the client.
|
||||||
|
top_k: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
return await self._pinecone_search(user_id, query_blob, top_k)
|
||||||
|
return await self._qdrant_search(user_id, query_blob, top_k)
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
"""Remove vectors by ID, scoped to *user_id*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Namespace / payload filter to prevent cross-user deletion.
|
||||||
|
vector_ids: List of vector IDs to remove.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_delete(user_id, vector_ids)
|
||||||
|
else:
|
||||||
|
await self._qdrant_delete(user_id, vector_ids)
|
||||||
|
|
||||||
|
# ── Pinecone implementation ───────────────────────────────────────
|
||||||
|
|
||||||
|
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"id": v.id,
|
||||||
|
"values": _blob_to_vector(v.blob),
|
||||||
|
"metadata": {
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
index.upsert(vectors=records, namespace=user_id)
|
||||||
|
|
||||||
|
async def _pinecone_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
response = index.query(
|
||||||
|
vector=query_vector,
|
||||||
|
top_k=top_k,
|
||||||
|
namespace=user_id,
|
||||||
|
include_metadata=True,
|
||||||
|
)
|
||||||
|
results: list[VectorSearchResult] = []
|
||||||
|
for match in response.get("matches", []):
|
||||||
|
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
||||||
|
results.append(
|
||||||
|
VectorSearchResult(
|
||||||
|
id=match["id"],
|
||||||
|
score=match["score"],
|
||||||
|
blob=blob_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
index.delete(ids=vector_ids, namespace=user_id)
|
||||||
|
|
||||||
|
# ── Qdrant implementation ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
points = [
|
||||||
|
PointStruct(
|
||||||
|
id=v.id,
|
||||||
|
vector=_blob_to_vector(v.blob),
|
||||||
|
payload={
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
||||||
|
|
||||||
|
async def _qdrant_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
hits = client.search(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=Filter(
|
||||||
|
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
||||||
|
),
|
||||||
|
limit=top_k,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
VectorSearchResult(
|
||||||
|
id=str(hit.id),
|
||||||
|
score=hit.score,
|
||||||
|
blob=base64.b64decode(hit.payload["blob"]),
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
client.delete(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
points_selector=PointIdsList(points=vector_ids),
|
||||||
|
)
|
||||||
@@ -17,3 +17,6 @@ httpx>=0.28.0
|
|||||||
websockets>=14.0
|
websockets>=14.0
|
||||||
pytest>=8.0.0
|
pytest>=8.0.0
|
||||||
pytest-asyncio>=0.24.0
|
pytest-asyncio>=0.24.0
|
||||||
|
moto[s3]>=5.0.0
|
||||||
|
pinecone>=5.0.0
|
||||||
|
qdrant-client>=1.7.0
|
||||||
|
|||||||
620
tests/test_agents.py
Normal file
620
tests/test_agents.py
Normal file
@@ -0,0 +1,620 @@
|
|||||||
|
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
import app.agents # noqa: F401 — triggers @registry.register decorators
|
||||||
|
from app.agents.checkpoint_agent import CheckpointAgent
|
||||||
|
from app.agents.note_agent import NoteAgent
|
||||||
|
from app.agents.project_agent import ProjectAgent
|
||||||
|
from app.agents.task_agent import TaskAgent
|
||||||
|
from app.core.agent_registry import registry
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_llm(response_text: str) -> MagicMock:
|
||||||
|
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.content = response_text
|
||||||
|
msg.tool_calls = []
|
||||||
|
llm = MagicMock()
|
||||||
|
bound = MagicMock()
|
||||||
|
bound.ainvoke = AsyncMock(return_value=msg)
|
||||||
|
llm.bind_tools = MagicMock(return_value=bound)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=msg)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_llm_with_tool_call(
|
||||||
|
tool_name: str, tool_args: dict[str, Any], final_text: str
|
||||||
|
) -> MagicMock:
|
||||||
|
"""Mock LLM that fires one tool call then returns *final_text*."""
|
||||||
|
tool_msg = MagicMock()
|
||||||
|
tool_msg.content = ""
|
||||||
|
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
|
||||||
|
|
||||||
|
final_msg = MagicMock()
|
||||||
|
final_msg.content = final_text
|
||||||
|
final_msg.tool_calls = []
|
||||||
|
|
||||||
|
bound = MagicMock()
|
||||||
|
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
|
||||||
|
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.bind_tools = MagicMock(return_value=bound)
|
||||||
|
llm.ainvoke = AsyncMock(return_value=final_msg)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
# ── Registration ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestAgentRegistration:
|
||||||
|
def test_all_agents_registered(self) -> None:
|
||||||
|
names = {a["name"] for a in registry.list_agents()}
|
||||||
|
assert {
|
||||||
|
"task_agent", "checkpoint_agent", "project_agent", "note_agent"
|
||||||
|
}.issubset(names)
|
||||||
|
|
||||||
|
def test_registry_returns_correct_types(self) -> None:
|
||||||
|
assert isinstance(registry.get("task_agent"), TaskAgent)
|
||||||
|
assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent)
|
||||||
|
assert isinstance(registry.get("project_agent"), ProjectAgent)
|
||||||
|
assert isinstance(registry.get("note_agent"), NoteAgent)
|
||||||
|
|
||||||
|
def test_descriptions_present(self) -> None:
|
||||||
|
for agent_info in registry.list_agents():
|
||||||
|
assert agent_info["description"], f"Empty description: {agent_info['name']}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TaskAgent ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskAgent:
|
||||||
|
def test_name(self) -> None:
|
||||||
|
assert TaskAgent().get_name() == "task_agent"
|
||||||
|
|
||||||
|
def test_description(self) -> None:
|
||||||
|
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
||||||
|
|
||||||
|
def test_get_tools_count(self) -> None:
|
||||||
|
assert len(TaskAgent().get_tools()) == 8
|
||||||
|
|
||||||
|
def test_tool_names(self) -> None:
|
||||||
|
names = {t.name for t in TaskAgent().get_tools()}
|
||||||
|
assert names == {
|
||||||
|
"list_tasks",
|
||||||
|
"create_task",
|
||||||
|
"update_task",
|
||||||
|
"delete_task",
|
||||||
|
"list_tasks_due_today",
|
||||||
|
"list_task_comments",
|
||||||
|
"add_task_comment",
|
||||||
|
"delete_task_comment",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_returns_string(self) -> None:
|
||||||
|
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Task created.")
|
||||||
|
result = await TaskAgent().handle("create a task", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_no_tool_calls(self) -> None:
|
||||||
|
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Here are your tasks.")
|
||||||
|
result = await TaskAgent().handle("list my tasks", {})
|
||||||
|
assert result == "Here are your tasks."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_with_create_task_tool_call(self) -> None:
|
||||||
|
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||||
|
"create_task",
|
||||||
|
{"title": "Buy groceries", "priority": "low"},
|
||||||
|
"Task 'Buy groceries' created.",
|
||||||
|
)
|
||||||
|
result = await TaskAgent().handle("add a grocery task", {})
|
||||||
|
assert result == "Task 'Buy groceries' created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_empty_context(self) -> None:
|
||||||
|
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Done.")
|
||||||
|
result = await TaskAgent().handle("help", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_rich_context(self) -> None:
|
||||||
|
context = {
|
||||||
|
"user_profile": {"id": "u1", "tier": "pro"},
|
||||||
|
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
||||||
|
}
|
||||||
|
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Tasks listed.")
|
||||||
|
result = await TaskAgent().handle("show tasks", context)
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestTaskAgentTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tasks_defaults(self) -> None:
|
||||||
|
from app.agents.task_agent import list_tasks
|
||||||
|
result = await list_tasks.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "tasks"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tasks_with_status_filter(self) -> None:
|
||||||
|
from app.agents.task_agent import list_tasks
|
||||||
|
result = await list_tasks.ainvoke({"status": "done"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["filters"]["status"] == "done"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_task_defaults(self) -> None:
|
||||||
|
from app.agents.task_agent import create_task
|
||||||
|
result = await create_task.ainvoke({"title": "Test task"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["table"] == "tasks"
|
||||||
|
assert data["data"]["title"] == "Test task"
|
||||||
|
assert data["data"]["status"] == "todo"
|
||||||
|
assert data["data"]["priority"] == "medium"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_task_with_all_fields(self) -> None:
|
||||||
|
from app.agents.task_agent import create_task
|
||||||
|
result = await create_task.ainvoke({
|
||||||
|
"title": "Deploy",
|
||||||
|
"priority": "high",
|
||||||
|
"status": "in_progress",
|
||||||
|
"project_id": "p1",
|
||||||
|
"is_ai_suggested": 1,
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["priority"] == "high"
|
||||||
|
assert data["data"]["status"] == "in_progress"
|
||||||
|
assert data["data"]["projectId"] == "p1"
|
||||||
|
assert data["data"]["isAiSuggested"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_task_with_status(self) -> None:
|
||||||
|
from app.agents.task_agent import update_task
|
||||||
|
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "update_record"
|
||||||
|
assert data["data"]["id"] == "t1"
|
||||||
|
assert data["data"]["updates"]["status"] == "done"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_task_empty_updates(self) -> None:
|
||||||
|
from app.agents.task_agent import update_task
|
||||||
|
result = await update_task.ainvoke({"task_id": "t1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["updates"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_task(self) -> None:
|
||||||
|
from app.agents.task_agent import delete_task
|
||||||
|
result = await delete_task.ainvoke({"task_id": "t1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["table"] == "tasks"
|
||||||
|
assert data["data"]["id"] == "t1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_tasks_due_today(self) -> None:
|
||||||
|
from app.agents.task_agent import list_tasks_due_today
|
||||||
|
result = await list_tasks_due_today.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list_due_today"
|
||||||
|
assert data["table"] == "tasks"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_task_comments(self) -> None:
|
||||||
|
from app.agents.task_agent import list_task_comments
|
||||||
|
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "taskComments"
|
||||||
|
assert data["filters"]["taskId"] == "t1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_task_comment(self) -> None:
|
||||||
|
from app.agents.task_agent import add_task_comment
|
||||||
|
result = await add_task_comment.ainvoke({
|
||||||
|
"task_id": "t1",
|
||||||
|
"author": "Alice",
|
||||||
|
"content": "Looks good!",
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["table"] == "taskComments"
|
||||||
|
assert data["data"]["taskId"] == "t1"
|
||||||
|
assert data["data"]["author"] == "Alice"
|
||||||
|
assert data["data"]["content"] == "Looks good!"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_task_comment(self) -> None:
|
||||||
|
from app.agents.task_agent import delete_task_comment
|
||||||
|
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["table"] == "taskComments"
|
||||||
|
assert data["data"]["id"] == "c1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── CheckpointAgent ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointAgent:
|
||||||
|
def test_name(self) -> None:
|
||||||
|
assert CheckpointAgent().get_name() == "checkpoint_agent"
|
||||||
|
|
||||||
|
def test_description(self) -> None:
|
||||||
|
assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete"
|
||||||
|
|
||||||
|
def test_get_tools_count(self) -> None:
|
||||||
|
assert len(CheckpointAgent().get_tools()) == 4
|
||||||
|
|
||||||
|
def test_tool_names(self) -> None:
|
||||||
|
names = {t.name for t in CheckpointAgent().get_tools()}
|
||||||
|
assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_no_tool_calls(self) -> None:
|
||||||
|
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("No checkpoints found.")
|
||||||
|
result = await CheckpointAgent().handle("list checkpoints", {})
|
||||||
|
assert result == "No checkpoints found."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_with_create_tool_call(self) -> None:
|
||||||
|
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||||
|
"create_checkpoint",
|
||||||
|
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
||||||
|
"Checkpoint 'MVP Launch' created.",
|
||||||
|
)
|
||||||
|
result = await CheckpointAgent().handle("add MVP checkpoint", {})
|
||||||
|
assert result == "Checkpoint 'MVP Launch' created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_empty_context(self) -> None:
|
||||||
|
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Done.")
|
||||||
|
result = await CheckpointAgent().handle("show milestones", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCheckpointAgentTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints_no_project(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import list_checkpoints
|
||||||
|
result = await list_checkpoints.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "checkpoints"
|
||||||
|
assert data["filters"]["projectId"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_checkpoints_with_project(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import list_checkpoints
|
||||||
|
result = await list_checkpoints.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["filters"]["projectId"] == "p1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpoint(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import create_checkpoint
|
||||||
|
result = await create_checkpoint.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"title": "Beta release",
|
||||||
|
"date": 1700000000000,
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["table"] == "checkpoints"
|
||||||
|
assert data["data"]["projectId"] == "p1"
|
||||||
|
assert data["data"]["title"] == "Beta release"
|
||||||
|
assert data["data"]["date"] == 1700000000000
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_checkpoint_ai_suggested(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import create_checkpoint
|
||||||
|
result = await create_checkpoint.ainvoke({
|
||||||
|
"project_id": "p1",
|
||||||
|
"title": "Review",
|
||||||
|
"date": 1700000000000,
|
||||||
|
"is_ai_suggested": 1,
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["isAiSuggested"] == 1
|
||||||
|
assert data["data"]["isApproved"] == 0
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_checkpoint_approve(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import update_checkpoint
|
||||||
|
result = await update_checkpoint.ainvoke({
|
||||||
|
"checkpoint_id": "c1",
|
||||||
|
"is_approved": 1,
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "update_record"
|
||||||
|
assert data["data"]["id"] == "c1"
|
||||||
|
assert data["data"]["updates"]["isApproved"] == 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_checkpoint_empty_updates(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import update_checkpoint
|
||||||
|
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["updates"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_checkpoint(self) -> None:
|
||||||
|
from app.agents.checkpoint_agent import delete_checkpoint
|
||||||
|
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["table"] == "checkpoints"
|
||||||
|
assert data["data"]["id"] == "c1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── ProjectAgent ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectAgent:
|
||||||
|
def test_name(self) -> None:
|
||||||
|
assert ProjectAgent().get_name() == "project_agent"
|
||||||
|
|
||||||
|
def test_description(self) -> None:
|
||||||
|
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
|
||||||
|
|
||||||
|
def test_get_tools_count(self) -> None:
|
||||||
|
assert len(ProjectAgent().get_tools()) == 6
|
||||||
|
|
||||||
|
def test_tool_names(self) -> None:
|
||||||
|
names = {t.name for t in ProjectAgent().get_tools()}
|
||||||
|
assert names == {
|
||||||
|
"list_projects",
|
||||||
|
"list_all_projects",
|
||||||
|
"get_project",
|
||||||
|
"create_project",
|
||||||
|
"update_project",
|
||||||
|
"delete_project",
|
||||||
|
}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_no_tool_calls(self) -> None:
|
||||||
|
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
||||||
|
result = await ProjectAgent().handle("show my projects", {})
|
||||||
|
assert result == "Project Alpha is active."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_with_create_project_tool_call(self) -> None:
|
||||||
|
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||||
|
"create_project",
|
||||||
|
{"name": "Pippo"},
|
||||||
|
"Project 'Pippo' created.",
|
||||||
|
)
|
||||||
|
result = await ProjectAgent().handle("create project Pippo", {})
|
||||||
|
assert result == "Project 'Pippo' created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_empty_context(self) -> None:
|
||||||
|
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Done.")
|
||||||
|
result = await ProjectAgent().handle("archive old project", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProjectAgentTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_projects_defaults(self) -> None:
|
||||||
|
from app.agents.project_agent import list_projects
|
||||||
|
result = await list_projects.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "projects"
|
||||||
|
assert data["filters"]["includeArchived"] is False
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_projects_include_archived(self) -> None:
|
||||||
|
from app.agents.project_agent import list_projects
|
||||||
|
result = await list_projects.ainvoke({"include_archived": 1})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["filters"]["includeArchived"] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_all_projects(self) -> None:
|
||||||
|
from app.agents.project_agent import list_all_projects
|
||||||
|
result = await list_all_projects.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list_all"
|
||||||
|
assert data["table"] == "projects"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_project(self) -> None:
|
||||||
|
from app.agents.project_agent import get_project
|
||||||
|
result = await get_project.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "get"
|
||||||
|
assert data["table"] == "projects"
|
||||||
|
assert data["data"]["id"] == "p1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_name_only(self) -> None:
|
||||||
|
from app.agents.project_agent import create_project
|
||||||
|
result = await create_project.ainvoke({"name": "Alpha"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["data"]["name"] == "Alpha"
|
||||||
|
assert data["data"]["clientId"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_project_with_client(self) -> None:
|
||||||
|
from app.agents.project_agent import create_project
|
||||||
|
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["clientId"] == "cl1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_project_archive(self) -> None:
|
||||||
|
from app.agents.project_agent import update_project
|
||||||
|
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "update_record"
|
||||||
|
assert data["data"]["id"] == "p1"
|
||||||
|
assert data["data"]["updates"]["status"] == "archived"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_project_empty_updates(self) -> None:
|
||||||
|
from app.agents.project_agent import update_project
|
||||||
|
result = await update_project.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["updates"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_project(self) -> None:
|
||||||
|
from app.agents.project_agent import delete_project
|
||||||
|
result = await delete_project.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["data"]["id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── NoteAgent ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoteAgent:
|
||||||
|
def test_name(self) -> None:
|
||||||
|
assert NoteAgent().get_name() == "note_agent"
|
||||||
|
|
||||||
|
def test_description(self) -> None:
|
||||||
|
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
|
||||||
|
|
||||||
|
def test_get_tools_count(self) -> None:
|
||||||
|
assert len(NoteAgent().get_tools()) == 5
|
||||||
|
|
||||||
|
def test_tool_names(self) -> None:
|
||||||
|
names = {t.name for t in NoteAgent().get_tools()}
|
||||||
|
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_no_tool_calls(self) -> None:
|
||||||
|
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Note created.")
|
||||||
|
result = await NoteAgent().handle("create a note", {})
|
||||||
|
assert result == "Note created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_with_create_note_tool_call(self) -> None:
|
||||||
|
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm_with_tool_call(
|
||||||
|
"create_note",
|
||||||
|
{"title": "Daily log", "content": "# Today\nAll good."},
|
||||||
|
"Note 'Daily log' created.",
|
||||||
|
)
|
||||||
|
result = await NoteAgent().handle("log today's progress", {})
|
||||||
|
assert result == "Note 'Daily log' created."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_handle_accepts_empty_context(self) -> None:
|
||||||
|
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("Done.")
|
||||||
|
result = await NoteAgent().handle("show notes", {})
|
||||||
|
assert isinstance(result, str)
|
||||||
|
|
||||||
|
|
||||||
|
class TestNoteAgentTools:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_notes_no_project(self) -> None:
|
||||||
|
from app.agents.note_agent import list_notes
|
||||||
|
result = await list_notes.ainvoke({})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "list"
|
||||||
|
assert data["table"] == "notes"
|
||||||
|
assert data["filters"]["projectId"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_notes_with_project(self) -> None:
|
||||||
|
from app.agents.note_agent import list_notes
|
||||||
|
result = await list_notes.ainvoke({"project_id": "p1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["filters"]["projectId"] == "p1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_note(self) -> None:
|
||||||
|
from app.agents.note_agent import get_note
|
||||||
|
result = await get_note.ainvoke({"note_id": "n1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "get"
|
||||||
|
assert data["table"] == "notes"
|
||||||
|
assert data["data"]["id"] == "n1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_note_minimal(self) -> None:
|
||||||
|
from app.agents.note_agent import create_note
|
||||||
|
result = await create_note.ainvoke({
|
||||||
|
"title": "Daily log",
|
||||||
|
"content": "# Today\nAll good.",
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "create_record"
|
||||||
|
assert data["table"] == "notes"
|
||||||
|
assert data["data"]["title"] == "Daily log"
|
||||||
|
assert data["data"]["content"] == "# Today\nAll good."
|
||||||
|
assert data["data"]["projectId"] is None
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_create_note_with_project(self) -> None:
|
||||||
|
from app.agents.note_agent import create_note
|
||||||
|
result = await create_note.ainvoke({
|
||||||
|
"title": "Sprint notes",
|
||||||
|
"content": "## Sprint 1",
|
||||||
|
"project_id": "p1",
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["projectId"] == "p1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_note_content_only(self) -> None:
|
||||||
|
from app.agents.note_agent import update_note
|
||||||
|
result = await update_note.ainvoke({
|
||||||
|
"note_id": "n1",
|
||||||
|
"content": "# Updated content",
|
||||||
|
})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "update_record"
|
||||||
|
assert data["data"]["id"] == "n1"
|
||||||
|
assert data["data"]["updates"]["content"] == "# Updated content"
|
||||||
|
assert "title" not in data["data"]["updates"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_note_empty_updates(self) -> None:
|
||||||
|
from app.agents.note_agent import update_note
|
||||||
|
result = await update_note.ainvoke({"note_id": "n1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["data"]["updates"] == {}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_note(self) -> None:
|
||||||
|
from app.agents.note_agent import delete_note
|
||||||
|
result = await delete_note.ainvoke({"note_id": "n1"})
|
||||||
|
data = json.loads(result)
|
||||||
|
assert data["action"] == "delete_record"
|
||||||
|
assert data["table"] == "notes"
|
||||||
|
assert data["data"]["id"] == "n1"
|
||||||
286
tests/test_execution_plan.py
Normal file
286
tests/test_execution_plan.py
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.execution_plan import (
|
||||||
|
ExecutionPlanBuilder,
|
||||||
|
PlanCache,
|
||||||
|
PromptTemplateRegistry,
|
||||||
|
plan_cache,
|
||||||
|
template_registry,
|
||||||
|
)
|
||||||
|
from app.schemas import ExecutionPlan
|
||||||
|
|
||||||
|
|
||||||
|
# ── PromptTemplateRegistry ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPromptTemplateRegistry:
|
||||||
|
def test_register_and_get(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_foo", "You are a foo agent.")
|
||||||
|
assert reg.get("tpl_foo") == "You are a foo agent."
|
||||||
|
|
||||||
|
def test_get_unknown_raises_key_error(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
with pytest.raises(KeyError, match="tpl_missing"):
|
||||||
|
reg.get("tpl_missing")
|
||||||
|
|
||||||
|
def test_has_returns_true_for_registered(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_x", "prompt text")
|
||||||
|
assert reg.has("tpl_x") is True
|
||||||
|
|
||||||
|
def test_has_returns_false_for_unregistered(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
assert reg.has("tpl_missing") is False
|
||||||
|
|
||||||
|
def test_list_ids_returns_all_registered_ids(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_a", "a")
|
||||||
|
reg.register("tpl_b", "b")
|
||||||
|
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
|
||||||
|
|
||||||
|
def test_list_ids_does_not_return_prompt_text(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_secret", "top secret prompt")
|
||||||
|
ids = reg.list_ids()
|
||||||
|
assert "top secret prompt" not in ids
|
||||||
|
|
||||||
|
def test_overwrite_existing_template(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
reg.register("tpl_x", "v1")
|
||||||
|
reg.register("tpl_x", "v2")
|
||||||
|
assert reg.get("tpl_x") == "v2"
|
||||||
|
|
||||||
|
def test_empty_registry_has_no_ids(self) -> None:
|
||||||
|
reg = PromptTemplateRegistry()
|
||||||
|
assert reg.list_ids() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestExecutionPlanBuilder:
|
||||||
|
def test_builds_empty_plan(self) -> None:
|
||||||
|
plan = ExecutionPlanBuilder("task_agent").build()
|
||||||
|
assert plan.agent == "task_agent"
|
||||||
|
assert plan.steps == []
|
||||||
|
|
||||||
|
def test_add_step_basic(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_step("create_task", {"priority": "high"})
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert len(plan.steps) == 1
|
||||||
|
assert plan.steps[0].action == "create_task"
|
||||||
|
assert plan.steps[0].variables == {"priority": "high"}
|
||||||
|
assert plan.steps[0].prompt_template is None
|
||||||
|
assert plan.steps[0].data_from_step is None
|
||||||
|
|
||||||
|
def test_add_step_no_params(self) -> None:
|
||||||
|
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
|
||||||
|
assert plan.steps[0].variables is None
|
||||||
|
|
||||||
|
def test_add_llm_step(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_llm_step("tpl_task_default", {"message": "hi"})
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert plan.steps[0].action == "llm"
|
||||||
|
assert plan.steps[0].prompt_template == "tpl_task_default"
|
||||||
|
assert plan.steps[0].variables == {"message": "hi"}
|
||||||
|
|
||||||
|
def test_add_llm_step_no_variables(self) -> None:
|
||||||
|
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
|
||||||
|
assert plan.steps[0].variables is None
|
||||||
|
|
||||||
|
def test_add_data_step(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_step("fetch_data")
|
||||||
|
.add_data_step("transform", data_from_step=0)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert plan.steps[1].action == "transform"
|
||||||
|
assert plan.steps[1].data_from_step == 0
|
||||||
|
|
||||||
|
def test_fluent_chaining_returns_builder(self) -> None:
|
||||||
|
builder = ExecutionPlanBuilder("analytics_agent")
|
||||||
|
result = builder.add_step("a")
|
||||||
|
assert result is builder
|
||||||
|
|
||||||
|
def test_fluent_chain_multiple_steps(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("analytics_agent")
|
||||||
|
.add_llm_step("tpl_analytics_default")
|
||||||
|
.add_step("format_output")
|
||||||
|
.add_data_step("store", data_from_step=0)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert len(plan.steps) == 3
|
||||||
|
|
||||||
|
def test_build_validates_data_from_step_out_of_range(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="data_from_step"):
|
||||||
|
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
|
||||||
|
|
||||||
|
def test_build_validates_data_from_step_self_reference(self) -> None:
|
||||||
|
"""data_from_step=0 on the first step (index 0) is invalid."""
|
||||||
|
with pytest.raises(ValueError, match="data_from_step"):
|
||||||
|
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
|
||||||
|
|
||||||
|
def test_build_validates_data_from_step_negative(self) -> None:
|
||||||
|
with pytest.raises(ValueError, match="data_from_step"):
|
||||||
|
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
|
||||||
|
|
||||||
|
def test_valid_data_from_step_at_index_two(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_step("step0")
|
||||||
|
.add_step("step1")
|
||||||
|
.add_data_step("step2", data_from_step=1)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert plan.steps[2].data_from_step == 1
|
||||||
|
|
||||||
|
def test_data_from_step_zero_valid_at_index_one(self) -> None:
|
||||||
|
plan = (
|
||||||
|
ExecutionPlanBuilder("task_agent")
|
||||||
|
.add_step("step0")
|
||||||
|
.add_data_step("step1", data_from_step=0)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
assert plan.steps[1].data_from_step == 0
|
||||||
|
|
||||||
|
def test_build_returns_new_plan_each_call(self) -> None:
|
||||||
|
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
|
||||||
|
plan1 = builder.build()
|
||||||
|
plan2 = builder.build()
|
||||||
|
assert plan1 is not plan2
|
||||||
|
assert plan1.steps == plan2.steps
|
||||||
|
|
||||||
|
def test_plan_is_execution_plan_instance(self) -> None:
|
||||||
|
plan = ExecutionPlanBuilder("task_agent").build()
|
||||||
|
assert isinstance(plan, ExecutionPlan)
|
||||||
|
|
||||||
|
|
||||||
|
# ── PlanCache ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestPlanCache:
|
||||||
|
def _plan(self, agent: str = "a") -> ExecutionPlan:
|
||||||
|
return ExecutionPlanBuilder(agent).build()
|
||||||
|
|
||||||
|
def test_cache_and_get(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
plan = self._plan()
|
||||||
|
cache.cache_plan("key1", plan)
|
||||||
|
assert cache.get_plan("key1") is plan
|
||||||
|
|
||||||
|
def test_get_missing_returns_none(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
assert cache.get_plan("nonexistent") is None
|
||||||
|
|
||||||
|
def test_get_all_playbooks_empty(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
assert cache.get_all_playbooks() == []
|
||||||
|
|
||||||
|
def test_get_all_playbooks_returns_all_stored(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
p1, p2 = self._plan("a"), self._plan("b")
|
||||||
|
cache.cache_plan("k1", p1)
|
||||||
|
cache.cache_plan("k2", p2)
|
||||||
|
playbooks = cache.get_all_playbooks()
|
||||||
|
assert len(playbooks) == 2
|
||||||
|
assert p1 in playbooks
|
||||||
|
assert p2 in playbooks
|
||||||
|
|
||||||
|
def test_lru_evicts_oldest_entry(self) -> None:
|
||||||
|
cache = PlanCache(maxsize=2)
|
||||||
|
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
||||||
|
cache.cache_plan("k1", p1)
|
||||||
|
cache.cache_plan("k2", p2)
|
||||||
|
cache.cache_plan("k3", p3) # k1 should be evicted
|
||||||
|
assert cache.get_plan("k1") is None
|
||||||
|
assert cache.get_plan("k2") is p2
|
||||||
|
assert cache.get_plan("k3") is p3
|
||||||
|
|
||||||
|
def test_lru_access_updates_recency(self) -> None:
|
||||||
|
cache = PlanCache(maxsize=2)
|
||||||
|
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
||||||
|
cache.cache_plan("k1", p1)
|
||||||
|
cache.cache_plan("k2", p2)
|
||||||
|
cache.get_plan("k1") # k1 is now most-recently used
|
||||||
|
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
|
||||||
|
assert cache.get_plan("k1") is p1
|
||||||
|
assert cache.get_plan("k2") is None
|
||||||
|
assert cache.get_plan("k3") is p3
|
||||||
|
|
||||||
|
def test_overwrite_existing_key(self) -> None:
|
||||||
|
cache = PlanCache()
|
||||||
|
p1, p2 = self._plan("a"), self._plan("b")
|
||||||
|
cache.cache_plan("same_key", p1)
|
||||||
|
cache.cache_plan("same_key", p2)
|
||||||
|
assert cache.get_plan("same_key") is p2
|
||||||
|
assert len(cache.get_all_playbooks()) == 1
|
||||||
|
|
||||||
|
def test_overwrite_does_not_consume_capacity(self) -> None:
|
||||||
|
cache = PlanCache(maxsize=2)
|
||||||
|
p1, p2 = self._plan("a"), self._plan("b")
|
||||||
|
cache.cache_plan("k1", p1)
|
||||||
|
cache.cache_plan("k1", p2) # overwrite, not a new slot
|
||||||
|
cache.cache_plan("k2", p1) # should fit without eviction
|
||||||
|
assert cache.get_plan("k1") is p2
|
||||||
|
assert cache.get_plan("k2") is p1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Module-level singletons ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestModuleSingletons:
|
||||||
|
def test_template_registry_has_all_agent_defaults(self) -> None:
|
||||||
|
for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"):
|
||||||
|
assert template_registry.has(f"tpl_{agent}_default"), (
|
||||||
|
f"Missing template: tpl_{agent}_default"
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_template_registry_has_operation_templates(self) -> None:
|
||||||
|
assert template_registry.has("tpl_task_extract_from_project")
|
||||||
|
assert template_registry.has("tpl_note_weekly_summary")
|
||||||
|
|
||||||
|
def test_template_registry_get_returns_non_empty_string(self) -> None:
|
||||||
|
text = template_registry.get("tpl_task_agent_default")
|
||||||
|
assert isinstance(text, str)
|
||||||
|
assert len(text) > 0
|
||||||
|
|
||||||
|
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
|
||||||
|
assert len(plan_cache.get_all_playbooks()) >= 2
|
||||||
|
|
||||||
|
def test_playbook_create_tasks_from_project(self) -> None:
|
||||||
|
plan = plan_cache.get_plan("create_tasks_from_project")
|
||||||
|
assert plan is not None
|
||||||
|
assert plan.agent == "project_agent"
|
||||||
|
assert len(plan.steps) == 2
|
||||||
|
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
|
||||||
|
assert plan.steps[1].data_from_step == 0
|
||||||
|
|
||||||
|
def test_playbook_generate_weekly_note(self) -> None:
|
||||||
|
plan = plan_cache.get_plan("generate_weekly_note")
|
||||||
|
assert plan is not None
|
||||||
|
assert plan.agent == "note_agent"
|
||||||
|
assert len(plan.steps) == 2
|
||||||
|
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
|
||||||
|
assert plan.steps[1].data_from_step == 0
|
||||||
|
|
||||||
|
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
|
||||||
|
"""Plans must not embed prompt text — only template IDs."""
|
||||||
|
for plan in plan_cache.get_all_playbooks():
|
||||||
|
for step in plan.steps:
|
||||||
|
if step.prompt_template is not None:
|
||||||
|
assert step.prompt_template.startswith("tpl_"), (
|
||||||
|
f"prompt_template looks like raw text: {step.prompt_template!r}"
|
||||||
|
)
|
||||||
348
tests/test_orchestrator.py
Normal file
348
tests/test_orchestrator.py
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
"""Integration tests for the orchestrator module."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.agent_registry import AgentRegistry, ChatAgent
|
||||||
|
from app.core.orchestrator import (
|
||||||
|
classify_intent,
|
||||||
|
orchestrate,
|
||||||
|
orchestrate_stream,
|
||||||
|
route_pipeline,
|
||||||
|
route_single,
|
||||||
|
)
|
||||||
|
from app.schemas import ChatContext, ChatRequest, ChatResponse, ExecutionPlan
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stub agents ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _TaskAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "task_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Manages tasks: create, update, list, suggest"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return f"task: {query}"
|
||||||
|
|
||||||
|
|
||||||
|
class _CalendarAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "calendar_agent"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "Calendar management: events, conflicts, scheduling"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
return f"calendar: {query}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _mock_llm(response_text: str) -> MagicMock:
|
||||||
|
"""Return a mock LLM that always produces *response_text*."""
|
||||||
|
msg = MagicMock()
|
||||||
|
msg.content = response_text
|
||||||
|
llm = MagicMock()
|
||||||
|
llm.ainvoke = AsyncMock(return_value=msg)
|
||||||
|
return llm
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _fresh_registry():
|
||||||
|
"""Reset the AgentRegistry singleton between tests."""
|
||||||
|
AgentRegistry._instance = None
|
||||||
|
yield
|
||||||
|
AgentRegistry._instance = None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def reg() -> AgentRegistry:
|
||||||
|
r = AgentRegistry()
|
||||||
|
r.register(_TaskAgent)
|
||||||
|
r.register(_CalendarAgent)
|
||||||
|
return r
|
||||||
|
|
||||||
|
|
||||||
|
# ── classify_intent ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestClassifyIntent:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
result = await classify_intent("add a task", {}, reg)
|
||||||
|
assert result == "task_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||||
|
result = await classify_intent("schedule a meeting", {}, reg)
|
||||||
|
assert result == "calendar_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("nonexistent_agent")
|
||||||
|
result = await classify_intent("do something", {}, reg)
|
||||||
|
assert result == "task_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
|
||||||
|
empty_reg = AgentRegistry()
|
||||||
|
# No LLM should be instantiated — early return path
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
result = await classify_intent("anything", {}, empty_reg)
|
||||||
|
mock_cls.assert_not_called()
|
||||||
|
assert result == "task_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm(" task_agent \n")
|
||||||
|
result = await classify_intent("create task", {}, reg)
|
||||||
|
assert result == "task_agent"
|
||||||
|
|
||||||
|
|
||||||
|
# ── route_single ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRouteSingle:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||||
|
result = await route_single("task_agent", "create a task", {}, reg)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
|
||||||
|
result = await route_single("task_agent", "create a task", {}, reg)
|
||||||
|
assert result.response == "task: create a task"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await route_single("nonexistent", "hello", {}, reg)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
|
||||||
|
result = await route_single("task_agent", "hi", {}, reg)
|
||||||
|
assert result.actions == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── route_pipeline ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRoutePipeline:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("synthesized result")
|
||||||
|
result = await route_pipeline(
|
||||||
|
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||||
|
)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("synthesized result")
|
||||||
|
result = await route_pipeline(
|
||||||
|
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
||||||
|
)
|
||||||
|
assert result.response == "synthesized result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_passes_previous_results_to_subsequent_agents(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
"""Each agent after the first should receive prior outputs in context."""
|
||||||
|
received_contexts: list[dict[str, Any]] = []
|
||||||
|
|
||||||
|
class _CapturingAgent(ChatAgent):
|
||||||
|
def get_name(self) -> str:
|
||||||
|
return "capture"
|
||||||
|
|
||||||
|
def get_description(self) -> str:
|
||||||
|
return "captures context for testing"
|
||||||
|
|
||||||
|
def get_tools(self) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
||||||
|
received_contexts.append(dict(context))
|
||||||
|
return "captured"
|
||||||
|
|
||||||
|
reg.register(_CapturingAgent)
|
||||||
|
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("done")
|
||||||
|
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
|
||||||
|
|
||||||
|
# The second agent (capture) must have received previous results
|
||||||
|
assert len(received_contexts) == 1
|
||||||
|
assert "previous_results" in received_contexts[0]
|
||||||
|
assert received_contexts[0]["previous_results"] == ["task: hi"]
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("single result")
|
||||||
|
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
|
||||||
|
assert result.response == "single result"
|
||||||
|
|
||||||
|
|
||||||
|
# ── orchestrate ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestrate:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_mode_returns_chat_response(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
assert result.response == "task: add a task"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_mode_returns_execution_plan(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="plan my tasks", execution_mode="plan")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ExecutionPlan)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_mode_agent_matches_classified(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("calendar_agent")
|
||||||
|
request = ChatRequest(
|
||||||
|
message="schedule something", execution_mode="plan"
|
||||||
|
)
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ExecutionPlan)
|
||||||
|
assert result.agent == "calendar_agent"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ExecutionPlan)
|
||||||
|
assert len(result.steps) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_plan_mode_template_id_contains_agent_name(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ExecutionPlan)
|
||||||
|
assert result.steps[0].prompt_template is not None
|
||||||
|
assert "task_agent" in result.steps[0].prompt_template
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_default_execution_mode_is_direct(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
# execution_mode defaults to "direct"
|
||||||
|
request = ChatRequest(message="help me")
|
||||||
|
result = await orchestrate(request, reg)
|
||||||
|
assert isinstance(result, ChatResponse)
|
||||||
|
|
||||||
|
|
||||||
|
# ── orchestrate_stream ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOrchestrateStream:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
assert len(chunks) >= 1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_last_chunk_is_final_json_frame(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="add a task", execution_mode="direct")
|
||||||
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
|
last = json.loads(chunks[-1])
|
||||||
|
assert last["done"] is True
|
||||||
|
assert "response" in last
|
||||||
|
assert "actions" in last
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_final_frame_response_matches_agent_output(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(message="create a task", execution_mode="direct")
|
||||||
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
|
final = json.loads(chunks[-1])
|
||||||
|
assert final["response"] == "task: create a task"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_text_chunks_before_final_frame(
|
||||||
|
self, reg: AgentRegistry
|
||||||
|
) -> None:
|
||||||
|
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
||||||
|
mock_cls.return_value = _mock_llm("task_agent")
|
||||||
|
request = ChatRequest(
|
||||||
|
message="x" * 200, execution_mode="direct"
|
||||||
|
) # long enough to produce multiple chunks
|
||||||
|
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
||||||
|
|
||||||
|
# All but the last chunk should be plain text (not valid final JSON)
|
||||||
|
non_final = chunks[:-1]
|
||||||
|
for chunk in non_final:
|
||||||
|
try:
|
||||||
|
parsed = json.loads(chunk)
|
||||||
|
assert parsed.get("done") is not True
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass # plain text chunk — expected
|
||||||
385
tests/test_storage.py
Normal file
385
tests/test_storage.py
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
"""Tests for the storage layer: encryption, BlobStore, and VectorStore."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import pytest
|
||||||
|
from botocore.exceptions import ClientError
|
||||||
|
from moto import mock_aws
|
||||||
|
|
||||||
|
from app.storage.encryption import reject_if_tampered, verify_checksum
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.vector_store import VectorStore, _blob_to_vector
|
||||||
|
from app.schemas import VectorItem, VectorSearchResult
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_BLOB = b"encrypted-payload-opaque-to-server"
|
||||||
|
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||||
|
_BUCKET = "test-bucket"
|
||||||
|
_REGION = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def s3_bucket():
|
||||||
|
"""Create a mocked S3 bucket and expose its name."""
|
||||||
|
with mock_aws():
|
||||||
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
||||||
|
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
||||||
|
os.environ.setdefault("AWS_DEFAULT_REGION", _REGION)
|
||||||
|
client = boto3.client("s3", region_name=_REGION)
|
||||||
|
client.create_bucket(Bucket=_BUCKET)
|
||||||
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||||
|
mock_settings.S3_BUCKET = _BUCKET
|
||||||
|
mock_settings.S3_REGION = _REGION
|
||||||
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||||
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||||
|
yield _BUCKET
|
||||||
|
|
||||||
|
|
||||||
|
def _pinecone_mock():
|
||||||
|
"""Return a mock Pinecone index with realistic return shapes."""
|
||||||
|
mock_index = MagicMock()
|
||||||
|
mock_index.query.return_value = {
|
||||||
|
"matches": [
|
||||||
|
{
|
||||||
|
"id": "v1",
|
||||||
|
"score": 0.95,
|
||||||
|
"metadata": {
|
||||||
|
"blob": base64.b64encode(b"result-blob").decode(),
|
||||||
|
"checksum": hashlib.sha256(b"result-blob").hexdigest(),
|
||||||
|
"user_id": "u1",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
}
|
||||||
|
mock_pc = MagicMock()
|
||||||
|
mock_pc.return_value.Index.return_value = mock_index
|
||||||
|
return mock_pc, mock_index
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestEncryption ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestEncryption:
|
||||||
|
def test_verify_checksum_correct(self) -> None:
|
||||||
|
assert verify_checksum(_BLOB, _CHECKSUM) is True
|
||||||
|
|
||||||
|
def test_verify_checksum_wrong(self) -> None:
|
||||||
|
assert verify_checksum(_BLOB, "0" * 64) is False
|
||||||
|
|
||||||
|
def test_verify_checksum_empty_checksum(self) -> None:
|
||||||
|
assert verify_checksum(_BLOB, "") is False
|
||||||
|
|
||||||
|
def test_verify_checksum_empty_blob(self) -> None:
|
||||||
|
expected = hashlib.sha256(b"").hexdigest()
|
||||||
|
assert verify_checksum(b"", expected) is True
|
||||||
|
|
||||||
|
def test_verify_checksum_tampered_blob(self) -> None:
|
||||||
|
tampered = _BLOB + b"\x00"
|
||||||
|
assert verify_checksum(tampered, _CHECKSUM) is False
|
||||||
|
|
||||||
|
def test_reject_if_tampered_passes_when_valid(self) -> None:
|
||||||
|
# Should not raise
|
||||||
|
reject_if_tampered(_BLOB, _CHECKSUM)
|
||||||
|
|
||||||
|
def test_reject_if_tampered_raises_400_on_mismatch(self) -> None:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
reject_if_tampered(_BLOB, "bad" * 20)
|
||||||
|
assert exc_info.value.status_code == 400
|
||||||
|
|
||||||
|
def test_reject_if_tampered_detail_mentions_checksum(self) -> None:
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
with pytest.raises(HTTPException) as exc_info:
|
||||||
|
reject_if_tampered(_BLOB, "bad" * 20)
|
||||||
|
assert "checksum" in exc_info.value.detail.lower()
|
||||||
|
|
||||||
|
def test_checksum_is_sha256_hex(self) -> None:
|
||||||
|
cs = hashlib.sha256(_BLOB).hexdigest()
|
||||||
|
assert len(cs) == 64
|
||||||
|
assert all(c in "0123456789abcdef" for c in cs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBlobStore ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlobStore:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_returns_correct_key(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
key = await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
assert key == "u1/tasks/r1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_object_exists_in_s3(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
# Verify by downloading — no exception means object exists
|
||||||
|
retrieved = await store.download("u1", "u1/tasks/r1")
|
||||||
|
assert retrieved == _BLOB
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_download_retrieves_same_bytes(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "notes", "n1", b"note-data", hashlib.sha256(b"note-data").hexdigest())
|
||||||
|
result = await store.download("u1", "u1/notes/n1")
|
||||||
|
assert result == b"note-data"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_removes_object(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.delete("u1", "u1/tasks/r1")
|
||||||
|
with pytest.raises(ClientError) as exc_info:
|
||||||
|
await store.download("u1", "u1/tasks/r1")
|
||||||
|
assert exc_info.value.response["Error"]["Code"] == "NoSuchKey"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_is_idempotent(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
# Delete a key that never existed — should not raise
|
||||||
|
await store.delete("u1", "u1/tasks/nonexistent")
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_returns_correct_keys(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.upload("u1", "tasks", "r2", _BLOB, _CHECKSUM)
|
||||||
|
keys = await store.list_keys("u1", "tasks")
|
||||||
|
assert set(keys) == {"u1/tasks/r1", "u1/tasks/r2"}
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_scoped_to_table(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.upload("u1", "notes", "n1", _BLOB, _CHECKSUM)
|
||||||
|
keys = await store.list_keys("u1", "tasks")
|
||||||
|
assert "u1/notes/n1" not in keys
|
||||||
|
assert "u1/tasks/r1" in keys
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_no_cross_user_leakage(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
await store.upload("u2", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
keys_u1 = await store.list_keys("u1", "tasks")
|
||||||
|
assert "u2/tasks/r1" not in keys_u1
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_keys_empty_table(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
keys = await store.list_keys("u1", "tasks")
|
||||||
|
assert keys == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_uses_sse_s3_encryption(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
# Verify S3 metadata was set — check via head_object
|
||||||
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||||
|
mock_settings.S3_BUCKET = _BUCKET
|
||||||
|
mock_settings.S3_REGION = _REGION
|
||||||
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||||
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||||
|
client = boto3.client("s3", region_name=_REGION)
|
||||||
|
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||||
|
assert response.get("ServerSideEncryption") == "AES256"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upload_stores_checksum_in_metadata(self, s3_bucket: str) -> None:
|
||||||
|
store = BlobStore()
|
||||||
|
await store.upload("u1", "tasks", "r1", _BLOB, _CHECKSUM)
|
||||||
|
client = boto3.client("s3", region_name=_REGION)
|
||||||
|
response = client.head_object(Bucket=_BUCKET, Key="u1/tasks/r1")
|
||||||
|
assert response["Metadata"]["checksum"] == _CHECKSUM
|
||||||
|
|
||||||
|
|
||||||
|
# ── _blob_to_vector helper ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBlobToVector:
|
||||||
|
def test_returns_32_floats(self) -> None:
|
||||||
|
v = _blob_to_vector(b"test")
|
||||||
|
assert len(v) == 32
|
||||||
|
|
||||||
|
def test_all_values_in_range(self) -> None:
|
||||||
|
v = _blob_to_vector(b"test")
|
||||||
|
assert all(-1.0 <= x <= 1.0 for x in v)
|
||||||
|
|
||||||
|
def test_deterministic(self) -> None:
|
||||||
|
assert _blob_to_vector(b"same") == _blob_to_vector(b"same")
|
||||||
|
|
||||||
|
def test_different_blobs_different_vectors(self) -> None:
|
||||||
|
assert _blob_to_vector(b"aaa") != _blob_to_vector(b"bbb")
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestVectorStorePinecone ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorStorePinecone:
|
||||||
|
def _store(self) -> VectorStore:
|
||||||
|
store = VectorStore()
|
||||||
|
store._use_pinecone = lambda: True # type: ignore[method-assign]
|
||||||
|
return store
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_calls_index_upsert(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"enc-blob", checksum=hashlib.sha256(b"enc-blob").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
mock_index.upsert.assert_called_once()
|
||||||
|
call_kwargs = mock_index.upsert.call_args[1]
|
||||||
|
assert call_kwargs.get("namespace") == "u1"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_encodes_blob_as_base64_in_metadata(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"secret", checksum=hashlib.sha256(b"secret").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
vectors_arg = mock_index.upsert.call_args[1]["vectors"]
|
||||||
|
assert vectors_arg[0]["metadata"]["blob"] == base64.b64encode(b"secret").decode()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_calls_index_query(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query-blob", top_k=5)
|
||||||
|
mock_index.query.assert_called_once()
|
||||||
|
query_kwargs = mock_index.query.call_args[1]
|
||||||
|
assert query_kwargs.get("namespace") == "u1"
|
||||||
|
assert query_kwargs.get("top_k") == 5
|
||||||
|
assert query_kwargs.get("include_metadata") is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_returns_vector_search_results(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
results = await store.search("u1", b"query", top_k=10)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], VectorSearchResult)
|
||||||
|
assert results[0].id == "v1"
|
||||||
|
assert results[0].score == 0.95
|
||||||
|
assert results[0].blob == b"result-blob"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_uses_derived_query_vector(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query-blob", top_k=3)
|
||||||
|
expected_vector = _blob_to_vector(b"query-blob")
|
||||||
|
actual_vector = mock_index.query.call_args[1].get("vector")
|
||||||
|
assert actual_vector == expected_vector
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_calls_index_delete(self) -> None:
|
||||||
|
mock_pc, mock_index = _pinecone_mock()
|
||||||
|
with patch("app.storage.vector_store.Pinecone", mock_pc):
|
||||||
|
store = self._store()
|
||||||
|
await store.delete("u1", ["v1", "v2"])
|
||||||
|
mock_index.delete.assert_called_once()
|
||||||
|
delete_kwargs = mock_index.delete.call_args[1]
|
||||||
|
assert delete_kwargs.get("namespace") == "u1"
|
||||||
|
assert set(delete_kwargs.get("ids", [])) == {"v1", "v2"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestVectorStoreQdrant ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestVectorStoreQdrant:
|
||||||
|
def _store(self) -> VectorStore:
|
||||||
|
store = VectorStore()
|
||||||
|
store._use_pinecone = lambda: False # type: ignore[method-assign]
|
||||||
|
return store
|
||||||
|
|
||||||
|
def _qdrant_mock(self) -> MagicMock:
|
||||||
|
mock_hit = MagicMock()
|
||||||
|
mock_hit.id = "v1"
|
||||||
|
mock_hit.score = 0.88
|
||||||
|
mock_hit.payload = {
|
||||||
|
"blob": base64.b64encode(b"qdrant-result").decode(),
|
||||||
|
"user_id": "u1",
|
||||||
|
}
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.search.return_value = [mock_hit]
|
||||||
|
return mock_client
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_calls_client_upsert(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
mock_client.upsert.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_upsert_uses_correct_collection(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
items = [VectorItem(id="v1", blob=b"enc", checksum=hashlib.sha256(b"enc").hexdigest())]
|
||||||
|
await store.upsert("u1", items)
|
||||||
|
call_kwargs = mock_client.upsert.call_args[1]
|
||||||
|
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_calls_client_search(self) -> None:
|
||||||
|
mock_client = self._qdrant_mock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query", top_k=5)
|
||||||
|
mock_client.search.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_passes_limit(self) -> None:
|
||||||
|
mock_client = self._qdrant_mock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.search("u1", b"query", top_k=7)
|
||||||
|
call_kwargs = mock_client.search.call_args[1]
|
||||||
|
assert call_kwargs.get("limit") == 7
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_search_returns_vector_search_results(self) -> None:
|
||||||
|
mock_client = self._qdrant_mock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
results = await store.search("u1", b"query", top_k=5)
|
||||||
|
assert len(results) == 1
|
||||||
|
assert isinstance(results[0], VectorSearchResult)
|
||||||
|
assert results[0].id == "v1"
|
||||||
|
assert results[0].score == 0.88
|
||||||
|
assert results[0].blob == b"qdrant-result"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_calls_client_delete(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.delete("u1", ["v1", "v2"])
|
||||||
|
mock_client.delete.assert_called_once()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_delete_uses_correct_collection(self) -> None:
|
||||||
|
mock_client = MagicMock()
|
||||||
|
with patch("app.storage.vector_store.QdrantClient", return_value=mock_client):
|
||||||
|
store = self._store()
|
||||||
|
await store.delete("u1", ["v1"])
|
||||||
|
call_kwargs = mock_client.delete.call_args[1]
|
||||||
|
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||||
Reference in New Issue
Block a user