Compare commits
50 Commits
8f7bc25611
...
feature/de
| Author | SHA1 | Date | |
|---|---|---|---|
| 47bf1881e5 | |||
| 24a9c1b752 | |||
| 706bf88883 | |||
| 4ff0b27084 | |||
| 61d2a18234 | |||
| b3687719b6 | |||
| f80bdfa8f7 | |||
| 617a17db40 | |||
| 92716cb89a | |||
| cfc9d7a942 | |||
| 2de67213f8 | |||
| f6ed383b3a | |||
| 9332e29e53 | |||
| 618076193a | |||
| 34f01234c9 | |||
| 0bd46937d3 | |||
| e6b5bc2e7d | |||
| c90ed58078 | |||
| 76c8f2bdad | |||
| 393b3befd6 | |||
| 2c08275934 | |||
| 7cb384fa63 | |||
| 7efaeba283 | |||
| b61ded8458 | |||
| ac71d99f9a | |||
| 3b3b3baf25 | |||
| 45415bb9ee | |||
| a775a2da18 | |||
| 24772f2b67 | |||
| fd1396a710 | |||
| 914f70bd85 | |||
| 608d6c784f | |||
| 19ad5be97f | |||
| 1dfd088e18 | |||
| c6e1e4e7fd | |||
| cc603aba06 | |||
| 6d9a16e513 | |||
| 27c087d5d8 | |||
|
|
4d7fd519c5 | ||
| 06de7c7ab0 | |||
| e3c7547c75 | |||
| 314780d59a | |||
| 091787a6da | |||
| 7f278c6f63 | |||
| 8bfce9da00 | |||
| 480e7ac5bd | |||
| d0b303e745 | |||
| 5d485b3665 | |||
| 9787befd4a | |||
| 9119474e71 |
32
.env.example
32
.env.example
@@ -10,18 +10,34 @@ JWT_ALGORITHM=HS256
|
|||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
# ── OpenAI ────────────────────────────────────────────────────────────────────
|
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||||
OPENAI_API_KEY=sk-...
|
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||||
|
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
GOOGLE_API_KEY=
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
# ── Stripe ────────────────────────────────────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=sk_test_...
|
STRIPE_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=whsec_...
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
||||||
S3_BUCKET=adiuva-backups
|
S3_BUCKET=adiuva
|
||||||
S3_REGION=us-east-1
|
S3_REGION=us-east-1
|
||||||
AWS_ACCESS_KEY_ID=AKIA...
|
S3_ENDPOINT_URL=
|
||||||
AWS_SECRET_ACCESS_KEY=...
|
AWS_ACCESS_KEY_ID=
|
||||||
|
AWS_SECRET_ACCESS_KEY=
|
||||||
|
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
|
||||||
|
# ── Vector Store ──────────────────────────────────────────────────────────────
|
||||||
|
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
|
||||||
|
PINECONE_API_KEY=
|
||||||
|
PINECONE_INDEX=adiuva
|
||||||
|
QDRANT_URL=
|
||||||
|
QDRANT_API_KEY=
|
||||||
|
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
||||||
|
|
||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
|
|||||||
@@ -1,21 +1,93 @@
|
|||||||
name: Deploy to Proxmox Docker
|
name: Test & Deploy API
|
||||||
run-name: Deploying ${{ gitea.sha }}
|
run-name: ${{ gitea.ref_name }} → Docker LXC
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
tags:
|
||||||
- main # O il nome del tuo branch principale
|
- 'v*'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
Deploy:
|
# ── 1. Run tests in an isolated Python container ──────────────────
|
||||||
runs-on: ubuntu-latest # Questo dipende dalle label che hai dato al tuo act_runner
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: python:3.12-slim
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Deploying via SSH
|
- name: Install git
|
||||||
|
run: apt-get update && apt-get install -y --no-install-recommends git
|
||||||
|
|
||||||
|
- name: Checkout Code
|
||||||
|
run: |
|
||||||
|
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
||||||
|
"http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . || \
|
||||||
|
git clone --depth 1 "http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . && \
|
||||||
|
git checkout "${GITHUB_SHA}"
|
||||||
|
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
- name: Run Linter
|
||||||
|
run: ruff check app/ tests/
|
||||||
|
|
||||||
|
- name: Run Tests
|
||||||
|
run: pytest tests/ -v --tb=short
|
||||||
|
|
||||||
|
# ── 2. Deploy to Docker LXC via SSH ─────────────────────────────────
|
||||||
|
deploy:
|
||||||
|
needs: test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: gitea.event_name == 'push'
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Deploy via SSH
|
||||||
uses: appleboy/ssh-action@v1.0.0
|
uses: appleboy/ssh-action@v1.0.0
|
||||||
with:
|
with:
|
||||||
host: ${{ secrets.SSH_HOST }}
|
host: ${{ secrets.SSH_HOST }}
|
||||||
username: ${{ secrets.SSH_USER }}
|
username: ${{ secrets.SSH_USER }}
|
||||||
key: ${{ secrets.SSH_KEY }}
|
key: ${{ secrets.SSH_KEY }}
|
||||||
script: |
|
script: |
|
||||||
cd /opt/adiuva-api
|
set -e
|
||||||
git pull origin main
|
DEPLOY_DIR="/opt/adiuva-api"
|
||||||
|
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||||
|
TAG="${{ gitea.ref_name }}"
|
||||||
|
|
||||||
|
# ── Pull latest code ──
|
||||||
|
cd /tmp && rm -rf adiuva-api-deploy
|
||||||
|
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
|
||||||
|
|
||||||
|
# ── Sync source (preserve .env) ──
|
||||||
|
cp -rf /tmp/adiuva-api-deploy/app/ \
|
||||||
|
/tmp/adiuva-api-deploy/alembic/ \
|
||||||
|
/tmp/adiuva-api-deploy/alembic.ini \
|
||||||
|
/tmp/adiuva-api-deploy/Dockerfile \
|
||||||
|
/tmp/adiuva-api-deploy/docker-compose.yml \
|
||||||
|
/tmp/adiuva-api-deploy/requirements.txt \
|
||||||
|
"$DEPLOY_DIR/"
|
||||||
|
rm -rf /tmp/adiuva-api-deploy
|
||||||
|
|
||||||
|
# ── Verify .env ──
|
||||||
|
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||||
|
echo "❌ $DEPLOY_DIR/.env not found. Create it before deploying."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Build & restart ──
|
||||||
|
cd "$DEPLOY_DIR"
|
||||||
|
docker compose down --remove-orphans || true
|
||||||
docker compose up -d --build
|
docker compose up -d --build
|
||||||
|
|
||||||
|
# ── Migrations ──
|
||||||
|
docker compose exec -T app alembic upgrade head
|
||||||
|
|
||||||
|
# ── Health check ──
|
||||||
|
echo "Waiting for app..."
|
||||||
|
sleep 5
|
||||||
|
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/api/v1/health)
|
||||||
|
if [ "$HTTP_CODE" -eq 200 ]; then
|
||||||
|
echo "✅ API is healthy (HTTP ${HTTP_CODE})"
|
||||||
|
else
|
||||||
|
echo "❌ Health check failed (HTTP ${HTTP_CODE})"
|
||||||
|
docker compose logs app --tail=50
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
64
.github/workflows/ci.yml
vendored
Normal file
64
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install ruff
|
||||||
|
run: pip install ruff>=0.8.0
|
||||||
|
|
||||||
|
- name: Ruff check
|
||||||
|
run: ruff check .
|
||||||
|
|
||||||
|
- name: Ruff format check
|
||||||
|
run: ruff format --check .
|
||||||
|
|
||||||
|
test:
|
||||||
|
name: Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Cache pip
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip
|
||||||
|
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
|
||||||
|
restore-keys: ${{ runner.os }}-pip-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -r requirements.txt
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: pytest -v --tb=short
|
||||||
|
|
||||||
|
docker:
|
||||||
|
name: Docker Build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: test
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Build image
|
||||||
|
run: docker build -t adiuva-api:ci .
|
||||||
|
|
||||||
|
- name: Verify gunicorn installed
|
||||||
|
run: docker run --rm adiuva-api:ci gunicorn --version
|
||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -31,3 +31,4 @@ Thumbs.db
|
|||||||
|
|
||||||
# Claude Code
|
# Claude Code
|
||||||
.claude/
|
.claude/
|
||||||
|
logs/
|
||||||
|
|||||||
530
BACKEND_PLAN.md
530
BACKEND_PLAN.md
@@ -1,530 +0,0 @@
|
|||||||
# Backend Plan — Adiuva Cloud API
|
|
||||||
|
|
||||||
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
|
||||||
>
|
|
||||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
|
|
||||||
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── app/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── main.py # FastAPI entry + CORS + lifespan + router includes
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── agent_registry.py # Base classes + singleton registry
|
|
||||||
│ │ ├── orchestrator.py # LLM-based intent router
|
|
||||||
│ │ ├── execution_plan.py # Plan builder + cache
|
|
||||||
│ │ └── plugin_loader.py # Dynamic agent loading
|
|
||||||
│ ├── agents/ # Chat agents (proprietary logic + prompts)
|
|
||||||
│ │ ├── __init__.py # Auto-registers all agents
|
|
||||||
│ │ ├── task_agent.py
|
|
||||||
│ │ ├── calendar_agent.py
|
|
||||||
│ │ ├── email_agent.py
|
|
||||||
│ │ └── analytics_agent.py
|
|
||||||
│ ├── api/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── routes/
|
|
||||||
│ │ │ ├── __init__.py
|
|
||||||
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
|
||||||
│ │ │ ├── plans.py # GET /plans/playbook
|
|
||||||
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
|
|
||||||
│ │ │ ├── vectors.py # Upsert/search cloud vector store
|
|
||||||
│ │ │ ├── backup.py # PUT/GET /backup
|
|
||||||
│ │ │ ├── plugins.py # Plugin marketplace
|
|
||||||
│ │ │ ├── auth.py # Register/login/refresh
|
|
||||||
│ │ │ └── billing.py # Checkout/webhook/subscription
|
|
||||||
│ │ └── middleware/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── auth.py # JWT validation
|
|
||||||
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
|
||||||
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
|
||||||
│ ├── storage/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
|
|
||||||
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
|
|
||||||
│ │ └── encryption.py # Integrity verification only — NO decryption
|
|
||||||
│ ├── marketplace/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
|
|
||||||
│ │ ├── plugin_review.py # Review queue + approval workflow
|
|
||||||
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
|
|
||||||
│ ├── billing/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
|
||||||
│ │ └── tier_manager.py # Feature matrix per tier
|
|
||||||
│ └── config/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ └── settings.py # Pydantic BaseSettings (env-based)
|
|
||||||
├── tests/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── conftest.py # Fixtures: test client, mock agents, mock LLM
|
|
||||||
│ ├── test_orchestrator.py
|
|
||||||
│ ├── test_agents.py
|
|
||||||
│ ├── test_auth.py
|
|
||||||
│ ├── test_backup.py
|
|
||||||
│ ├── test_storage.py
|
|
||||||
│ └── test_plugins.py
|
|
||||||
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
|
|
||||||
│ ├── alembic.ini
|
|
||||||
│ └── versions/
|
|
||||||
├── requirements.txt
|
|
||||||
├── Dockerfile
|
|
||||||
├── docker-compose.yml # App + PostgreSQL + Redis (dev)
|
|
||||||
├── .env.example
|
|
||||||
└── README.md
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step-by-Step Implementation
|
|
||||||
|
|
||||||
### Step 1 — Project scaffolding ✅
|
|
||||||
- [x] Initialize repo with the directory structure above
|
|
||||||
- [x] Write `requirements.txt`:
|
|
||||||
```
|
|
||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
langchain>=0.3.0
|
|
||||||
langchain-openai>=0.3.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
stripe>=11.0.0
|
|
||||||
boto3>=1.35.0
|
|
||||||
slowapi>=0.1.9
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
alembic>=1.14.0
|
|
||||||
bcrypt>=4.2.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
httpx>=0.28.0
|
|
||||||
websockets>=14.0
|
|
||||||
pytest>=8.0.0
|
|
||||||
pytest-asyncio>=0.24.0
|
|
||||||
```
|
|
||||||
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
|
||||||
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
|
||||||
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
|
||||||
- [x] Write `.env.example`
|
|
||||||
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
|
|
||||||
|
|
||||||
### Step 2 — Pydantic schemas (API contracts) ✅
|
|
||||||
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
|
||||||
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
|
||||||
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
|
||||||
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
|
||||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
|
|
||||||
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
|
||||||
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
|
||||||
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
|
||||||
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
|
||||||
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
|
||||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
|
||||||
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
|
|
||||||
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
|
|
||||||
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
|
|
||||||
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
|
|
||||||
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
|
|
||||||
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
|
|
||||||
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
|
|
||||||
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
|
|
||||||
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
|
|
||||||
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
|
|
||||||
- `PluginInstallRequest`: `plugin_id: str`
|
|
||||||
- **Outcome:** All request/response models defined and validated.
|
|
||||||
|
|
||||||
### Step 3 — Agent Registry + base classes ✅
|
|
||||||
- [x] `app/core/agent_registry.py`:
|
|
||||||
- `BaseAgent(ABC)`:
|
|
||||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
|
||||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
|
||||||
- `ChatAgent(BaseAgent)`:
|
|
||||||
- Abstract `async handle(query: str, context: dict) -> str`
|
|
||||||
- Abstract `get_tools() -> list` (LangChain tool definitions)
|
|
||||||
- Concrete `_tool_loop(llm, messages, tools, max_iter=5) -> str` — shared tool-calling loop
|
|
||||||
- `AgentRegistry` (singleton):
|
|
||||||
- `_agents: dict[str, ChatAgent]`
|
|
||||||
- `register(agent_class)` — decorator pattern
|
|
||||||
- `get(name) -> ChatAgent`
|
|
||||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
|
||||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
|
||||||
- [x] Unit tests: register, get, list, call_agent with mock
|
|
||||||
- **Outcome:** Pluggable agent framework.
|
|
||||||
|
|
||||||
### Step 4 — Orchestrator ✅
|
|
||||||
- [x] `app/core/orchestrator.py`:
|
|
||||||
- `async classify_intent(message, context, registry) -> str`:
|
|
||||||
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
|
||||||
- Uses gpt-4o-mini via LangChain for low latency
|
|
||||||
- Falls back to `task_agent` if no clear match
|
|
||||||
- `async route_single(agent_name, message, context) -> ChatResponse`:
|
|
||||||
- Instantiates agent from registry
|
|
||||||
- Calls `agent.handle(message, context)`
|
|
||||||
- Returns response + any actions the agent produced
|
|
||||||
- `async route_pipeline(agent_names, message, context) -> ChatResponse`:
|
|
||||||
- Executes agents in sequence
|
|
||||||
- Each agent receives `{...context, previous_results: [...]}`
|
|
||||||
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
|
||||||
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
|
||||||
- Main entry point
|
|
||||||
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
|
|
||||||
- Classifies intent
|
|
||||||
- If `execution_mode == 'direct'`: route + return response
|
|
||||||
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
|
||||||
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
|
||||||
- Same as orchestrate but yields tokens for WebSocket streaming
|
|
||||||
- [x] Integration tests with mocked LLM and mocked agents
|
|
||||||
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
|
||||||
|
|
||||||
### Step 5 — Execution Plan generator ✅
|
|
||||||
- [x] `app/core/execution_plan.py`:
|
|
||||||
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
|
||||||
- `ExecutionPlanBuilder`:
|
|
||||||
- `add_step(action, params) -> self`
|
|
||||||
- `add_llm_step(template_id, variables) -> self`
|
|
||||||
- `add_data_step(action, data_from_step) -> self`
|
|
||||||
- `build() -> ExecutionPlan` — validates step references
|
|
||||||
- `PlanCache`:
|
|
||||||
- In-memory LRU (maxsize=1000)
|
|
||||||
- `cache_plan(key, plan)`, `get_plan(key)`, `get_all_playbooks() -> list[ExecutionPlan]`
|
|
||||||
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
|
||||||
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
|
||||||
|
|
||||||
### Step 6 — Chat Agents ✅
|
|
||||||
- [x] `app/agents/task_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
|
|
||||||
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
|
|
||||||
- Accepts flexible context; sentinel `-1` for optional integer update fields
|
|
||||||
- [x] `app/agents/checkpoint_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages project checkpoints (milestones): list, create, update, delete"
|
|
||||||
- Tools (4): `list_checkpoints(project_id)`, `create_checkpoint(project_id, title, date, is_ai_suggested, is_approved)`, `update_checkpoint(checkpoint_id, ...)`, `delete_checkpoint(checkpoint_id)`
|
|
||||||
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
|
|
||||||
- [x] `app/agents/project_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
|
|
||||||
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
|
|
||||||
- [x] `app/agents/note_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages notes: list, get, create, update, delete"
|
|
||||||
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
|
|
||||||
- content is Markdown; `get_note` should be called before update to preserve existing content
|
|
||||||
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
|
|
||||||
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
|
|
||||||
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Checkpoints, Projects, Notes), all registered and tested.
|
|
||||||
|
|
||||||
### Step 7 — Storage Layer ✅
|
|
||||||
- [x] `app/storage/blob_store.py`:
|
|
||||||
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
|
|
||||||
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
|
|
||||||
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
|
|
||||||
- [x] `app/storage/vector_store.py`:
|
|
||||||
- `VectorStore`: `async upsert`, `async search`, `async delete`
|
|
||||||
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
|
|
||||||
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
|
|
||||||
- ANN on encrypted data: known accuracy trade-off, documented
|
|
||||||
- [x] `app/storage/encryption.py`:
|
|
||||||
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
|
|
||||||
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
|
|
||||||
- Backend NEVER holds decryption keys
|
|
||||||
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
|
|
||||||
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
|
|
||||||
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
|
|
||||||
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
|
|
||||||
|
|
||||||
### Step 8 — API Routes ✅
|
|
||||||
|
|
||||||
#### 8a — Chat endpoint
|
|
||||||
- [x] `app/api/routes/chat.py`:
|
|
||||||
- `POST /api/v1/chat`:
|
|
||||||
- Request: `ChatRequest`
|
|
||||||
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
|
||||||
- Response: `ChatResponse` or `ExecutionPlan`
|
|
||||||
- `WebSocket /api/v1/chat/stream`:
|
|
||||||
- Client sends `ChatRequest` as first JSON frame
|
|
||||||
- Server yields token strings via `orchestrate_stream()`
|
|
||||||
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
|
||||||
- Heartbeat ping every 30s to keep connection alive
|
|
||||||
|
|
||||||
#### 8b — Plans endpoint
|
|
||||||
- [x] `app/api/routes/plans.py`:
|
|
||||||
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
|
||||||
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
|
||||||
|
|
||||||
#### 8c — Storage endpoint (cloud records)
|
|
||||||
- [x] `app/api/routes/storage.py`:
|
|
||||||
- `POST /api/v1/storage/records`: Create encrypted record
|
|
||||||
- Request: `StorageRecordCreate`
|
|
||||||
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
|
|
||||||
- Response: `{id: str, created_at: int}`
|
|
||||||
- `GET /api/v1/storage/records`: List record metadata (no blobs)
|
|
||||||
- Query params: `table: str`, `page: int`, `limit: int`
|
|
||||||
- Response: `list[{id, table, checksum, created_at, updated_at}]`
|
|
||||||
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
|
|
||||||
- Response: blob bytes + `X-Checksum` header
|
|
||||||
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
|
|
||||||
- Request: `StorageRecordUpdate`
|
|
||||||
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
|
|
||||||
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
|
|
||||||
|
|
||||||
#### 8d — Vectors endpoint (cloud vector store)
|
|
||||||
- [x] `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/upsert`:
|
|
||||||
- Request: `VectorUpsertRequest`
|
|
||||||
- Verifies checksums, delegates to `VectorStore.upsert()`
|
|
||||||
- Response: `{upserted: int}`
|
|
||||||
- `POST /api/v1/storage/vectors/search`:
|
|
||||||
- Request: `VectorSearchRequest`
|
|
||||||
- Delegates to `VectorStore.search()`
|
|
||||||
- Response: `VectorSearchResponse`
|
|
||||||
- `DELETE /api/v1/storage/vectors`:
|
|
||||||
- Request: `{ids: list[str]}`
|
|
||||||
|
|
||||||
#### 8e — Backup endpoint
|
|
||||||
- [x] `app/api/routes/backup.py`:
|
|
||||||
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
|
||||||
- Free: 0 (no backup)
|
|
||||||
- Pro: 5 GB
|
|
||||||
- Power: 25 GB
|
|
||||||
- Team: unlimited
|
|
||||||
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
|
||||||
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
|
||||||
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
|
||||||
|
|
||||||
#### 8f — Plugins endpoint
|
|
||||||
- [x] `app/api/routes/plugins.py`:
|
|
||||||
- `GET /api/v1/plugins`:
|
|
||||||
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
|
|
||||||
- Response: `PluginListResponse`
|
|
||||||
- Available from Power tier and above
|
|
||||||
- `GET /api/v1/plugins/{id}`:
|
|
||||||
- Response: `PluginManifest` + ratings + install count
|
|
||||||
- `POST /api/v1/plugins/{id}/install`:
|
|
||||||
- Request: `PluginInstallRequest`
|
|
||||||
- Records installation for the user (billing tracking, analytics)
|
|
||||||
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
|
|
||||||
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
|
|
||||||
- `DELETE /api/v1/plugins/{id}/install`:
|
|
||||||
- Unregisters installation
|
|
||||||
|
|
||||||
#### 8g — Auth endpoint
|
|
||||||
- [x] `app/api/routes/auth.py`:
|
|
||||||
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
|
||||||
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
|
||||||
|
|
||||||
#### 8h — Billing endpoint
|
|
||||||
- [x] `app/api/routes/billing.py`:
|
|
||||||
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
|
||||||
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
|
||||||
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
|
||||||
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
|
||||||
|
|
||||||
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
|
|
||||||
|
|
||||||
### Step 9 — Middleware
|
|
||||||
|
|
||||||
#### 9a — Auth middleware
|
|
||||||
- [x] `app/api/middleware/auth.py`:
|
|
||||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
|
||||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
|
||||||
- Raises `401` on invalid/expired token
|
|
||||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
|
||||||
|
|
||||||
#### 9b — Rate limiter
|
|
||||||
- [x] `app/api/middleware/rate_limit.py`:
|
|
||||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
|
||||||
- Tier-based limits:
|
|
||||||
- Free: 20 req/min
|
|
||||||
- Pro: 60 req/min
|
|
||||||
- Power: 120 req/min
|
|
||||||
- Team: 200 req/seat/min
|
|
||||||
- Custom 429 response with `Retry-After` header
|
|
||||||
|
|
||||||
#### 9c — Sanitizer
|
|
||||||
- [x] `app/api/middleware/sanitizer.py`:
|
|
||||||
- Response middleware that scans response bodies
|
|
||||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
|
||||||
- Pattern-based detection + exact match against known prompt fingerprints
|
|
||||||
- Logs sanitization events for monitoring
|
|
||||||
|
|
||||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
|
||||||
|
|
||||||
### Step 10 — Plugin Marketplace ✅
|
|
||||||
- [x] `app/marketplace/plugin_registry.py`:
|
|
||||||
- `PluginRegistry`:
|
|
||||||
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
|
||||||
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
|
||||||
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
|
||||||
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
|
||||||
- `async reject_plugin(plugin_id, reason: str) -> None`
|
|
||||||
- [x] `app/marketplace/plugin_review.py`:
|
|
||||||
- `ReviewQueue`:
|
|
||||||
- `async get_pending() -> list[dict]`
|
|
||||||
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
|
||||||
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
|
||||||
- [x] `app/marketplace/revenue_share.py`:
|
|
||||||
- `RevenueShare`:
|
|
||||||
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
|
||||||
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
|
||||||
- `async get_earnings(developer_id, period) -> dict`
|
|
||||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
|
||||||
|
|
||||||
### Step 11 — Billing & Tier management
|
|
||||||
- [ ] `app/billing/stripe_service.py`:
|
|
||||||
- `create_checkout_session(user_id, tier) -> str`
|
|
||||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
|
||||||
- `get_subscription(user_id) -> dict | None`
|
|
||||||
- `cancel_subscription(user_id) -> None`
|
|
||||||
- [ ] `app/billing/tier_manager.py`:
|
|
||||||
- `TierManager`:
|
|
||||||
- Feature matrix:
|
|
||||||
```python
|
|
||||||
FEATURES = {
|
|
||||||
'free': {
|
|
||||||
'agents': 3,
|
|
||||||
'batch_active': 2,
|
|
||||||
'cloud_storage_gb': 0,
|
|
||||||
'backup_gb': 0,
|
|
||||||
'providers': 1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'pro': {
|
|
||||||
'agents': -1, # unlimited
|
|
||||||
'batch_active': 10,
|
|
||||||
'cloud_storage_gb': 5,
|
|
||||||
'backup_gb': 5,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'power': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1, # unlimited
|
|
||||||
'cloud_storage_gb': 25,
|
|
||||||
'backup_gb': 25,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'team': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1,
|
|
||||||
'cloud_storage_gb': -1,
|
|
||||||
'backup_gb': -1,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- `get_tier(user_id) -> BillingTier`
|
|
||||||
- `check_feature(user_id, feature) -> bool`
|
|
||||||
- `get_rate_limit(tier) -> int`
|
|
||||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
|
||||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
|
||||||
|
|
||||||
### Step 12 — Database (auth/billing/marketplace only)
|
|
||||||
- [ ] PostgreSQL schema via Alembic:
|
|
||||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
|
||||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
|
||||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
|
||||||
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
|
||||||
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
|
|
||||||
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
|
|
||||||
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
|
||||||
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
|
||||||
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
|
||||||
- [ ] Initial Alembic migration
|
|
||||||
- [ ] SQLAlchemy models in `app/models.py`
|
|
||||||
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
|
||||||
|
|
||||||
### Step 13 — Testing & deployment
|
|
||||||
- [ ] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
|
|
||||||
- [ ] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
|
||||||
- [ ] `tests/test_agents.py`: each agent with mocked tools
|
|
||||||
- [ ] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
|
||||||
- [ ] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
|
||||||
- [ ] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
|
|
||||||
- [ ] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
|
|
||||||
- [ ] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
|
||||||
- [ ] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
|
||||||
- **Outcome:** Fully tested, deployable backend.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API Contract Summary
|
|
||||||
|
|
||||||
| Method | Endpoint | Auth | Request | Response |
|
|
||||||
|--------|----------|------|---------|----------|
|
|
||||||
| POST | `/api/v1/auth/register` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/login` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/refresh` | No | `{refresh_token}` | `AuthTokens` |
|
|
||||||
| GET | `/api/v1/auth/me` | JWT | — | `UserProfile` |
|
|
||||||
| POST | `/api/v1/chat` | JWT | `ChatRequest` | `ChatResponse \| ExecutionPlan` |
|
|
||||||
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
|
||||||
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
|
||||||
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
|
||||||
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
|
|
||||||
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
|
|
||||||
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
|
|
||||||
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
|
|
||||||
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
|
|
||||||
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
|
|
||||||
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
|
||||||
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
|
||||||
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
|
|
||||||
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
|
|
||||||
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
|
|
||||||
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
|
||||||
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
|
||||||
| DELETE | `/api/v1/billing/subscription` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/health` | No | — | `{status, version}` |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stack
|
|
||||||
|
|
||||||
| Layer | Technology |
|
|
||||||
|-------|-----------|
|
|
||||||
| Framework | FastAPI + Uvicorn |
|
|
||||||
| LLM | LangChain + langchain-openai |
|
|
||||||
| Auth | PyJWT + bcrypt + OAuth2 |
|
|
||||||
| Billing | stripe-python + Stripe Connect |
|
|
||||||
| Blob storage | boto3 (S3) |
|
|
||||||
| Vector store | Pinecone or Qdrant (configurable) |
|
|
||||||
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
|
||||||
| Rate limiting | slowapi |
|
|
||||||
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
|
|
||||||
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Development Rules
|
|
||||||
|
|
||||||
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
|
|
||||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
|
|
||||||
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
|
|
||||||
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
|
||||||
5. **Type hints everywhere.** All functions have full type annotations.
|
|
||||||
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
|
||||||
7. **Structured logging.** JSON logs with request ID correlation.
|
|
||||||
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
|
|
||||||
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
|
||||||
10
Dockerfile
10
Dockerfile
@@ -21,6 +21,10 @@ COPY --from=builder /install /usr/local
|
|||||||
# Copy application source
|
# Copy application source
|
||||||
COPY app/ app/
|
COPY app/ app/
|
||||||
|
|
||||||
|
# Copy Alembic migration files
|
||||||
|
COPY alembic/ alembic/
|
||||||
|
COPY alembic.ini .
|
||||||
|
|
||||||
# Ensure appuser owns the working directory
|
# Ensure appuser owns the working directory
|
||||||
RUN chown -R appuser:appgroup /app
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
@@ -28,4 +32,8 @@ USER appuser
|
|||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "4", \
|
||||||
|
"--timeout", "120"]
|
||||||
|
|||||||
793
README.md
Normal file
793
README.md
Normal file
@@ -0,0 +1,793 @@
|
|||||||
|
# Adiuva Cloud API
|
||||||
|
|
||||||
|
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
|
||||||
|
|
||||||
|
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [Architecture](#architecture)
|
||||||
|
- [Key Features](#key-features)
|
||||||
|
- [Tech Stack](#tech-stack)
|
||||||
|
- [Getting Started](#getting-started)
|
||||||
|
- [Docker Deployment](#docker-deployment)
|
||||||
|
- [Environment Variables](#environment-variables)
|
||||||
|
- [API Reference](#api-reference)
|
||||||
|
- [Data Model](#data-model)
|
||||||
|
- [AI Agent System](#ai-agent-system)
|
||||||
|
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
||||||
|
- [Middleware](#middleware)
|
||||||
|
- [Storage Layer](#storage-layer)
|
||||||
|
- [Billing & Tiers](#billing--tiers)
|
||||||
|
- [Plugin Marketplace](#plugin-marketplace)
|
||||||
|
- [Testing](#testing)
|
||||||
|
- [Project Structure](#project-structure)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
|
||||||
|
|
||||||
|
### Design Principles
|
||||||
|
|
||||||
|
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
|
||||||
|
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
||||||
|
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
|
||||||
|
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
||||||
|
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
|
||||||
|
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
|
||||||
|
│ Desktop App │────▶│ │
|
||||||
|
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
|
||||||
|
└──────────────┘ │ │
|
||||||
|
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
||||||
|
│ │ Auth Routes │ │ Chat Routes │ │
|
||||||
|
│ │ Billing Routes │ │ ↓ │ │
|
||||||
|
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
||||||
|
│ │ Backup Routes │ │ ↓ classify intent │ │
|
||||||
|
│ │ Plugin Routes │ │ Agent Registry │ │
|
||||||
|
│ │ Vector Routes │ │ ↓ │ │
|
||||||
|
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
|
||||||
|
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
|
||||||
|
│ │ (GPT-4o + LangChain) │ │
|
||||||
|
│ └────────────────────────────┘ │
|
||||||
|
└────────────────────────────────────────────────────────┘
|
||||||
|
│ │ │
|
||||||
|
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
|
||||||
|
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
|
||||||
|
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
|
||||||
|
│ Billing, │ │ backups) │ │ (Vectors) │
|
||||||
|
│ Metadata) │ └───────────────┘ └────────────────┘
|
||||||
|
└────────────┘
|
||||||
|
│
|
||||||
|
┌────────▼───┐
|
||||||
|
│ Stripe │
|
||||||
|
│ (Billing, │
|
||||||
|
│ Connect) │
|
||||||
|
└────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
||||||
|
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
||||||
|
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
||||||
|
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
||||||
|
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
||||||
|
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
|
||||||
|
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
|
||||||
|
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
||||||
|
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
||||||
|
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
||||||
|
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
||||||
|
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
|
||||||
|
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
||||||
|
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
|
||||||
|
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
|
||||||
|
| Package | Version | Purpose |
|
||||||
|
|---|---|---|
|
||||||
|
| `fastapi` | ≥ 0.115.0 | Web framework |
|
||||||
|
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
|
||||||
|
| `gunicorn` | ≥ 22.0.0 | Production process manager |
|
||||||
|
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
|
||||||
|
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
|
||||||
|
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
|
||||||
|
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
|
||||||
|
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
||||||
|
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
||||||
|
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
||||||
|
| `boto3` | ≥ 1.35.0 | AWS S3 client |
|
||||||
|
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
||||||
|
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
||||||
|
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
||||||
|
| `alembic` | ≥ 1.14.0 | Database migration management |
|
||||||
|
| `bcrypt` | ≥ 4.2.0 | Password hashing |
|
||||||
|
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
|
||||||
|
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
||||||
|
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
||||||
|
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
||||||
|
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
|
||||||
|
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
|
||||||
|
| `pytest` | ≥ 8.0.0 | Test framework |
|
||||||
|
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
||||||
|
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
||||||
|
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
|
||||||
|
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.12+
|
||||||
|
- PostgreSQL 16+
|
||||||
|
- An OpenAI API key (for LLM features)
|
||||||
|
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
||||||
|
- AWS credentials (optional — needed for S3 storage in production)
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone the repository
|
||||||
|
git clone <repo-url> && cd adiuva-api
|
||||||
|
|
||||||
|
# Create a virtual environment
|
||||||
|
python -m venv .venv && source .venv/bin/activate
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Configure environment
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start PostgreSQL (or use the Docker Compose database)
|
||||||
|
docker compose up db -d
|
||||||
|
|
||||||
|
# Run migrations
|
||||||
|
alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run the Development Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Docker Deployment
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts two services:
|
||||||
|
|
||||||
|
- **app** — FastAPI server on port `8000`
|
||||||
|
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
||||||
|
|
||||||
|
The compose file also includes optional services for fully local deployments:
|
||||||
|
|
||||||
|
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
|
||||||
|
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
|
||||||
|
|
||||||
|
### Dockerfile Details
|
||||||
|
|
||||||
|
The Dockerfile uses a multi-stage build:
|
||||||
|
|
||||||
|
1. **Builder stage** — Installs Python dependencies into a virtual environment.
|
||||||
|
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
|
||||||
|
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Production command (run by the container)
|
||||||
|
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Homelab / Self-Hosted Deployment
|
||||||
|
|
||||||
|
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
|
||||||
|
|
||||||
|
### 1. Start all services
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
|
||||||
|
|
||||||
|
### 2. Create the MinIO bucket
|
||||||
|
|
||||||
|
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
|
||||||
|
docker compose exec minio mc mb local/adiuva
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Configure your `.env`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Database (uses the compose PostgreSQL)
|
||||||
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
|
||||||
|
# S3 → MinIO
|
||||||
|
S3_BUCKET=adiuva
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
AWS_ACCESS_KEY_ID=minioadmin
|
||||||
|
AWS_SECRET_ACCESS_KEY=minioadmin
|
||||||
|
|
||||||
|
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
|
||||||
|
QDRANT_URL=http://qdrant:6333
|
||||||
|
QDRANT_API_KEY=
|
||||||
|
PINECONE_API_KEY=
|
||||||
|
|
||||||
|
# Billing — leave empty to stub (no Stripe needed)
|
||||||
|
STRIPE_SECRET_KEY=
|
||||||
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
|
# LLM — the only external service
|
||||||
|
OPENAI_API_KEY=sk-...
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# Auth
|
||||||
|
JWT_SECRET=your-secret-here
|
||||||
|
ENV=dev
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Run migrations
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose exec app alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### What runs where
|
||||||
|
|
||||||
|
| Service | Runs on | Port | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| FastAPI app | Docker | 8000 | API server |
|
||||||
|
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
|
||||||
|
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
|
||||||
|
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
|
||||||
|
| Stripe | — | — | Stubbed when keys are empty |
|
||||||
|
| OpenAI / LLM | Cloud | — | Only external dependency |
|
||||||
|
|
||||||
|
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
|
||||||
|
|
||||||
|
| Variable | Type | Default | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
|
||||||
|
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
|
||||||
|
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
|
||||||
|
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
||||||
|
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
||||||
|
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
||||||
|
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
|
||||||
|
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
|
||||||
|
| `S3_REGION` | `str` | `us-east-1` | AWS region |
|
||||||
|
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
|
||||||
|
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
|
||||||
|
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
|
||||||
|
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
|
||||||
|
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
|
||||||
|
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
|
||||||
|
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
|
||||||
|
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
|
||||||
|
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
||||||
|
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
||||||
|
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
||||||
|
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
|
||||||
|
|
||||||
|
### Health
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
|
||||||
|
|
||||||
|
### Auth
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
|
||||||
|
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
|
||||||
|
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
|
||||||
|
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
|
||||||
|
|
||||||
|
### Chat
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
||||||
|
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
||||||
|
|
||||||
|
### Plans
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
||||||
|
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
||||||
|
|
||||||
|
### Storage (Cloud Records)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
|
||||||
|
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
|
||||||
|
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
|
||||||
|
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
|
||||||
|
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
|
||||||
|
|
||||||
|
### Vectors (Cloud Vector Store)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
|
||||||
|
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
|
||||||
|
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
|
||||||
|
|
||||||
|
### Backup
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
|
||||||
|
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
|
||||||
|
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
|
||||||
|
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
|
||||||
|
|
||||||
|
### Plugins (Marketplace)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
|
||||||
|
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
|
||||||
|
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
|
||||||
|
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
|
||||||
|
|
||||||
|
### Billing
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
|
||||||
|
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
|
||||||
|
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
|
||||||
|
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Model
|
||||||
|
|
||||||
|
9 tables managed by Alembic migrations. Source: `app/models.py`
|
||||||
|
|
||||||
|
### Tables
|
||||||
|
|
||||||
|
| Table | Primary Key | Key Columns | Purpose |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
||||||
|
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
||||||
|
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
||||||
|
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
|
||||||
|
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
|
||||||
|
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
|
||||||
|
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
|
||||||
|
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
|
||||||
|
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
|
||||||
|
|
||||||
|
### Enum Types
|
||||||
|
|
||||||
|
| Enum | Values |
|
||||||
|
|---|---|
|
||||||
|
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
||||||
|
| `plugin_status` | `pending_review`, `approved`, `rejected` |
|
||||||
|
| `review_decision` | `approved`, `rejected` |
|
||||||
|
|
||||||
|
### Migrations
|
||||||
|
|
||||||
|
| Version | Description |
|
||||||
|
|---|---|
|
||||||
|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
|
||||||
|
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## AI Agent System
|
||||||
|
|
||||||
|
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
|
||||||
|
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
||||||
|
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
||||||
|
|
||||||
|
### Registered Agents
|
||||||
|
|
||||||
|
| Agent | Registry Name | Tools | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
||||||
|
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
||||||
|
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
|
||||||
|
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
||||||
|
|
||||||
|
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
||||||
|
|
||||||
|
### Switching LLM Providers
|
||||||
|
|
||||||
|
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# OpenAI (default)
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# Anthropic
|
||||||
|
LLM_MODEL=anthropic/claude-3.5-sonnet
|
||||||
|
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
|
||||||
|
|
||||||
|
# Google Gemini
|
||||||
|
LLM_MODEL=gemini/gemini-pro
|
||||||
|
LLM_ROUTER_MODEL=gemini/gemini-flash
|
||||||
|
|
||||||
|
# Local Ollama
|
||||||
|
LLM_MODEL=ollama/llama3
|
||||||
|
LLM_ROUTER_MODEL=ollama/llama3
|
||||||
|
|
||||||
|
# AWS Bedrock
|
||||||
|
LLM_MODEL=bedrock/anthropic.claude-v2
|
||||||
|
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Orchestration & Execution Plans
|
||||||
|
|
||||||
|
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
|
||||||
|
|
||||||
|
### Orchestrator
|
||||||
|
|
||||||
|
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
|
||||||
|
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
|
||||||
|
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
|
||||||
|
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
|
||||||
|
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
|
||||||
|
|
||||||
|
### Execution Plans
|
||||||
|
|
||||||
|
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
|
||||||
|
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
|
||||||
|
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
|
||||||
|
|
||||||
|
### Built-in Templates (6)
|
||||||
|
|
||||||
|
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
||||||
|
|
||||||
|
### Built-in Playbooks (2)
|
||||||
|
|
||||||
|
| Playbook | Description |
|
||||||
|
|---|---|
|
||||||
|
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
|
||||||
|
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Middleware
|
||||||
|
|
||||||
|
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
|
||||||
|
|
||||||
|
### JWT Authentication
|
||||||
|
|
||||||
|
Source: `app/api/middleware/auth.py`
|
||||||
|
|
||||||
|
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
|
||||||
|
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
|
||||||
|
- Falls back to `free` when no subscription row exists.
|
||||||
|
- Raises `401 Unauthorized` on invalid or expired tokens.
|
||||||
|
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||||
|
|
||||||
|
### Tier-Based Rate Limiter
|
||||||
|
|
||||||
|
Source: `app/api/middleware/rate_limit.py`
|
||||||
|
|
||||||
|
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
|
||||||
|
- Per-user 60-second window sized by subscription tier:
|
||||||
|
|
||||||
|
| Tier | Requests / Minute |
|
||||||
|
|---|---|
|
||||||
|
| Free | 20 |
|
||||||
|
| Pro | 60 |
|
||||||
|
| Power | 120 |
|
||||||
|
| Team | 200 |
|
||||||
|
|
||||||
|
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
|
||||||
|
- **Exempt paths:** register, login, webhook, health
|
||||||
|
|
||||||
|
### Response Sanitizer
|
||||||
|
|
||||||
|
Source: `app/api/middleware/sanitizer.py`
|
||||||
|
|
||||||
|
- Runs only on `/api/v1/chat` endpoints.
|
||||||
|
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
||||||
|
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
||||||
|
- Logs sanitization events as `WARNING`.
|
||||||
|
- Binary responses (storage, backup) are never touched.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Storage Layer
|
||||||
|
|
||||||
|
### Blob Store
|
||||||
|
|
||||||
|
Source: `app/storage/blob_store.py`
|
||||||
|
|
||||||
|
- S3-backed storage for E2E encrypted blobs.
|
||||||
|
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
|
||||||
|
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
|
||||||
|
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
|
||||||
|
- The backend **never inspects or decrypts blob content**.
|
||||||
|
|
||||||
|
### Vector Store
|
||||||
|
|
||||||
|
Source: `app/storage/vector_store.py`
|
||||||
|
|
||||||
|
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
|
||||||
|
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
|
||||||
|
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
|
||||||
|
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
|
||||||
|
- Methods: `upsert()`, `search()`, `delete()`
|
||||||
|
|
||||||
|
### Encryption Utilities
|
||||||
|
|
||||||
|
Source: `app/storage/encryption.py`
|
||||||
|
|
||||||
|
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
|
||||||
|
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
|
||||||
|
- **No decryption key ever reaches the backend.**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Billing & Tiers
|
||||||
|
|
||||||
|
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
||||||
|
|
||||||
|
### Feature Matrix
|
||||||
|
|
||||||
|
| Feature | Free | Pro | Power | Team |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
||||||
|
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
||||||
|
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
|
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
|
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
||||||
|
| Batch Builder | — | — | ✓ | ✓ |
|
||||||
|
| Plugin Marketplace | — | — | ✓ | ✓ |
|
||||||
|
| SSO | — | — | — | ✓ |
|
||||||
|
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
||||||
|
|
||||||
|
### Stripe Integration
|
||||||
|
|
||||||
|
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
|
||||||
|
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
|
||||||
|
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
|
||||||
|
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
|
||||||
|
|
||||||
|
### Tier Manager
|
||||||
|
|
||||||
|
- `get_tier(user_id)` — Returns the user's current billing tier.
|
||||||
|
- `check_feature(tier, feature)` — Boolean feature gate check.
|
||||||
|
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
||||||
|
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Plugin Marketplace
|
||||||
|
|
||||||
|
Source: `app/marketplace/`
|
||||||
|
|
||||||
|
### Plugin Registry
|
||||||
|
|
||||||
|
- PostgreSQL-backed catalog of submitted and approved plugins.
|
||||||
|
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
|
||||||
|
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
|
||||||
|
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
|
||||||
|
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
|
||||||
|
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
|
||||||
|
|
||||||
|
### Review Queue
|
||||||
|
|
||||||
|
- Automated security checklist before human review:
|
||||||
|
- Plugin ID must match `^[a-z0-9-]+$`
|
||||||
|
- Permissions must be from the allowed set only
|
||||||
|
- No binary blobs in the manifest
|
||||||
|
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
|
||||||
|
- `get_pending(db)` — Lists plugins awaiting review.
|
||||||
|
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
||||||
|
|
||||||
|
### Revenue Sharing
|
||||||
|
|
||||||
|
- **70% developer / 30% platform** split on all paid plugin sales.
|
||||||
|
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
|
||||||
|
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
|
||||||
|
- Gracefully stubs transfers when Stripe is not configured.
|
||||||
|
|
||||||
|
### Seed Plugins
|
||||||
|
|
||||||
|
| Plugin | Category | Price |
|
||||||
|
|---|---|---|
|
||||||
|
| GitHub Sync | Productivity | Free |
|
||||||
|
| Slack Notifier | Communication | €4.99 |
|
||||||
|
| Time Tracker | Productivity | €9.99 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
pytest
|
||||||
|
|
||||||
|
# Run a specific test file
|
||||||
|
pytest tests/test_auth.py
|
||||||
|
|
||||||
|
# Run with verbose output
|
||||||
|
pytest -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Infrastructure
|
||||||
|
|
||||||
|
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
||||||
|
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
|
||||||
|
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
||||||
|
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
||||||
|
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
|
||||||
|
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
||||||
|
- **No external dependencies** — all tests run fully offline.
|
||||||
|
|
||||||
|
### Test Coverage
|
||||||
|
|
||||||
|
| File | Coverage |
|
||||||
|
|---|---|
|
||||||
|
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
||||||
|
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
|
||||||
|
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
|
||||||
|
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
|
||||||
|
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
|
||||||
|
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
|
||||||
|
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
|
||||||
|
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
|
||||||
|
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
adiuva-api/
|
||||||
|
├── alembic.ini # Alembic configuration
|
||||||
|
├── BACKEND_PLAN.md # Architecture & design decisions
|
||||||
|
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
||||||
|
├── Dockerfile # Multi-stage production build
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
│
|
||||||
|
├── alembic/ # Database migrations
|
||||||
|
│ ├── env.py # Alembic environment config
|
||||||
|
│ ├── script.py.mako # Migration template
|
||||||
|
│ └── versions/
|
||||||
|
│ ├── 001_initial_schema.py # Tables, indexes, FKs
|
||||||
|
│ └── 002_seed_plugins.py # Seed marketplace plugins
|
||||||
|
│
|
||||||
|
├── app/ # Application source
|
||||||
|
│ ├── main.py # FastAPI app factory, middleware, routes
|
||||||
|
│ ├── db.py # Async SQLAlchemy engine & session
|
||||||
|
│ ├── models.py # SQLAlchemy ORM models (9 tables)
|
||||||
|
│ ├── schemas.py # Pydantic request/response schemas
|
||||||
|
│ │
|
||||||
|
│ ├── config/
|
||||||
|
│ │ └── settings.py # Pydantic Settings (env vars)
|
||||||
|
│ │
|
||||||
|
│ ├── agents/ # LLM-powered domain agents
|
||||||
|
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
||||||
|
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
||||||
|
│ │ ├── timeline_agent.py # Milestones (4 tools)
|
||||||
|
│ │ └── note_agent.py # Markdown notes (5 tools)
|
||||||
|
│ │
|
||||||
|
│ ├── core/ # Orchestration engine
|
||||||
|
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
||||||
|
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
||||||
|
│ │ ├── orchestrator.py # Intent classification & routing
|
||||||
|
│ │ └── execution_plan.py # Plan builder, templates, cache
|
||||||
|
│ │
|
||||||
|
│ ├── api/ # HTTP layer
|
||||||
|
│ │ ├── deps.py # Shared FastAPI dependencies
|
||||||
|
│ │ ├── middleware/
|
||||||
|
│ │ │ ├── auth.py # JWT validation, live tier lookup
|
||||||
|
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
||||||
|
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
||||||
|
│ │ └── routes/
|
||||||
|
│ │ ├── auth.py # Register, login, refresh, me
|
||||||
|
│ │ ├── chat.py # Chat + WebSocket streaming
|
||||||
|
│ │ ├── plans.py # Execution plan playbooks
|
||||||
|
│ │ ├── storage.py # E2E encrypted record CRUD
|
||||||
|
│ │ ├── vectors.py # Vector upsert, search, delete
|
||||||
|
│ │ ├── backup.py # Encrypted backup management
|
||||||
|
│ │ ├── plugins.py # Marketplace browse & install
|
||||||
|
│ │ └── billing.py # Stripe checkout & webhooks
|
||||||
|
│ │
|
||||||
|
│ ├── storage/ # Storage backends
|
||||||
|
│ │ ├── blob_store.py # S3 blob storage
|
||||||
|
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
|
||||||
|
│ │ └── encryption.py # Checksum verification utilities
|
||||||
|
│ │
|
||||||
|
│ ├── billing/ # Subscription management
|
||||||
|
│ │ ├── stripe_service.py # Stripe API integration
|
||||||
|
│ │ └── tier_manager.py # Feature matrix & quota enforcement
|
||||||
|
│ │
|
||||||
|
│ └── marketplace/ # Plugin ecosystem
|
||||||
|
│ ├── plugin_registry.py # Catalog CRUD & search
|
||||||
|
│ ├── plugin_review.py # Security checklist & review queue
|
||||||
|
│ └── revenue_share.py # 70/30 split & Stripe Connect
|
||||||
|
│
|
||||||
|
└── tests/ # Test suite
|
||||||
|
├── conftest.py # Fixtures: DB, S3, auth, seeds
|
||||||
|
├── test_auth.py
|
||||||
|
├── test_orchestrator.py
|
||||||
|
├── test_agents.py
|
||||||
|
├── test_storage.py
|
||||||
|
├── test_backup.py
|
||||||
|
├── test_plugins.py
|
||||||
|
├── test_agent_registry.py
|
||||||
|
├── test_execution_plan.py
|
||||||
|
└── test_middleware.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
*To be determined.*
|
||||||
47
alembic.ini
Normal file
47
alembic.ini
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Alembic configuration file.
|
||||||
|
# The async app uses postgresql+asyncpg:// at runtime.
|
||||||
|
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
prepend_sys_path = .
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
93
alembic/env.py
Normal file
93
alembic/env.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Alembic migration environment — async-compatible.
|
||||||
|
|
||||||
|
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
|
||||||
|
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
|
||||||
|
env var by replacing the driver prefix.
|
||||||
|
|
||||||
|
Run migrations with:
|
||||||
|
alembic upgrade head
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
|
# Alembic Config object (gives access to alembic.ini values).
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Set up Python logging from alembic.ini.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# Import the Base so that Alembic can detect model changes for --autogenerate.
|
||||||
|
from app.models import Base # noqa: E402
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_url(async_url: str) -> str:
|
||||||
|
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
|
||||||
|
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_url() -> str:
|
||||||
|
db_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not db_url:
|
||||||
|
# Fall back to settings if env var not set directly.
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
db_url = settings.DATABASE_URL
|
||||||
|
return _sync_url(db_url)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Emit SQL without a live DB connection."""
|
||||||
|
url = _get_url()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection): # type: ignore[no-untyped-def]
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_migrations_online_async() -> None:
|
||||||
|
"""Run migrations against a live DB using the async engine."""
|
||||||
|
async_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not async_url:
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
async_url = settings.DATABASE_URL
|
||||||
|
|
||||||
|
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
asyncio.run(run_migrations_online_async())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
209
alembic/versions/001_initial_schema.py
Normal file
209
alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||||
|
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||||
|
|
||||||
|
Revision ID: 001
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-02
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "001"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enum types — idempotent creation via exception handling ───────────
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE billing_tier AS ENUM ('free', 'pro', 'power', 'team');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
|
# ── users ─────────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("email", sa.String(255), nullable=False),
|
||||||
|
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||||
|
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||||
|
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("email"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_users_email", "users", ["email"])
|
||||||
|
|
||||||
|
# ── refresh_tokens ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"refresh_tokens",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("token_hash", sa.String(64), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("token_hash"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
|
||||||
|
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
|
||||||
|
|
||||||
|
# ── subscriptions ─────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"subscriptions",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||||
|
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
|
||||||
|
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("user_id"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||||
|
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||||
|
|
||||||
|
# ── storage_records ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"storage_records",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("table_name", sa.String(100), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||||
|
|
||||||
|
# ── backup_metadata ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"backup_metadata",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer, nullable=False),
|
||||||
|
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugins ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugins",
|
||||||
|
sa.Column("id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||||
|
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||||
|
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||||
|
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
||||||
|
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||||
|
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||||
|
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||||
|
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── plugin_installations ──────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_installations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||||
|
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_reviews",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
||||||
|
sa.Column("notes", sa.Text, nullable=True),
|
||||||
|
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||||
|
|
||||||
|
# ── revenue_events ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"revenue_events",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||||
|
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("revenue_events")
|
||||||
|
op.drop_table("plugin_reviews")
|
||||||
|
op.drop_table("plugin_installations")
|
||||||
|
op.drop_table("plugins")
|
||||||
|
op.drop_table("backup_metadata")
|
||||||
|
op.drop_table("storage_records")
|
||||||
|
op.drop_table("subscriptions")
|
||||||
|
op.drop_table("refresh_tokens")
|
||||||
|
op.drop_table("users")
|
||||||
|
|
||||||
|
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||||
|
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||||
92
alembic/versions/002_seed_plugins.py
Normal file
92
alembic/versions/002_seed_plugins.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2026-03-03
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "002"
|
||||||
|
down_revision: Union[str, None] = "001"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
{
|
||||||
|
"id": "plugin-github-sync",
|
||||||
|
"name": "GitHub Sync",
|
||||||
|
"description": "Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 0,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-slack-notify",
|
||||||
|
"name": "Slack Notifier",
|
||||||
|
"description": "Post task and timeline updates to Slack channels.",
|
||||||
|
"version": "1.2.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "communication",
|
||||||
|
"price_cents": 499,
|
||||||
|
"permissions": json.dumps(["read:tasks", "read:timelines"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-time-tracker",
|
||||||
|
"name": "Time Tracker",
|
||||||
|
"description": "Track time spent on tasks with automatic reporting.",
|
||||||
|
"version": "0.9.1",
|
||||||
|
"author_name": "Third Party",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 999,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
plugins = sa.table(
|
||||||
|
"plugins",
|
||||||
|
sa.column("id", sa.String),
|
||||||
|
sa.column("name", sa.String),
|
||||||
|
sa.column("description", sa.Text),
|
||||||
|
sa.column("version", sa.String),
|
||||||
|
sa.column("author_name", sa.String),
|
||||||
|
sa.column("category", sa.String),
|
||||||
|
sa.column("price_cents", sa.Integer),
|
||||||
|
sa.column("permissions", sa.Text),
|
||||||
|
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
|
||||||
|
sa.column("s3_package_key", sa.String),
|
||||||
|
sa.column("install_count", sa.Integer),
|
||||||
|
sa.column("avg_rating", sa.Float),
|
||||||
|
)
|
||||||
|
op.bulk_insert(plugins, _SEED_PLUGINS)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"DELETE FROM plugins WHERE id IN ("
|
||||||
|
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
|
||||||
|
")"
|
||||||
|
)
|
||||||
127
alembic/versions/003_agent_tables.py
Normal file
127
alembic/versions/003_agent_tables.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Add agent config and run log tables: local_agent_configs, cloud_agent_configs, agent_run_logs.
|
||||||
|
|
||||||
|
Revision ID: 003
|
||||||
|
Revises: 002
|
||||||
|
Create Date: 2026-03-05
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "003"
|
||||||
|
down_revision: Union[str, None] = "002"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enum types — idempotent creation ──────────────────────────────────
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
|
# ── local_agent_configs ───────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
# ── cloud_agent_configs ───────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
# ── agent_run_logs ─────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"agent_run_logs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
# Plain string — not a FK because it references either local_agent_configs or
|
||||||
|
# cloud_agent_configs depending on agent_type.
|
||||||
|
sa.Column("agent_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"agent_type",
|
||||||
|
postgresql.ENUM("local", "cloud", name="agent_type", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
postgresql.ENUM("running", "success", "error", "partial", name="agent_run_status", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
server_default="running",
|
||||||
|
),
|
||||||
|
sa.Column("items_processed", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("items_created", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("errors", sa.JSON, nullable=True),
|
||||||
|
sa.Column("started_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_agent_run_logs_user_id", "agent_run_logs", ["user_id"])
|
||||||
|
op.create_index("ix_agent_run_logs_agent_id", "agent_run_logs", ["agent_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("agent_run_logs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
|
|
||||||
|
op.execute("DROP TYPE IF EXISTS cloud_provider;")
|
||||||
|
op.execute("DROP TYPE IF EXISTS agent_run_status;")
|
||||||
|
op.execute("DROP TYPE IF EXISTS agent_type;")
|
||||||
144
alembic/versions/004_add_memory_tables.py
Normal file
144
alembic/versions/004_add_memory_tables.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""Add memory tables and user encryption_key column.
|
||||||
|
|
||||||
|
Memory tables:
|
||||||
|
memory_core — per-user key/value preferences (encrypted)
|
||||||
|
memory_associative — semantic memory with pgvector embedding (encrypted)
|
||||||
|
memory_episodic — session summaries (encrypted)
|
||||||
|
memory_proactive — behavioral patterns (encrypted)
|
||||||
|
|
||||||
|
Also adds encryption_key column to users table.
|
||||||
|
|
||||||
|
Revision ID: 004
|
||||||
|
Revises: 003
|
||||||
|
Create Date: 2026-03-08
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "004"
|
||||||
|
down_revision: Union[str, None] = "003"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enable pgvector extension (idempotent) ────────────────────────────────
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
|
||||||
|
# ── Add encryption_key to users ───────────────────────────────────────────
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("encryption_key", sa.String(64), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── memory_core ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_core",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("key", sa.String(255), nullable=False),
|
||||||
|
sa.Column("value_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_core_user_id", "memory_core", ["user_id"])
|
||||||
|
|
||||||
|
# ── memory_associative ────────────────────────────────────────────────────
|
||||||
|
# The embedding column uses pgvector's vector(1536) type.
|
||||||
|
op.create_table(
|
||||||
|
"memory_associative",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("content_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("entity_type", sa.String(100), nullable=True),
|
||||||
|
sa.Column("entity_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Add the pgvector column separately (not supported by generic sa types)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE memory_associative ADD COLUMN embedding vector(1536);"
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_associative_user_id", "memory_associative", ["user_id"])
|
||||||
|
# IVFFlat index for approximate nearest-neighbour search
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX ix_memory_associative_embedding "
|
||||||
|
"ON memory_associative USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── memory_episodic ───────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_episodic",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("summary_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("session_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_episodic_user_id", "memory_episodic", ["user_id"])
|
||||||
|
op.create_index("ix_memory_episodic_session_id", "memory_episodic", ["session_id"])
|
||||||
|
|
||||||
|
# ── memory_proactive ──────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_proactive",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("pattern_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("confidence", sa.Float, nullable=False, server_default="0.5"),
|
||||||
|
sa.Column("source", sa.String(50), nullable=False, server_default="inferred"),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_proactive_user_id", "memory_proactive", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("memory_proactive")
|
||||||
|
op.drop_table("memory_episodic")
|
||||||
|
op.drop_index("ix_memory_associative_embedding", "memory_associative")
|
||||||
|
op.drop_table("memory_associative")
|
||||||
|
op.drop_table("memory_core")
|
||||||
|
op.drop_column("users", "encryption_key")
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
"""add name and surname to users table
|
||||||
|
|
||||||
|
Revision ID: 818478c251dc
|
||||||
|
Revises: 004
|
||||||
|
Create Date: 2026-03-10 15:10:42.811947
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '818478c251dc'
|
||||||
|
down_revision: Union[str, None] = '004'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column('users', sa.Column('name', sa.String(length=100), nullable=True))
|
||||||
|
op.add_column('users', sa.Column('surname', sa.String(length=100), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('users', 'surname')
|
||||||
|
op.drop_column('users', 'name')
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Import all agent modules to trigger @registry.register decorators."""
|
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
|
||||||
|
|
||||||
from app.agents import checkpoint_agent, note_agent, project_agent, task_agent
|
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
@@ -1,122 +0,0 @@
|
|||||||
"""Checkpoint agent — project milestone management (list, create, update, delete)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
|
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n"
|
|
||||||
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
|
||||||
" - For update_checkpoint, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all checkpoints across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_checkpoints(project_id: str = "") -> str:
|
|
||||||
"""List checkpoints. Provide project_id to scope to a specific project."""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "list",
|
|
||||||
"table": "checkpoints",
|
|
||||||
"filters": {"projectId": project_id or None},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def create_checkpoint(
|
|
||||||
project_id: str,
|
|
||||||
title: str,
|
|
||||||
date: int,
|
|
||||||
is_ai_suggested: int = 0,
|
|
||||||
is_approved: int = 0,
|
|
||||||
) -> str:
|
|
||||||
"""Create a project checkpoint (milestone).
|
|
||||||
project_id: REQUIRED UUID of the parent project
|
|
||||||
title: descriptive name for the milestone
|
|
||||||
date: Unix timestamp in milliseconds
|
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
|
||||||
is_approved: 0 until the user confirms
|
|
||||||
"""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "create_record",
|
|
||||||
"table": "checkpoints",
|
|
||||||
"data": {
|
|
||||||
"projectId": project_id,
|
|
||||||
"title": title,
|
|
||||||
"date": date,
|
|
||||||
"isAiSuggested": is_ai_suggested,
|
|
||||||
"isApproved": is_approved,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def update_checkpoint(
|
|
||||||
checkpoint_id: str,
|
|
||||||
title: str = "",
|
|
||||||
date: int = -1,
|
|
||||||
is_approved: int = -1,
|
|
||||||
) -> str:
|
|
||||||
"""Update a checkpoint. Only pass fields that should change.
|
|
||||||
checkpoint_id: UUID of the checkpoint (required)
|
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
|
||||||
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
|
||||||
"""
|
|
||||||
updates: dict[str, Any] = {}
|
|
||||||
if title:
|
|
||||||
updates["title"] = title
|
|
||||||
if date != -1:
|
|
||||||
updates["date"] = date
|
|
||||||
if is_approved != -1:
|
|
||||||
updates["isApproved"] = is_approved
|
|
||||||
return json.dumps({
|
|
||||||
"action": "update_record",
|
|
||||||
"table": "checkpoints",
|
|
||||||
"data": {"id": checkpoint_id, "updates": updates},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def delete_checkpoint(checkpoint_id: str) -> str:
|
|
||||||
"""Delete a checkpoint permanently by its UUID."""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "delete_record",
|
|
||||||
"table": "checkpoints",
|
|
||||||
"data": {"id": checkpoint_id},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class CheckpointAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "checkpoint_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages project checkpoints (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
@@ -1,50 +1,38 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
"""Note agent — tool definitions for Markdown note CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.core.llm import embed
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "list",
|
action="select",
|
||||||
"table": "notes",
|
table="notes",
|
||||||
"filters": {"projectId": project_id or None},
|
filters={"projectId": project_id or None},
|
||||||
})
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No notes found."
|
||||||
|
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def get_note(note_id: str) -> str:
|
async def get_note(note_id: str) -> str:
|
||||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||||
return json.dumps({
|
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||||
"action": "get",
|
row = result.get("row")
|
||||||
"table": "notes",
|
if not row:
|
||||||
"data": {"id": note_id},
|
return f"Note {note_id} not found."
|
||||||
})
|
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -58,15 +46,24 @@ async def create_note(
|
|||||||
content: Markdown body text (required)
|
content: Markdown body text (required)
|
||||||
project_id: optional UUID linking this note to a project
|
project_id: optional UUID linking this note to a project
|
||||||
"""
|
"""
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "create_record",
|
action="insert",
|
||||||
"table": "notes",
|
table="notes",
|
||||||
"data": {
|
data={
|
||||||
"title": title,
|
"title": title,
|
||||||
"content": content,
|
"content": content,
|
||||||
"projectId": project_id or None,
|
"projectId": project_id or None,
|
||||||
},
|
},
|
||||||
})
|
)
|
||||||
|
row = result["row"]
|
||||||
|
# Index the note content in the vector store.
|
||||||
|
vector = await embed(content)
|
||||||
|
await execute_on_client(
|
||||||
|
action="vector_upsert",
|
||||||
|
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -84,40 +81,28 @@ async def update_note(
|
|||||||
updates["title"] = title
|
updates["title"] = title
|
||||||
if content:
|
if content:
|
||||||
updates["content"] = content
|
updates["content"] = content
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "update_record",
|
action="update",
|
||||||
"table": "notes",
|
table="notes",
|
||||||
"data": {"id": note_id, "updates": updates},
|
data={"id": note_id, "updates": updates},
|
||||||
})
|
)
|
||||||
|
row = result["row"]
|
||||||
|
# Re-index if content changed.
|
||||||
|
if content:
|
||||||
|
vector = await embed(content)
|
||||||
|
await execute_on_client(
|
||||||
|
action="vector_upsert",
|
||||||
|
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_note(note_id: str) -> str:
|
async def delete_note(note_id: str) -> str:
|
||||||
"""Delete a note permanently by its UUID."""
|
"""Delete a note permanently by its UUID."""
|
||||||
return json.dumps({
|
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||||
"action": "delete_record",
|
return f"Note {note_id} deleted."
|
||||||
"table": "notes",
|
|
||||||
"data": {"id": note_id},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class NoteAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "note_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages notes: list, get, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_notes, get_note, create_note, update_note, delete_note]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -1,32 +1,12 @@
|
|||||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
"""Project agent — tool definitions for project lifecycle CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.core.ws_context import execute_on_client
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
|
||||||
"update, and archive projects in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: active, archived\n"
|
|
||||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
|
||||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
|
||||||
" derive it from context data — do not fabricate content\n"
|
|
||||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
|
||||||
" user wants a complete cross-client view including archived projects\n"
|
|
||||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
|
||||||
" list_projects if you only have a project name\n"
|
|
||||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
|
||||||
" only call delete_project when the user explicitly confirms deletion."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -37,14 +17,19 @@ async def list_projects(
|
|||||||
"""List projects, optionally filtered by client_id.
|
"""List projects, optionally filtered by client_id.
|
||||||
include_archived: 1 to include archived projects, 0 for active only (default).
|
include_archived: 1 to include archived projects, 0 for active only (default).
|
||||||
"""
|
"""
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "list",
|
action="select",
|
||||||
"table": "projects",
|
table="projects",
|
||||||
"filters": {
|
filters={
|
||||||
"clientId": client_id or None,
|
"clientId": client_id or None,
|
||||||
"includeArchived": bool(include_archived),
|
"includeArchived": bool(include_archived),
|
||||||
},
|
},
|
||||||
})
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No projects found."
|
||||||
|
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -52,20 +37,25 @@ async def list_all_projects() -> str:
|
|||||||
"""List every project regardless of client or status.
|
"""List every project regardless of client or status.
|
||||||
Use only when the user wants a complete cross-client overview.
|
Use only when the user wants a complete cross-client overview.
|
||||||
"""
|
"""
|
||||||
return json.dumps({
|
result = await execute_on_client(action="select", table="projects")
|
||||||
"action": "list_all",
|
rows = result.get("rows", [])
|
||||||
"table": "projects",
|
if not rows:
|
||||||
})
|
return "No projects found."
|
||||||
|
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def get_project(project_id: str) -> str:
|
async def get_project(project_id: str) -> str:
|
||||||
"""Fetch a single project by its UUID."""
|
"""Fetch a single project by its UUID."""
|
||||||
return json.dumps({
|
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
||||||
"action": "get",
|
row = result.get("row")
|
||||||
"table": "projects",
|
if not row:
|
||||||
"data": {"id": project_id},
|
return f"Project {project_id} not found."
|
||||||
})
|
return (
|
||||||
|
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
||||||
|
f"clientId: {row.get('clientId', 'none')})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -77,14 +67,13 @@ async def create_project(
|
|||||||
name: human-readable project name (required)
|
name: human-readable project name (required)
|
||||||
client_id: optional UUID of the owning client
|
client_id: optional UUID of the owning client
|
||||||
"""
|
"""
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "create_record",
|
action="insert",
|
||||||
"table": "projects",
|
table="projects",
|
||||||
"data": {
|
data={"name": name, "clientId": client_id or None},
|
||||||
"name": name,
|
)
|
||||||
"clientId": client_id or None,
|
row = result["row"]
|
||||||
},
|
return f"Project created: '{row['name']}' (id: {row['id']})"
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -109,11 +98,13 @@ async def update_project(
|
|||||||
updates["status"] = status
|
updates["status"] = status
|
||||||
if ai_summary:
|
if ai_summary:
|
||||||
updates["aiSummary"] = ai_summary
|
updates["aiSummary"] = ai_summary
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "update_record",
|
action="update",
|
||||||
"table": "projects",
|
table="projects",
|
||||||
"data": {"id": project_id, "updates": updates},
|
data={"id": project_id, "updates": updates},
|
||||||
})
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -122,37 +113,8 @@ async def delete_project(project_id: str) -> str:
|
|||||||
IMPORTANT: prefer update_project(status='archived') unless the user
|
IMPORTANT: prefer update_project(status='archived') unless the user
|
||||||
has explicitly confirmed they want permanent deletion.
|
has explicitly confirmed they want permanent deletion.
|
||||||
"""
|
"""
|
||||||
return json.dumps({
|
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
||||||
"action": "delete_record",
|
return f"Project {project_id} permanently deleted."
|
||||||
"table": "projects",
|
|
||||||
"data": {"id": project_id},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class ProjectAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "project_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [
|
|
||||||
list_projects,
|
|
||||||
list_all_projects,
|
|
||||||
get_project,
|
|
||||||
create_project,
|
|
||||||
update_project,
|
|
||||||
delete_project,
|
|
||||||
]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
@@ -1,33 +1,13 @@
|
|||||||
"""Task agent — full CRUD for tasks and task comments."""
|
"""Task agent — tool definitions for task and task comment CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.core.ws_context import execute_on_client
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a task management assistant for a project workspace.\n"
|
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: todo, in_progress, done\n"
|
|
||||||
" - priority must be one of: high, medium, low\n"
|
|
||||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
|
||||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
|
||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
|
||||||
" did not explicitly request; 0 otherwise\n"
|
|
||||||
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
@@ -42,16 +22,24 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "list",
|
action="select",
|
||||||
"table": "tasks",
|
table="tasks",
|
||||||
"filters": {
|
filters={
|
||||||
"projectId": project_id or None,
|
"projectId": project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
},
|
},
|
||||||
})
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No tasks found matching the given filters."
|
||||||
|
lines = [
|
||||||
|
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -77,10 +65,10 @@ async def create_task(
|
|||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
is_approved: 0 until the user confirms; 1 when confirmed
|
is_approved: 0 until the user confirms; 1 when confirmed
|
||||||
"""
|
"""
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "create_record",
|
action="insert",
|
||||||
"table": "tasks",
|
table="tasks",
|
||||||
"data": {
|
data={
|
||||||
"title": title,
|
"title": title,
|
||||||
"description": description or None,
|
"description": description or None,
|
||||||
"status": status,
|
"status": status,
|
||||||
@@ -91,7 +79,12 @@ async def create_task(
|
|||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
"isApproved": is_approved,
|
"isApproved": is_approved,
|
||||||
},
|
},
|
||||||
})
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return (
|
||||||
|
f"Task created: '{row['title']}' "
|
||||||
|
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -128,30 +121,41 @@ async def update_task(
|
|||||||
updates["projectId"] = project_id
|
updates["projectId"] = project_id
|
||||||
if is_approved != -1:
|
if is_approved != -1:
|
||||||
updates["isApproved"] = is_approved
|
updates["isApproved"] = is_approved
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "update_record",
|
action="update",
|
||||||
"table": "tasks",
|
table="tasks",
|
||||||
"data": {"id": task_id, "updates": updates},
|
data={"id": task_id, "updates": updates},
|
||||||
})
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_task(task_id: str) -> str:
|
async def delete_task(task_id: str) -> str:
|
||||||
"""Delete a task permanently by its UUID."""
|
"""Delete a task permanently by its UUID."""
|
||||||
return json.dumps({
|
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||||
"action": "delete_record",
|
return f"Task {task_id} deleted."
|
||||||
"table": "tasks",
|
|
||||||
"data": {"id": task_id},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_tasks_due_today() -> str:
|
async def list_tasks_due_today() -> str:
|
||||||
"""List all tasks whose due date falls on today's date."""
|
"""List all tasks whose due date falls on today's date."""
|
||||||
return json.dumps({
|
now = datetime.now(tz=timezone.utc)
|
||||||
"action": "list_due_today",
|
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||||
"table": "tasks",
|
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
||||||
})
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="tasks",
|
||||||
|
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No tasks are due today."
|
||||||
|
lines = [
|
||||||
|
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
# ── Task comment tools ────────────────────────────────────────────────
|
# ── Task comment tools ────────────────────────────────────────────────
|
||||||
@@ -160,11 +164,16 @@ async def list_tasks_due_today() -> str:
|
|||||||
@tool
|
@tool
|
||||||
async def list_task_comments(task_id: str) -> str:
|
async def list_task_comments(task_id: str) -> str:
|
||||||
"""List all comments on a task by its UUID."""
|
"""List all comments on a task by its UUID."""
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "list",
|
action="select",
|
||||||
"table": "taskComments",
|
table="taskComments",
|
||||||
"filters": {"taskId": task_id},
|
filters={"taskId": task_id},
|
||||||
})
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return f"No comments found for task {task_id}."
|
||||||
|
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -174,56 +183,20 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
author: name or ID of the comment author
|
author: name or ID of the comment author
|
||||||
content: comment text
|
content: comment text
|
||||||
"""
|
"""
|
||||||
return json.dumps({
|
result = await execute_on_client(
|
||||||
"action": "create_record",
|
action="insert",
|
||||||
"table": "taskComments",
|
table="taskComments",
|
||||||
"data": {
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
"taskId": task_id,
|
)
|
||||||
"author": author,
|
row = result["row"]
|
||||||
"content": content,
|
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_task_comment(comment_id: str) -> str:
|
async def delete_task_comment(comment_id: str) -> str:
|
||||||
"""Delete a task comment by its UUID."""
|
"""Delete a task comment by its UUID."""
|
||||||
return json.dumps({
|
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||||
"action": "delete_record",
|
return f"Comment {comment_id} deleted."
|
||||||
"table": "taskComments",
|
|
||||||
"data": {"id": comment_id},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class TaskAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "task_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [
|
|
||||||
list_tasks,
|
|
||||||
create_task,
|
|
||||||
update_task,
|
|
||||||
delete_task,
|
|
||||||
list_tasks_due_today,
|
|
||||||
list_task_comments,
|
|
||||||
add_task_comment,
|
|
||||||
delete_task_comment,
|
|
||||||
]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
|
|||||||
92
app/agents/timeline_agent.py
Normal file
92
app/agents/timeline_agent.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Timeline agent — tool definitions for project milestone CRUD."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table="timelines",
|
||||||
|
filters={"projectId": project_id or None},
|
||||||
|
)
|
||||||
|
rows = result.get("rows", [])
|
||||||
|
if not rows:
|
||||||
|
return "No timelines found."
|
||||||
|
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||||
|
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def create_timeline(
|
||||||
|
project_id: str,
|
||||||
|
title: str,
|
||||||
|
date: int,
|
||||||
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
|
) -> str:
|
||||||
|
"""Create a project timeline (milestone).
|
||||||
|
project_id: REQUIRED UUID of the parent project
|
||||||
|
title: descriptive name for the milestone
|
||||||
|
date: Unix timestamp in milliseconds
|
||||||
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="timelines",
|
||||||
|
data={
|
||||||
|
"projectId": project_id,
|
||||||
|
"title": title,
|
||||||
|
"date": date,
|
||||||
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_timeline(
|
||||||
|
timeline_id: str,
|
||||||
|
title: str = "",
|
||||||
|
date: int = -1,
|
||||||
|
is_approved: int = -1,
|
||||||
|
) -> str:
|
||||||
|
"""Update a timeline. Only pass fields that should change.
|
||||||
|
timeline_id: UUID of the timeline (required)
|
||||||
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
||||||
|
"""
|
||||||
|
updates: dict[str, Any] = {}
|
||||||
|
if title:
|
||||||
|
updates["title"] = title
|
||||||
|
if date != -1:
|
||||||
|
updates["date"] = date
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="update",
|
||||||
|
table="timelines",
|
||||||
|
data={"id": timeline_id, "updates": updates},
|
||||||
|
)
|
||||||
|
row = result["row"]
|
||||||
|
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def delete_timeline(timeline_id: str) -> str:
|
||||||
|
"""Delete a timeline permanently by its UUID."""
|
||||||
|
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||||
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Auth middleware — JWT validation dependency.
|
"""Auth middleware — JWT validation dependency.
|
||||||
|
|
||||||
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||||
It decodes the Bearer JWT, validates signature and expiry, and returns a
|
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||||
``UserProfile`` carrying ``id``, ``email``, and ``tier``.
|
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||||
|
without requiring token re-issue.
|
||||||
|
|
||||||
Exempt routes (no JWT required):
|
Exempt routes (no JWT required):
|
||||||
- POST /api/v1/auth/register
|
- POST /api/v1/auth/register
|
||||||
@@ -15,8 +16,11 @@ from __future__ import annotations
|
|||||||
from fastapi import Depends, HTTPException, status
|
from fastapi import Depends, HTTPException, status
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
from app.schemas import UserProfile
|
from app.schemas import UserProfile
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
@@ -24,12 +28,15 @@ oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
|||||||
|
|
||||||
async def get_current_user(
|
async def get_current_user(
|
||||||
token: str = Depends(oauth2_scheme),
|
token: str = Depends(oauth2_scheme),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> UserProfile:
|
) -> UserProfile:
|
||||||
"""Validate a Bearer JWT and return the authenticated user.
|
"""Validate a Bearer JWT and return the authenticated user.
|
||||||
|
|
||||||
|
The JWT is used for identity and expiry only. The tier is fetched live
|
||||||
|
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||||
|
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||||
|
|
||||||
Raises HTTP 401 on any invalid or expired token.
|
Raises HTTP 401 on any invalid or expired token.
|
||||||
The tier embedded in the JWT is used for feature-gating until Step 12
|
|
||||||
adds a live DB lookup.
|
|
||||||
"""
|
"""
|
||||||
credentials_exc = HTTPException(
|
credentials_exc = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -42,10 +49,29 @@ async def get_current_user(
|
|||||||
)
|
)
|
||||||
user_id: str | None = payload.get("sub")
|
user_id: str | None = payload.get("sub")
|
||||||
email: str | None = payload.get("email")
|
email: str | None = payload.get("email")
|
||||||
tier: str = payload.get("tier", "free")
|
|
||||||
if not user_id or not email:
|
if not user_id or not email:
|
||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
except JWTError:
|
except JWTError:
|
||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
tier: str = result.scalar_one_or_none() or "free"
|
||||||
|
|
||||||
|
# Fetch name/surname from user row.
|
||||||
|
user_result = await db.execute(
|
||||||
|
select(User.name, User.surname).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user_id,
|
||||||
|
email=email,
|
||||||
|
name=user_row.name if user_row else None,
|
||||||
|
surname=user_row.surname if user_row else None,
|
||||||
|
tier=tier,
|
||||||
|
) # type: ignore[arg-type]
|
||||||
|
|||||||
317
app/api/routes/agent_setup.py
Normal file
317
app/api/routes/agent_setup.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
POST /agents/journey/start — start a new journey session
|
||||||
|
POST /agents/journey/message — continue the conversation
|
||||||
|
|
||||||
|
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
||||||
|
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
||||||
|
|
||||||
|
Journey flow:
|
||||||
|
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
||||||
|
2. Server creates a session, calls the LLM with a contextual system prompt,
|
||||||
|
and returns the first question.
|
||||||
|
3. Client sends follow-up messages to ``/message``.
|
||||||
|
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
||||||
|
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||||
|
5. Server parses the block, sets ``done=True``, and returns the template.
|
||||||
|
|
||||||
|
The ``prompt_template`` from the final response is meant to be stored in
|
||||||
|
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
||||||
|
by the Electron client (via the agent CRUD endpoints).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import CloudAgentConfig, LocalAgentConfig
|
||||||
|
from app.schemas import (
|
||||||
|
JourneyMessageRequest,
|
||||||
|
JourneyResponse,
|
||||||
|
JourneyStartRequest,
|
||||||
|
UserProfile,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
||||||
|
|
||||||
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
|
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||||
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
|
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
||||||
|
_MAX_TURNS: int = 5
|
||||||
|
|
||||||
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _JourneySession:
|
||||||
|
session_id: str
|
||||||
|
user_id: str
|
||||||
|
agent_type: str # "local" | "cloud"
|
||||||
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||||
|
|
||||||
|
|
||||||
|
# session_id → session
|
||||||
|
_sessions: dict[str, _JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
||||||
|
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
||||||
|
s = _sessions.get(session_id)
|
||||||
|
if s is None or s.is_expired():
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
|
if s.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_LOCAL_PREAMBLE = """\
|
||||||
|
What kind of files are in the directories you want to monitor? \
|
||||||
|
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
||||||
|
|
||||||
|
_CLOUD_PREAMBLE = """\
|
||||||
|
What kind of emails or messages should I look for? \
|
||||||
|
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
|
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
||||||
|
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
||||||
|
|
||||||
|
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
|
1. The type and format of the source content.
|
||||||
|
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
||||||
|
3. How fields should be mapped (e.g. email subject → task title).
|
||||||
|
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
5. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
|
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
||||||
|
these exact markers on their own lines:
|
||||||
|
|
||||||
|
{template_start}
|
||||||
|
<the complete extraction prompt here>
|
||||||
|
{template_end}
|
||||||
|
|
||||||
|
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
||||||
|
and must return a JSON array of records in this shape:
|
||||||
|
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
||||||
|
|
||||||
|
Rules for the generated template:
|
||||||
|
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
||||||
|
- Include concrete examples of mappings.
|
||||||
|
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
||||||
|
- Set isAiSuggested: true and isApproved: false on every record.
|
||||||
|
{existing_section}\
|
||||||
|
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
||||||
|
source_description = (
|
||||||
|
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
||||||
|
)
|
||||||
|
existing_section = (
|
||||||
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
|
f"---\n{existing_template}\n---\n"
|
||||||
|
if existing_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
|
source_description=source_description,
|
||||||
|
template_start=_TEMPLATE_START,
|
||||||
|
template_end=_TEMPLATE_END,
|
||||||
|
existing_section=existing_section,
|
||||||
|
max_turns=_MAX_TURNS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _first_question(agent_type: str) -> str:
|
||||||
|
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
||||||
|
|
||||||
|
|
||||||
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_template(text: str) -> str | None:
|
||||||
|
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||||
|
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||||
|
return None
|
||||||
|
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||||
|
end_idx = text.index(_TEMPLATE_END)
|
||||||
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM call ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM."""
|
||||||
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
|
for turn in history:
|
||||||
|
if turn["role"] == "user":
|
||||||
|
messages.append(HumanMessage(content=turn["content"]))
|
||||||
|
else:
|
||||||
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
return response.content # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Existing-config loader ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_existing_template(
|
||||||
|
agent_id: str,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> str | None:
|
||||||
|
"""Return the prompt_template of an existing agent config, or None."""
|
||||||
|
# Try local first, then cloud.
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local = local_result.scalar_one_or_none()
|
||||||
|
if local is not None:
|
||||||
|
return local.prompt_template
|
||||||
|
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud = cloud_result.scalar_one_or_none()
|
||||||
|
return cloud.prompt_template if cloud is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
|
async def start_journey(
|
||||||
|
body: JourneyStartRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Start a new Chatbot Journey session.
|
||||||
|
|
||||||
|
If ``agent_id`` is provided the session is pre-seeded with the existing
|
||||||
|
agent's ``prompt_template`` so the user can refine it.
|
||||||
|
"""
|
||||||
|
# Load existing template (may be None).
|
||||||
|
existing_template: str | None = None
|
||||||
|
if body.agent_id:
|
||||||
|
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
||||||
|
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
||||||
|
# the user may be starting a fresh journey for a not-yet-persisted config).
|
||||||
|
|
||||||
|
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
||||||
|
first_question = _first_question(body.agent_type)
|
||||||
|
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
session = _JourneySession(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
agent_type=body.agent_type,
|
||||||
|
# Seed history with the AI's first question so it stays consistent.
|
||||||
|
history=[{"role": "assistant", "content": first_question}],
|
||||||
|
)
|
||||||
|
# Store the system prompt inside the session for reuse in /message.
|
||||||
|
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
||||||
|
_sessions[session_id] = session
|
||||||
|
|
||||||
|
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
||||||
|
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
|
async def send_journey_message(
|
||||||
|
body: JourneyMessageRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Send a message in an existing Chatbot Journey session.
|
||||||
|
|
||||||
|
The server appends the user's message to the conversation history,
|
||||||
|
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
||||||
|
``prompt_template`` block the response includes ``done=True`` and the
|
||||||
|
extracted template.
|
||||||
|
"""
|
||||||
|
session = _get_session(body.session_id, current_user.id)
|
||||||
|
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Append user turn to history.
|
||||||
|
session.history.append({"role": "user", "content": body.message})
|
||||||
|
|
||||||
|
# Call the LLM with the full conversation so far.
|
||||||
|
ai_reply = await _call_llm(system_prompt, session.history)
|
||||||
|
|
||||||
|
# Append AI turn.
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
# Check if the LLM produced the final template.
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
# Strip the sentinel markers from the message shown to the user.
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
||||||
|
# Clean up the session immediately on completion.
|
||||||
|
_sessions.pop(body.session_id, None)
|
||||||
|
else:
|
||||||
|
# Nudge the LLM to wrap up after max turns.
|
||||||
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
|
if turns >= _MAX_TURNS:
|
||||||
|
# Add a system-level nudge as a hidden user message.
|
||||||
|
session.history.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"[System: You have enough information. Please generate the final "
|
||||||
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
return JourneyResponse(
|
||||||
|
session_id=body.session_id,
|
||||||
|
message=display_message,
|
||||||
|
done=done,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
)
|
||||||
452
app/api/routes/agents.py
Normal file
452
app/api/routes/agents.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
GET /agents/catalog — hardcoded agent type catalog
|
||||||
|
GET /agents/local — list user's local agent configs
|
||||||
|
POST /agents/local — create local agent (tier-gated)
|
||||||
|
PUT /agents/local/{agent_id} — partial update (ownership check)
|
||||||
|
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
||||||
|
GET /agents/cloud — list user's cloud agent configs
|
||||||
|
POST /agents/cloud — create cloud agent (tier-gated)
|
||||||
|
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
||||||
|
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
||||||
|
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
||||||
|
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, or_, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import FEATURES
|
||||||
|
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
||||||
|
from app.core.device_manager import device_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
from app.schemas import (
|
||||||
|
AgentCatalogItem,
|
||||||
|
AgentRunLogResponse,
|
||||||
|
CloudAgentConfigCreate,
|
||||||
|
CloudAgentConfigResponse,
|
||||||
|
CloudAgentConfigUpdate,
|
||||||
|
LocalAgentConfigCreate,
|
||||||
|
LocalAgentConfigResponse,
|
||||||
|
LocalAgentConfigUpdate,
|
||||||
|
UserProfile,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Datetime helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _dt_ms(dt: datetime) -> int:
|
||||||
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Model → schema converters ─────────────────────────────────────────
|
||||||
|
|
||||||
|
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
||||||
|
return LocalAgentConfigResponse(
|
||||||
|
id=a.id,
|
||||||
|
name=a.name,
|
||||||
|
device_id=a.device_id,
|
||||||
|
directory_paths=a.directory_paths,
|
||||||
|
data_types=a.data_types,
|
||||||
|
prompt_template=a.prompt_template,
|
||||||
|
file_extensions=a.file_extensions,
|
||||||
|
schedule_cron=a.schedule_cron,
|
||||||
|
enabled=a.enabled,
|
||||||
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
||||||
|
return CloudAgentConfigResponse(
|
||||||
|
id=a.id,
|
||||||
|
provider=a.provider, # type: ignore[arg-type]
|
||||||
|
name=a.name,
|
||||||
|
data_types=a.data_types,
|
||||||
|
prompt_template=a.prompt_template,
|
||||||
|
schedule_cron=a.schedule_cron,
|
||||||
|
filter_config=a.filter_config,
|
||||||
|
enabled=a.enabled,
|
||||||
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
|
return AgentRunLogResponse(
|
||||||
|
id=log.id,
|
||||||
|
agent_id=log.agent_id,
|
||||||
|
agent_type=log.agent_type, # type: ignore[arg-type]
|
||||||
|
status=log.status, # type: ignore[arg-type]
|
||||||
|
items_processed=log.items_processed,
|
||||||
|
items_created=log.items_created,
|
||||||
|
errors=log.errors or [],
|
||||||
|
started_at=_dt_ms(log.started_at),
|
||||||
|
completed_at=_dt_ms_opt(log.completed_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Ownership-checked lookups ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_local_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> LocalAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_cloud_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> CloudAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier limit helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return combined enabled local + cloud agent count for the user."""
|
||||||
|
local_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(LocalAgentConfig.id)).where(
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
LocalAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
cloud_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(CloudAgentConfig.id)).where(
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
CloudAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
return local_count + cloud_count
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
|
if limit != -1 and current_count >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
||||||
|
|
||||||
|
class _RunsPage(BaseModel):
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
limit: int
|
||||||
|
items: list[AgentRunLogResponse]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/catalog", response_model=list[AgentCatalogItem])
|
||||||
|
async def get_agent_catalog(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> list[AgentCatalogItem]:
|
||||||
|
"""Return the static list of available agent types and their descriptions."""
|
||||||
|
return [
|
||||||
|
AgentCatalogItem(
|
||||||
|
type="local_directory",
|
||||||
|
name="Local Directory Monitor",
|
||||||
|
description="Watches local directories, extracts data from files using AI",
|
||||||
|
),
|
||||||
|
AgentCatalogItem(
|
||||||
|
type="gmail",
|
||||||
|
name="Gmail Connector",
|
||||||
|
description="Scans Gmail inbox, extracts tasks/notes from emails",
|
||||||
|
),
|
||||||
|
AgentCatalogItem(
|
||||||
|
type="teams",
|
||||||
|
name="Microsoft Teams Connector",
|
||||||
|
description="Monitors Teams messages, extracts action items",
|
||||||
|
),
|
||||||
|
AgentCatalogItem(
|
||||||
|
type="outlook",
|
||||||
|
name="Outlook Connector",
|
||||||
|
description="Scans Outlook inbox, extracts tasks/notes",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent CRUD ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
||||||
|
async def list_local_agents(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[LocalAgentConfigResponse]:
|
||||||
|
"""List all local directory agent configs owned by the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
return [_to_local_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_local_agent(
|
||||||
|
body: LocalAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Create a new local directory agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = LocalAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
name=body.name,
|
||||||
|
device_id=body.device_id,
|
||||||
|
directory_paths=body.directory_paths,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
file_extensions=body.file_extensions,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
||||||
|
async def update_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: LocalAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Partially update a local agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/local/{agent_id}", response_model=dict)
|
||||||
|
async def delete_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
||||||
|
async def list_cloud_agents(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[CloudAgentConfigResponse]:
|
||||||
|
"""List all cloud connector agent configs owned by the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
return [_to_cloud_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_cloud_agent(
|
||||||
|
body: CloudAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Create a new cloud connector agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = CloudAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
provider=body.provider,
|
||||||
|
name=body.name,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
oauth_token_encrypted=body.oauth_token_encrypted,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
filter_config=body.filter_config,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
||||||
|
async def update_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: CloudAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Partially update a cloud agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/cloud/{agent_id}", response_model=dict)
|
||||||
|
async def delete_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Run logs ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/runs", response_model=_RunsPage)
|
||||||
|
async def list_run_logs(
|
||||||
|
agent_id: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=20, ge=1, le=100),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _RunsPage:
|
||||||
|
"""Return paginated run logs for the authenticated user.
|
||||||
|
|
||||||
|
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
||||||
|
"""
|
||||||
|
base_filter = [AgentRunLog.user_id == current_user.id]
|
||||||
|
if agent_id:
|
||||||
|
base_filter.append(AgentRunLog.agent_id == agent_id)
|
||||||
|
|
||||||
|
total = (
|
||||||
|
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
||||||
|
).scalar_one()
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentRunLog)
|
||||||
|
.where(*base_filter)
|
||||||
|
.order_by(AgentRunLog.started_at.desc())
|
||||||
|
.offset((page - 1) * limit)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
||||||
|
|
||||||
|
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Manual trigger stub ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
async def trigger_agent_run(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AgentRunLogResponse:
|
||||||
|
"""Manually trigger an agent run.
|
||||||
|
|
||||||
|
Looks up the agent config (local or cloud) by ID with ownership check,
|
||||||
|
creates a run log entry with ``status="running"``, and returns it.
|
||||||
|
|
||||||
|
Actual dispatch to the agent runner is wired in Step 3.4 once
|
||||||
|
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
||||||
|
"""
|
||||||
|
# Determine agent type by trying local first, then cloud.
|
||||||
|
# Keep the full config object so we can pass it to the agent runner.
|
||||||
|
local_config: LocalAgentConfig | None = None
|
||||||
|
cloud_config: CloudAgentConfig | None = None
|
||||||
|
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local_config = local_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if local_config is not None:
|
||||||
|
agent_type = "local"
|
||||||
|
else:
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud_config = cloud_result.scalar_one_or_none()
|
||||||
|
if cloud_config is not None:
|
||||||
|
agent_type = "cloud"
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
user_id=current_user.id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
|
||||||
|
# Dispatch the run as a background task — returns 202 immediately.
|
||||||
|
if agent_type == "local" and local_config is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
||||||
|
)
|
||||||
|
elif agent_type == "cloud" and cloud_config is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
||||||
|
)
|
||||||
|
|
||||||
|
return _to_run_log_response(run_log)
|
||||||
@@ -1,33 +1,37 @@
|
|||||||
"""Auth routes: register, login, refresh, me.
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
Users and refresh tokens are kept in an in-memory dict until Step 12
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||||
migrates them to PostgreSQL.
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||||
|
SHA-256 hashes so plaintext never reaches the DB.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import RefreshToken, User
|
||||||
from app.schemas import AuthTokens, UserProfile
|
from app.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
|
|
||||||
_users: dict[str, dict[str, Any]] = {} # email → user record
|
|
||||||
_refresh_tokens: dict[str, str] = {} # plain token → user_id
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ─────────────────────────────────────────────────
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _hash_password(password: str) -> str:
|
def _hash_password(password: str) -> str:
|
||||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
@@ -36,33 +40,34 @@ def _verify_password(password: str, hashed: str) -> bool:
|
|||||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
|
def _hash_token(plain_token: str) -> str:
|
||||||
|
"""SHA-256 of the plain refresh token string."""
|
||||||
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||||
|
"""Return (signed JWT, expires_at_ms)."""
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||||
access_payload = {
|
payload = {
|
||||||
"sub": user_id,
|
"sub": user_id,
|
||||||
"email": email,
|
"email": email,
|
||||||
"tier": tier,
|
"tier": tier,
|
||||||
"exp": access_exp,
|
"exp": exp,
|
||||||
"iat": now,
|
"iat": now,
|
||||||
}
|
}
|
||||||
access_token = jwt.encode(
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
return token, exp * 1000 # ms for client
|
||||||
)
|
|
||||||
refresh_token = str(uuid.uuid4())
|
|
||||||
_refresh_tokens[refresh_token] = user_id
|
|
||||||
return AuthTokens(
|
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
expires_at=access_exp * 1000, # milliseconds for client
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ────────────────────────────────────────────────────
|
# ── Request bodies ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class _RegisterRequest(BaseModel):
|
class _RegisterRequest(BaseModel):
|
||||||
email: str
|
email: str
|
||||||
password: str
|
password: str
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class _LoginRequest(BaseModel):
|
class _LoginRequest(BaseModel):
|
||||||
@@ -76,43 +81,155 @@ class _RefreshRequest(BaseModel):
|
|||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||||
async def register(body: _RegisterRequest) -> AuthTokens:
|
async def register(
|
||||||
|
body: _RegisterRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
"""Create a new account and return JWT tokens."""
|
"""Create a new account and return JWT tokens."""
|
||||||
if body.email in _users:
|
existing = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||||
user_id = str(uuid.uuid4())
|
|
||||||
_users[body.email] = {
|
user = User(
|
||||||
"id": user_id,
|
id=str(uuid.uuid4()),
|
||||||
"email": body.email,
|
email=body.email,
|
||||||
"password_hash": _hash_password(body.password),
|
name=body.name,
|
||||||
"tier": "free",
|
surname=body.surname,
|
||||||
}
|
password_hash=_hash_password(body.password),
|
||||||
return _make_tokens(user_id, body.email, "free")
|
tier="free",
|
||||||
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush() # get user.id without committing
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=AuthTokens)
|
@router.post("/login", response_model=AuthTokens)
|
||||||
async def login(body: _LoginRequest) -> AuthTokens:
|
async def login(
|
||||||
|
body: _LoginRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
"""Validate credentials and return JWT tokens."""
|
"""Validate credentials and return JWT tokens."""
|
||||||
user = _users.get(body.email)
|
result = await db.execute(select(User).where(User.email == body.email))
|
||||||
if not user or not _verify_password(body.password, user["password_hash"]):
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not _verify_password(body.password, user.password_hash):
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=AuthTokens)
|
@router.post("/refresh", response_model=AuthTokens)
|
||||||
async def refresh(body: _RefreshRequest) -> AuthTokens:
|
async def refresh(
|
||||||
|
body: _RefreshRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
"""Rotate a refresh token and return a new token pair."""
|
"""Rotate a refresh token and return a new token pair."""
|
||||||
user_id = _refresh_tokens.pop(body.refresh_token, None)
|
token_hash = _hash_token(body.refresh_token)
|
||||||
if user_id is None:
|
result = await db.execute(
|
||||||
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
rt = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||||
user = next((u for u in _users.values() if u["id"] == user_id), None)
|
|
||||||
|
# Rotate: delete old token, issue new one.
|
||||||
|
await db.delete(rt)
|
||||||
|
|
||||||
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||||
|
user = user_result.scalar_one_or_none()
|
||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
new_rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=new_expires,
|
||||||
|
)
|
||||||
|
db.add(new_rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateProfileRequest(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserProfile)
|
@router.get("/me", response_model=UserProfile)
|
||||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||||
"""Return the profile for the authenticated user."""
|
"""Return the profile for the authenticated user."""
|
||||||
return current_user
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me", response_model=UserProfile)
|
||||||
|
async def update_profile(
|
||||||
|
body: _UpdateProfileRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update the authenticated user's name and surname."""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
if body.name is not None:
|
||||||
|
user.name = body.name
|
||||||
|
if body.surname is not None:
|
||||||
|
user.surname = body.surname
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user.id,
|
||||||
|
email=user.email,
|
||||||
|
name=user.name,
|
||||||
|
surname=user.surname,
|
||||||
|
tier=current_user.tier,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
||||||
|
|
||||||
Blobs are stored in S3 via BlobStore. Backup metadata is kept in an
|
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
|
||||||
in-memory dict until Step 12 migrates it to PostgreSQL (backup_metadata table).
|
PostgreSQL ``backup_metadata`` table.
|
||||||
|
|
||||||
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
||||||
treating "history" as a ``{backup_id}`` path parameter.
|
treating "history" as a ``{backup_id}`` path parameter.
|
||||||
@@ -9,13 +9,17 @@ treating "history" as a ``{backup_id}`` path parameter.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import uuid
|
||||||
from email.utils import parsedate_to_datetime
|
from email.utils import parsedate_to_datetime
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import BackupMetadata as BackupMetadataModel
|
||||||
from app.schemas import BackupMetadata, UserProfile
|
from app.schemas import BackupMetadata, UserProfile
|
||||||
from app.storage.blob_store import BlobStore
|
from app.storage.blob_store import BlobStore
|
||||||
from app.storage.encryption import reject_if_tampered
|
from app.storage.encryption import reject_if_tampered
|
||||||
@@ -24,35 +28,25 @@ router = APIRouter(prefix="/backup", tags=["backup"])
|
|||||||
|
|
||||||
_blob_store = BlobStore()
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
|
|
||||||
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
|
|
||||||
|
|
||||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
_TIER_BACKUP_LIMITS_GB: dict[str, int] = {
|
"""Return total backup bytes stored by *user_id*."""
|
||||||
"free": 0,
|
result = await db.execute(
|
||||||
"pro": 5,
|
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
|
||||||
"power": 25,
|
BackupMetadataModel.user_id == user_id
|
||||||
"team": -1, # unlimited
|
)
|
||||||
}
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None:
|
async def _check_backup_quota(
|
||||||
|
user: UserProfile, size_bytes: int, db: AsyncSession
|
||||||
|
) -> None:
|
||||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||||
limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0)
|
current = await _current_backup_bytes(user.id, db)
|
||||||
if limit_gb == 0:
|
tier_manager.enforce_backup_quota(
|
||||||
raise HTTPException(
|
user.tier, current_bytes=current, additional_bytes=size_bytes
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
)
|
||||||
detail="Backup is not available on the free tier",
|
|
||||||
)
|
|
||||||
if limit_gb == -1:
|
|
||||||
return # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024**3
|
|
||||||
used = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
|
||||||
if used + size_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("")
|
@router.put("")
|
||||||
@@ -62,6 +56,7 @@ async def upload_backup(
|
|||||||
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
||||||
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Upload an E2E-encrypted backup blob.
|
"""Upload an E2E-encrypted backup blob.
|
||||||
|
|
||||||
@@ -69,24 +64,23 @@ async def upload_backup(
|
|||||||
"""
|
"""
|
||||||
blob = await request.body()
|
blob = await request.body()
|
||||||
reject_if_tampered(blob, x_backup_checksum)
|
reject_if_tampered(blob, x_backup_checksum)
|
||||||
_check_backup_quota(current_user.id, current_user.tier, len(blob))
|
await _check_backup_quota(current_user, len(blob), db)
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
s3_key = await _blob_store.upload(
|
||||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||||
)
|
)
|
||||||
|
|
||||||
backup_record: dict[str, Any] = {
|
row = BackupMetadataModel(
|
||||||
"id": str(x_backup_timestamp),
|
id=str(uuid.uuid4()),
|
||||||
"s3_key": s3_key,
|
user_id=current_user.id,
|
||||||
"version": x_backup_version,
|
s3_key=s3_key,
|
||||||
"timestamp": x_backup_timestamp,
|
version=x_backup_version,
|
||||||
"checksum": x_backup_checksum,
|
timestamp=x_backup_timestamp,
|
||||||
"size_bytes": len(blob),
|
checksum=x_backup_checksum,
|
||||||
}
|
size_bytes=len(blob),
|
||||||
|
)
|
||||||
user_backups = _backups.setdefault(current_user.id, [])
|
db.add(row)
|
||||||
user_backups.append(backup_record)
|
await db.commit()
|
||||||
user_backups.sort(key=lambda b: b["timestamp"], reverse=True)
|
|
||||||
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
@@ -94,16 +88,23 @@ async def upload_backup(
|
|||||||
@router.get("/history", response_model=list[BackupMetadata])
|
@router.get("/history", response_model=list[BackupMetadata])
|
||||||
async def backup_history(
|
async def backup_history(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> list[BackupMetadata]:
|
) -> list[BackupMetadata]:
|
||||||
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
return [
|
return [
|
||||||
BackupMetadata(
|
BackupMetadata(
|
||||||
version=b["version"],
|
version=r.version,
|
||||||
timestamp=b["timestamp"],
|
timestamp=r.timestamp,
|
||||||
checksum=b["checksum"],
|
checksum=r.checksum,
|
||||||
chunk_count=1, # single-chunk uploads for now — TODO(Step12): track real count
|
chunk_count=1,
|
||||||
)
|
)
|
||||||
for b in _backups.get(current_user.id, [])
|
for r in rows
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -111,32 +112,37 @@ async def backup_history(
|
|||||||
async def download_backup(
|
async def download_backup(
|
||||||
request: Request,
|
request: Request,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
||||||
user_backups = _backups.get(current_user.id, [])
|
result = await db.execute(
|
||||||
if not user_backups:
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
latest = result.scalar_one_or_none()
|
||||||
|
if latest is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
||||||
|
|
||||||
latest = user_backups[0]
|
|
||||||
|
|
||||||
ims_header = request.headers.get("If-Modified-Since")
|
ims_header = request.headers.get("If-Modified-Since")
|
||||||
if ims_header:
|
if ims_header:
|
||||||
try:
|
try:
|
||||||
ims_dt = parsedate_to_datetime(ims_header)
|
ims_dt = parsedate_to_datetime(ims_header)
|
||||||
ims_ms = int(ims_dt.timestamp() * 1000)
|
ims_ms = int(ims_dt.timestamp() * 1000)
|
||||||
if latest["timestamp"] <= ims_ms:
|
if latest.timestamp <= ims_ms:
|
||||||
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
||||||
except Exception:
|
except Exception:
|
||||||
pass # malformed header — ignore and serve the blob
|
pass # malformed header — ignore and serve the blob
|
||||||
|
|
||||||
blob = await _blob_store.download(current_user.id, latest["s3_key"])
|
blob = await _blob_store.download(current_user.id, latest.s3_key)
|
||||||
return Response(
|
return Response(
|
||||||
content=blob,
|
content=blob,
|
||||||
media_type="application/octet-stream",
|
media_type="application/octet-stream",
|
||||||
headers={
|
headers={
|
||||||
"X-Backup-Version": str(latest["version"]),
|
"X-Backup-Version": str(latest.version),
|
||||||
"X-Backup-Timestamp": str(latest["timestamp"]),
|
"X-Backup-Timestamp": str(latest.timestamp),
|
||||||
"X-Checksum": latest["checksum"],
|
"X-Checksum": latest.checksum,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -145,14 +151,21 @@ async def download_backup(
|
|||||||
async def delete_backup(
|
async def delete_backup(
|
||||||
backup_id: str,
|
backup_id: str,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Delete a specific backup by ID."""
|
"""Delete a specific backup by ID."""
|
||||||
user_backups = _backups.get(current_user.id, [])
|
result = await db.execute(
|
||||||
target = next((b for b in user_backups if b["id"] == backup_id), None)
|
select(BackupMetadataModel).where(
|
||||||
|
BackupMetadataModel.id == backup_id,
|
||||||
|
BackupMetadataModel.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
target = result.scalar_one_or_none()
|
||||||
if target is None:
|
if target is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
||||||
|
|
||||||
await _blob_store.delete(current_user.id, target["s3_key"])
|
await _blob_store.delete(current_user.id, target.s3_key)
|
||||||
_backups[current_user.id] = [b for b in user_backups if b["id"] != backup_id]
|
await db.delete(target)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|||||||
@@ -1,44 +1,25 @@
|
|||||||
"""Billing routes: Stripe checkout, webhook, subscription management.
|
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||||
|
|
||||||
Subscription records are kept in-memory until Step 12 migrates them to
|
Business logic lives in ``app.billing.stripe_service.StripeService``.
|
||||||
PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when
|
The route layer handles HTTP concerns (request parsing, response shaping)
|
||||||
STRIPE_SECRET_KEY is not configured, allowing local development without keys.
|
and delegates everything else to the service singleton.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import stripe as stripe_lib
|
from fastapi import APIRouter, Depends, Header, Request, status
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.billing.stripe_service import stripe_service
|
||||||
|
from app.db import get_session
|
||||||
from app.schemas import BillingTier, UserProfile
|
from app.schemas import BillingTier, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
|
|
||||||
# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12
|
|
||||||
_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record
|
|
||||||
|
|
||||||
_TIER_PRICE_IDS: dict[str, str] = {
|
|
||||||
"pro": "price_pro_monthly", # replace with real Stripe price IDs
|
|
||||||
"power": "price_power_monthly",
|
|
||||||
"team": "price_team_monthly",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _stripe_configured() -> bool:
|
|
||||||
return bool(settings.STRIPE_SECRET_KEY)
|
|
||||||
|
|
||||||
|
|
||||||
def _stripe() -> Any:
|
|
||||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
|
||||||
return stripe_lib
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ─────────────────────────────────────────────────────
|
# ── Request bodies ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -57,40 +38,15 @@ async def create_checkout(
|
|||||||
|
|
||||||
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||||
"""
|
"""
|
||||||
if body.tier == "free":
|
url = stripe_service.create_checkout_session(current_user.id, body.tier)
|
||||||
raise HTTPException(
|
return {"checkout_url": url}
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Cannot create a checkout session for the free tier",
|
|
||||||
)
|
|
||||||
|
|
||||||
if _stripe_configured():
|
|
||||||
price_id = _TIER_PRICE_IDS.get(body.tier)
|
|
||||||
if not price_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Unknown tier: {body.tier}",
|
|
||||||
)
|
|
||||||
s = _stripe()
|
|
||||||
session = s.checkout.Session.create(
|
|
||||||
payment_method_types=["card"],
|
|
||||||
mode="subscription",
|
|
||||||
line_items=[{"price": price_id, "quantity": 1}],
|
|
||||||
success_url=(
|
|
||||||
"https://app.adiuva.app/billing/success"
|
|
||||||
"?session_id={CHECKOUT_SESSION_ID}"
|
|
||||||
),
|
|
||||||
cancel_url="https://app.adiuva.app/billing/cancel",
|
|
||||||
metadata={"user_id": current_user.id, "tier": body.tier},
|
|
||||||
)
|
|
||||||
return {"checkout_url": session.url}
|
|
||||||
|
|
||||||
return {"checkout_url": "https://stripe.com/stub-checkout"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/webhook", response_model=dict)
|
@router.post("/webhook", response_model=dict)
|
||||||
async def stripe_webhook(
|
async def stripe_webhook(
|
||||||
request: Request,
|
request: Request,
|
||||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Handle Stripe webhook events.
|
"""Handle Stripe webhook events.
|
||||||
|
|
||||||
@@ -98,57 +54,17 @@ async def stripe_webhook(
|
|||||||
Returns 200 immediately when Stripe is not configured (local dev).
|
Returns 200 immediately when Stripe is not configured (local dev).
|
||||||
"""
|
"""
|
||||||
payload = await request.body()
|
payload = await request.body()
|
||||||
|
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||||
if not _stripe_configured():
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
try:
|
|
||||||
s = _stripe()
|
|
||||||
event = s.Webhook.construct_event(
|
|
||||||
payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET
|
|
||||||
)
|
|
||||||
except stripe_lib.error.SignatureVerificationError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Invalid Stripe signature",
|
|
||||||
)
|
|
||||||
|
|
||||||
event_type: str = event["type"]
|
|
||||||
data: dict[str, Any] = event["data"]["object"]
|
|
||||||
|
|
||||||
if event_type == "checkout.session.completed":
|
|
||||||
user_id = data.get("metadata", {}).get("user_id")
|
|
||||||
tier = data.get("metadata", {}).get("tier", "free")
|
|
||||||
sub_id = data.get("subscription")
|
|
||||||
if user_id:
|
|
||||||
_subscriptions[user_id] = {
|
|
||||||
"tier": tier,
|
|
||||||
"stripe_subscription_id": sub_id,
|
|
||||||
"status": "active",
|
|
||||||
"current_period_end": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
elif event_type == "customer.subscription.updated":
|
|
||||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif event_type == "customer.subscription.deleted":
|
|
||||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif event_type == "invoice.payment_failed":
|
|
||||||
# TODO(Step12): flag subscription as past_due, notify user
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/subscription", response_model=dict)
|
@router.get("/subscription", response_model=dict)
|
||||||
async def get_subscription(
|
async def get_subscription(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Return the current subscription info for the authenticated user."""
|
"""Return the current subscription info for the authenticated user."""
|
||||||
sub = _subscriptions.get(current_user.id)
|
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||||
if sub is None:
|
if sub is None:
|
||||||
return {
|
return {
|
||||||
"tier": current_user.tier,
|
"tier": current_user.tier,
|
||||||
@@ -159,26 +75,11 @@ async def get_subscription(
|
|||||||
return sub
|
return sub
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/subscription", response_model=dict)
|
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||||
async def cancel_subscription(
|
async def cancel_subscription(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Cancel the active subscription."""
|
"""Cancel the active subscription."""
|
||||||
sub = _subscriptions.get(current_user.id)
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
if sub is None or not sub.get("stripe_subscription_id"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="No active subscription found",
|
|
||||||
)
|
|
||||||
|
|
||||||
if _stripe_configured():
|
|
||||||
s = _stripe()
|
|
||||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
|
||||||
|
|
||||||
_subscriptions[current_user.id] = {
|
|
||||||
**sub,
|
|
||||||
"tier": "free",
|
|
||||||
"status": "canceled",
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|||||||
@@ -1,78 +1,42 @@
|
|||||||
"""Chat routes: POST /chat and WebSocket /chat/stream."""
|
"""Chat routes: POST /chat (REST fallback).
|
||||||
|
|
||||||
|
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
from fastapi import APIRouter, Depends
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.config.settings import settings
|
from app.core.deep_agent import run_home
|
||||||
from app.core.orchestrator import orchestrate, orchestrate_stream
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.db import async_session
|
||||||
|
from app.schemas import ChatRequest, ChatResponse, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def chat(
|
async def chat(
|
||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""Route a chat message through the orchestrator.
|
"""Route a chat message through the Home deep agent (non-streaming)."""
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(current_user.id, body.message)
|
||||||
|
|
||||||
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
context = {
|
||||||
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
**body.context.model_dump(),
|
||||||
"""
|
**memory_context,
|
||||||
result = await orchestrate(body)
|
}
|
||||||
|
|
||||||
|
response_text = await run_home(
|
||||||
|
user_id=current_user.id,
|
||||||
|
message=body.message,
|
||||||
|
context=context,
|
||||||
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
|
result = ChatResponse(response=response_text)
|
||||||
return JSONResponse(content=result.model_dump())
|
return JSONResponse(content=result.model_dump())
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/stream")
|
|
||||||
async def chat_stream(websocket: WebSocket) -> None:
|
|
||||||
"""Streaming chat via WebSocket.
|
|
||||||
|
|
||||||
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
|
|
||||||
|
|
||||||
Protocol:
|
|
||||||
1. Client sends ``ChatRequest`` as the first JSON text frame.
|
|
||||||
2. Server streams response text chunks.
|
|
||||||
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
|
|
||||||
4. Server pings every 30 s to keep the connection alive.
|
|
||||||
"""
|
|
||||||
# Authenticate before accepting the connection
|
|
||||||
token = websocket.query_params.get("token", "")
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
if not user_id:
|
|
||||||
raise JWTError("missing sub")
|
|
||||||
except JWTError:
|
|
||||||
await websocket.close(code=1008) # 1008 = Policy Violation
|
|
||||||
return
|
|
||||||
|
|
||||||
await websocket.accept()
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw = await websocket.receive_text()
|
|
||||||
body = ChatRequest.model_validate_json(raw)
|
|
||||||
|
|
||||||
async def _heartbeat() -> None:
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
||||||
await websocket.send_text(json.dumps({"ping": True}))
|
|
||||||
|
|
||||||
heartbeat_task = asyncio.create_task(_heartbeat())
|
|
||||||
try:
|
|
||||||
async for chunk in orchestrate_stream(body):
|
|
||||||
await websocket.send_text(chunk)
|
|
||||||
finally:
|
|
||||||
heartbeat_task.cancel()
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
pass
|
|
||||||
|
|||||||
365
app/api/routes/device_ws.py
Normal file
365
app/api/routes/device_ws.py
Normal file
@@ -0,0 +1,365 @@
|
|||||||
|
"""Device WebSocket endpoint.
|
||||||
|
|
||||||
|
Persistent connection from Electron devices to the backend.
|
||||||
|
|
||||||
|
WS /api/v1/ws/device?token=<jwt>
|
||||||
|
|
||||||
|
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
||||||
|
available during the WebSocket handshake).
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
1. Client connects → JWT validated → connection accepted.
|
||||||
|
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
||||||
|
3. Backend registers the connection in ``DeviceConnectionManager``.
|
||||||
|
4. Session enters message dispatch loop + heartbeat.
|
||||||
|
|
||||||
|
Incoming frame dispatch:
|
||||||
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
|
- ``agent_data`` → enqueued in the per-run agent data queue.
|
||||||
|
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
||||||
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
|
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
||||||
|
|
||||||
|
On disconnect:
|
||||||
|
- Unregisters from DeviceConnectionManager.
|
||||||
|
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
||||||
|
with message "device disconnected".
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
|
from app.core.device_manager import device_manager
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.core.deep_agent import run_home_stream, run_floating_stream
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
|
from app.db import async_session
|
||||||
|
from app.models import AgentRunLog
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||||
|
|
||||||
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/device")
|
||||||
|
async def device_ws(websocket: WebSocket) -> None:
|
||||||
|
"""Persistent WebSocket endpoint for Electron device connections.
|
||||||
|
|
||||||
|
Authentication is via ``?token=<jwt>`` query parameter.
|
||||||
|
"""
|
||||||
|
# ── 1. Authenticate before accepting ─────────────────────────────
|
||||||
|
token = websocket.query_params.get("token", "")
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
if not user_id:
|
||||||
|
raise JWTError("missing sub")
|
||||||
|
except JWTError:
|
||||||
|
await websocket.close(code=1008) # Policy Violation
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
# ── 2. Await device_hello frame ───────────────────────────────────
|
||||||
|
try:
|
||||||
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||||
|
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
hello = json.loads(raw)
|
||||||
|
if hello.get("type") != WsFrameType.device_hello:
|
||||||
|
raise ValueError("expected device_hello as first frame")
|
||||||
|
device_id: str = hello["device_id"]
|
||||||
|
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||||
|
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||||
|
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Register connection ────────────────────────────────────────
|
||||||
|
device_manager.register(user_id, device_id, websocket)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: connected user=%s device=%s agents=%s",
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
agent_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger any overdue agent runs now that the device is connected.
|
||||||
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||||
|
|
||||||
|
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||||
|
try:
|
||||||
|
await asyncio.gather(
|
||||||
|
_message_loop(websocket, user_id),
|
||||||
|
_heartbeat_loop(websocket),
|
||||||
|
)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
||||||
|
finally:
|
||||||
|
device_manager.unregister(user_id)
|
||||||
|
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
||||||
|
await _mark_runs_disconnected(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Message dispatch loop ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||||
|
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
||||||
|
async for raw in websocket.iter_text():
|
||||||
|
try:
|
||||||
|
frame: dict = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame_type = frame.get("type")
|
||||||
|
|
||||||
|
if frame_type == WsFrameType.tool_result:
|
||||||
|
call_id = frame.get("id")
|
||||||
|
if call_id:
|
||||||
|
device_manager.resolve_pending_call(user_id, call_id, frame)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_data:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
await queue.put(frame)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data for unknown run user=%s run=%s",
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_complete:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
# Sentinel: signals the agent data stream is finished.
|
||||||
|
await queue.put(None)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.home_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_home_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.floating_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == "pong":
|
||||||
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
|
||||||
|
|
||||||
|
|
||||||
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
|
async def _executor(payload: dict) -> dict:
|
||||||
|
payload["type"] = WsFrameType.tool_call
|
||||||
|
call_id = payload["id"]
|
||||||
|
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
|
||||||
|
await websocket.send_text(json.dumps(payload))
|
||||||
|
future = device_manager.create_pending_call(user_id, call_id)
|
||||||
|
try:
|
||||||
|
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
|
||||||
|
call_id, payload.get("action"), user_id,
|
||||||
|
)
|
||||||
|
# Clean up the pending future so it doesn't leak
|
||||||
|
conn = device_manager._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.pending_calls.pop(call_id, None)
|
||||||
|
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
|
||||||
|
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
|
||||||
|
call_id, type(result).__name__,
|
||||||
|
list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
|
if result is None:
|
||||||
|
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
|
||||||
|
return result
|
||||||
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_home_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
|
||||||
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
try:
|
||||||
|
event_stream = run_home_stream(
|
||||||
|
user_id, message, context, db_session_factory=async_session
|
||||||
|
)
|
||||||
|
formatter = HomeFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: home_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Memory: store episode after response ──────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_floating_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
scope: dict = frame.get("scope", {})
|
||||||
|
|
||||||
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
|
|
||||||
|
context: dict = {"scope": scope, **memory_context}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
try:
|
||||||
|
event_stream = run_floating_stream(
|
||||||
|
user_id, message, context, scope=scope,
|
||||||
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
|
formatter = FloatingFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: floating_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Memory: store episode after response ──────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
|
"""Send a ping frame every 30 s to keep the connection alive."""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||||
|
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Disconnect cleanup ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _mark_runs_disconnected(user_id: str) -> None:
|
||||||
|
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
await db.execute(
|
||||||
|
update(AgentRunLog)
|
||||||
|
.where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.status == "running",
|
||||||
|
)
|
||||||
|
.values(
|
||||||
|
status="error",
|
||||||
|
errors=["device disconnected"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
||||||
|
user_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.core.execution_plan import plan_cache
|
|
||||||
from app.schemas import ExecutionPlan, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/plans", tags=["plans"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook", response_model=list[ExecutionPlan])
|
|
||||||
async def list_playbooks(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached execution plan playbooks for the authenticated user.
|
|
||||||
|
|
||||||
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
|
||||||
"""
|
|
||||||
return plan_cache.get_all_playbooks()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
|
||||||
async def get_playbook(
|
|
||||||
plan_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> ExecutionPlan:
|
|
||||||
"""Return a specific execution plan playbook by ID."""
|
|
||||||
plan = plan_cache.get_plan(plan_id)
|
|
||||||
if plan is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Plan not found: {plan_id}",
|
|
||||||
)
|
|
||||||
return plan
|
|
||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Plugins routes: browse and install plugins from the marketplace.
|
"""Plugins routes: browse and install plugins from the marketplace.
|
||||||
|
|
||||||
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes introduced
|
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
|
||||||
in Step 10. Step 12 will swap those services' in-memory stores for
|
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
|
||||||
PostgreSQL persistence.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -11,10 +10,14 @@ from typing import Any, Literal
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.db import get_session
|
||||||
from app.marketplace.plugin_registry import registry
|
from app.marketplace.plugin_registry import registry
|
||||||
from app.marketplace.revenue_share import revenue_share
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
from app.models import PluginInstallation, PluginReview as PluginReviewModel
|
||||||
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||||
@@ -36,7 +39,7 @@ def _require_plugin_tier(user: UserProfile) -> None:
|
|||||||
class _PluginDetail(BaseModel):
|
class _PluginDetail(BaseModel):
|
||||||
plugin: PluginManifest
|
plugin: PluginManifest
|
||||||
install_count: int
|
install_count: int
|
||||||
ratings: list[Any] # Step 12 populates from plugin_reviews table
|
ratings: list[Any]
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
@@ -48,26 +51,44 @@ async def list_plugins(
|
|||||||
page: int = Query(default=1, ge=1),
|
page: int = Query(default=1, ge=1),
|
||||||
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> PluginListResponse:
|
) -> PluginListResponse:
|
||||||
"""Browse the plugin marketplace. Requires Power tier or above."""
|
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||||
_require_plugin_tier(current_user)
|
_require_plugin_tier(current_user)
|
||||||
return await registry.list_plugins(category=category, query=q, page=page, sort=sort)
|
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||||
async def get_plugin(
|
async def get_plugin(
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> _PluginDetail:
|
) -> _PluginDetail:
|
||||||
"""Get full plugin details including install count. Requires Power tier or above."""
|
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||||
_require_plugin_tier(current_user)
|
_require_plugin_tier(current_user)
|
||||||
entry = await registry.get_plugin(plugin_id)
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Fetch review ratings for this plugin
|
||||||
|
review_result = await db.execute(
|
||||||
|
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
|
||||||
|
)
|
||||||
|
reviews = review_result.scalars().all()
|
||||||
|
ratings = [
|
||||||
|
{
|
||||||
|
"reviewer_id": r.reviewer_id,
|
||||||
|
"decision": r.decision,
|
||||||
|
"notes": r.notes,
|
||||||
|
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
|
||||||
|
}
|
||||||
|
for r in reviews
|
||||||
|
]
|
||||||
|
|
||||||
return _PluginDetail(
|
return _PluginDetail(
|
||||||
plugin=entry["manifest"],
|
plugin=entry["manifest"],
|
||||||
install_count=entry["install_count"],
|
install_count=entry["install_count"],
|
||||||
ratings=[], # Step 12 populates from plugin_reviews table
|
ratings=ratings,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -76,17 +97,27 @@ async def install_plugin(
|
|||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
||||||
|
|
||||||
Requires Power tier or above.
|
Requires Power tier or above.
|
||||||
"""
|
"""
|
||||||
_require_plugin_tier(current_user)
|
_require_plugin_tier(current_user)
|
||||||
entry = await registry.get_plugin(plugin_id)
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
if entry is None:
|
if entry is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Record the installation in plugin_installations
|
||||||
|
installation = PluginInstallation(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
)
|
||||||
|
db.add(installation)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
await revenue_share.record_install(
|
await revenue_share.record_install(
|
||||||
|
db,
|
||||||
plugin_id=plugin_id,
|
plugin_id=plugin_id,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
amount_cents=entry["manifest"].price_cents,
|
amount_cents=entry["manifest"].price_cents,
|
||||||
@@ -100,7 +131,18 @@ async def install_plugin(
|
|||||||
async def uninstall_plugin(
|
async def uninstall_plugin(
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Unregister a plugin installation."""
|
"""Unregister a plugin installation."""
|
||||||
await registry.record_uninstall(plugin_id)
|
result = await db.execute(
|
||||||
|
select(PluginInstallation).where(
|
||||||
|
PluginInstallation.plugin_id == plugin_id,
|
||||||
|
PluginInstallation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
installation = result.scalar_one_or_none()
|
||||||
|
if installation is not None:
|
||||||
|
await db.delete(installation)
|
||||||
|
await db.commit()
|
||||||
|
await registry.record_uninstall(db, plugin_id)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|||||||
@@ -1,19 +1,22 @@
|
|||||||
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
||||||
|
|
||||||
Blobs are stored in S3 via BlobStore. Record metadata is kept in an
|
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
|
||||||
in-memory dict until Step 12 migrates it to PostgreSQL (storage_records table).
|
PostgreSQL ``storage_records`` table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import StorageRecord
|
||||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||||
from app.storage.blob_store import BlobStore
|
from app.storage.blob_store import BlobStore
|
||||||
from app.storage.encryption import reject_if_tampered
|
from app.storage.encryption import reject_if_tampered
|
||||||
@@ -22,17 +25,6 @@ router = APIRouter(prefix="/storage", tags=["storage"])
|
|||||||
|
|
||||||
_blob_store = BlobStore()
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
|
|
||||||
_records: dict[str, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
|
||||||
_TIER_STORAGE_LIMITS_GB: dict[str, int] = {
|
|
||||||
"free": 0,
|
|
||||||
"pro": 5,
|
|
||||||
"power": 25,
|
|
||||||
"team": -1, # unlimited
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local response schemas ─────────────────────────────────────────────
|
# ── Local response schemas ─────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -51,25 +43,34 @@ class _RecordMeta(BaseModel):
|
|||||||
|
|
||||||
# ── Helpers ────────────────────────────────────────────────────────────
|
# ── Helpers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None:
|
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
|
"""Return total bytes stored by *user_id*."""
|
||||||
limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0)
|
result = await db.execute(
|
||||||
if limit_gb == -1:
|
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
|
||||||
return # unlimited
|
StorageRecord.user_id == user_id
|
||||||
limit_bytes = limit_gb * 1024**3
|
|
||||||
used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
|
||||||
if used + additional_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
|
||||||
"""Look up a record and verify ownership. Always returns 404 on mismatch
|
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
|
||||||
|
current = await _current_usage_bytes(user.id, db)
|
||||||
|
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_record_for_user(
|
||||||
|
record_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> StorageRecord:
|
||||||
|
"""Look up a record and verify ownership. Returns 404 on mismatch
|
||||||
to prevent user enumeration attacks."""
|
to prevent user enumeration attacks."""
|
||||||
record = _records.get(record_id)
|
result = await db.execute(
|
||||||
if record is None or record["user_id"] != user_id:
|
select(StorageRecord).where(
|
||||||
|
StorageRecord.id == record_id, StorageRecord.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
||||||
return record
|
return record
|
||||||
|
|
||||||
@@ -80,30 +81,32 @@ def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
|||||||
async def create_record(
|
async def create_record(
|
||||||
body: StorageRecordCreate,
|
body: StorageRecordCreate,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> _CreateResponse:
|
) -> _CreateResponse:
|
||||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||||
reject_if_tampered(body.blob, body.checksum)
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
_check_quota(current_user.id, current_user.tier, len(body.blob))
|
await _check_quota(current_user, len(body.blob), db)
|
||||||
|
|
||||||
record_id = str(uuid.uuid4())
|
record_id = str(uuid.uuid4())
|
||||||
now = int(time.time() * 1000)
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
s3_key = await _blob_store.upload(
|
||||||
current_user.id, body.table, record_id, body.blob, body.checksum
|
current_user.id, body.table, record_id, body.blob, body.checksum
|
||||||
)
|
)
|
||||||
|
|
||||||
_records[record_id] = {
|
record = StorageRecord(
|
||||||
"id": record_id,
|
id=record_id,
|
||||||
"user_id": current_user.id,
|
user_id=current_user.id,
|
||||||
"table": body.table,
|
table_name=body.table,
|
||||||
"s3_key": s3_key,
|
s3_key=s3_key,
|
||||||
"checksum": body.checksum,
|
checksum=body.checksum,
|
||||||
"size_bytes": len(body.blob),
|
size_bytes=len(body.blob),
|
||||||
"created_at": now,
|
)
|
||||||
"updated_at": now,
|
db.add(record)
|
||||||
}
|
await db.commit()
|
||||||
|
await db.refresh(record)
|
||||||
|
|
||||||
return _CreateResponse(id=record_id, created_at=now)
|
created_at_ms = int(record.created_at.timestamp() * 1000)
|
||||||
|
return _CreateResponse(id=record_id, created_at=created_at_ms)
|
||||||
|
|
||||||
|
|
||||||
@router.get("/records", response_model=list[_RecordMeta])
|
@router.get("/records", response_model=list[_RecordMeta])
|
||||||
@@ -112,23 +115,26 @@ async def list_records(
|
|||||||
page: int = Query(default=1, ge=1),
|
page: int = Query(default=1, ge=1),
|
||||||
limit: int = Query(default=50, ge=1, le=200),
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> list[_RecordMeta]:
|
) -> list[_RecordMeta]:
|
||||||
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
||||||
all_records = [
|
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
|
||||||
r for r in _records.values()
|
if table is not None:
|
||||||
if r["user_id"] == current_user.id and (table is None or r["table"] == table)
|
query = query.where(StorageRecord.table_name == table)
|
||||||
]
|
query = query.offset((page - 1) * limit).limit(limit)
|
||||||
start = (page - 1) * limit
|
|
||||||
page_records = all_records[start : start + limit]
|
result = await db.execute(query)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
|
||||||
return [
|
return [
|
||||||
_RecordMeta(
|
_RecordMeta(
|
||||||
id=r["id"],
|
id=r.id,
|
||||||
table=r["table"],
|
table=r.table_name,
|
||||||
checksum=r["checksum"],
|
checksum=r.checksum,
|
||||||
created_at=r["created_at"],
|
created_at=int(r.created_at.timestamp() * 1000),
|
||||||
updated_at=r["updated_at"],
|
updated_at=int(r.updated_at.timestamp() * 1000),
|
||||||
)
|
)
|
||||||
for r in page_records
|
for r in rows
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@@ -136,14 +142,15 @@ async def list_records(
|
|||||||
async def download_record(
|
async def download_record(
|
||||||
record_id: str,
|
record_id: str,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> Response:
|
) -> Response:
|
||||||
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
||||||
record = _get_record_for_user(record_id, current_user.id)
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
blob = await _blob_store.download(current_user.id, record["s3_key"])
|
blob = await _blob_store.download(current_user.id, record.s3_key)
|
||||||
return Response(
|
return Response(
|
||||||
content=blob,
|
content=blob,
|
||||||
media_type="application/octet-stream",
|
media_type="application/octet-stream",
|
||||||
headers={"X-Checksum": record["checksum"]},
|
headers={"X-Checksum": record.checksum},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -152,23 +159,24 @@ async def update_record(
|
|||||||
record_id: str,
|
record_id: str,
|
||||||
body: StorageRecordUpdate,
|
body: StorageRecordUpdate,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
||||||
record = _get_record_for_user(record_id, current_user.id)
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
reject_if_tampered(body.blob, body.checksum)
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
|
||||||
delta = len(body.blob) - record["size_bytes"]
|
delta = len(body.blob) - record.size_bytes
|
||||||
if delta > 0:
|
if delta > 0:
|
||||||
_check_quota(current_user.id, current_user.tier, delta)
|
await _check_quota(current_user, delta, db)
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
s3_key = await _blob_store.upload(
|
||||||
current_user.id, record["table"], record_id, body.blob, body.checksum
|
current_user.id, record.table_name, record_id, body.blob, body.checksum
|
||||||
)
|
)
|
||||||
|
|
||||||
record["s3_key"] = s3_key
|
record.s3_key = s3_key
|
||||||
record["checksum"] = body.checksum
|
record.checksum = body.checksum
|
||||||
record["size_bytes"] = len(body.blob)
|
record.size_bytes = len(body.blob)
|
||||||
record["updated_at"] = int(time.time() * 1000)
|
await db.commit()
|
||||||
|
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
@@ -177,9 +185,11 @@ async def update_record(
|
|||||||
async def delete_record(
|
async def delete_record(
|
||||||
record_id: str,
|
record_id: str,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> dict[str, bool]:
|
) -> dict[str, bool]:
|
||||||
"""Delete a record and its S3 blob."""
|
"""Delete a record and its S3 blob."""
|
||||||
record = _get_record_for_user(record_id, current_user.id)
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
await _blob_store.delete(current_user.id, record["s3_key"])
|
await _blob_store.delete(current_user.id, record.s3_key)
|
||||||
del _records[record_id]
|
await db.delete(record)
|
||||||
|
await db.commit()
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Vectors routes: upsert, search, and delete cloud vector store entries."""
|
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Depends
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.llm import embed
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
UserProfile,
|
UserProfile,
|
||||||
VectorSearchRequest,
|
VectorSearchRequest,
|
||||||
@@ -24,6 +25,14 @@ class _VectorDeleteRequest(BaseModel):
|
|||||||
ids: list[str]
|
ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedResponse(BaseModel):
|
||||||
|
vector: list[float]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/vectors/upsert", response_model=dict)
|
@router.post("/vectors/upsert", response_model=dict)
|
||||||
async def upsert_vectors(
|
async def upsert_vectors(
|
||||||
body: VectorUpsertRequest,
|
body: VectorUpsertRequest,
|
||||||
@@ -54,3 +63,17 @@ async def delete_vectors(
|
|||||||
"""Delete vectors by ID, scoped to the authenticated user."""
|
"""Delete vectors by ID, scoped to the authenticated user."""
|
||||||
await _vector_store.delete(current_user.id, body.ids)
|
await _vector_store.delete(current_user.id, body.ids)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/embed", response_model=_EmbedResponse)
|
||||||
|
async def embed_text(
|
||||||
|
body: _EmbedRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _EmbedResponse:
|
||||||
|
"""Generate a 1536-dim embedding vector for the given text.
|
||||||
|
|
||||||
|
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||||
|
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
|
||||||
|
"""
|
||||||
|
vector = await embed(body.text)
|
||||||
|
return _EmbedResponse(vector=vector)
|
||||||
|
|||||||
@@ -0,0 +1,4 @@
|
|||||||
|
from app.billing.stripe_service import stripe_service
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
|
||||||
|
__all__ = ["stripe_service", "tier_manager"]
|
||||||
|
|||||||
256
app/billing/stripe_service.py
Normal file
256
app/billing/stripe_service.py
Normal file
@@ -0,0 +1,256 @@
|
|||||||
|
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||||
|
|
||||||
|
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||||
|
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||||
|
configured, enabling local development without live credentials.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||||
|
TIER_PRICE_IDS: dict[str, str] = {
|
||||||
|
"pro": "price_pro_monthly",
|
||||||
|
"power": "price_power_monthly",
|
||||||
|
"team": "price_team_monthly",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StripeService:
|
||||||
|
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||||
|
|
||||||
|
# ── Internal helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _configured(self) -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Public API ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_checkout_session(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
tier: str,
|
||||||
|
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
|
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||||
|
) -> str:
|
||||||
|
"""Create a Stripe checkout session and return the URL.
|
||||||
|
|
||||||
|
Returns a stub URL when Stripe is not configured.
|
||||||
|
Raises ``HTTP 400`` for the free tier or an unknown tier.
|
||||||
|
"""
|
||||||
|
if tier == "free":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Cannot create a checkout session for the free tier",
|
||||||
|
)
|
||||||
|
|
||||||
|
price_id = TIER_PRICE_IDS.get(tier)
|
||||||
|
if not price_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unknown tier: {tier}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._configured():
|
||||||
|
return "https://stripe.com/stub-checkout"
|
||||||
|
|
||||||
|
s = self._client()
|
||||||
|
session = s.checkout.Session.create(
|
||||||
|
payment_method_types=["card"],
|
||||||
|
mode="subscription",
|
||||||
|
line_items=[{"price": price_id, "quantity": 1}],
|
||||||
|
success_url=success_url,
|
||||||
|
cancel_url=cancel_url,
|
||||||
|
metadata={"user_id": user_id, "tier": tier},
|
||||||
|
)
|
||||||
|
return session.url
|
||||||
|
|
||||||
|
async def handle_webhook(
|
||||||
|
self,
|
||||||
|
payload: bytes,
|
||||||
|
sig_header: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Process a Stripe webhook event.
|
||||||
|
|
||||||
|
Verifies the signature, then dispatches on event type.
|
||||||
|
Raises ``HTTP 400`` on signature mismatch.
|
||||||
|
No-ops when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
if not self._configured():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = self._client()
|
||||||
|
event = s.Webhook.construct_event(
|
||||||
|
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||||
|
)
|
||||||
|
except stripe_lib.error.SignatureVerificationError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid Stripe signature",
|
||||||
|
)
|
||||||
|
|
||||||
|
event_type: str = event["type"]
|
||||||
|
data: dict[str, Any] = event["data"]["object"]
|
||||||
|
|
||||||
|
if event_type == "checkout.session.completed":
|
||||||
|
user_id = data.get("metadata", {}).get("user_id")
|
||||||
|
tier = data.get("metadata", {}).get("tier", "free")
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
await self._upsert_subscription(
|
||||||
|
db, user_id, sub_id, tier, "active", period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.updated":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
new_status = data.get("status", "active")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status=new_status, current_period_end=period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.deleted":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, tier="free", status="canceled"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "invoice.payment_failed":
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status="past_due"
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def get_subscription(
|
||||||
|
self, user_id: str, db: AsyncSession
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"tier": sub.tier,
|
||||||
|
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||||
|
"status": sub.status,
|
||||||
|
"current_period_end": (
|
||||||
|
int(sub.current_period_end.timestamp() * 1000)
|
||||||
|
if sub.current_period_end
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||||
|
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||||
|
|
||||||
|
Raises ``HTTP 404`` when no active subscription exists.
|
||||||
|
"""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None or not sub.stripe_subscription_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="No active subscription found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._configured():
|
||||||
|
s = self._client()
|
||||||
|
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||||
|
|
||||||
|
sub.tier = "free"
|
||||||
|
sub.status = "canceled"
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _upsert_subscription(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
stripe_subscription_id: str | None,
|
||||||
|
tier: str,
|
||||||
|
sub_status: str,
|
||||||
|
current_period_end: datetime | None,
|
||||||
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
sub = Subscription(user_id=user_id)
|
||||||
|
db.add(sub)
|
||||||
|
sub.stripe_subscription_id = stripe_subscription_id
|
||||||
|
sub.tier = tier
|
||||||
|
sub.status = sub_status
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
async def _update_subscription_by_stripe_id(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
stripe_subscription_id: str,
|
||||||
|
*,
|
||||||
|
tier: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
current_period_end: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(
|
||||||
|
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
if tier is not None:
|
||||||
|
sub.tier = tier
|
||||||
|
if status is not None:
|
||||||
|
sub.status = status
|
||||||
|
if current_period_end is not None:
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton shared across the app.
|
||||||
|
stripe_service = StripeService()
|
||||||
189
app/billing/tier_manager.py
Normal file
189
app/billing/tier_manager.py
Normal file
@@ -0,0 +1,189 @@
|
|||||||
|
"""Tier manager: feature matrix and quota enforcement.
|
||||||
|
|
||||||
|
``TierManager`` is the single source of truth for what each billing tier
|
||||||
|
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||||
|
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||||
|
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.schemas import BillingTier
|
||||||
|
|
||||||
|
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||||
|
FEATURES: dict[str, dict[str, Any]] = {
|
||||||
|
"free": {
|
||||||
|
"agents": 3,
|
||||||
|
"batch_active": 2,
|
||||||
|
"cloud_storage_gb": 0,
|
||||||
|
"backup_gb": 0,
|
||||||
|
"providers": 1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"pro": {
|
||||||
|
"agents": -1, # unlimited
|
||||||
|
"batch_active": 10,
|
||||||
|
"cloud_storage_gb": 5,
|
||||||
|
"backup_gb": 5,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"power": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1, # unlimited
|
||||||
|
"cloud_storage_gb": 25,
|
||||||
|
"backup_gb": 25,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"team": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1,
|
||||||
|
"cloud_storage_gb": -1, # unlimited
|
||||||
|
"backup_gb": -1, # unlimited
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Requests-per-minute limit per tier.
|
||||||
|
RATE_LIMITS: dict[str, int] = {
|
||||||
|
"free": 20,
|
||||||
|
"pro": 60,
|
||||||
|
"power": 120,
|
||||||
|
"team": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TierManager:
|
||||||
|
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||||
|
|
||||||
|
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
|
Falls back to ``'free'`` when no subscription row exists.
|
||||||
|
"""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
tier: str | None = result.scalar_one_or_none()
|
||||||
|
if tier is None or tier not in FEATURES:
|
||||||
|
return "free"
|
||||||
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||||
|
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||||
|
|
||||||
|
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||||
|
"""
|
||||||
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
return value != 0
|
||||||
|
|
||||||
|
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||||
|
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||||
|
if not self.check_feature(tier, feature):
|
||||||
|
detail = (
|
||||||
|
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||||
|
if tier_name
|
||||||
|
else f"Feature '{feature}' is not available on your current tier."
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
# ── Rate limiting ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||||
|
"""Return the requests-per-minute limit for ``tier``."""
|
||||||
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
|
# ── Storage quota ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def enforce_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||||
|
|
||||||
|
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||||
|
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||||
|
"""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def enforce_backup_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||||
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
return False
|
||||||
|
if limit_gb == -1:
|
||||||
|
return True
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
return current_bytes + additional_bytes <= limit_bytes
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton shared across the app.
|
||||||
|
tier_manager = TierManager()
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
from typing import Literal
|
from typing import Literal
|
||||||
from pydantic_settings import BaseSettings
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
@@ -14,6 +14,7 @@ class Settings(BaseSettings):
|
|||||||
|
|
||||||
S3_BUCKET: str = ""
|
S3_BUCKET: str = ""
|
||||||
S3_REGION: str = "us-east-1"
|
S3_REGION: str = "us-east-1"
|
||||||
|
S3_ENDPOINT_URL: str = ""
|
||||||
AWS_ACCESS_KEY_ID: str = ""
|
AWS_ACCESS_KEY_ID: str = ""
|
||||||
AWS_SECRET_ACCESS_KEY: str = ""
|
AWS_SECRET_ACCESS_KEY: str = ""
|
||||||
|
|
||||||
@@ -23,14 +24,37 @@ class Settings(BaseSettings):
|
|||||||
QDRANT_API_KEY: str = ""
|
QDRANT_API_KEY: str = ""
|
||||||
|
|
||||||
OPENAI_API_KEY: str = ""
|
OPENAI_API_KEY: str = ""
|
||||||
|
ANTHROPIC_API_KEY: str = ""
|
||||||
|
GOOGLE_API_KEY: str = ""
|
||||||
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
|
||||||
|
LLM_MODEL: str = "gpt-4o"
|
||||||
|
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
||||||
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
# GitHub Copilot OAuth token storage directory.
|
||||||
|
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||||
|
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||||
|
GITHUB_COPILOT_TOKEN_DIR: str = ""
|
||||||
|
|
||||||
|
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
|
||||||
|
GMAIL_CLIENT_ID: str = ""
|
||||||
|
GMAIL_CLIENT_SECRET: str = ""
|
||||||
|
MS_CLIENT_ID: str = ""
|
||||||
|
MS_CLIENT_SECRET: str = ""
|
||||||
|
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||||
|
MS_TENANT_ID: str = "common"
|
||||||
|
|
||||||
|
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||||
|
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||||
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
class Config:
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
env_file = ".env"
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -1,137 +0,0 @@
|
|||||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
|
||||||
"""Common base for all agents."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
user_id: str = "",
|
|
||||||
shared_memory: dict[str, Any] | None = None,
|
|
||||||
vector_store_context: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.user_id = user_id
|
|
||||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
|
||||||
self.vector_store_context: list[str] = vector_store_context or []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_name(self) -> str: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_description(self) -> str: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def skills(self) -> list[str]:
|
|
||||||
"""Override in subclasses to advertise capabilities."""
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(BaseAgent):
|
|
||||||
"""Base class for LLM-powered chat agents."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
"""Process a user query and return a text response."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
"""Return LangChain tool definitions available to this agent."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def _tool_loop(
|
|
||||||
self,
|
|
||||||
llm: Any,
|
|
||||||
messages: list[Any],
|
|
||||||
tools: list[Any],
|
|
||||||
max_iter: int = 5,
|
|
||||||
) -> str:
|
|
||||||
"""Shared tool-calling loop.
|
|
||||||
|
|
||||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
|
||||||
requesting tool calls or *max_iter* is reached, and returns the
|
|
||||||
final text response.
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
||||||
|
|
||||||
for _ in range(max_iter):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
return str(response.content)
|
|
||||||
|
|
||||||
# Execute each requested tool call
|
|
||||||
tool_map = {t.name: t for t in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
|
||||||
if tool_fn is None:
|
|
||||||
result = f"Unknown tool: {call['name']}"
|
|
||||||
else:
|
|
||||||
result = await tool_fn.ainvoke(call["args"])
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exhausted iterations — ask model for a final answer without tools
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
return str(response.content)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRegistry:
|
|
||||||
"""Singleton registry for ChatAgent subclasses."""
|
|
||||||
|
|
||||||
_instance: AgentRegistry | None = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._agents: dict[str, type[ChatAgent]] = {}
|
|
||||||
|
|
||||||
def __new__(cls) -> AgentRegistry:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
cls._instance._agents = {}
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
# ── public API ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
|
||||||
"""Class decorator — registers an agent by its name."""
|
|
||||||
instance = agent_class()
|
|
||||||
name = instance.get_name()
|
|
||||||
self._agents[name] = agent_class
|
|
||||||
return agent_class
|
|
||||||
|
|
||||||
def get(self, name: str) -> ChatAgent:
|
|
||||||
"""Return a fresh instance of the named agent."""
|
|
||||||
cls = self._agents.get(name)
|
|
||||||
if cls is None:
|
|
||||||
raise KeyError(f"Agent not found: {name}")
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
def list_agents(self) -> list[dict[str, str]]:
|
|
||||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
|
||||||
result: list[dict[str, str]] = []
|
|
||||||
for cls in self._agents.values():
|
|
||||||
inst = cls()
|
|
||||||
result.append(
|
|
||||||
{"name": inst.get_name(), "description": inst.get_description()}
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def call_agent(
|
|
||||||
self, name: str, query: str, context: dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Instantiate the named agent and call its ``handle`` method."""
|
|
||||||
agent = self.get(name)
|
|
||||||
return await agent.handle(query, context)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
registry = AgentRegistry()
|
|
||||||
718
app/core/agent_runner.py
Normal file
718
app/core/agent_runner.py
Normal file
@@ -0,0 +1,718 @@
|
|||||||
|
"""Agent run manager.
|
||||||
|
|
||||||
|
Drives two agent types:
|
||||||
|
|
||||||
|
* **Local directory agent** — sends an ``agent_run`` frame to the connected
|
||||||
|
Electron device, waits for the device to stream back file contents via
|
||||||
|
``agent_data`` frames, then calls the LLM to extract structured items from
|
||||||
|
each file and pushes inserts to Electron via tool-call round-trips.
|
||||||
|
|
||||||
|
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
||||||
|
Teams, Outlook) and pushes extracted items to Electron. **This path is
|
||||||
|
a stub** — provider integrations are implemented in Step 3.6.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
Background tasks are spawned with ``asyncio.create_task()``::
|
||||||
|
|
||||||
|
asyncio.create_task(run_local_agent(user_id, config, run_log, device_manager))
|
||||||
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||||
|
|
||||||
|
The ``trigger_pending_runs`` function is called by the device WS endpoint
|
||||||
|
when Electron sends ``device_hello``, so any overdue runs fire immediately
|
||||||
|
when the device reconnects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from croniter import croniter
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.db import async_session
|
||||||
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Max seconds to wait for Electron to finish streaming file data.
|
||||||
|
_FILE_READ_TIMEOUT: int = 120
|
||||||
|
# Max seconds to wait for Electron to acknowledge a single tool-call insert.
|
||||||
|
_INSERT_TIMEOUT: int = 30
|
||||||
|
|
||||||
|
# ── Allowed tables & extraction schema hints ───────────────────────────────
|
||||||
|
|
||||||
|
_ALLOWED_TABLES: frozenset[str] = frozenset(
|
||||||
|
{"tasks", "notes", "timelines", "projects", "taskComments"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Field descriptions fed to the extraction LLM as concise schema references.
|
||||||
|
_TABLE_SCHEMAS: dict[str, str] = {
|
||||||
|
"tasks": (
|
||||||
|
"title (str, required), description (str), "
|
||||||
|
"status (todo|in_progress|done, default todo), "
|
||||||
|
"priority (high|medium|low, default medium), "
|
||||||
|
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
|
||||||
|
),
|
||||||
|
"notes": "title (str, required), content (str, markdown), projectId (str)",
|
||||||
|
"timelines": (
|
||||||
|
"title (str, required), projectId (str, required), date (ms timestamp int)"
|
||||||
|
),
|
||||||
|
"projects": "name (str, required), clientId (str)",
|
||||||
|
"taskComments": "taskId (str, required), author (str), content (str, required)",
|
||||||
|
}
|
||||||
|
|
||||||
|
_EXTRACTION_SYSTEM_PROMPT = """\
|
||||||
|
You are a data extraction assistant for a freelance project management tool.
|
||||||
|
Given a document, extract structured records matching the user's instructions.
|
||||||
|
|
||||||
|
Output a JSON array (no markdown fences, no explanation) of objects shaped:
|
||||||
|
[{{"table": "<table_name>", "data": {{...fields}}}}, ...]
|
||||||
|
|
||||||
|
Allowed table names and their fields:
|
||||||
|
{table_schemas}
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Only extract tables listed in the "data_types" instructions.
|
||||||
|
- Use camelCase field names exactly as shown above.
|
||||||
|
- Omit optional fields you cannot determine; do not invent data.
|
||||||
|
- Never include id, createdAt, updatedAt, isAiSuggested, or isApproved.
|
||||||
|
- If nothing relevant is found, return an empty JSON array: []
|
||||||
|
- Return ONLY the JSON array.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cron helper ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
|
||||||
|
"""Return ``True`` if the next scheduled run time has already passed.
|
||||||
|
|
||||||
|
Always validates the cron expression first — an invalid expression returns
|
||||||
|
``False`` (fail-safe: never trigger an unparseable schedule).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if last_run_at is None:
|
||||||
|
# Validate the expression before deciding this is overdue.
|
||||||
|
croniter(schedule_cron, now)
|
||||||
|
return True
|
||||||
|
ts = last_run_at
|
||||||
|
if ts.tzinfo is None:
|
||||||
|
ts = ts.replace(tzinfo=timezone.utc)
|
||||||
|
cron = croniter(schedule_cron, ts)
|
||||||
|
next_run: datetime = cron.get_next(datetime)
|
||||||
|
return now >= next_run
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc)
|
||||||
|
return False # Fail-safe: don't trigger if expression is invalid.
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM extraction ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_items_from_content(
|
||||||
|
prompt_template: str,
|
||||||
|
file_content: str,
|
||||||
|
data_types: list[str],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Call the LLM to extract structured records from *file_content*.
|
||||||
|
|
||||||
|
Returns a validated list of ``{table: str, data: dict}`` objects.
|
||||||
|
Items referencing tables not in *data_types* are discarded.
|
||||||
|
"""
|
||||||
|
allowed = [t for t in data_types if t in _ALLOWED_TABLES]
|
||||||
|
if not allowed:
|
||||||
|
return []
|
||||||
|
|
||||||
|
schema_text = "\n".join(
|
||||||
|
f" {table}: {_TABLE_SCHEMAS.get(table, '(unknown)')}" for table in allowed
|
||||||
|
)
|
||||||
|
system_prompt = _EXTRACTION_SYSTEM_PROMPT.format(table_schemas=schema_text)
|
||||||
|
user_prompt = (
|
||||||
|
f"User instructions: {prompt_template}\n\n"
|
||||||
|
f"Extract these record types: {', '.join(allowed)}\n\n"
|
||||||
|
f"Document:\n{file_content[:8000]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = get_llm()
|
||||||
|
raw = ""
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
|
||||||
|
)
|
||||||
|
raw = str(response.content).strip()
|
||||||
|
items: list[dict] = json.loads(raw)
|
||||||
|
if not isinstance(items, list):
|
||||||
|
raise ValueError("LLM response is not a JSON array")
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"agent_runner: LLM extraction returned invalid JSON: %s — snippet: %.200r",
|
||||||
|
exc,
|
||||||
|
raw,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
# Other exceptions (LLM API errors, network errors) propagate to the
|
||||||
|
# caller (run_local_agent) which records them per-file in the run log.
|
||||||
|
|
||||||
|
validated: list[dict[str, Any]] = []
|
||||||
|
for item in items:
|
||||||
|
table = item.get("table")
|
||||||
|
data = item.get("data")
|
||||||
|
if not isinstance(table, str) or table not in allowed:
|
||||||
|
continue
|
||||||
|
if not isinstance(data, dict) or not data:
|
||||||
|
continue
|
||||||
|
# Strip any server-generated or forbidden fields.
|
||||||
|
for _field in ("id", "createdAt", "updatedAt", "isAiSuggested", "isApproved"):
|
||||||
|
data.pop(_field, None)
|
||||||
|
validated.append({"table": table, "data": data})
|
||||||
|
return validated
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool-call insert helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_insert_to_client(
|
||||||
|
user_id: str,
|
||||||
|
table: str,
|
||||||
|
data: dict[str, Any],
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send an ``insert`` tool_call frame to Electron and await the tool_result.
|
||||||
|
|
||||||
|
All inserts include ``isAiSuggested=1, isApproved=0`` so the user can
|
||||||
|
review AI-produced records before they are treated as confirmed.
|
||||||
|
|
||||||
|
Raises ``asyncio.TimeoutError`` if Electron does not respond within
|
||||||
|
``_INSERT_TIMEOUT`` seconds. Raises ``RuntimeError`` if the device
|
||||||
|
disconnects before the frame can be sent.
|
||||||
|
"""
|
||||||
|
call_id = str(uuid.uuid4())
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": call_id,
|
||||||
|
"action": "insert",
|
||||||
|
"table": table,
|
||||||
|
"data": {**data, "isAiSuggested": 1, "isApproved": 0},
|
||||||
|
}
|
||||||
|
fut = device_mgr.create_pending_call(user_id, call_id)
|
||||||
|
await device_mgr.send_frame(user_id, payload)
|
||||||
|
return await asyncio.wait_for(fut, timeout=_INSERT_TIMEOUT)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent runner ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_local_agent(
|
||||||
|
user_id: str,
|
||||||
|
config: LocalAgentConfig,
|
||||||
|
run_log: AgentRunLog,
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a local directory agent run end-to-end.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
|
||||||
|
1. Verify the device identified by ``config.device_id`` is currently online.
|
||||||
|
2. Pre-create the agent_data queue so no incoming frames are lost.
|
||||||
|
3. Send ``agent_run`` frame to Electron (paths, extensions, prompt, data_types).
|
||||||
|
4. Consume ``agent_data`` frames until the ``None`` sentinel from
|
||||||
|
``agent_complete``.
|
||||||
|
5. For each received file call the LLM to extract ``{table, data}`` items.
|
||||||
|
6. Push each item to Electron as an ``insert`` tool-call; include
|
||||||
|
``isAiSuggested=1, isApproved=0`` so users can review AI suggestions.
|
||||||
|
7. Persist the run outcome (status, counts, errors) and update
|
||||||
|
``config.last_run_at``.
|
||||||
|
"""
|
||||||
|
run_id = run_log.id
|
||||||
|
|
||||||
|
# ── 1. Device online check ─────────────────────────────────────────
|
||||||
|
if not device_mgr.is_online(user_id, config.device_id):
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: skip run=%s — device %r offline for user=%s",
|
||||||
|
run_id,
|
||||||
|
config.device_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Device {config.device_id!r} is not connected"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 2. Pre-create agent_data queue ────────────────────────────────
|
||||||
|
try:
|
||||||
|
device_mgr.get_agent_data_queue(user_id, run_id)
|
||||||
|
except RuntimeError:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=["Device disconnected before agent run could start"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Send agent_run frame ────────────────────────────────────────
|
||||||
|
frame: dict[str, Any] = {
|
||||||
|
"type": "agent_run",
|
||||||
|
"run_id": run_id,
|
||||||
|
"agent_id": config.id,
|
||||||
|
"config": {
|
||||||
|
"paths": config.directory_paths,
|
||||||
|
"file_extensions": config.file_extensions,
|
||||||
|
"prompt_template": config.prompt_template,
|
||||||
|
"data_types": config.data_types,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await device_mgr.send_frame(user_id, frame)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to send agent_run frame: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: sent agent_run run=%s agent=%s user=%s",
|
||||||
|
run_id,
|
||||||
|
config.id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 4. Consume agent_data frames ──────────────────────────────────
|
||||||
|
files: list[dict[str, Any]] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
queue = device_mgr.get_agent_data_queue(user_id, run_id)
|
||||||
|
deadline = asyncio.get_event_loop().time() + _FILE_READ_TIMEOUT
|
||||||
|
while True:
|
||||||
|
remaining = deadline - asyncio.get_event_loop().time()
|
||||||
|
if remaining <= 0:
|
||||||
|
errors.append("Timed out waiting for file data from device")
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
frame_data = await asyncio.wait_for(queue.get(), timeout=remaining)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append("Timed out waiting for file data from device")
|
||||||
|
break
|
||||||
|
if frame_data is None:
|
||||||
|
# Sentinel from agent_complete — stream is done.
|
||||||
|
break
|
||||||
|
files.extend(frame_data.get("files", []))
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Queue error reading agent data: {exc}")
|
||||||
|
|
||||||
|
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
for file_info in files:
|
||||||
|
file_path: str = file_info.get("path", "<unknown>")
|
||||||
|
content: str = file_info.get("content", "")
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
try:
|
||||||
|
extracted = await _extract_items_from_content(
|
||||||
|
config.prompt_template, content, config.data_types
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM extraction error for {file_path!r}: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for item in extracted:
|
||||||
|
try:
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
user_id, item["table"], item["data"], device_mgr
|
||||||
|
)
|
||||||
|
if result.get("error"):
|
||||||
|
errors.append(
|
||||||
|
f"Insert failed ({item['table']}, {file_path!r}): {result['error']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
items_created += 1
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append(
|
||||||
|
f"Timed out awaiting insert ack ({item['table']}, {file_path!r})"
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Insert error ({item['table']}, {file_path!r}): {exc}")
|
||||||
|
|
||||||
|
# ── 7. Finalise ────────────────────────────────────────────────────
|
||||||
|
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
||||||
|
|
||||||
|
if errors and items_created == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=items_created,
|
||||||
|
errors=errors,
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="local",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
|
run_id,
|
||||||
|
final_status,
|
||||||
|
items_processed,
|
||||||
|
items_created,
|
||||||
|
len(errors),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Default lookback window when an agent has never run before.
|
||||||
|
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||||||
|
|
||||||
|
|
||||||
|
async def run_cloud_agent(
|
||||||
|
user_id: str,
|
||||||
|
config: CloudAgentConfig,
|
||||||
|
run_log: AgentRunLog,
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a cloud connector agent run end-to-end.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
|
||||||
|
1. Verify the user's device is online — results are pushed to Electron
|
||||||
|
via WS tool-call frames. If no device is connected, abort.
|
||||||
|
2. Decrypt the stored OAuth token from ``config.oauth_token_encrypted``.
|
||||||
|
3. Instantiate the provider client (Gmail or MS Graph).
|
||||||
|
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
|
||||||
|
the first run) applying ``config.filter_config`` filters.
|
||||||
|
5. For each message/email call ``_extract_items_from_content`` with
|
||||||
|
``config.prompt_template`` to get structured ``{table, data}`` items.
|
||||||
|
6. Push each item to Electron as an ``insert`` tool-call.
|
||||||
|
7. If the provider refreshed its access token, re-encrypt and write it
|
||||||
|
back to ``config.oauth_token_encrypted``.
|
||||||
|
8. Persist the run outcome via ``_finalize_run``.
|
||||||
|
"""
|
||||||
|
run_id = run_log.id
|
||||||
|
|
||||||
|
# ── 1. Device online check ─────────────────────────────────────────
|
||||||
|
if not device_mgr.is_online(user_id):
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: skip cloud run=%s — no device online for user=%s",
|
||||||
|
run_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=["No connected device — cloud agent results cannot be delivered"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 2. Decrypt OAuth token ─────────────────────────────────────────
|
||||||
|
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||||||
|
|
||||||
|
if not config.oauth_token_encrypted:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Instantiate provider client ────────────────────────────────
|
||||||
|
try:
|
||||||
|
provider = get_provider(config.provider, credentials_info)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[str(exc)],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 4. Fetch messages ─────────────────────────────────────────────
|
||||||
|
since: datetime | None = config.last_run_at
|
||||||
|
if since is None:
|
||||||
|
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||||||
|
if since.tzinfo is None:
|
||||||
|
since = since.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.provider == "gmail":
|
||||||
|
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "outlook":
|
||||||
|
raw_messages = await provider.fetch_emails( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "teams":
|
||||||
|
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_messages = []
|
||||||
|
except RuntimeError as exc:
|
||||||
|
logger.error(
|
||||||
|
"agent_runner: provider fetch failed for cloud agent %s: %s",
|
||||||
|
config.id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Provider fetch failed: {exc}"],
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
|
||||||
|
config.id,
|
||||||
|
len(raw_messages),
|
||||||
|
config.provider,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
||||||
|
for msg in raw_messages:
|
||||||
|
content_text = msg.as_text
|
||||||
|
if not content_text:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
try:
|
||||||
|
extracted = await _extract_items_from_content(
|
||||||
|
config.prompt_template, content_text, config.data_types
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM extraction error for message {msg.id!r}: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for item in extracted:
|
||||||
|
try:
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
user_id, item["table"], item["data"], device_mgr
|
||||||
|
)
|
||||||
|
if result.get("error"):
|
||||||
|
errors.append(
|
||||||
|
f"Insert failed ({item['table']}, msg={msg.id!r}): {result['error']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
items_created += 1
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append(
|
||||||
|
f"Timed out awaiting insert ack ({item['table']}, msg={msg.id!r})"
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Insert error ({item['table']}, msg={msg.id!r}): {exc}")
|
||||||
|
|
||||||
|
# ── 7. Persist refreshed token (if any) ───────────────────────────
|
||||||
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
|
if refreshed:
|
||||||
|
try:
|
||||||
|
new_encrypted = encrypt_token(refreshed)
|
||||||
|
async with async_session() as db:
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||||||
|
)
|
||||||
|
cfg_row = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg_row:
|
||||||
|
cfg_row.oauth_token_encrypted = new_encrypted
|
||||||
|
await db.commit()
|
||||||
|
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to persist refreshed token for agent %s: %s", config.id, exc)
|
||||||
|
|
||||||
|
# ── 8. Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_created == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=items_created,
|
||||||
|
errors=errors,
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
|
run_id,
|
||||||
|
final_status,
|
||||||
|
items_processed,
|
||||||
|
items_created,
|
||||||
|
len(errors),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pending-run trigger ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def trigger_pending_runs(
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Dispatch any overdue agent runs after an Electron device connects.
|
||||||
|
|
||||||
|
Called as a background task from the device WS endpoint on ``device_hello``.
|
||||||
|
|
||||||
|
Scheduling rules:
|
||||||
|
|
||||||
|
* **Local agents**: only triggered when ``config.device_id == device_id``.
|
||||||
|
* **Cloud agents**: triggered on any connected device (no device binding).
|
||||||
|
* Runs execute **sequentially** to avoid flooding the WS connection.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
LocalAgentConfig.enabled == True, # noqa: E712
|
||||||
|
LocalAgentConfig.device_id == device_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
|
||||||
|
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
CloudAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
|
||||||
|
|
||||||
|
# Build ordered list of overdue (type, config) pairs.
|
||||||
|
pending: list[tuple[str, Any]] = []
|
||||||
|
for cfg in local_configs:
|
||||||
|
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||||
|
pending.append(("local", cfg))
|
||||||
|
for cfg in cloud_configs:
|
||||||
|
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||||
|
pending.append(("cloud", cfg))
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent_type, cfg in pending:
|
||||||
|
# Create a fresh run log for this scheduled dispatch.
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=cfg.id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
|
||||||
|
if agent_type == "local":
|
||||||
|
await run_local_agent(user_id, cfg, run_log, device_mgr)
|
||||||
|
else:
|
||||||
|
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _finalize_run(
|
||||||
|
run_log: AgentRunLog,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
items_processed: int = 0,
|
||||||
|
items_created: int = 0,
|
||||||
|
errors: list[str] | None = None,
|
||||||
|
update_config_last_run: bool = False,
|
||||||
|
config_id: str | None = None,
|
||||||
|
config_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Persist the run outcome and optionally update ``LocalAgentConfig.last_run_at``.
|
||||||
|
|
||||||
|
Uses a fresh DB session so this is safe to call from background tasks
|
||||||
|
after the original request session has closed.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
managed = await db.merge(run_log)
|
||||||
|
managed.status = status
|
||||||
|
managed.items_processed = items_processed
|
||||||
|
managed.items_created = items_created
|
||||||
|
managed.errors = errors or []
|
||||||
|
managed.completed_at = now
|
||||||
|
|
||||||
|
if update_config_last_run and config_id:
|
||||||
|
if config_type == "local":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
elif config_type == "cloud":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc
|
||||||
|
)
|
||||||
489
app/core/deep_agent.py
Normal file
489
app/core/deep_agent.py
Normal file
@@ -0,0 +1,489 @@
|
|||||||
|
"""Deep Agent — ``create_deep_agent`` supervisors for home and floating modes.
|
||||||
|
|
||||||
|
Two supervisor graphs (via ``deepagents.create_deep_agent``):
|
||||||
|
* **HomeSupervisor** — gathers data from multiple domains, presents
|
||||||
|
structured overview with entity/chart tags.
|
||||||
|
* **FloatingSupervisor** — focused, scoped assistant for a single entity/domain.
|
||||||
|
|
||||||
|
Each supervisor delegates to four sub-agents (task, project, note, timeline)
|
||||||
|
via the built-in ``task`` tool provided by ``SubAgentMiddleware``.
|
||||||
|
The sub-agents talk to Electron via ``execute_on_client``.
|
||||||
|
|
||||||
|
Built-in middleware provides: todo-list tracking, virtual filesystem,
|
||||||
|
automatic context summarisation, prompt-caching, and tool-call patching.
|
||||||
|
|
||||||
|
Streaming uses ``astream(stream_mode=["messages", "updates"])`` so that
|
||||||
|
callers can sniff:
|
||||||
|
* ``("messages", (token, metadata))`` — text tokens for streaming
|
||||||
|
* ``("updates", ...)`` — tool call results for mutations
|
||||||
|
|
||||||
|
An ``update_core_memory`` tool is available to both supervisors for
|
||||||
|
persisting user preferences mid-conversation (MemGPT-style).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
from deepagents import create_deep_agent
|
||||||
|
from langchain_core.messages import AIMessage, AIMessageChunk, HumanMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.core.ws_context import (
|
||||||
|
clear_tool_result_collector,
|
||||||
|
set_tool_result_collector,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Sub-agent tool imports ────────────────────────────────────────────
|
||||||
|
|
||||||
|
from app.agents.task_agent import ( # noqa: E402
|
||||||
|
add_task_comment,
|
||||||
|
create_task,
|
||||||
|
delete_task,
|
||||||
|
delete_task_comment,
|
||||||
|
list_task_comments,
|
||||||
|
list_tasks,
|
||||||
|
list_tasks_due_today,
|
||||||
|
update_task,
|
||||||
|
)
|
||||||
|
from app.agents.note_agent import ( # noqa: E402
|
||||||
|
create_note,
|
||||||
|
delete_note,
|
||||||
|
get_note,
|
||||||
|
list_notes,
|
||||||
|
update_note,
|
||||||
|
)
|
||||||
|
from app.agents.project_agent import ( # noqa: E402
|
||||||
|
create_project,
|
||||||
|
delete_project,
|
||||||
|
get_project,
|
||||||
|
list_all_projects,
|
||||||
|
list_projects,
|
||||||
|
update_project,
|
||||||
|
)
|
||||||
|
from app.agents.timeline_agent import ( # noqa: E402
|
||||||
|
create_timeline,
|
||||||
|
delete_timeline,
|
||||||
|
list_timelines,
|
||||||
|
update_timeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Sub-agent definitions ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
_TASK_TOOLS = [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
|
|
||||||
|
_NOTE_TOOLS = [list_notes, get_note, create_note, update_note, delete_note]
|
||||||
|
|
||||||
|
_PROJECT_TOOLS = [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
|
|
||||||
|
_TIMELINE_TOOLS = [list_timelines, create_timeline, update_timeline, delete_timeline]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_subagent_specs() -> list[dict[str, Any]]:
|
||||||
|
"""Return SubAgent dicts for the four workspace domains.
|
||||||
|
|
||||||
|
Each dict follows the ``deepagents`` ``SubAgent`` TypedDict:
|
||||||
|
name, description, system_prompt, tools, model
|
||||||
|
The model and middleware are filled in by ``create_deep_agent`` automatically.
|
||||||
|
"""
|
||||||
|
llm = get_llm()
|
||||||
|
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"name": "task_agent",
|
||||||
|
"description": (
|
||||||
|
"Manages tasks and comments: list, create, update, delete, "
|
||||||
|
"due-today, and comments. Use when the user asks about tasks, "
|
||||||
|
"to-dos, assignments, deadlines, or anything task-related."
|
||||||
|
),
|
||||||
|
"system_prompt": (
|
||||||
|
"You are a task management assistant. You create, update, list, "
|
||||||
|
"and track tasks and their comments.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: todo, in_progress, done\n"
|
||||||
|
" - priority must be one of: high, medium, low\n"
|
||||||
|
" - due_date is a Unix timestamp in milliseconds\n"
|
||||||
|
" - assignees is a JSON-encoded array of strings\n"
|
||||||
|
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
||||||
|
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Always confirm the action in plain, user-friendly language."
|
||||||
|
),
|
||||||
|
"tools": _TASK_TOOLS
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "note_agent",
|
||||||
|
"description": (
|
||||||
|
"Manages notes: list, get, create, update, delete. "
|
||||||
|
"Use when the user asks about notes, documents, or written content."
|
||||||
|
),
|
||||||
|
"system_prompt": (
|
||||||
|
"You are a note-taking assistant. You help users create, retrieve, "
|
||||||
|
"update, and delete Markdown notes in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - content is always Markdown; preserve formatting when updating\n"
|
||||||
|
" - When updating, call get_note first if you need to read existing "
|
||||||
|
"content before appending or replacing sections\n"
|
||||||
|
" - Do not fabricate note content."
|
||||||
|
),
|
||||||
|
"tools": _NOTE_TOOLS
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "project_agent",
|
||||||
|
"description": (
|
||||||
|
"Manages projects: list, get, create, update, archive, delete. "
|
||||||
|
"Use when the user asks about projects, workspaces, or project status."
|
||||||
|
),
|
||||||
|
"system_prompt": (
|
||||||
|
"You are a project management assistant. You help users create, "
|
||||||
|
"find, update, and archive projects.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: active, archived\n"
|
||||||
|
" - Prefer archiving over deletion\n"
|
||||||
|
" - ai_summary is populated only when the user asks for a summary."
|
||||||
|
),
|
||||||
|
"tools": _PROJECT_TOOLS
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "timeline_agent",
|
||||||
|
"description": (
|
||||||
|
"Manages project timelines and milestones: list, create, update, "
|
||||||
|
"delete. Use when the user asks about timelines, milestones, "
|
||||||
|
"deadlines, or project scheduling."
|
||||||
|
),
|
||||||
|
"system_prompt": (
|
||||||
|
"You are a project timeline assistant. Timelines are milestone "
|
||||||
|
"dates that track progress on a project.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - project_id is REQUIRED for every create\n"
|
||||||
|
" - date is a Unix timestamp in milliseconds\n"
|
||||||
|
" - For update_timeline, use -1 for integer fields you do not "
|
||||||
|
"want to change."
|
||||||
|
),
|
||||||
|
"tools": _TIMELINE_TOOLS
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Update core memory tool ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_update_core_memory_tool(user_id: str, db_session_factory):
|
||||||
|
"""Create a tool that persists a key/value preference in core memory."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_core_memory(key: str, value: str) -> str:
|
||||||
|
"""Save a user preference or fact to long-term core memory.
|
||||||
|
key: short label for the memory (e.g. 'preferred_language', 'timezone')
|
||||||
|
value: the value to remember
|
||||||
|
Use this when the user states a preference or fact worth remembering.
|
||||||
|
"""
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
|
||||||
|
async with db_session_factory() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, key, value)
|
||||||
|
return f"Remembered: {key} = {value}"
|
||||||
|
|
||||||
|
return update_core_memory
|
||||||
|
|
||||||
|
|
||||||
|
# ── System prompts ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_HOME_SYSTEM = (
|
||||||
|
"You are Adiuva, a smart workspace assistant on the Home dashboard.\n"
|
||||||
|
"Your job is to help the user by gathering data from their workspace and "
|
||||||
|
"presenting a comprehensive overview.\n\n"
|
||||||
|
"You have sub-agents (task_agent, note_agent, project_agent, "
|
||||||
|
"timeline_agent) accessible via the `task` tool. Delegate to "
|
||||||
|
"the appropriate sub-agent(s) based on the user's request. You can call "
|
||||||
|
"multiple sub-agents in parallel if needed.\n\n"
|
||||||
|
"You also have an update_core_memory tool — use it when the user states "
|
||||||
|
"a preference or important fact worth remembering long-term.\n\n"
|
||||||
|
"IMPORTANT: You do NOT have direct access to workspace data. Always "
|
||||||
|
"delegate to your subagents using the task() tool. Do not attempt to "
|
||||||
|
"answer workspace queries yourself — the subagents have the tools to "
|
||||||
|
"fetch and modify data. You can call multiple subagents in parallel "
|
||||||
|
"when the request spans multiple domains.\n\n"
|
||||||
|
"## Entity References\n"
|
||||||
|
"When your response mentions specific workspace entities, embed them "
|
||||||
|
"inline using entity tags so the UI can render interactive components.\n"
|
||||||
|
"Format: <type>[comma-separated UUIDs]</type>\n"
|
||||||
|
"Supported types: task, project, note, timeline\n\n"
|
||||||
|
"Example response:\n"
|
||||||
|
" Here is your project:\n"
|
||||||
|
" <project>[abc-123-def]</project>\n"
|
||||||
|
" It has these pending tasks:\n"
|
||||||
|
" <task>[def-456,ghi-789]</task>\n\n"
|
||||||
|
"IMPORTANT: Only include IDs of entities that are directly relevant to "
|
||||||
|
"the user's question. Do NOT dump all entity IDs returned by a tool — "
|
||||||
|
"filter to only the ones the user asked about or that matter for the answer.\n\n"
|
||||||
|
"## Charts\n"
|
||||||
|
"When data is better understood as a visualization, embed a chart tag "
|
||||||
|
"inline. The frontend renders it using shadcn/ui Recharts components.\n"
|
||||||
|
"Format: <chart>{{JSON}}</chart>\n\n"
|
||||||
|
"JSON shape:\n"
|
||||||
|
' {{"chartType":"<type>","title":"...","data":[...],"config":{{...}}}}\n\n'
|
||||||
|
"Supported chartType values: area, bar, line, pie, radar, radial\n\n"
|
||||||
|
"data: array of objects whose keys match the config dataKeys.\n"
|
||||||
|
"config: {{ dataKey: {{ label, color }} }} — follows shadcn ChartConfig.\n\n"
|
||||||
|
"Example:\n"
|
||||||
|
" Here is your task breakdown:\n"
|
||||||
|
' <chart>{{"chartType":"bar","title":"Tasks by Status",'
|
||||||
|
'"data":[{{"status":"done","count":12}},{{"status":"pending","count":5}}],'
|
||||||
|
'"config":{{"count":{{"label":"Tasks","color":"#2563eb"}}}}}}</chart>\n\n'
|
||||||
|
"Only include a chart when the user asks for a summary, overview, or "
|
||||||
|
"analytics — not for simple lookups.\n\n"
|
||||||
|
"Memory context:\n{memory_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_SYSTEM = (
|
||||||
|
"You are Adiuva, a focused workspace assistant in the floating panel.\n"
|
||||||
|
"The user is currently working in the '{scope_type}' section"
|
||||||
|
"{scope_detail}.\n\n"
|
||||||
|
"You have sub-agents (task_agent, note_agent, project_agent, "
|
||||||
|
"timeline_agent) accessible via the `task` tool. Focus your "
|
||||||
|
"help on the user's current scope, but you can use other sub-agents "
|
||||||
|
"if the request requires it.\n\n"
|
||||||
|
"You also have an update_core_memory tool — use it when the user states "
|
||||||
|
"a preference or important fact worth remembering long-term.\n\n"
|
||||||
|
"IMPORTANT: You do NOT have direct access to workspace data. Always "
|
||||||
|
"delegate to your subagents using the task() tool. Do not attempt to "
|
||||||
|
"answer workspace queries yourself — the subagents have the tools to "
|
||||||
|
"fetch and modify data.\n\n"
|
||||||
|
"Provide direct, conversational responses.\n\n"
|
||||||
|
"Memory context:\n{memory_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_memory_context(memory: dict[str, Any]) -> str:
|
||||||
|
"""Format the memory dict into a readable string for the system prompt."""
|
||||||
|
if not memory:
|
||||||
|
return "(no memory available)"
|
||||||
|
parts = []
|
||||||
|
if memory.get("core_memory"):
|
||||||
|
parts.append("Preferences: " + json.dumps(memory["core_memory"]))
|
||||||
|
if memory.get("associative_memory"):
|
||||||
|
parts.append("Related memories: " + "; ".join(memory["associative_memory"][:3]))
|
||||||
|
if memory.get("episodic_memory"):
|
||||||
|
parts.append("Recent sessions: " + "; ".join(memory["episodic_memory"][:3]))
|
||||||
|
if memory.get("proactive_hints"):
|
||||||
|
parts.append("Patterns: " + "; ".join(memory["proactive_hints"][:3]))
|
||||||
|
return "\n".join(parts) if parts else "(no memory available)"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Graph builders ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def build_home_graph(
|
||||||
|
user_id: str,
|
||||||
|
memory_context: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
):
|
||||||
|
"""Build the Home supervisor graph."""
|
||||||
|
subagent_specs = _make_subagent_specs()
|
||||||
|
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
|
||||||
|
|
||||||
|
prompt = _HOME_SYSTEM.format(
|
||||||
|
memory_context=_format_memory_context(memory_context),
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_deep_agent(
|
||||||
|
model=get_llm(),
|
||||||
|
tools=[memory_tool],
|
||||||
|
system_prompt=prompt,
|
||||||
|
subagents=subagent_specs,
|
||||||
|
name="home_supervisor",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_floating_graph(
|
||||||
|
user_id: str,
|
||||||
|
memory_context: dict[str, Any],
|
||||||
|
scope: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
):
|
||||||
|
"""Build the Floating supervisor graph."""
|
||||||
|
subagent_specs = _make_subagent_specs()
|
||||||
|
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
|
||||||
|
|
||||||
|
scope_type = scope.get("type", "general")
|
||||||
|
scope_id = scope.get("id")
|
||||||
|
scope_detail = f" (id: {scope_id})" if scope_id else ""
|
||||||
|
|
||||||
|
prompt = _FLOATING_SYSTEM.format(
|
||||||
|
scope_type=scope_type,
|
||||||
|
scope_detail=scope_detail,
|
||||||
|
memory_context=_format_memory_context(memory_context),
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_deep_agent(
|
||||||
|
model=get_llm(),
|
||||||
|
tools=[memory_tool],
|
||||||
|
system_prompt=prompt,
|
||||||
|
subagents=subagent_specs,
|
||||||
|
name="floating_supervisor",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stream event type ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Events yielded by run_*_stream:
|
||||||
|
# ("token", str) — text token for streaming
|
||||||
|
# ("tool_start", dict) — {"name": "task_agent", "args": {...}}
|
||||||
|
# ("tool_end", dict) — {"name": "task_agent", "result": "..."}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stream runners ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _run_graph_stream(
|
||||||
|
graph,
|
||||||
|
message: str,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Run a supervisor graph with streaming, yielding event tuples.
|
||||||
|
|
||||||
|
Uses ``stream_mode=["messages", "updates"]`` to get both token-level
|
||||||
|
streaming and update events for tool calls.
|
||||||
|
"""
|
||||||
|
inputs = {"messages": [HumanMessage(content=message)]}
|
||||||
|
|
||||||
|
collector: list[dict] = []
|
||||||
|
set_tool_result_collector(collector)
|
||||||
|
try:
|
||||||
|
async for stream_mode, chunk in graph.astream(
|
||||||
|
inputs,
|
||||||
|
stream_mode=["messages", "updates"],
|
||||||
|
):
|
||||||
|
if stream_mode == "messages":
|
||||||
|
msg, metadata = chunk
|
||||||
|
agent_name = (
|
||||||
|
metadata.get("lc_agent_name", "?")
|
||||||
|
if isinstance(metadata, dict) else "?"
|
||||||
|
)
|
||||||
|
node = (
|
||||||
|
metadata.get("langgraph_node", "?")
|
||||||
|
if isinstance(metadata, dict) else "?"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Log every message event with agent attribution
|
||||||
|
if isinstance(msg, (AIMessage, AIMessageChunk)) and msg.content:
|
||||||
|
logger.info(
|
||||||
|
"[%s] %s node=%s content=%s",
|
||||||
|
agent_name,
|
||||||
|
type(msg).__name__,
|
||||||
|
node,
|
||||||
|
str(msg.content),
|
||||||
|
)
|
||||||
|
elif isinstance(msg, (AIMessage, AIMessageChunk)) and msg.tool_calls:
|
||||||
|
tool_names = [tc["name"] for tc in msg.tool_calls]
|
||||||
|
logger.info(
|
||||||
|
"[%s] %s node=%s tool_calls=%s",
|
||||||
|
agent_name,
|
||||||
|
type(msg).__name__,
|
||||||
|
node,
|
||||||
|
tool_names,
|
||||||
|
)
|
||||||
|
elif hasattr(msg, "name") and hasattr(msg, "content") and msg.content:
|
||||||
|
# ToolMessage — log tool result
|
||||||
|
logger.info(
|
||||||
|
"[%s] ToolMessage tool=%s node=%s result=%s",
|
||||||
|
agent_name,
|
||||||
|
getattr(msg, "name", "?"),
|
||||||
|
node,
|
||||||
|
str(msg.content),
|
||||||
|
)
|
||||||
|
# Only yield tokens from the supervisor's final response
|
||||||
|
# (not from sub-agent internal LLM calls).
|
||||||
|
# Accept both AIMessageChunk (streamed tokens) and AIMessage
|
||||||
|
# (full response from non-streaming providers).
|
||||||
|
# create_deep_agent names the LLM node "model".
|
||||||
|
if (
|
||||||
|
isinstance(msg, (AIMessage, AIMessageChunk))
|
||||||
|
and msg.content
|
||||||
|
and not msg.tool_calls
|
||||||
|
and isinstance(metadata, dict)
|
||||||
|
and metadata.get("langgraph_node") == "model"
|
||||||
|
):
|
||||||
|
yield ("token", str(msg.content))
|
||||||
|
|
||||||
|
elif stream_mode == "updates":
|
||||||
|
# Updates is a dict of {node_name: state_update}
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
continue
|
||||||
|
for node_name, state_update in chunk.items():
|
||||||
|
if node_name != "tools":
|
||||||
|
continue
|
||||||
|
# Tool node executed — extract tool call results
|
||||||
|
tool_messages = state_update.get("messages", [])
|
||||||
|
for tool_msg in tool_messages:
|
||||||
|
if hasattr(tool_msg, "name") and hasattr(tool_msg, "content"):
|
||||||
|
yield (
|
||||||
|
"tool_end",
|
||||||
|
{"name": tool_msg.name, "result": str(tool_msg.content)},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
# Yield the collected mutations so callers can attach them to stream_end
|
||||||
|
yield ("mutations", collector)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Run the Home supervisor and yield streaming events."""
|
||||||
|
graph = build_home_graph(user_id, context, db_session_factory)
|
||||||
|
async for event in _run_graph_stream(graph, message):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
scope: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Run the Floating supervisor and yield streaming events."""
|
||||||
|
graph = build_floating_graph(user_id, context, scope, db_session_factory)
|
||||||
|
async for event in _run_graph_stream(graph, message):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
) -> str:
|
||||||
|
"""Run the Home supervisor (non-streaming) and return full response text."""
|
||||||
|
graph = build_home_graph(user_id, context, db_session_factory)
|
||||||
|
result = await graph.ainvoke(
|
||||||
|
{"messages": [HumanMessage(content=message)]}
|
||||||
|
)
|
||||||
|
messages = result["messages"]
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None):
|
||||||
|
return str(msg.content)
|
||||||
|
return ""
|
||||||
183
app/core/device_manager.py
Normal file
183
app/core/device_manager.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Device connection manager.
|
||||||
|
|
||||||
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
|
The manager participates in two interaction patterns:
|
||||||
|
|
||||||
|
1. **Tool-call round-trip** (bidirectional CRUD):
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
||||||
|
``tool_result`` frame.
|
||||||
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
|
receive the result dict from Electron.
|
||||||
|
|
||||||
|
2. **Agent-data streaming** (local directory agent runs):
|
||||||
|
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
||||||
|
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
||||||
|
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
||||||
|
a specific ``run_id`` so the agent runner can iterate frames.
|
||||||
|
|
||||||
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
|
device WS route and the agent runner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeviceConnection:
|
||||||
|
"""State for a single connected Electron device."""
|
||||||
|
|
||||||
|
ws: WebSocket
|
||||||
|
device_id: str
|
||||||
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
# Per-run queues for agent_data / agent_complete frames.
|
||||||
|
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceConnectionManager:
|
||||||
|
"""Singleton registry of active Electron WebSocket connections.
|
||||||
|
|
||||||
|
Thread/task safety note: asyncio is single-threaded by design. All
|
||||||
|
mutations happen inside await-points on the main event loop, so no
|
||||||
|
locking is required for the in-memory dicts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._connections: dict[str, DeviceConnection] = {}
|
||||||
|
|
||||||
|
# ── Registration ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
|
||||||
|
"""Store the active connection for *user_id*, replacing any previous one."""
|
||||||
|
if user_id in self._connections:
|
||||||
|
old = self._connections[user_id]
|
||||||
|
logger.info(
|
||||||
|
"device_manager: replacing existing connection for user=%s device=%s",
|
||||||
|
user_id,
|
||||||
|
old.device_id,
|
||||||
|
)
|
||||||
|
# Cancel any futures that were waiting on the old connection.
|
||||||
|
for fut in old.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
|
||||||
|
logger.info(
|
||||||
|
"device_manager: registered user=%s device=%s", user_id, device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def unregister(self, user_id: str) -> None:
|
||||||
|
"""Remove the connection for *user_id* and cancel any pending futures."""
|
||||||
|
conn = self._connections.pop(user_id, None)
|
||||||
|
if conn is None:
|
||||||
|
return
|
||||||
|
for fut in conn.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
logger.info("device_manager: unregistered user=%s", user_id)
|
||||||
|
|
||||||
|
# ── Presence queries ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_ws(self, user_id: str) -> WebSocket | None:
|
||||||
|
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
return conn.ws if conn else None
|
||||||
|
|
||||||
|
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
|
||||||
|
"""Return ``True`` if the user has an active connection.
|
||||||
|
|
||||||
|
If *device_id* is provided also checks that it matches the connected device.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
return False
|
||||||
|
if device_id is not None:
|
||||||
|
return conn.device_id == device_id
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ── Frame sending ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def send_frame(self, user_id: str, frame: dict) -> None:
|
||||||
|
"""Send *frame* as a JSON text message to the device.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if the user is not connected.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"send_frame: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
await conn.ws.send_text(json.dumps(frame))
|
||||||
|
|
||||||
|
# ── Tool-call round-trip ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_pending_call(
|
||||||
|
self, user_id: str, call_id: str
|
||||||
|
) -> asyncio.Future[dict]:
|
||||||
|
"""Register a Future that will be resolved when the tool_result arrives.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if the user is not connected.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"create_pending_call: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut: asyncio.Future[dict] = loop.create_future()
|
||||||
|
conn.pending_calls[call_id] = fut
|
||||||
|
return fut
|
||||||
|
|
||||||
|
def resolve_pending_call(
|
||||||
|
self, user_id: str, call_id: str, result: dict
|
||||||
|
) -> None:
|
||||||
|
"""Fulfil the Future registered under *call_id* with the Electron result.
|
||||||
|
|
||||||
|
No-ops if the call_id is unknown (already timed out or cancelled).
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
return
|
||||||
|
fut = conn.pending_calls.pop(call_id, None)
|
||||||
|
if fut is not None and not fut.done():
|
||||||
|
fut.set_result(result)
|
||||||
|
|
||||||
|
# ── Agent-data queue ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_agent_data_queue(
|
||||||
|
self, user_id: str, run_id: str
|
||||||
|
) -> asyncio.Queue[dict | None]:
|
||||||
|
"""Return (creating if absent) the queue for *run_id* agent frames.
|
||||||
|
|
||||||
|
The agent runner reads from this queue. The device WS handler writes
|
||||||
|
to it. ``None`` is the sentinel that signals the stream is finished.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"get_agent_data_queue: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
if run_id not in conn.agent_data_queues:
|
||||||
|
conn.agent_data_queues[run_id] = asyncio.Queue()
|
||||||
|
return conn.agent_data_queues[run_id]
|
||||||
|
|
||||||
|
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
||||||
|
"""Remove the queue for *run_id* once a run has completed."""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.agent_data_queues.pop(run_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton — import this everywhere.
|
||||||
|
device_manager = DeviceConnectionManager()
|
||||||
@@ -1,222 +0,0 @@
|
|||||||
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.schemas import ExecutionPlan, PlanStep
|
|
||||||
|
|
||||||
|
|
||||||
# ── Prompt Template Registry ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateRegistry:
|
|
||||||
"""Server-side store mapping template IDs to prompt text.
|
|
||||||
|
|
||||||
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
|
||||||
The actual prompt text is resolved here on the server, keeping prompt IP
|
|
||||||
out of API responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._templates: dict[str, str] = {}
|
|
||||||
|
|
||||||
def register(self, template_id: str, prompt_text: str) -> None:
|
|
||||||
self._templates[template_id] = prompt_text
|
|
||||||
|
|
||||||
def get(self, template_id: str) -> str:
|
|
||||||
"""Resolve a template ID to its prompt text.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the template is not registered.
|
|
||||||
"""
|
|
||||||
text = self._templates.get(template_id)
|
|
||||||
if text is None:
|
|
||||||
raise KeyError(f"Template not found: {template_id!r}")
|
|
||||||
return text
|
|
||||||
|
|
||||||
def has(self, template_id: str) -> bool:
|
|
||||||
return template_id in self._templates
|
|
||||||
|
|
||||||
def list_ids(self) -> list[str]:
|
|
||||||
"""Return all registered template IDs (never the text)."""
|
|
||||||
return list(self._templates.keys())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plan Builder ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlanBuilder:
|
|
||||||
"""Fluent builder for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent: str) -> None:
|
|
||||||
self._agent = agent
|
|
||||||
self._steps: list[PlanStep] = []
|
|
||||||
|
|
||||||
# ── step adders ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def add_step(
|
|
||||||
self, action: str, params: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a generic action step with optional parameters."""
|
|
||||||
self._steps.append(PlanStep(action=action, variables=params))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_llm_step(
|
|
||||||
self, template_id: str, variables: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append an LLM step referencing a server-side template by ID."""
|
|
||||||
self._steps.append(
|
|
||||||
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a step whose input comes from the output of an earlier step."""
|
|
||||||
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
|
||||||
return self
|
|
||||||
|
|
||||||
# ── build ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def build(self) -> ExecutionPlan:
|
|
||||||
"""Validate step references and return the ``ExecutionPlan``.
|
|
||||||
|
|
||||||
Raises ``ValueError`` if any ``data_from_step`` references a
|
|
||||||
non-existent or future step index.
|
|
||||||
"""
|
|
||||||
for i, step in enumerate(self._steps):
|
|
||||||
if step.data_from_step is not None:
|
|
||||||
if not (0 <= step.data_from_step < i):
|
|
||||||
raise ValueError(
|
|
||||||
f"Step {i}: data_from_step={step.data_from_step} must "
|
|
||||||
f"reference a preceding step index in range 0..{i - 1}"
|
|
||||||
)
|
|
||||||
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PlanCache:
|
|
||||||
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
|
||||||
The cache also serves as a runtime memoisation layer so that repeated
|
|
||||||
identical intent classifications can skip re-building the plan.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, maxsize: int = 1000) -> None:
|
|
||||||
self._maxsize = maxsize
|
|
||||||
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
|
||||||
|
|
||||||
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
|
||||||
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
|
||||||
if key in self._cache:
|
|
||||||
del self._cache[key] # remove so re-insertion places it at the end
|
|
||||||
elif len(self._cache) >= self._maxsize:
|
|
||||||
self._cache.popitem(last=False) # evict least-recently-used
|
|
||||||
self._cache[key] = plan
|
|
||||||
|
|
||||||
def get_plan(self, key: str) -> ExecutionPlan | None:
|
|
||||||
"""Return the cached plan for *key*, or ``None`` if not present.
|
|
||||||
|
|
||||||
Accessing a plan marks it as most-recently used.
|
|
||||||
"""
|
|
||||||
if key not in self._cache:
|
|
||||||
return None
|
|
||||||
self._cache.move_to_end(key)
|
|
||||||
return self._cache[key]
|
|
||||||
|
|
||||||
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached plans (most-recently used last)."""
|
|
||||||
return list(self._cache.values())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Module-level singletons ───────────────────────────────────────────
|
|
||||||
|
|
||||||
template_registry = PromptTemplateRegistry()
|
|
||||||
plan_cache = PlanCache()
|
|
||||||
|
|
||||||
|
|
||||||
def _register_builtin_templates() -> None:
|
|
||||||
"""Register the built-in server-side prompt templates.
|
|
||||||
|
|
||||||
These strings never leave the server. Clients only receive the IDs.
|
|
||||||
"""
|
|
||||||
_tpls: dict[str, str] = {
|
|
||||||
"tpl_task_agent_default": (
|
|
||||||
"You are a task management assistant. Help the user create, update, "
|
|
||||||
"list, and track tasks. Use correct status values (todo, in_progress, "
|
|
||||||
"done) and priority values (high, medium, low) from the workspace model."
|
|
||||||
),
|
|
||||||
"tpl_checkpoint_agent_default": (
|
|
||||||
"You are a project checkpoint assistant. Help the user create and manage "
|
|
||||||
"milestone checkpoints on their projects. Every checkpoint requires a "
|
|
||||||
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
|
||||||
),
|
|
||||||
"tpl_project_agent_default": (
|
|
||||||
"You are a project management assistant. Help the user create, find, "
|
|
||||||
"update, and archive projects. Projects have a name, an optional client, "
|
|
||||||
"and a status of either active or archived."
|
|
||||||
),
|
|
||||||
"tpl_note_agent_default": (
|
|
||||||
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
|
||||||
"and delete Markdown notes. Notes can optionally be linked to a project."
|
|
||||||
),
|
|
||||||
"tpl_task_extract_from_project": (
|
|
||||||
"Extract all actionable tasks from the provided project context. "
|
|
||||||
"Return a structured list of tasks, each with a title, inferred priority "
|
|
||||||
"(high, medium, or low), suggested status (todo), and a due_date in "
|
|
||||||
"milliseconds where a deadline can be inferred."
|
|
||||||
),
|
|
||||||
"tpl_note_weekly_summary": (
|
|
||||||
"Generate a weekly project summary note from the provided workspace data. "
|
|
||||||
"Include: tasks completed this week, tasks due soon, active projects, "
|
|
||||||
"and upcoming checkpoints. Format the output as clean Markdown."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for tid, text in _tpls.items():
|
|
||||||
template_registry.register(tid, text)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_playbooks() -> None:
|
|
||||||
"""Pre-build and cache the built-in playbooks."""
|
|
||||||
playbooks: list[tuple[str, ExecutionPlan]] = [
|
|
||||||
(
|
|
||||||
"create_tasks_from_project",
|
|
||||||
ExecutionPlanBuilder("project_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_task_extract_from_project",
|
|
||||||
{"source": "project_context"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"generate_weekly_note",
|
|
||||||
ExecutionPlanBuilder("note_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_note_weekly_summary",
|
|
||||||
{"period": "last_7_days"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for key, plan in playbooks:
|
|
||||||
plan_cache.cache_plan(key, plan)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialise on module load
|
|
||||||
_register_builtin_templates()
|
|
||||||
_load_playbooks()
|
|
||||||
116
app/core/llm.py
Normal file
116
app/core/llm.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
|
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
|
||||||
|
instead of directly constructing a provider-specific class. The model string
|
||||||
|
follows the `LiteLLM model naming convention
|
||||||
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
|
|
||||||
|
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
|
||||||
|
* Anthropic: ``anthropic/claude-3.5-sonnet``
|
||||||
|
* Google: ``gemini/gemini-pro``
|
||||||
|
* Ollama: ``ollama/llama3``
|
||||||
|
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||||
|
|
||||||
|
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
||||||
|
— no code changes required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
|
||||||
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("cerebras/"):
|
||||||
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||||
|
# No API key is required; returning None lets LiteLLM handle auth.
|
||||||
|
return None
|
||||||
|
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
||||||
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
*,
|
||||||
|
model: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
"""Return a LangChain chat model backed by LiteLLM.
|
||||||
|
|
||||||
|
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
|
||||||
|
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
|
||||||
|
``openai`` client transparently when the model string contains a provider
|
||||||
|
prefix (``anthropic/…``, ``gemini/…``, etc.).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model:
|
||||||
|
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
|
||||||
|
temperature:
|
||||||
|
Sampling temperature. ``0`` = deterministic.
|
||||||
|
"""
|
||||||
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
# Point LiteLLM to the custom token directory when configured.
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
|
||||||
|
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
|
||||||
|
if "/" in model:
|
||||||
|
return ChatLiteLLM(model=model, temperature=temperature)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=_api_key_for_model(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_router_llm(
|
||||||
|
*,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
"""Return the lighter model used for intent classification / routing."""
|
||||||
|
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed(text: str) -> list[float]:
|
||||||
|
"""Return an embedding vector for *text*.
|
||||||
|
|
||||||
|
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
|
||||||
|
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
|
||||||
|
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
|
||||||
|
model names to preserve existing behaviour.
|
||||||
|
"""
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
|
||||||
|
# so the provider's auth mechanism is applied correctly.
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
|
||||||
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
|
return response.data[0].embedding
|
||||||
231
app/core/memory_middleware.py
Normal file
231
app/core/memory_middleware.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""Memory Middleware — enrich requests with memory context and store interactions.
|
||||||
|
|
||||||
|
Four-tier memory model (MemGPT-style):
|
||||||
|
core — persistent key/value user preferences, always injected
|
||||||
|
associative — semantic similarity search via pgvector (top-k)
|
||||||
|
episodic — recent session summaries (last N)
|
||||||
|
proactive — behavioral patterns above confidence threshold
|
||||||
|
|
||||||
|
All memory content is encrypted at rest using the per-user Fernet key
|
||||||
|
stored in User.encryption_key. Decryption happens in-memory only.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
memory = MemoryMiddleware(db_session)
|
||||||
|
context = await memory.enrich_context(user_id, message)
|
||||||
|
# ... run agent ...
|
||||||
|
await memory.store_episode(user_id, session_id, message, response)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tuning constants
|
||||||
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
|
_EPISODIC_RECENT_N = 10
|
||||||
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMiddleware:
|
||||||
|
"""Enrich agent context with memory and persist interactions after."""
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
||||||
|
"""Build memory context dict to inject into the agent before LLM call.
|
||||||
|
|
||||||
|
Returns a dict with keys:
|
||||||
|
core_memory — {key: plaintext_value, ...}
|
||||||
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||||
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||||
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||||
|
"""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
core = await self._load_core(user_id, fernet)
|
||||||
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
|
episodic = await self._load_episodic(user_id, fernet)
|
||||||
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"core_memory": core,
|
||||||
|
"associative_memory": associative,
|
||||||
|
"episodic_memory": episodic,
|
||||||
|
"proactive_hints": proactive,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def store_episode(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
response: str,
|
||||||
|
) -> None:
|
||||||
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||||
|
latency low. Full LLM summarisation can be added in a later step.
|
||||||
|
"""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||||
|
encrypted = _encrypt(fernet, summary)
|
||||||
|
|
||||||
|
row = MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
summary_encrypted=encrypted,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
||||||
|
"""Upsert a core memory key/value for a user."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, value)
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
existing.value_encrypted = encrypted
|
||||||
|
else:
|
||||||
|
self._db.add(MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
key=key,
|
||||||
|
value_encrypted=encrypted,
|
||||||
|
))
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
|
"""Load the user's Fernet key from DB. Returns None if missing."""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||||
|
return None
|
||||||
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: dict[str, str] = {}
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out[row.key] = plaintext
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_associative(
|
||||||
|
self, user_id: str, message: str, fernet: Fernet
|
||||||
|
) -> list[str]:
|
||||||
|
"""Load top-k associative memories.
|
||||||
|
|
||||||
|
Production: uses pgvector cosine similarity on the message embedding.
|
||||||
|
Current implementation: keyword-based fallback (no external embedding call)
|
||||||
|
so tests pass without a live OpenAI key.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(_EPISODIC_RECENT_N)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryProactive)
|
||||||
|
.where(
|
||||||
|
MemoryProactive.user_id == user_id,
|
||||||
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||||
|
)
|
||||||
|
.order_by(MemoryProactive.confidence.desc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ── Encryption helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||||
|
return fernet.encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||||
|
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
||||||
|
try:
|
||||||
|
return fernet.decrypt(ciphertext.encode()).decode()
|
||||||
|
except (InvalidToken, Exception) as exc:
|
||||||
|
logger.warning("memory: decrypt failed: %s", exc)
|
||||||
|
return None
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, AsyncGenerator
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.agent_registry import AgentRegistry
|
|
||||||
from app.core.agent_registry import registry as _default_registry
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
|
||||||
|
|
||||||
_FALLBACK_AGENT = "task_agent"
|
|
||||||
|
|
||||||
_CLASSIFY_SYSTEM = (
|
|
||||||
"You are an intent classifier. Given the user message and context, decide "
|
|
||||||
"which agent to route to.\n"
|
|
||||||
"Available agents: {agents}\n"
|
|
||||||
"Respond with just the agent name, nothing else."
|
|
||||||
)
|
|
||||||
|
|
||||||
_SYNTHESIZE_HUMAN = (
|
|
||||||
"Combine the following agent results into one coherent response.\n\n"
|
|
||||||
"Agent results:\n{results}\n\n"
|
|
||||||
"Original message: {message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_llm(model: str = "gpt-4o-mini") -> ChatOpenAI:
|
|
||||||
return ChatOpenAI(model=model, temperature=0, api_key=settings.OPENAI_API_KEY)
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_intent(
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> str:
|
|
||||||
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
|
||||||
|
|
||||||
Falls back to ``task_agent`` when the registry is empty or the model
|
|
||||||
returns a name that is not registered.
|
|
||||||
"""
|
|
||||||
agents = reg.list_agents()
|
|
||||||
if not agents:
|
|
||||||
return _FALLBACK_AGENT
|
|
||||||
|
|
||||||
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
|
||||||
# Truncate context to keep the classification prompt short
|
|
||||||
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
|
||||||
|
|
||||||
llm = _make_llm()
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[SystemMessage(content=system), HumanMessage(content=human)]
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = str(response.content).strip().lower()
|
|
||||||
known = {a["name"] for a in agents}
|
|
||||||
return agent_name if agent_name in known else _FALLBACK_AGENT
|
|
||||||
|
|
||||||
|
|
||||||
async def route_single(
|
|
||||||
agent_name: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
|
||||||
response_text = await reg.call_agent(agent_name, message, context)
|
|
||||||
return ChatResponse(response=response_text)
|
|
||||||
|
|
||||||
|
|
||||||
async def route_pipeline(
|
|
||||||
agent_names: list[str],
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Execute agents sequentially; each agent receives previous results in context.
|
|
||||||
|
|
||||||
A final LLM synthesis call merges all results into one coherent response.
|
|
||||||
"""
|
|
||||||
previous_results: list[str] = []
|
|
||||||
|
|
||||||
for agent_name in agent_names:
|
|
||||||
ctx = {**context, "previous_results": list(previous_results)}
|
|
||||||
result = await reg.call_agent(agent_name, message, ctx)
|
|
||||||
previous_results.append(result)
|
|
||||||
|
|
||||||
results_str = "\n\n".join(
|
|
||||||
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
|
||||||
)
|
|
||||||
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
|
||||||
llm = _make_llm()
|
|
||||||
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
|
||||||
return ChatResponse(response=str(synthesis.content))
|
|
||||||
|
|
||||||
|
|
||||||
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
|
||||||
"""Build an ``ExecutionPlan`` for the resolved agent.
|
|
||||||
|
|
||||||
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
|
||||||
If a default template exists for the agent, an LLM step is emitted;
|
|
||||||
otherwise a plain ``handle`` action step is used.
|
|
||||||
"""
|
|
||||||
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
|
||||||
|
|
||||||
template_id = f"tpl_{agent_name}_default"
|
|
||||||
builder = ExecutionPlanBuilder(agent_name)
|
|
||||||
if template_registry.has(template_id):
|
|
||||||
builder.add_llm_step(template_id, {"message": message})
|
|
||||||
else:
|
|
||||||
builder.add_step("handle", {"message": message})
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> ChatResponse | ExecutionPlan:
|
|
||||||
"""Main orchestration entry point.
|
|
||||||
|
|
||||||
* Classifies the user's intent to select an agent.
|
|
||||||
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
|
||||||
``ChatResponse``.
|
|
||||||
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
|
||||||
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
|
|
||||||
if request.execution_mode == "direct":
|
|
||||||
return await route_single(agent_name, request.message, context, reg)
|
|
||||||
|
|
||||||
# plan mode — return plan, do not execute
|
|
||||||
return _build_plan(agent_name, request.message)
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_stream(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming orchestration — yields text chunks then a final JSON frame.
|
|
||||||
|
|
||||||
The final frame is a JSON object:
|
|
||||||
``{"done": true, "response": "...", "actions": []}``.
|
|
||||||
|
|
||||||
Agents do not yet support token-level streaming; the full response is
|
|
||||||
fetched first, then emitted in fixed-size chunks. Token-level streaming
|
|
||||||
will be wired in Step 6 when agents expose ``astream()``.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
response_text = await reg.call_agent(agent_name, request.message, context)
|
|
||||||
|
|
||||||
chunk_size = 50
|
|
||||||
for i in range(0, len(response_text), chunk_size):
|
|
||||||
yield response_text[i : i + chunk_size]
|
|
||||||
|
|
||||||
final = ChatResponse(response=response_text)
|
|
||||||
yield json.dumps({"done": True, **final.model_dump()})
|
|
||||||
141
app/core/output_formatter.py
Normal file
141
app/core/output_formatter.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
||||||
|
|
||||||
|
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
||||||
|
* ``("token", str)`` — supervisor text token
|
||||||
|
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
||||||
|
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
||||||
|
|
||||||
|
HomeFormatter:
|
||||||
|
* Streams text tokens as-is → emits ``WsStreamText``
|
||||||
|
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
||||||
|
for the frontend to parse and render as interactive components)
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
|
||||||
|
FloatingFormatter:
|
||||||
|
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
||||||
|
* Streams text tokens → emits ``WsStreamText``
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Map sub-agent tool name → floating domain / entity type
|
||||||
|
_AGENT_DOMAIN: dict[str, str] = {
|
||||||
|
"task_agent": "tasks",
|
||||||
|
"timeline_agent": "timelines",
|
||||||
|
"note_agent": "notes",
|
||||||
|
"project_agent": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
|
class HomeFormatter:
|
||||||
|
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
||||||
|
|
||||||
|
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
||||||
|
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
||||||
|
is responsible for parsing those and rendering interactive components.
|
||||||
|
Mutations are attached to ``WsStreamEnd``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "token":
|
||||||
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FloatingFormatter:
|
||||||
|
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
||||||
|
|
||||||
|
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
||||||
|
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
||||||
|
``WsStreamText``. No block parsing for floating context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
domain_sent = False
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "tool_end" and not domain_sent:
|
||||||
|
# Sniff domain from the first sub-agent that completes
|
||||||
|
name = data.get("name", "")
|
||||||
|
domain = _AGENT_DOMAIN.get(name, "tasks")
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain=domain, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
|
||||||
|
elif event_type == "token":
|
||||||
|
if not domain_sent:
|
||||||
|
# First token arrived before any tool_end — default domain
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
# If no events triggered domain_sent (edge case), still emit structure
|
||||||
|
if not domain_sent:
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
100
app/core/ws_context.py
Normal file
100
app/core/ws_context.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""WebSocket client executor context.
|
||||||
|
|
||||||
|
Holds a per-request async callback that tools call to execute CRUD
|
||||||
|
operations on the Electron client's local SQLite / LanceDB databases.
|
||||||
|
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Callable, Coroutine
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Holds the execute callback for the current WS session.
|
||||||
|
# Set by the chat WS handler before the deep agent runs; cleared after.
|
||||||
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||||
|
"_client_executor"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional collector that captures raw execute_on_client results.
|
||||||
|
# Set by the deep agent tool loop to capture CRUD mutations.
|
||||||
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
|
"_tool_result_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||||
|
"""Register *lst* as the collector for this async context."""
|
||||||
|
_tool_result_collector.set(lst)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_tool_result_collector() -> None:
|
||||||
|
"""Clear the collector (best-effort)."""
|
||||||
|
_tool_result_collector.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||||
|
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||||
|
_client_executor.set(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_client_executor() -> None:
|
||||||
|
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
|
||||||
|
try:
|
||||||
|
_client_executor.set(None) # type: ignore[arg-type]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_on_client(
|
||||||
|
action: str,
|
||||||
|
table: str | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
vector: list[float] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a CRUD/vector operation to the Electron client and return the result.
|
||||||
|
|
||||||
|
Builds a ``tool_call`` payload, invokes the per-session WS callback,
|
||||||
|
and returns the ``tool_result`` dict from Electron.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
|
||||||
|
"""
|
||||||
|
callback = _client_executor.get(None)
|
||||||
|
if callback is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"execute_on_client() called outside a WebSocket session — "
|
||||||
|
"no client executor is set."
|
||||||
|
)
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
|
||||||
|
if table is not None:
|
||||||
|
payload["table"] = table
|
||||||
|
if data is not None:
|
||||||
|
payload["data"] = data
|
||||||
|
if filters is not None:
|
||||||
|
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||||
|
if vector is not None:
|
||||||
|
payload["vector"] = vector
|
||||||
|
if limit is not None:
|
||||||
|
payload["limit"] = limit
|
||||||
|
|
||||||
|
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
|
result = await callback(payload)
|
||||||
|
if result is None:
|
||||||
|
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
|
else:
|
||||||
|
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
|
collector = _tool_result_collector.get(None)
|
||||||
|
if collector is not None and action in ("insert", "update", "delete"):
|
||||||
|
collector.append({
|
||||||
|
"action": action,
|
||||||
|
"table": table,
|
||||||
|
"data": data or {},
|
||||||
|
})
|
||||||
|
return result
|
||||||
40
app/db.py
Normal file
40
app/db.py
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
"""Database engine, session factory, and base model.
|
||||||
|
|
||||||
|
All app code uses the async SQLAlchemy API. Alembic migrations use the
|
||||||
|
synchronous psycopg2 URL for the CLI (see alembic/env.py).
|
||||||
|
|
||||||
|
Usage in routes:
|
||||||
|
from app.db import get_session
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
async def my_route(db: AsyncSession = Depends(get_session)):
|
||||||
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
engine = create_async_engine(
|
||||||
|
settings.DATABASE_URL,
|
||||||
|
pool_pre_ping=True,
|
||||||
|
echo=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
||||||
|
|
||||||
|
|
||||||
|
class Base(DeclarativeBase):
|
||||||
|
"""Shared declarative base for all ORM models."""
|
||||||
|
|
||||||
|
|
||||||
|
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""FastAPI dependency that yields an async DB session per request."""
|
||||||
|
async with async_session() as session:
|
||||||
|
yield session
|
||||||
164
app/integrations/__init__.py
Normal file
164
app/integrations/__init__.py
Normal file
@@ -0,0 +1,164 @@
|
|||||||
|
"""Cloud provider integration utilities.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
|
||||||
|
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
|
||||||
|
* ``get_provider()`` — factory that returns the correct client given a
|
||||||
|
provider name and decrypted OAuth credentials dict.
|
||||||
|
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
|
||||||
|
encryption for OAuth tokens stored in ``cloud_agent_configs``.
|
||||||
|
|
||||||
|
Encryption rationale
|
||||||
|
--------------------
|
||||||
|
Unlike user content (which is E2E-encrypted client-side and **never**
|
||||||
|
decrypted server-side), OAuth tokens *must* be decrypted server-side
|
||||||
|
because the backend makes provider API calls on behalf of the user.
|
||||||
|
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
|
||||||
|
is never returned to clients.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Shared message types ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmailMessage:
|
||||||
|
"""A single email message fetched from Gmail or Outlook."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
subject: str
|
||||||
|
sender: str
|
||||||
|
body_text: str
|
||||||
|
date: datetime
|
||||||
|
labels: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
"""Return a human-readable text representation for LLM extraction."""
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{labels_str}\n"
|
||||||
|
f"Subject: {self.subject}\n\n"
|
||||||
|
f"{self.body_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
"""A single Teams chat or channel message fetched from MS Graph."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
sender: str
|
||||||
|
channel: str | None
|
||||||
|
date: datetime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
"""Return a human-readable text representation for LLM extraction."""
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{channel_str}\n\n"
|
||||||
|
f"{self.content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fernet helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet() -> Fernet:
|
||||||
|
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
|
||||||
|
must ensure this is configured before persisting OAuth tokens.
|
||||||
|
"""
|
||||||
|
key = settings.OAUTH_ENCRYPTION_KEY
|
||||||
|
if not key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"OAUTH_ENCRYPTION_KEY is not set. "
|
||||||
|
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||||
|
)
|
||||||
|
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_token(token_info: dict) -> str:
|
||||||
|
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
|
||||||
|
|
||||||
|
Stores the full ``{access_token, refresh_token, token_uri, client_id,
|
||||||
|
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||||
|
ValueError: ``token_info`` is not a non-empty dict.
|
||||||
|
"""
|
||||||
|
if not isinstance(token_info, dict) or not token_info:
|
||||||
|
raise ValueError("token_info must be a non-empty dict")
|
||||||
|
plaintext = json.dumps(token_info).encode("utf-8")
|
||||||
|
return _get_fernet().encrypt(plaintext).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_token(encrypted: str) -> dict:
|
||||||
|
"""Decrypt a Fernet-encrypted token string and return the credential dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||||
|
ValueError: The encrypted string is invalid or was encrypted with a
|
||||||
|
different key.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||||
|
return json.loads(plaintext)
|
||||||
|
except (InvalidToken, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ── Provider factory ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
provider: str,
|
||||||
|
credentials_info: dict,
|
||||||
|
) -> "GmailClient | MSGraphClient":
|
||||||
|
"""Return the correct provider client for *provider*.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
provider:
|
||||||
|
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
|
||||||
|
credentials_info:
|
||||||
|
Decrypted OAuth credential dict (Google or Microsoft shape).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Unknown provider name.
|
||||||
|
"""
|
||||||
|
if provider == "gmail":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(credentials_info)
|
||||||
|
if provider in {"outlook", "teams"}:
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(credentials_info)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown cloud provider {provider!r}. "
|
||||||
|
"Supported: 'gmail', 'outlook', 'teams'."
|
||||||
|
)
|
||||||
335
app/integrations/gmail.py
Normal file
335
app/integrations/gmail.py
Normal file
@@ -0,0 +1,335 @@
|
|||||||
|
"""Gmail API client for cloud agent integration.
|
||||||
|
|
||||||
|
Wraps the Google Gmail REST API to fetch email messages matching a
|
||||||
|
``filter_config`` dict. Uses the official ``google-api-python-client``
|
||||||
|
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
|
||||||
|
blocking the event loop.
|
||||||
|
|
||||||
|
Token refresh is handled transparently: when the stored access token has
|
||||||
|
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
||||||
|
token to obtain a fresh one. The caller is responsible for persisting
|
||||||
|
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
|
||||||
|
(see ``agent_runner.run_cloud_agent``).
|
||||||
|
|
||||||
|
Credential dict shape (Google OAuth2):
|
||||||
|
{
|
||||||
|
"token": "<access_token>",
|
||||||
|
"refresh_token": "<refresh_token>",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "<client_id>",
|
||||||
|
"client_secret": "<client_secret>",
|
||||||
|
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||||
|
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import email
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.integrations import EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Gmail search date format — e.g. "after:2025/01/01"
|
||||||
|
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||||
|
|
||||||
|
# Maximum characters of body text forwarded to the LLM.
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
# Maximum messages retrieved per run (prevents runaway quota usage).
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gmail_query(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
"""Build a Gmail search query string from *filter_config* and *since*.
|
||||||
|
|
||||||
|
Supported ``filter_config`` keys:
|
||||||
|
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
|
||||||
|
senders (list[str]): Sender addresses or domains to include
|
||||||
|
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
|
||||||
|
|
||||||
|
A hard ``since`` date (from last run) always overrides ``date_range.from``
|
||||||
|
when it is earlier.
|
||||||
|
"""
|
||||||
|
parts: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
# Labels — joined with OR when multiple given.
|
||||||
|
labels: list[str] = cfg.get("labels", [])
|
||||||
|
if labels:
|
||||||
|
if len(labels) == 1:
|
||||||
|
parts.append(f"label:{labels[0]}")
|
||||||
|
else:
|
||||||
|
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||||
|
parts.append(f"({label_expr})")
|
||||||
|
|
||||||
|
# Senders — each prefixed with "from:".
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
for sender in senders:
|
||||||
|
parts.append(f"from:{sender}")
|
||||||
|
|
||||||
|
# Date range.
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
|
||||||
|
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw_html: str) -> str:
|
||||||
|
"""Remove HTML tags and decode entities to get plain text."""
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||||
|
decoded = html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_body(payload: dict[str, Any]) -> str:
|
||||||
|
"""Recursively extract the plain-text body from a Gmail message payload.
|
||||||
|
|
||||||
|
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
|
||||||
|
Returns an empty string if no body can be extracted.
|
||||||
|
"""
|
||||||
|
mime_type: str = payload.get("mimeType", "")
|
||||||
|
body: dict = payload.get("body", {})
|
||||||
|
parts: list[dict] = payload.get("parts", [])
|
||||||
|
|
||||||
|
if mime_type == "text/plain":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if mime_type == "text/html":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return _strip_html(raw)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
# Multipart — prefer text/plain part, fall back to text/html.
|
||||||
|
plain_fallback = ""
|
||||||
|
for part in parts:
|
||||||
|
part_mime = part.get("mimeType", "")
|
||||||
|
if part_mime == "text/plain":
|
||||||
|
return _parse_body(part)
|
||||||
|
if part_mime == "text/html" and not plain_fallback:
|
||||||
|
plain_fallback = _parse_body(part)
|
||||||
|
if part_mime.startswith("multipart/"):
|
||||||
|
nested = _parse_body(part)
|
||||||
|
if nested:
|
||||||
|
return nested
|
||||||
|
return plain_fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_date(raw: str) -> datetime:
|
||||||
|
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
|
||||||
|
try:
|
||||||
|
parsed = email.utils.parsedate_to_datetime(raw)
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||||
|
return parsed.astimezone(timezone.utc)
|
||||||
|
except Exception:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class GmailClient:
|
||||||
|
"""Fetch email messages from a Gmail account via the Gmail REST API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
credentials_info:
|
||||||
|
Decrypted OAuth2 credential dict. Must contain at minimum
|
||||||
|
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
|
||||||
|
``client_id`` + ``client_secret``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
expiry_str: str | None = credentials_info.get("expiry")
|
||||||
|
expiry: datetime | None = None
|
||||||
|
if expiry_str:
|
||||||
|
try:
|
||||||
|
expiry = datetime.fromisoformat(
|
||||||
|
expiry_str.replace("Z", "+00:00")
|
||||||
|
).replace(tzinfo=timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._credentials = Credentials(
|
||||||
|
token=credentials_info.get("token"),
|
||||||
|
refresh_token=credentials_info.get("refresh_token"),
|
||||||
|
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||||
|
client_id=credentials_info.get("client_id"),
|
||||||
|
client_secret=credentials_info.get("client_secret"),
|
||||||
|
scopes=credentials_info.get("scopes"),
|
||||||
|
expiry=expiry,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Public API ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
|
||||||
|
|
||||||
|
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
|
||||||
|
to avoid blocking the async event loop.
|
||||||
|
|
||||||
|
Token refresh is performed automatically when the access token has
|
||||||
|
expired. After the call, ``self.refreshed_credentials`` may be
|
||||||
|
consulted to detect whether new credentials should be persisted.
|
||||||
|
"""
|
||||||
|
query = _build_gmail_query(filter_config, since)
|
||||||
|
logger.debug("gmail: executing search query %r", query)
|
||||||
|
return await asyncio.to_thread(self._fetch_sync, query)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
"""Return updated credential dict if the access token was refreshed.
|
||||||
|
|
||||||
|
If the credentials were refreshed during ``fetch_messages()``, returns
|
||||||
|
a new dict that should be re-encrypted and written back to the DB.
|
||||||
|
Returns ``None`` if no refresh occurred.
|
||||||
|
"""
|
||||||
|
creds = self._credentials
|
||||||
|
if not creds.valid and creds.expired:
|
||||||
|
return None
|
||||||
|
# Check whether the token changed from what was stored.
|
||||||
|
if creds.token != self._credentials_info.get("token"):
|
||||||
|
result = {
|
||||||
|
"token": creds.token,
|
||||||
|
"refresh_token": creds.refresh_token,
|
||||||
|
"token_uri": creds.token_uri,
|
||||||
|
"client_id": creds.client_id,
|
||||||
|
"client_secret": creds.client_secret,
|
||||||
|
"scopes": list(creds.scopes or []),
|
||||||
|
}
|
||||||
|
if creds.expiry:
|
||||||
|
result["expiry"] = creds.expiry.isoformat()
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── Internal sync worker ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||||
|
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
|
||||||
|
import googleapiclient.discovery
|
||||||
|
import googleapiclient.errors
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
# Refresh token if needed before building the service.
|
||||||
|
if self._credentials.expired and self._credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
self._credentials.refresh(Request())
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||||
|
|
||||||
|
service = googleapiclient.discovery.build(
|
||||||
|
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||||
|
)
|
||||||
|
user_api = service.users() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# ── List matching message IDs ──────────────────────────────────────
|
||||||
|
ids: list[str] = []
|
||||||
|
page_token: str | None = None
|
||||||
|
while len(ids) < _MAX_MESSAGES:
|
||||||
|
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"userId": "me",
|
||||||
|
"maxResults": batch_size,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
kwargs["q"] = query
|
||||||
|
if page_token:
|
||||||
|
kwargs["pageToken"] = page_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = user_api.messages().list(**kwargs).execute()
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||||
|
|
||||||
|
for msg in resp.get("messages", []):
|
||||||
|
ids.append(msg["id"])
|
||||||
|
|
||||||
|
page_token = resp.get("nextPageToken")
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
logger.debug("gmail: no messages matched query %r", query)
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||||
|
|
||||||
|
# ── Fetch individual message details ──────────────────────────────
|
||||||
|
messages: list[EmailMessage] = []
|
||||||
|
for msg_id in ids:
|
||||||
|
try:
|
||||||
|
msg = user_api.messages().get(
|
||||||
|
userId="me", id=msg_id, format="full"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
h["name"].lower(): h["value"]
|
||||||
|
for h in msg.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
subject = headers.get("subject", "(no subject)")
|
||||||
|
sender = headers.get("from", "unknown")
|
||||||
|
date_raw = headers.get("date", "")
|
||||||
|
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||||
|
labels = msg.get("labelIds", [])
|
||||||
|
|
||||||
|
messages.append(EmailMessage(
|
||||||
|
id=msg_id,
|
||||||
|
subject=subject,
|
||||||
|
sender=sender,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
labels=labels,
|
||||||
|
))
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
|
||||||
|
|
||||||
|
logger.info("gmail: returned %d message(s)", len(messages))
|
||||||
|
return messages
|
||||||
352
app/integrations/ms_graph.py
Normal file
352
app/integrations/ms_graph.py
Normal file
@@ -0,0 +1,352 @@
|
|||||||
|
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
|
||||||
|
|
||||||
|
Handles two data sources:
|
||||||
|
|
||||||
|
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
|
||||||
|
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
|
||||||
|
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
|
||||||
|
``/me/chats/getAllMessages`` filtered by date.
|
||||||
|
|
||||||
|
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
|
||||||
|
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
|
||||||
|
dependency) is used for all API calls.
|
||||||
|
|
||||||
|
Credential dict shape (Microsoft OAuth2 / MSAL):
|
||||||
|
{
|
||||||
|
"access_token": "<access_token>",
|
||||||
|
"refresh_token": "<refresh_token>",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
|
||||||
|
"expires_in": 3600
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.integrations import ChatMessage, EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
# Max items fetched per run.
|
||||||
|
_MAX_EMAILS = 200
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
# Max characters of body forwarded to the LLM.
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw: str) -> str:
|
||||||
|
"""Strip HTML tags and collapse whitespace."""
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||||
|
import html as _html
|
||||||
|
decoded = _html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _odata_datetime(dt: datetime) -> str:
|
||||||
|
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
|
||||||
|
utc = dt.astimezone(timezone.utc)
|
||||||
|
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_email_filter(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
|
||||||
|
|
||||||
|
Supported ``filter_config`` keys:
|
||||||
|
senders (list[str]): Sender email addresses.
|
||||||
|
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
|
||||||
|
folders (list[str]): Folder display names (not directly filterable
|
||||||
|
via OData, so ignored here — callers iterate
|
||||||
|
folder IDs separately if needed; listed for
|
||||||
|
completeness).
|
||||||
|
|
||||||
|
A hard ``since`` date always overrides ``date_range.from`` when it is
|
||||||
|
earlier.
|
||||||
|
"""
|
||||||
|
clauses: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
# Senders.
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
if senders:
|
||||||
|
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||||
|
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||||
|
|
||||||
|
# Date range.
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
|
||||||
|
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
if to_dt.tzinfo is None:
|
||||||
|
to_dt = to_dt.replace(tzinfo=timezone.utc)
|
||||||
|
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " and ".join(clauses)
|
||||||
|
|
||||||
|
|
||||||
|
class MSGraphClient:
|
||||||
|
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
credentials_info:
|
||||||
|
Decrypted MSAL credential dict.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
self._access_token: str = credentials_info.get("access_token", "")
|
||||||
|
self._original_access_token: str = self._access_token
|
||||||
|
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||||
|
|
||||||
|
# ── Token management ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {self._access_token}"}
|
||||||
|
|
||||||
|
async def _refresh_access_token(self) -> None:
|
||||||
|
"""Use MSAL to exchange the refresh token for a fresh access token.
|
||||||
|
|
||||||
|
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: MSAL reports an auth error.
|
||||||
|
"""
|
||||||
|
import msal
|
||||||
|
|
||||||
|
app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=settings.MS_CLIENT_ID,
|
||||||
|
client_credential=settings.MS_CLIENT_SECRET,
|
||||||
|
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
|
||||||
|
)
|
||||||
|
scopes: list[str] = self._credentials_info.get("scope", "").split()
|
||||||
|
if not scopes:
|
||||||
|
scopes = ["https://graph.microsoft.com/.default"]
|
||||||
|
|
||||||
|
result = app.acquire_token_by_refresh_token(
|
||||||
|
self._refresh_token,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
if "access_token" not in result:
|
||||||
|
error = result.get("error_description", result.get("error", "unknown"))
|
||||||
|
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||||
|
|
||||||
|
self._access_token = result["access_token"]
|
||||||
|
# MSAL may issue a new refresh token.
|
||||||
|
if "refresh_token" in result:
|
||||||
|
self._refresh_token = result["refresh_token"]
|
||||||
|
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||||
|
self._credentials_info["access_token"] = self._access_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
"""Return updated credential dict if the access token was refreshed.
|
||||||
|
|
||||||
|
Returns ``None`` if no change was made.
|
||||||
|
"""
|
||||||
|
if self._access_token != self._original_access_token:
|
||||||
|
return {**self._credentials_info, "access_token": self._access_token}
|
||||||
|
return None
|
||||||
|
|
||||||
|
# ── HTTP helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
retry_on_401: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""GET *url* with auth; refresh token on 401 and retry once."""
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||||
|
logger.debug("ms_graph: 401 on %s — refreshing token", url)
|
||||||
|
await self._refresh_access_token()
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 429:
|
||||||
|
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
# ── Public API ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def fetch_emails(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filter_config:
|
||||||
|
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
|
||||||
|
since:
|
||||||
|
Hard lower-bound on email date (from last agent run).
|
||||||
|
"""
|
||||||
|
odata_filter = _build_email_filter(filter_config, since)
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"$top": 50,
|
||||||
|
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
|
||||||
|
"$orderby": "receivedDateTime desc",
|
||||||
|
}
|
||||||
|
if odata_filter:
|
||||||
|
params["$filter"] = odata_filter
|
||||||
|
|
||||||
|
emails: list[EmailMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/messages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(emails) < _MAX_EMAILS:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
for item in data.get("value", []):
|
||||||
|
emails.append(self._parse_email(item))
|
||||||
|
if len(emails) >= _MAX_EMAILS:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {} # nextLink already contains encoded params.
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||||
|
return emails
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
|
||||||
|
|
||||||
|
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
|
||||||
|
The ``filter_config.channels`` key is checked as a text-filter on
|
||||||
|
the channel name post-fetch (the API doesn't support channel OData
|
||||||
|
filter directly on ``getAllMessages``).
|
||||||
|
"""
|
||||||
|
cfg = filter_config or {}
|
||||||
|
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||||
|
params: dict[str, Any] = {"$top": 50}
|
||||||
|
if since:
|
||||||
|
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
|
||||||
|
|
||||||
|
messages: list[ChatMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(messages) < _MAX_MESSAGES:
|
||||||
|
try:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
# getAllMessages requires specific licensing; degrade gracefully.
|
||||||
|
if exc.response.status_code in (403, 404):
|
||||||
|
logger.warning(
|
||||||
|
"ms_graph: /me/chats/getAllMessages not available (%d) — "
|
||||||
|
"check Teams license or permissions",
|
||||||
|
exc.response.status_code,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
raise
|
||||||
|
|
||||||
|
for item in data.get("value", []):
|
||||||
|
msg = self._parse_teams_message(item)
|
||||||
|
if channel_filter and msg.channel:
|
||||||
|
if not any(c in msg.channel.lower() for c in channel_filter):
|
||||||
|
continue
|
||||||
|
messages.append(msg)
|
||||||
|
if len(messages) >= _MAX_MESSAGES:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
# ── Parsers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||||
|
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||||
|
sender_block = item.get("from", {}) or {}
|
||||||
|
sender_addr = (
|
||||||
|
(sender_block.get("emailAddress") or {}).get("address", "unknown")
|
||||||
|
)
|
||||||
|
date_str: str = item.get("receivedDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_body: str = body_block.get("content", "")
|
||||||
|
if content_type == "html":
|
||||||
|
body_text = _strip_html(raw_body)
|
||||||
|
else:
|
||||||
|
body_text = raw_body or item.get("bodyPreview", "")
|
||||||
|
body_text = body_text[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return EmailMessage(
|
||||||
|
id=item.get("id", ""),
|
||||||
|
subject=subject,
|
||||||
|
sender=sender_addr,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
|
||||||
|
msg_id: str = item.get("id", "")
|
||||||
|
sender_block = (item.get("from") or {}).get("user") or {}
|
||||||
|
sender: str = sender_block.get("displayName", "unknown")
|
||||||
|
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
|
||||||
|
|
||||||
|
date_str: str = item.get("createdDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_content: str = body_block.get("content", "")
|
||||||
|
content = _strip_html(raw_content) if content_type == "html" else raw_content
|
||||||
|
content = content[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
id=msg_id,
|
||||||
|
content=content,
|
||||||
|
sender=sender,
|
||||||
|
channel=channel,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
37
app/main.py
37
app/main.py
@@ -1,8 +1,16 @@
|
|||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
|
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||||
|
|
||||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
@@ -10,13 +18,12 @@ from app.config.settings import settings
|
|||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: initialise DB connection pool and agent registry
|
# Startup: initialise DB connection pool
|
||||||
from app.core.agent_registry import registry # noqa: F401 — triggers module load
|
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
# Shutdown: nothing to clean up for now
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
|
from app.db import engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
@@ -41,16 +48,18 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
|
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
app.include_router(plans.router, prefix="/api/v1")
|
app.include_router(storage.router, prefix="/api/v1")
|
||||||
app.include_router(storage.router, prefix="/api/v1")
|
app.include_router(vectors.router, prefix="/api/v1")
|
||||||
app.include_router(vectors.router, prefix="/api/v1")
|
app.include_router(backup.router, prefix="/api/v1")
|
||||||
app.include_router(backup.router, prefix="/api/v1")
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
app.include_router(plugins.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
|
app.include_router(agent_setup.router, prefix="/api/v1")
|
||||||
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
async def health() -> dict:
|
async def health() -> dict:
|
||||||
|
|||||||
@@ -1,8 +1,7 @@
|
|||||||
"""Plugin catalog registry.
|
"""Plugin catalog registry backed by PostgreSQL.
|
||||||
|
|
||||||
Maintains the authoritative list of plugins, their review status, and
|
Maintains the authoritative list of plugins, their review status, and
|
||||||
aggregate install counts. Storage is in-memory until Step 12 migrates to
|
aggregate install counts. All data is persisted in the ``plugins`` table.
|
||||||
the ``plugins`` PostgreSQL table.
|
|
||||||
|
|
||||||
Module-level singleton::
|
Module-level singleton::
|
||||||
|
|
||||||
@@ -11,144 +10,103 @@ Module-level singleton::
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import json
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import Plugin
|
||||||
from app.schemas import PluginListResponse, PluginManifest
|
from app.schemas import PluginListResponse, PluginManifest
|
||||||
|
|
||||||
# ── Pre-seeded approved plugins (mirrors the Step 8 stub catalog) ─────
|
|
||||||
|
|
||||||
_SEED_PLUGINS: list[dict[str, Any]] = [
|
|
||||||
{
|
|
||||||
"manifest": PluginManifest(
|
|
||||||
id="plugin-github-sync",
|
|
||||||
name="GitHub Sync",
|
|
||||||
description="Sync tasks with GitHub Issues and pull requests.",
|
|
||||||
version="1.0.0",
|
|
||||||
author="Adiuva",
|
|
||||||
permissions=["read:tasks", "write:tasks"],
|
|
||||||
category="productivity",
|
|
||||||
price_cents=0,
|
|
||||||
),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
"rejection_reason": None,
|
|
||||||
"submitted_at": int(time.time()),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"manifest": PluginManifest(
|
|
||||||
id="plugin-slack-notify",
|
|
||||||
name="Slack Notifier",
|
|
||||||
description="Post task and checkpoint updates to Slack channels.",
|
|
||||||
version="1.2.0",
|
|
||||||
author="Adiuva",
|
|
||||||
permissions=["read:tasks", "read:checkpoints"],
|
|
||||||
category="communication",
|
|
||||||
price_cents=499,
|
|
||||||
),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
"rejection_reason": None,
|
|
||||||
"submitted_at": int(time.time()),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"manifest": PluginManifest(
|
|
||||||
id="plugin-time-tracker",
|
|
||||||
name="Time Tracker",
|
|
||||||
description="Track time spent on tasks with automatic reporting.",
|
|
||||||
version="0.9.1",
|
|
||||||
author="Third Party",
|
|
||||||
permissions=["read:tasks", "write:tasks"],
|
|
||||||
category="productivity",
|
|
||||||
price_cents=999,
|
|
||||||
),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
"rejection_reason": None,
|
|
||||||
"submitted_at": int(time.time()),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
_PAGE_SIZE = 20
|
_PAGE_SIZE = 20
|
||||||
|
|
||||||
|
|
||||||
|
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
|
||||||
|
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
|
||||||
|
try:
|
||||||
|
permissions = json.loads(p.permissions) if p.permissions else []
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
permissions = []
|
||||||
|
return PluginManifest(
|
||||||
|
id=p.id,
|
||||||
|
name=p.name,
|
||||||
|
description=p.description,
|
||||||
|
version=p.version,
|
||||||
|
author=p.author_name,
|
||||||
|
permissions=permissions,
|
||||||
|
category=p.category,
|
||||||
|
price_cents=p.price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class PluginRegistry:
|
class PluginRegistry:
|
||||||
"""In-process plugin catalog.
|
"""PostgreSQL-backed plugin catalog.
|
||||||
|
|
||||||
All mutating methods are ``async`` to make the future DB swap transparent
|
All methods accept an ``AsyncSession`` parameter so the calling route
|
||||||
to callers.
|
controls the session lifecycle.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
# plugin_id → entry dict (deep-copied so each instance is independent)
|
|
||||||
self._catalog: dict[str, dict[str, Any]] = {
|
|
||||||
e["manifest"].id: copy.deepcopy(e) for e in _SEED_PLUGINS
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Queries ──────────────────────────────────────────────────────
|
# ── Queries ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def list_plugins(
|
async def list_plugins(
|
||||||
self,
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
category: str | None = None,
|
category: str | None = None,
|
||||||
query: str | None = None,
|
query: str | None = None,
|
||||||
page: int = 1,
|
page: int = 1,
|
||||||
sort: Literal["rating", "installs", "newest"] = "newest",
|
sort: Literal["rating", "installs", "newest"] = "newest",
|
||||||
) -> PluginListResponse:
|
) -> PluginListResponse:
|
||||||
"""Return a page of approved plugins, optionally filtered and sorted."""
|
"""Return a page of approved plugins, optionally filtered and sorted."""
|
||||||
entries = [e for e in self._catalog.values() if e["status"] == "approved"]
|
base = select(Plugin).where(Plugin.status == "approved")
|
||||||
|
|
||||||
if category:
|
if category:
|
||||||
entries = [e for e in entries if e["manifest"].category == category]
|
base = base.where(Plugin.category == category)
|
||||||
|
|
||||||
if query:
|
if query:
|
||||||
q_lower = query.lower()
|
pattern = f"%{query}%"
|
||||||
entries = [
|
base = base.where(
|
||||||
e
|
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
|
||||||
for e in entries
|
)
|
||||||
if q_lower in e["manifest"].name.lower()
|
|
||||||
or q_lower in e["manifest"].description.lower()
|
|
||||||
]
|
|
||||||
|
|
||||||
|
# Count
|
||||||
|
count_q = select(func.count()).select_from(base.subquery())
|
||||||
|
total = (await db.execute(count_q)).scalar_one()
|
||||||
|
|
||||||
|
# Sort
|
||||||
if sort == "installs":
|
if sort == "installs":
|
||||||
entries = sorted(entries, key=lambda e: e["install_count"], reverse=True)
|
base = base.order_by(Plugin.install_count.desc())
|
||||||
elif sort == "rating":
|
elif sort == "rating":
|
||||||
entries = sorted(entries, key=lambda e: e["avg_rating"], reverse=True)
|
base = base.order_by(Plugin.avg_rating.desc())
|
||||||
# "newest" = catalog insertion order (dict preserves insertion in Python 3.7+)
|
else: # newest
|
||||||
|
base = base.order_by(Plugin.created_at.desc())
|
||||||
|
|
||||||
total = len(entries)
|
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
|
||||||
start = (page - 1) * _PAGE_SIZE
|
rows = (await db.execute(base)).scalars().all()
|
||||||
page_entries = entries[start : start + _PAGE_SIZE]
|
|
||||||
|
|
||||||
return PluginListResponse(
|
return PluginListResponse(
|
||||||
plugins=[e["manifest"] for e in page_entries],
|
plugins=[_plugin_to_manifest(r) for r in rows],
|
||||||
total=total,
|
total=total,
|
||||||
page=page,
|
page=page,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def get_plugin(self, plugin_id: str) -> dict[str, Any] | None:
|
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
|
||||||
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
||||||
entry = self._catalog.get(plugin_id)
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
if entry is None:
|
p = result.scalar_one_or_none()
|
||||||
|
if p is None:
|
||||||
return None
|
return None
|
||||||
return {
|
return {
|
||||||
"manifest": entry["manifest"],
|
"manifest": _plugin_to_manifest(p),
|
||||||
"status": entry["status"],
|
"status": p.status,
|
||||||
"install_count": entry["install_count"],
|
"install_count": p.install_count,
|
||||||
"avg_rating": entry["avg_rating"],
|
"avg_rating": p.avg_rating,
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── Mutations ────────────────────────────────────────────────────
|
# ── Mutations ────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def submit_plugin(
|
async def submit_plugin(
|
||||||
self,
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
manifest: PluginManifest,
|
manifest: PluginManifest,
|
||||||
package_s3_key: str,
|
package_s3_key: str,
|
||||||
) -> str:
|
) -> str:
|
||||||
@@ -157,54 +115,97 @@ class PluginRegistry:
|
|||||||
Returns the plugin_id. If a plugin with the same id already exists
|
Returns the plugin_id. If a plugin with the same id already exists
|
||||||
it is overwritten (re-submission after rejection).
|
it is overwritten (re-submission after rejection).
|
||||||
"""
|
"""
|
||||||
plugin_id = manifest.id or str(uuid.uuid4())
|
plugin_id = manifest.id
|
||||||
self._catalog[plugin_id] = {
|
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
"manifest": manifest,
|
row = existing.scalar_one_or_none()
|
||||||
"status": "pending_review",
|
|
||||||
"s3_package_key": package_s3_key,
|
if row is not None:
|
||||||
"install_count": 0,
|
row.name = manifest.name
|
||||||
"avg_rating": 0.0,
|
row.description = manifest.description
|
||||||
"rejection_reason": None,
|
row.version = manifest.version
|
||||||
"submitted_at": int(time.time()),
|
row.author_name = manifest.author
|
||||||
}
|
row.category = manifest.category
|
||||||
|
row.price_cents = manifest.price_cents
|
||||||
|
row.permissions = json.dumps(manifest.permissions)
|
||||||
|
row.status = "pending_review"
|
||||||
|
row.s3_package_key = package_s3_key
|
||||||
|
row.rejection_reason = None
|
||||||
|
else:
|
||||||
|
row = Plugin(
|
||||||
|
id=plugin_id,
|
||||||
|
name=manifest.name,
|
||||||
|
description=manifest.description,
|
||||||
|
version=manifest.version,
|
||||||
|
author_name=manifest.author,
|
||||||
|
category=manifest.category,
|
||||||
|
price_cents=manifest.price_cents,
|
||||||
|
permissions=json.dumps(manifest.permissions),
|
||||||
|
status="pending_review",
|
||||||
|
s3_package_key=package_s3_key,
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
await db.commit()
|
||||||
return plugin_id
|
return plugin_id
|
||||||
|
|
||||||
async def approve_plugin(self, plugin_id: str) -> None:
|
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
"""Set *plugin_id* status to ``'approved'``.
|
"""Set *plugin_id* status to ``'approved'``.
|
||||||
|
|
||||||
Raises ``KeyError`` if the plugin is not found.
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
"""
|
"""
|
||||||
if plugin_id not in self._catalog:
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
self._catalog[plugin_id]["status"] = "approved"
|
row.status = "approved"
|
||||||
self._catalog[plugin_id]["rejection_reason"] = None
|
row.rejection_reason = None
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
async def reject_plugin(self, plugin_id: str, reason: str) -> None:
|
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
|
||||||
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
||||||
|
|
||||||
Raises ``KeyError`` if the plugin is not found.
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
"""
|
"""
|
||||||
if plugin_id not in self._catalog:
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
self._catalog[plugin_id]["status"] = "rejected"
|
row.status = "rejected"
|
||||||
self._catalog[plugin_id]["rejection_reason"] = reason
|
row.rejection_reason = reason
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
async def record_install(self, plugin_id: str) -> None:
|
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
||||||
if plugin_id in self._catalog:
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
self._catalog[plugin_id]["install_count"] += 1
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.install_count = row.install_count + 1
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
async def record_uninstall(self, plugin_id: str) -> None:
|
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
"""Decrement the install count for *plugin_id*, floored at 0."""
|
"""Decrement the install count for *plugin_id*, floored at 0."""
|
||||||
if plugin_id in self._catalog:
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
current = self._catalog[plugin_id]["install_count"]
|
row = result.scalar_one_or_none()
|
||||||
self._catalog[plugin_id]["install_count"] = max(0, current - 1)
|
if row is not None:
|
||||||
|
row.install_count = max(0, row.install_count - 1)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
||||||
|
|
||||||
def _get_pending_entries(self) -> list[dict[str, Any]]:
|
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
"""Return all entries with status='pending_review' (synchronous helper)."""
|
"""Return all entries with status='pending_review'."""
|
||||||
return [e for e in self._catalog.values() if e["status"] == "pending_review"]
|
result = await db.execute(
|
||||||
|
select(Plugin).where(Plugin.status == "pending_review")
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"manifest": _plugin_to_manifest(r),
|
||||||
|
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
# Module-level singleton
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Plugin review workflow.
|
"""Plugin review workflow backed by PostgreSQL.
|
||||||
|
|
||||||
Manages the approval queue for newly submitted plugins and enforces a
|
Manages the approval queue for newly submitted plugins and enforces a
|
||||||
security checklist before any plugin is made visible in the marketplace.
|
security checklist before any plugin is made visible in the marketplace.
|
||||||
@@ -11,10 +11,12 @@ Module-level singleton::
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import time
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.marketplace.plugin_registry import registry
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import PluginReview as PluginReviewModel
|
||||||
from app.schemas import PluginManifest
|
from app.schemas import PluginManifest
|
||||||
|
|
||||||
# ── Security policy ───────────────────────────────────────────────────
|
# ── Security policy ───────────────────────────────────────────────────
|
||||||
@@ -27,8 +29,8 @@ ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
|||||||
"write:projects",
|
"write:projects",
|
||||||
"read:notes",
|
"read:notes",
|
||||||
"write:notes",
|
"write:notes",
|
||||||
"read:checkpoints",
|
"read:timelines",
|
||||||
"write:checkpoints",
|
"write:timelines",
|
||||||
"read:calendar",
|
"read:calendar",
|
||||||
"write:calendar",
|
"write:calendar",
|
||||||
}
|
}
|
||||||
@@ -72,20 +74,16 @@ def validate_manifest(manifest: PluginManifest) -> None:
|
|||||||
class ReviewQueue:
|
class ReviewQueue:
|
||||||
"""Approval queue for pending plugin submissions.
|
"""Approval queue for pending plugin submissions.
|
||||||
|
|
||||||
Delegates status changes to the shared ``PluginRegistry`` singleton so
|
Delegates status changes to the shared ``PluginRegistry`` singleton.
|
||||||
there is a single source of truth for plugin state.
|
Review records are persisted in the ``plugin_reviews`` table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
# Completed reviews — Step 12 stores in plugin_reviews table
|
|
||||||
self._reviews: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def get_pending(self) -> list[dict[str, Any]]:
|
|
||||||
"""Return all plugins currently awaiting review.
|
"""Return all plugins currently awaiting review.
|
||||||
|
|
||||||
Each item is ``{plugin_id, manifest, submitted_at}``.
|
Each item is ``{plugin_id, manifest, submitted_at}``.
|
||||||
"""
|
"""
|
||||||
entries = registry._get_pending_entries()
|
entries = await registry.get_pending_entries(db)
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
"plugin_id": e["manifest"].id,
|
"plugin_id": e["manifest"].id,
|
||||||
@@ -97,6 +95,7 @@ class ReviewQueue:
|
|||||||
|
|
||||||
async def submit_review(
|
async def submit_review(
|
||||||
self,
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
reviewer_id: str,
|
reviewer_id: str,
|
||||||
decision: Literal["approved", "rejected"],
|
decision: Literal["approved", "rejected"],
|
||||||
@@ -108,19 +107,18 @@ class ReviewQueue:
|
|||||||
``KeyError`` if *plugin_id* is not found in the registry.
|
``KeyError`` if *plugin_id* is not found in the registry.
|
||||||
"""
|
"""
|
||||||
if decision == "approved":
|
if decision == "approved":
|
||||||
await registry.approve_plugin(plugin_id)
|
await registry.approve_plugin(db, plugin_id)
|
||||||
else:
|
else:
|
||||||
await registry.reject_plugin(plugin_id, reason=notes)
|
await registry.reject_plugin(db, plugin_id, reason=notes)
|
||||||
|
|
||||||
self._reviews.append(
|
review = PluginReviewModel(
|
||||||
{
|
plugin_id=plugin_id,
|
||||||
"plugin_id": plugin_id,
|
reviewer_id=reviewer_id,
|
||||||
"reviewer_id": reviewer_id,
|
decision=decision,
|
||||||
"decision": decision,
|
notes=notes,
|
||||||
"notes": notes,
|
|
||||||
"reviewed_at": int(time.time()),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
db.add(review)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
# Module-level singleton
|
||||||
|
|||||||
@@ -1,8 +1,8 @@
|
|||||||
"""Revenue share tracking and Stripe Connect payouts.
|
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
|
||||||
|
|
||||||
Records every plugin installation as a revenue event and facilitates
|
Records every plugin installation as a revenue event and facilitates
|
||||||
70 % / 30 % payouts to developers via Stripe Connect. Storage is
|
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
|
||||||
in-memory until Step 12 migrates to the ``revenue_events`` table.
|
in the ``revenue_events`` table.
|
||||||
|
|
||||||
Module-level singleton::
|
Module-level singleton::
|
||||||
|
|
||||||
@@ -12,13 +12,16 @@ Module-level singleton::
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import stripe as stripe_lib
|
import stripe as stripe_lib
|
||||||
|
from sqlalchemy import extract, func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.marketplace.plugin_registry import registry
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import Plugin, RevenueEvent
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -35,10 +38,6 @@ class RevenueShare:
|
|||||||
is not configured, consistent with the rest of the billing layer.
|
is not configured, consistent with the rest of the billing layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
# Step 12 replaces with revenue_events DB table
|
|
||||||
self._events: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────
|
# ── Helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@@ -54,6 +53,7 @@ class RevenueShare:
|
|||||||
|
|
||||||
async def record_install(
|
async def record_install(
|
||||||
self,
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
plugin_id: str,
|
plugin_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
amount_cents: int,
|
amount_cents: int,
|
||||||
@@ -72,11 +72,12 @@ class RevenueShare:
|
|||||||
stripe_transfer_id: str | None = None
|
stripe_transfer_id: str | None = None
|
||||||
|
|
||||||
if amount_cents > 0 and self._stripe_configured():
|
if amount_cents > 0 and self._stripe_configured():
|
||||||
plugin_entry = registry._catalog.get(plugin_id)
|
# Look up the plugin's author Stripe account from the DB
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
plugin_row = result.scalar_one_or_none()
|
||||||
developer_stripe_account: str | None = None
|
developer_stripe_account: str | None = None
|
||||||
if plugin_entry:
|
if plugin_row and plugin_row.author_id:
|
||||||
# Step 12: look up developer's Stripe account from DB
|
# Future: look up user.stripe_connect_account_id
|
||||||
# For now, the author field is used as a placeholder key.
|
|
||||||
developer_stripe_account = None # no real account yet
|
developer_stripe_account = None # no real account yet
|
||||||
|
|
||||||
if developer_stripe_account:
|
if developer_stripe_account:
|
||||||
@@ -103,22 +104,21 @@ class RevenueShare:
|
|||||||
plugin_id,
|
plugin_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
self._events.append(
|
event = RevenueEvent(
|
||||||
{
|
plugin_id=plugin_id,
|
||||||
"plugin_id": plugin_id,
|
user_id=user_id,
|
||||||
"user_id": user_id,
|
amount_cents=amount_cents,
|
||||||
"amount_cents": amount_cents,
|
developer_share_cents=developer_share_cents,
|
||||||
"developer_share_cents": developer_share_cents,
|
stripe_transfer_id=stripe_transfer_id,
|
||||||
"stripe_transfer_id": stripe_transfer_id,
|
|
||||||
"paid_at": None,
|
|
||||||
"created_at": int(time.time()),
|
|
||||||
}
|
|
||||||
)
|
)
|
||||||
|
db.add(event)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
await registry.record_install(plugin_id)
|
await registry.record_install(db, plugin_id)
|
||||||
|
|
||||||
async def get_earnings(
|
async def get_earnings(
|
||||||
self,
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
developer_id: str,
|
developer_id: str,
|
||||||
period: str | None = None,
|
period: str | None = None,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
@@ -136,54 +136,81 @@ class RevenueShare:
|
|||||||
"developer_share_cents": int,
|
"developer_share_cents": int,
|
||||||
}
|
}
|
||||||
"""
|
"""
|
||||||
# Find plugin ids belonging to this developer
|
# Find plugin ids belonging to this developer (by author_name match)
|
||||||
developer_plugin_ids: set[str] = {
|
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
|
||||||
pid
|
plugin_result = await db.execute(plugin_q)
|
||||||
for pid, entry in registry._catalog.items()
|
developer_plugin_ids = [row[0] for row in plugin_result.all()]
|
||||||
if entry["manifest"].author == developer_id
|
|
||||||
}
|
|
||||||
|
|
||||||
events = [e for e in self._events if e["plugin_id"] in developer_plugin_ids]
|
if not developer_plugin_ids:
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": 0,
|
||||||
|
"total_revenue_cents": 0,
|
||||||
|
"developer_share_cents": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
query = select(
|
||||||
|
func.count().label("total_installs"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
|
||||||
|
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
|
||||||
|
|
||||||
if period:
|
if period:
|
||||||
# Filter by YYYY-MM prefix of the created_at timestamp
|
# Filter by YYYY-MM: extract year and month from created_at
|
||||||
events = [
|
try:
|
||||||
e
|
year, month = period.split("-")
|
||||||
for e in events
|
query = query.where(
|
||||||
if time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
extract("year", RevenueEvent.created_at) == int(year),
|
||||||
]
|
extract("month", RevenueEvent.created_at) == int(month),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass # invalid period format — return all
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
row = result.one()
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"developer_id": developer_id,
|
"developer_id": developer_id,
|
||||||
"period": period,
|
"period": period,
|
||||||
"total_installs": len(events),
|
"total_installs": row.total_installs,
|
||||||
"total_revenue_cents": sum(e["amount_cents"] for e in events),
|
"total_revenue_cents": row.total_revenue,
|
||||||
"developer_share_cents": sum(e["developer_share_cents"] for e in events),
|
"developer_share_cents": row.dev_share,
|
||||||
}
|
}
|
||||||
|
|
||||||
async def payout_developer(self, plugin_id: str, period: str) -> None:
|
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
|
||||||
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
||||||
|
|
||||||
Marks processed events with ``paid_at`` timestamp.
|
Marks processed events with ``paid_at`` timestamp.
|
||||||
Stubs gracefully when Stripe is not configured.
|
Stubs gracefully when Stripe is not configured.
|
||||||
"""
|
"""
|
||||||
unpaid = [
|
try:
|
||||||
e
|
year, month = period.split("-")
|
||||||
for e in self._events
|
year_int, month_int = int(year), int(month)
|
||||||
if e["plugin_id"] == plugin_id
|
except ValueError:
|
||||||
and e["paid_at"] is None
|
logger.warning("Invalid period format: %s", period)
|
||||||
and time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
return
|
||||||
]
|
|
||||||
|
|
||||||
total_dev_share = sum(e["developer_share_cents"] for e in unpaid)
|
result = await db.execute(
|
||||||
|
select(RevenueEvent).where(
|
||||||
|
RevenueEvent.plugin_id == plugin_id,
|
||||||
|
RevenueEvent.paid_at.is_(None),
|
||||||
|
extract("year", RevenueEvent.created_at) == year_int,
|
||||||
|
extract("month", RevenueEvent.created_at) == month_int,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
unpaid = list(result.scalars().all())
|
||||||
|
|
||||||
|
total_dev_share = sum(e.developer_share_cents for e in unpaid)
|
||||||
if total_dev_share <= 0 or not unpaid:
|
if total_dev_share <= 0 or not unpaid:
|
||||||
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
||||||
return
|
return
|
||||||
|
|
||||||
if self._stripe_configured():
|
if self._stripe_configured():
|
||||||
plugin_entry = registry._catalog.get(plugin_id)
|
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
developer_stripe_account: str | None = None # Step 12: fetch from DB
|
plugin_row = plugin_result.scalar_one_or_none()
|
||||||
if plugin_entry and developer_stripe_account:
|
developer_stripe_account: str | None = None # Future: fetch from DB
|
||||||
|
if plugin_row and developer_stripe_account:
|
||||||
try:
|
try:
|
||||||
s = self._stripe()
|
s = self._stripe()
|
||||||
s.Transfer.create(
|
s.Transfer.create(
|
||||||
@@ -196,9 +223,10 @@ class RevenueShare:
|
|||||||
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
||||||
return
|
return
|
||||||
|
|
||||||
paid_ts = int(time.time())
|
paid_ts = datetime.now(timezone.utc)
|
||||||
for event in unpaid:
|
for event in unpaid:
|
||||||
event["paid_at"] = paid_ts
|
event.paid_at = paid_ts
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
# Module-level singleton
|
||||||
|
|||||||
476
app/models.py
Normal file
476
app/models.py
Normal file
@@ -0,0 +1,476 @@
|
|||||||
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
|
Only auth, billing, storage metadata, and marketplace data live here.
|
||||||
|
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||||
|
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||||
|
|
||||||
|
Table inventory:
|
||||||
|
users — account credentials + tier
|
||||||
|
refresh_tokens — hashed refresh token store
|
||||||
|
subscriptions — Stripe subscription records
|
||||||
|
storage_records — S3 blob metadata (no plaintext)
|
||||||
|
backup_metadata — encrypted backup manifests
|
||||||
|
plugins — marketplace plugin catalog
|
||||||
|
plugin_installations — per-user install records
|
||||||
|
plugin_reviews — admin review decisions
|
||||||
|
revenue_events — Stripe Connect 70/30 split ledger
|
||||||
|
memory_core — per-user persistent key/value preferences (encrypted)
|
||||||
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
|
memory_proactive — per-user behavioral patterns (encrypted)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
|
Boolean,
|
||||||
|
DateTime,
|
||||||
|
Enum,
|
||||||
|
Float,
|
||||||
|
ForeignKey,
|
||||||
|
Integer,
|
||||||
|
JSON,
|
||||||
|
String,
|
||||||
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
|
Uuid,
|
||||||
|
func,
|
||||||
|
)
|
||||||
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
|
from app.db import Base
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _uuid() -> str:
|
||||||
|
return str(uuid.uuid4())
|
||||||
|
|
||||||
|
|
||||||
|
def _now() -> datetime:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Enum types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||||
|
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||||
|
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||||
|
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
||||||
|
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
||||||
|
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Models ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class User(Base):
|
||||||
|
__tablename__ = "users"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
|
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||||
|
# Used to encrypt/decrypt all memory rows for this user.
|
||||||
|
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
subscription: Mapped[Subscription | None] = relationship(
|
||||||
|
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class RefreshToken(Base):
|
||||||
|
__tablename__ = "refresh_tokens"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class Subscription(Base):
|
||||||
|
__tablename__ = "subscriptions"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, unique=True, index=True
|
||||||
|
)
|
||||||
|
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
||||||
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
||||||
|
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecord(Base):
|
||||||
|
__tablename__ = "storage_records"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BackupMetadata(Base):
|
||||||
|
__tablename__ = "backup_metadata"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||||
|
# nullable until developer account system is built
|
||||||
|
author_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||||
|
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
submitted_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
reviews: Mapped[list[PluginReview]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallation(Base):
|
||||||
|
__tablename__ = "plugin_installations"
|
||||||
|
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
installed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginReview(Base):
|
||||||
|
__tablename__ = "plugin_reviews"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
reviewer_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||||
|
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
reviewed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueEvent(Base):
|
||||||
|
__tablename__ = "revenue_events"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAgentConfig(Base):
|
||||||
|
__tablename__ = "local_agent_configs"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
device_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
|
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
||||||
|
back_populates="local_agent",
|
||||||
|
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
||||||
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
overlaps="run_logs,cloud_agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfig(Base):
|
||||||
|
__tablename__ = "cloud_agent_configs"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
provider: Mapped[str] = mapped_column(CloudProviderEnum, nullable=False)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
oauth_token_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
filter_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
|
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
||||||
|
back_populates="cloud_agent",
|
||||||
|
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
||||||
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
|
cascade="all, delete-orphan",
|
||||||
|
overlaps="run_logs,local_agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AgentRunLog(Base):
|
||||||
|
__tablename__ = "agent_run_logs"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
# Plain string — not a FK because it references either local_agent_configs or cloud_agent_configs
|
||||||
|
# depending on agent_type. Query by (agent_id, agent_type) to locate the source config.
|
||||||
|
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||||
|
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
||||||
|
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
|
started_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
|
||||||
|
local_agent: Mapped[LocalAgentConfig | None] = relationship(
|
||||||
|
back_populates="run_logs",
|
||||||
|
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
||||||
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
|
overlaps="run_logs,cloud_agent",
|
||||||
|
)
|
||||||
|
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
|
||||||
|
back_populates="run_logs",
|
||||||
|
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
||||||
|
foreign_keys="AgentRunLog.agent_id",
|
||||||
|
overlaps="run_logs,local_agent",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryCore(Base):
|
||||||
|
"""Per-user persistent key/value preferences, encrypted at rest.
|
||||||
|
|
||||||
|
Examples: preferred_language, timezone, work_style.
|
||||||
|
Decrypted in-memory only using User.encryption_key.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_core"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryAssociative(Base):
|
||||||
|
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
|
||||||
|
|
||||||
|
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
|
||||||
|
Tests (SQLite): stored as JSON list.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_associative"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
||||||
|
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryEpisodic(Base):
|
||||||
|
"""Per-user session summaries, encrypted at rest.
|
||||||
|
|
||||||
|
One row per session interaction; used to recall recent conversations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_episodic"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryProactive(Base):
|
||||||
|
"""Per-user inferred behavioral patterns, encrypted at rest.
|
||||||
|
|
||||||
|
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
|
||||||
|
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
|
||||||
|
"""
|
||||||
|
|
||||||
|
__tablename__ = "memory_proactive"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False, index=True,
|
||||||
|
)
|
||||||
|
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
|
||||||
|
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
308
app/schemas.py
308
app/schemas.py
@@ -5,6 +5,7 @@ Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
@@ -26,6 +27,8 @@ class AuthTokens(BaseModel):
|
|||||||
class UserProfile(BaseModel):
|
class UserProfile(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
email: str
|
email: str
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
tier: BillingTier
|
tier: BillingTier
|
||||||
|
|
||||||
|
|
||||||
@@ -38,41 +41,13 @@ class ChatContext(BaseModel):
|
|||||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class PlanAction(BaseModel):
|
|
||||||
type: Literal[
|
|
||||||
"create_record",
|
|
||||||
"update_record",
|
|
||||||
"delete_record",
|
|
||||||
"index_document",
|
|
||||||
"send_notification",
|
|
||||||
]
|
|
||||||
table: str | None = None
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
class ChatRequest(BaseModel):
|
||||||
message: str
|
message: str
|
||||||
context: ChatContext = Field(default_factory=ChatContext)
|
context: ChatContext = Field(default_factory=ChatContext)
|
||||||
execution_mode: Literal["direct", "plan"] = "direct"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
class ChatResponse(BaseModel):
|
||||||
response: str
|
response: str
|
||||||
actions: list[PlanAction] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plans ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class PlanStep(BaseModel):
|
|
||||||
action: str
|
|
||||||
prompt_template: str | None = None
|
|
||||||
variables: dict[str, Any] | None = None
|
|
||||||
data_from_step: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlan(BaseModel):
|
|
||||||
agent: str
|
|
||||||
steps: list[PlanStep] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Backup ───────────────────────────────────────────────────────────
|
# ── Backup ───────────────────────────────────────────────────────────
|
||||||
@@ -155,3 +130,280 @@ class PluginListResponse(BaseModel):
|
|||||||
|
|
||||||
class PluginInstallRequest(BaseModel):
|
class PluginInstallRequest(BaseModel):
|
||||||
plugin_id: str
|
plugin_id: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||||
|
|
||||||
|
class WsFrameType(str, Enum):
|
||||||
|
# ── v2 frame types (kept for backward compat) ──────────────────────
|
||||||
|
chat_request = "chat_request"
|
||||||
|
text_chunk = "text_chunk"
|
||||||
|
tool_call = "tool_call"
|
||||||
|
tool_result = "tool_result"
|
||||||
|
final = "final"
|
||||||
|
ping = "ping"
|
||||||
|
agent_run = "agent_run"
|
||||||
|
agent_data = "agent_data"
|
||||||
|
agent_complete = "agent_complete"
|
||||||
|
device_hello = "device_hello"
|
||||||
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
|
home_request = "home_request"
|
||||||
|
floating_request = "floating_request"
|
||||||
|
stream_start = "stream_start"
|
||||||
|
stream_text = "stream_text"
|
||||||
|
stream_end = "stream_end"
|
||||||
|
floating_domain = "floating_domain"
|
||||||
|
data_request = "data_request"
|
||||||
|
data_response = "data_response"
|
||||||
|
mutation = "mutation"
|
||||||
|
|
||||||
|
|
||||||
|
class WsToolCall(BaseModel):
|
||||||
|
"""Server → Client: requests a CRUD/vector operation on the local DB."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.tool_call] = WsFrameType.tool_call
|
||||||
|
id: str
|
||||||
|
action: str
|
||||||
|
table: str | None = None
|
||||||
|
data: dict[str, Any] | None = None
|
||||||
|
filters: dict[str, Any] | None = None
|
||||||
|
vector: list[float] | None = None
|
||||||
|
limit: int | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsToolResult(BaseModel):
|
||||||
|
"""Client → Server: result of a CRUD/vector operation."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.tool_result] = WsFrameType.tool_result
|
||||||
|
id: str
|
||||||
|
row: dict[str, Any] | None = None
|
||||||
|
rows: list[dict[str, Any]] | None = None
|
||||||
|
results: list[dict[str, Any]] | None = None
|
||||||
|
deleted: bool | None = None
|
||||||
|
ok: bool | None = None
|
||||||
|
error: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsTextChunk(BaseModel):
|
||||||
|
"""Server → Client: incremental LLM response text."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.text_chunk] = WsFrameType.text_chunk
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsFinal(BaseModel):
|
||||||
|
"""Server → Client: signals end of response with the complete text."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.final] = WsFrameType.final
|
||||||
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── WebSocket Agent Frame Protocol ────────────────────────────────────
|
||||||
|
|
||||||
|
class WsDeviceHello(BaseModel):
|
||||||
|
"""Client → Server: device identification on WS connect."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
|
||||||
|
device_id: str
|
||||||
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentRun(BaseModel):
|
||||||
|
"""Server → Client: trigger an agent run on the connected device."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
||||||
|
run_id: str
|
||||||
|
agent_id: str
|
||||||
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentData(BaseModel):
|
||||||
|
"""Client → Server: files read by the local agent."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
||||||
|
run_id: str
|
||||||
|
files: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentComplete(BaseModel):
|
||||||
|
"""Client → Server: Electron signals it has finished reading files."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
||||||
|
run_id: str
|
||||||
|
files_read: int
|
||||||
|
errors: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
|
class WsFloatingScope(BaseModel):
|
||||||
|
"""Scope for a floating request — narrows the agent to a specific entity."""
|
||||||
|
|
||||||
|
type: Literal["task", "project", "note", "timeline"]
|
||||||
|
id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class WsHomeRequest(BaseModel):
|
||||||
|
"""Client → Server: Home chat message."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
||||||
|
message: str
|
||||||
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsFloatingRequest(BaseModel):
|
||||||
|
"""Client → Server: Floating chat message scoped to an entity."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
||||||
|
message: str
|
||||||
|
scope: WsFloatingScope
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamStart(BaseModel):
|
||||||
|
"""Server → Client: signals start of a streaming response."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_start] = WsFrameType.stream_start
|
||||||
|
request_id: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamText(BaseModel):
|
||||||
|
"""Server → Client: streamed text token."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_text] = WsFrameType.stream_text
|
||||||
|
request_id: str
|
||||||
|
chunk: str
|
||||||
|
|
||||||
|
|
||||||
|
class WsStreamEnd(BaseModel):
|
||||||
|
"""Server → Client: signals end of a streaming response."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
|
request_id: str
|
||||||
|
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsFloatingDomain(BaseModel):
|
||||||
|
"""Server → Client: domain determined for a floating request."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
|
request_id: str
|
||||||
|
domain: Literal["tasks", "timelines", "notes", "projects"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AgentCatalogItem(BaseModel):
|
||||||
|
type: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
config_schema: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local Agent Config ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class LocalAgentConfigCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
device_id: str
|
||||||
|
directory_paths: list[str]
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
file_extensions: list[str]
|
||||||
|
schedule_cron: str
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAgentConfigUpdate(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
device_id: str | None = None
|
||||||
|
directory_paths: list[str] | None = None
|
||||||
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
file_extensions: list[str] | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class LocalAgentConfigResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
device_id: str
|
||||||
|
directory_paths: list[str]
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
file_extensions: list[str]
|
||||||
|
schedule_cron: str
|
||||||
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Agent Config ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CloudAgentConfigCreate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
oauth_token_encrypted: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigUpdate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"] | None = None
|
||||||
|
name: str | None = None
|
||||||
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
oauth_token_encrypted: str | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigResponse(BaseModel):
|
||||||
|
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None
|
||||||
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class AgentRunLogResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
agent_id: str
|
||||||
|
agent_type: Literal["local", "cloud"]
|
||||||
|
status: Literal["running", "success", "error", "partial"]
|
||||||
|
items_processed: int
|
||||||
|
items_created: int
|
||||||
|
errors: list[str]
|
||||||
|
started_at: int
|
||||||
|
completed_at: int | None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class JourneyStartRequest(BaseModel):
|
||||||
|
agent_type: Literal["local", "cloud"]
|
||||||
|
agent_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyMessageRequest(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyResponse(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
done: bool
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|||||||
@@ -9,7 +9,6 @@ from __future__ import annotations
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
from botocore.exceptions import ClientError
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
|
||||||
@@ -23,12 +22,14 @@ class BlobStore:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def _client(self) -> Any:
|
def _client(self) -> Any:
|
||||||
return boto3.client(
|
kwargs: dict[str, Any] = {
|
||||||
"s3",
|
"region_name": settings.S3_REGION,
|
||||||
region_name=settings.S3_REGION,
|
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
|
||||||
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
|
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
|
||||||
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
|
}
|
||||||
)
|
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
|
||||||
|
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
|
||||||
|
return boto3.client("s3", **kwargs)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _key(user_id: str, table: str, record_id: str) -> str:
|
def _key(user_id: str, table: str, record_id: str) -> str:
|
||||||
|
|||||||
@@ -1,21 +1,23 @@
|
|||||||
version: "3.9"
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
app:
|
app:
|
||||||
build: .
|
build: .
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "8080:8000"
|
||||||
env_file:
|
env_file:
|
||||||
- .env
|
- path: .env
|
||||||
|
required: false
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
||||||
|
volumes:
|
||||||
|
- copilot_tokens:/root/.config/litellm/github_copilot
|
||||||
depends_on:
|
depends_on:
|
||||||
db:
|
db:
|
||||||
condition: service_healthy
|
condition: service_healthy
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
db:
|
db:
|
||||||
image: postgres:16-alpine
|
image: pgvector/pgvector:pg16
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: postgres
|
||||||
POSTGRES_PASSWORD: postgres
|
POSTGRES_PASSWORD: postgres
|
||||||
@@ -34,5 +36,37 @@ services:
|
|||||||
# image: redis:7-alpine
|
# image: redis:7-alpine
|
||||||
# restart: unless-stopped
|
# restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Local S3-compatible storage (MinIO) ──
|
||||||
|
minio:
|
||||||
|
image: minio/minio:latest
|
||||||
|
command: server /data --console-address ":9001"
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
- "9001:9001"
|
||||||
|
environment:
|
||||||
|
MINIO_ROOT_USER: minioadmin
|
||||||
|
MINIO_ROOT_PASSWORD: minioadmin
|
||||||
|
volumes:
|
||||||
|
- minio_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "mc", "ready", "local"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Local vector store (Qdrant) ──
|
||||||
|
qdrant:
|
||||||
|
image: qdrant/qdrant:latest
|
||||||
|
ports:
|
||||||
|
- "6333:6333"
|
||||||
|
- "6334:6334"
|
||||||
|
volumes:
|
||||||
|
- qdrant_data:/qdrant/storage
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
minio_data:
|
||||||
|
qdrant_data:
|
||||||
|
copilot_tokens:
|
||||||
|
|||||||
56
logging.conf
Normal file
56
logging.conf
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
[loggers]
|
||||||
|
keys=root,uvicorn,uvicorn.error,uvicorn.access,sqlalchemy,watchfiles
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys=console,file
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys=default
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level=INFO
|
||||||
|
handlers=console,file
|
||||||
|
|
||||||
|
[logger_uvicorn]
|
||||||
|
level=INFO
|
||||||
|
handlers=
|
||||||
|
qualname=uvicorn
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_uvicorn.error]
|
||||||
|
level=INFO
|
||||||
|
handlers=
|
||||||
|
qualname=uvicorn.error
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_uvicorn.access]
|
||||||
|
level=INFO
|
||||||
|
handlers=
|
||||||
|
qualname=uvicorn.access
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level=WARNING
|
||||||
|
handlers=
|
||||||
|
qualname=sqlalchemy
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_watchfiles]
|
||||||
|
level=WARNING
|
||||||
|
handlers=
|
||||||
|
qualname=watchfiles
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class=StreamHandler
|
||||||
|
formatter=default
|
||||||
|
args=(sys.stderr,)
|
||||||
|
|
||||||
|
[handler_file]
|
||||||
|
class=logging.handlers.RotatingFileHandler
|
||||||
|
formatter=default
|
||||||
|
args=('logs/app.log', 'a', 10485760, 5, 'utf-8')
|
||||||
|
|
||||||
|
[formatter_default]
|
||||||
|
format=%(asctime)s %(levelname)s %(name)s: %(message)s
|
||||||
|
datefmt=%Y-%m-%d %H:%M:%S
|
||||||
@@ -1,7 +1,12 @@
|
|||||||
fastapi>=0.115.0
|
fastapi>=0.115.0
|
||||||
uvicorn[standard]>=0.34.0
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
langchain>=0.3.0
|
langchain>=0.3.0
|
||||||
langchain-openai>=0.3.0
|
langchain-openai>=0.3.0
|
||||||
|
langchain-litellm>=0.1.0
|
||||||
|
langgraph>=0.3.0
|
||||||
|
deepagents>=0.4.10
|
||||||
|
litellm>=1.50.0
|
||||||
pydantic>=2.10.0
|
pydantic>=2.10.0
|
||||||
pydantic-settings>=2.7.0
|
pydantic-settings>=2.7.0
|
||||||
python-jose[cryptography]>=3.3.0
|
python-jose[cryptography]>=3.3.0
|
||||||
@@ -15,8 +20,18 @@ bcrypt>=4.2.0
|
|||||||
python-dotenv>=1.0.0
|
python-dotenv>=1.0.0
|
||||||
httpx>=0.28.0
|
httpx>=0.28.0
|
||||||
websockets>=14.0
|
websockets>=14.0
|
||||||
|
psycopg2-binary>=2.9.0
|
||||||
pytest>=8.0.0
|
pytest>=8.0.0
|
||||||
pytest-asyncio>=0.24.0
|
pytest-asyncio>=0.24.0
|
||||||
|
aiosqlite>=0.20.0
|
||||||
moto[s3]>=5.0.0
|
moto[s3]>=5.0.0
|
||||||
pinecone>=5.0.0
|
pinecone>=5.0.0
|
||||||
qdrant-client>=1.7.0
|
qdrant-client>=1.7.0
|
||||||
|
croniter>=3.0.0
|
||||||
|
google-api-python-client>=2.130.0
|
||||||
|
google-auth>=2.29.0
|
||||||
|
google-auth-oauthlib>=1.2.0
|
||||||
|
google-auth-httplib2>=0.2.0
|
||||||
|
msal>=1.28.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
ruff>=0.8.0
|
||||||
|
|||||||
235
tests/conftest.py
Normal file
235
tests/conftest.py
Normal file
@@ -0,0 +1,235 @@
|
|||||||
|
"""Shared test fixtures for database-backed tests.
|
||||||
|
|
||||||
|
Provides an async SQLite in-memory engine that auto-creates all tables,
|
||||||
|
a per-test session, and a FastAPI ``TestClient`` wired to use it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from collections.abc import AsyncGenerator, Generator
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from jose import jwt
|
||||||
|
from moto import mock_aws
|
||||||
|
from sqlalchemy import StaticPool, event
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.db import Base, get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import Plugin, Subscription, User
|
||||||
|
|
||||||
|
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
||||||
|
|
||||||
|
TEST_USER_IDS: dict[str, str] = {
|
||||||
|
"free": "00000000-0000-0000-0000-000000000001",
|
||||||
|
"pro": "00000000-0000-0000-0000-000000000002",
|
||||||
|
"power": "00000000-0000-0000-0000-000000000003",
|
||||||
|
"team": "00000000-0000-0000-0000-000000000004",
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Async SQLite engine ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
_TEST_ENGINE = create_async_engine(
|
||||||
|
"sqlite+aiosqlite://",
|
||||||
|
connect_args={"check_same_thread": False},
|
||||||
|
poolclass=StaticPool,
|
||||||
|
)
|
||||||
|
|
||||||
|
_TestSessionLocal = async_sessionmaker(
|
||||||
|
_TEST_ENGINE,
|
||||||
|
expire_on_commit=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Enable foreign key enforcement for SQLite (off by default).
|
||||||
|
@event.listens_for(_TEST_ENGINE.sync_engine, "connect")
|
||||||
|
def _set_sqlite_pragma(dbapi_conn, _connection_record): # noqa: ANN001
|
||||||
|
cursor = dbapi_conn.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(autouse=True)
|
||||||
|
async def _create_tables():
|
||||||
|
"""Create all tables before each test, seed test users, then drop after."""
|
||||||
|
async with _TEST_ENGINE.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.create_all)
|
||||||
|
|
||||||
|
# Seed one User + Subscription per tier so FK constraints and auth work.
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
for tier, uid in TEST_USER_IDS.items():
|
||||||
|
session.add(User(
|
||||||
|
id=uid,
|
||||||
|
email=f"{tier}@test.com",
|
||||||
|
password_hash="$2b$12$fakehashfortesting000000000000000000000000000",
|
||||||
|
tier=tier,
|
||||||
|
))
|
||||||
|
session.add(Subscription(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=uid,
|
||||||
|
tier=tier,
|
||||||
|
stripe_subscription_id=f"sub_test_{tier}",
|
||||||
|
status="active",
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
yield
|
||||||
|
async with _TEST_ENGINE.begin() as conn:
|
||||||
|
await conn.run_sync(Base.metadata.drop_all)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def db_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
"""Yield a per-test async DB session."""
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
yield session
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # noqa: ANN001
|
||||||
|
"""FastAPI test client with ``get_session`` overridden to use the test DB."""
|
||||||
|
|
||||||
|
async def _override_get_session() -> AsyncGenerator[AsyncSession, None]:
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _override_get_session
|
||||||
|
with TestClient(app) as c:
|
||||||
|
yield c
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Seed data helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
Plugin(
|
||||||
|
id="plugin-github-sync",
|
||||||
|
name="GitHub Sync",
|
||||||
|
description="Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
version="1.0.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=0,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-slack-notify",
|
||||||
|
name="Slack Notifier",
|
||||||
|
description="Post task and timeline updates to Slack channels.",
|
||||||
|
version="1.2.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="communication",
|
||||||
|
price_cents=499,
|
||||||
|
permissions=json.dumps(["read:tasks", "read:timelines"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-time-tracker",
|
||||||
|
name="Time Tracker",
|
||||||
|
description="Track time spent on tasks with automatic reporting.",
|
||||||
|
version="0.9.1",
|
||||||
|
author_name="Third Party",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=999,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
|
||||||
|
"""Insert the 3 default approved plugins and return them."""
|
||||||
|
plugins = []
|
||||||
|
for template in _SEED_PLUGINS:
|
||||||
|
p = Plugin(
|
||||||
|
id=template.id,
|
||||||
|
name=template.name,
|
||||||
|
description=template.description,
|
||||||
|
version=template.version,
|
||||||
|
author_name=template.author_name,
|
||||||
|
category=template.category,
|
||||||
|
price_cents=template.price_cents,
|
||||||
|
permissions=template.permissions,
|
||||||
|
status=template.status,
|
||||||
|
s3_package_key=template.s3_package_key,
|
||||||
|
install_count=template.install_count,
|
||||||
|
avg_rating=template.avg_rating,
|
||||||
|
)
|
||||||
|
db_session.add(p)
|
||||||
|
plugins.append(p)
|
||||||
|
await db_session.commit()
|
||||||
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
|
# ── JWT helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def make_jwt(
|
||||||
|
tier: str = "power",
|
||||||
|
user_id: str | None = None,
|
||||||
|
email: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a signed test JWT.
|
||||||
|
|
||||||
|
Uses the fixed ``TEST_USER_IDS`` mapping so the auth middleware can
|
||||||
|
find the corresponding ``Subscription`` row in the test database.
|
||||||
|
"""
|
||||||
|
uid = user_id or TEST_USER_IDS.get(tier, str(uuid.uuid4()))
|
||||||
|
now = int(time.time())
|
||||||
|
payload = {
|
||||||
|
"sub": uid,
|
||||||
|
"email": email or f"{tier}@test.com",
|
||||||
|
"tier": tier,
|
||||||
|
"exp": now + 3600,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
|
|
||||||
|
|
||||||
|
def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, str]:
|
||||||
|
"""Return an Authorization header dict for the given tier."""
|
||||||
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||||
|
|
||||||
|
|
||||||
|
# ── S3 mock fixture ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
S3_TEST_BUCKET = "test-bucket"
|
||||||
|
S3_TEST_REGION = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def s3_bucket():
|
||||||
|
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
|
||||||
|
with mock_aws():
|
||||||
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
||||||
|
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
||||||
|
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
|
||||||
|
client = boto3.client("s3", region_name=S3_TEST_REGION)
|
||||||
|
client.create_bucket(Bucket=S3_TEST_BUCKET)
|
||||||
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||||
|
mock_settings.S3_BUCKET = S3_TEST_BUCKET
|
||||||
|
mock_settings.S3_REGION = S3_TEST_REGION
|
||||||
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||||
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||||
|
yield S3_TEST_BUCKET
|
||||||
@@ -1,214 +0,0 @@
|
|||||||
"""Unit tests for the agent registry, base classes, and tool loop."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _StubAgent(ChatAgent):
|
|
||||||
"""Minimal concrete agent for testing."""
|
|
||||||
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "stub"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "A stub agent for tests"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"echo: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
class _AnotherAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "another"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Another stub"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return "another"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _fresh_registry():
|
|
||||||
"""Reset the singleton between tests."""
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
yield
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def reg() -> AgentRegistry:
|
|
||||||
return AgentRegistry()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tests ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class TestRegisterAndGet:
|
|
||||||
def test_register_decorator(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
agent = reg.get("stub")
|
|
||||||
assert isinstance(agent, _StubAgent)
|
|
||||||
|
|
||||||
def test_get_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError, match="not found"):
|
|
||||||
reg.get("nonexistent")
|
|
||||||
|
|
||||||
def test_register_multiple(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
reg.register(_AnotherAgent)
|
|
||||||
assert reg.get("stub").get_name() == "stub"
|
|
||||||
assert reg.get("another").get_name() == "another"
|
|
||||||
|
|
||||||
|
|
||||||
class TestListAgents:
|
|
||||||
def test_empty(self, reg: AgentRegistry) -> None:
|
|
||||||
assert reg.list_agents() == []
|
|
||||||
|
|
||||||
def test_list_after_register(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
agents = reg.list_agents()
|
|
||||||
assert len(agents) == 1
|
|
||||||
assert agents[0] == {"name": "stub", "description": "A stub agent for tests"}
|
|
||||||
|
|
||||||
def test_list_multiple(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
reg.register(_AnotherAgent)
|
|
||||||
names = {a["name"] for a in reg.list_agents()}
|
|
||||||
assert names == {"stub", "another"}
|
|
||||||
|
|
||||||
|
|
||||||
class TestCallAgent:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
reg.register(_StubAgent)
|
|
||||||
result = await reg.call_agent("stub", "hello", {})
|
|
||||||
assert result == "echo: hello"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_call_unknown_raises(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
await reg.call_agent("nope", "hi", {})
|
|
||||||
|
|
||||||
|
|
||||||
class TestSingleton:
|
|
||||||
def test_singleton_identity(self) -> None:
|
|
||||||
a = AgentRegistry()
|
|
||||||
b = AgentRegistry()
|
|
||||||
assert a is b
|
|
||||||
|
|
||||||
|
|
||||||
class TestToolLoop:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_no_tool_calls(self) -> None:
|
|
||||||
"""When the LLM responds without tool calls, return content directly."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
ai_msg = MagicMock()
|
|
||||||
ai_msg.content = "final answer"
|
|
||||||
ai_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=ai_msg)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [])
|
|
||||||
assert result == "final answer"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_tool_call_then_answer(self) -> None:
|
|
||||||
"""LLM requests one tool call, gets result, then answers."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
# First response: tool call
|
|
||||||
tool_call_msg = MagicMock()
|
|
||||||
tool_call_msg.content = ""
|
|
||||||
tool_call_msg.tool_calls = [
|
|
||||||
{"id": "call_1", "name": "my_tool", "args": {"x": 1}}
|
|
||||||
]
|
|
||||||
|
|
||||||
# Second response: final answer
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "done"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
# Mock tool
|
|
||||||
tool = AsyncMock()
|
|
||||||
tool.name = "my_tool"
|
|
||||||
tool.ainvoke = AsyncMock(return_value="tool_result")
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [tool])
|
|
||||||
assert result == "done"
|
|
||||||
tool.ainvoke.assert_called_once_with({"x": 1})
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unknown_tool_handled(self) -> None:
|
|
||||||
"""Unknown tool names produce an error message instead of crashing."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
tool_call_msg = MagicMock()
|
|
||||||
tool_call_msg.content = ""
|
|
||||||
tool_call_msg.tool_calls = [
|
|
||||||
{"id": "call_1", "name": "missing", "args": {}}
|
|
||||||
]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "recovered"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm)
|
|
||||||
llm.ainvoke = AsyncMock(side_effect=[tool_call_msg, final_msg])
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [])
|
|
||||||
assert result == "recovered"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_max_iter_reached(self) -> None:
|
|
||||||
"""When max iterations are exhausted, a final no-tools call is made."""
|
|
||||||
agent = _StubAgent()
|
|
||||||
|
|
||||||
# Every response requests a tool call
|
|
||||||
loop_msg = MagicMock()
|
|
||||||
loop_msg.content = ""
|
|
||||||
loop_msg.tool_calls = [
|
|
||||||
{"id": "call_x", "name": "t", "args": {}}
|
|
||||||
]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = "gave up"
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
tool = AsyncMock()
|
|
||||||
tool.name = "t"
|
|
||||||
tool.ainvoke = AsyncMock(return_value="ok")
|
|
||||||
|
|
||||||
llm_with_tools = AsyncMock()
|
|
||||||
llm_with_tools.ainvoke = AsyncMock(return_value=loop_msg)
|
|
||||||
|
|
||||||
llm = AsyncMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=llm_with_tools)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
|
|
||||||
result = await agent._tool_loop(llm, [], [tool], max_iter=2)
|
|
||||||
assert result == "gave up"
|
|
||||||
assert llm_with_tools.ainvoke.call_count == 2
|
|
||||||
871
tests/test_agent_runner.py
Normal file
871
tests/test_agent_runner.py
Normal file
@@ -0,0 +1,871 @@
|
|||||||
|
"""Tests for Step 3.4: agent_runner module.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
Unit:
|
||||||
|
- _is_overdue — cron schedule overdue detection
|
||||||
|
- _extract_items_from_content — LLM extraction + JSON parsing + validation
|
||||||
|
- _send_insert_to_client — tool_call frame construction + timeout
|
||||||
|
- run_local_agent — end-to-end local agent happy path
|
||||||
|
- run_local_agent — device offline path
|
||||||
|
- run_local_agent — file-read timeout path
|
||||||
|
- run_local_agent — LLM extraction error path
|
||||||
|
- run_cloud_agent — stub returns error immediately
|
||||||
|
- trigger_pending_runs — overdue local + cloud dispatched
|
||||||
|
- trigger_pending_runs — non-overdue skipped
|
||||||
|
- trigger_pending_runs — device_id filter for local agents
|
||||||
|
|
||||||
|
Integration:
|
||||||
|
- POST /agents/{id}/run — 404 on unknown agent
|
||||||
|
- POST /agents/{id}/run — creates run log + dispatches background task
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.core.agent_runner import (
|
||||||
|
_extract_items_from_content,
|
||||||
|
_is_overdue,
|
||||||
|
_send_insert_to_client,
|
||||||
|
run_cloud_agent,
|
||||||
|
run_local_agent,
|
||||||
|
trigger_pending_runs,
|
||||||
|
)
|
||||||
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_FREE_UID = TEST_USER_IDS["free"]
|
||||||
|
_PRO_UID = TEST_USER_IDS["pro"]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
|
||||||
|
return LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
name="Test Local Agent",
|
||||||
|
directory_paths=["/home/user/emails"],
|
||||||
|
data_types=["tasks", "notes"],
|
||||||
|
prompt_template="Extract tasks and notes from this document.",
|
||||||
|
file_extensions=[".txt", ".eml"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
|
||||||
|
return CloudAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
provider="gmail",
|
||||||
|
name="Test Gmail Agent",
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks from email.",
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
|
||||||
|
return AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
|
||||||
|
mgr = DeviceConnectionManager()
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
mgr.register(user_id, device_id, ws)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _is_overdue
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_is_overdue_never_run():
|
||||||
|
"""An agent that has never run is always overdue."""
|
||||||
|
assert _is_overdue("0 */6 * * *", None) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_very_recently_run():
|
||||||
|
"""An agent that just ran is not overdue."""
|
||||||
|
last = datetime.now(timezone.utc)
|
||||||
|
assert _is_overdue("0 */6 * * *", last) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_long_ago():
|
||||||
|
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
|
||||||
|
from datetime import timedelta
|
||||||
|
last = datetime.now(timezone.utc) - timedelta(days=2)
|
||||||
|
assert _is_overdue("0 */6 * * *", last) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_invalid_cron_returns_false():
|
||||||
|
"""Unparseable cron must not raise and should return False (fail-safe)."""
|
||||||
|
assert _is_overdue("not a cron", None) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_naive_datetime():
|
||||||
|
"""Naive datetime objects are handled without raising."""
|
||||||
|
from datetime import timedelta
|
||||||
|
last = datetime.utcnow() - timedelta(days=1) # naive
|
||||||
|
# Should not raise.
|
||||||
|
result = _is_overdue("0 */6 * * *", last)
|
||||||
|
assert isinstance(result, bool)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _extract_items_from_content
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_happy_path():
|
||||||
|
"""LLM returns valid JSON array; items with allowed tables are returned."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
|
||||||
|
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content(
|
||||||
|
"Extract tasks and notes.",
|
||||||
|
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
|
||||||
|
["tasks", "notes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(items) == 2
|
||||||
|
assert items[0]["table"] == "tasks"
|
||||||
|
assert items[0]["data"]["title"] == "Buy milk"
|
||||||
|
assert items[1]["table"] == "notes"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_strips_forbidden_fields():
|
||||||
|
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{
|
||||||
|
"table": "tasks",
|
||||||
|
"data": {
|
||||||
|
"title": "Review PR",
|
||||||
|
"id": "should-be-removed",
|
||||||
|
"createdAt": 99999,
|
||||||
|
"isAiSuggested": 0,
|
||||||
|
"isApproved": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
data = items[0]["data"]
|
||||||
|
assert "id" not in data
|
||||||
|
assert "createdAt" not in data
|
||||||
|
assert "isAiSuggested" not in data
|
||||||
|
assert "isApproved" not in data
|
||||||
|
assert data["title"] == "Review PR"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_invalid_json_returns_empty():
|
||||||
|
"""LLM returning invalid JSON must return empty list without raising."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = "Sorry, I cannot extract anything."
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||||
|
|
||||||
|
assert items == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_disallowed_table_filtered():
|
||||||
|
"""Items whose table is not in data_types are discarded."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{"table": "tasks", "data": {"title": "Valid task"}},
|
||||||
|
{"table": "projects", "data": {"name": "Should be filtered"}},
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
# Only "tasks" is in data_types — "projects" should be filtered.
|
||||||
|
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0]["table"] == "tasks"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_empty_data_types_returns_empty():
|
||||||
|
"""If no allowed data_types match, skip LLM call and return immediately."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content("Extract.", "content", [])
|
||||||
|
|
||||||
|
mock_llm.ainvoke.assert_not_called()
|
||||||
|
assert items == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_llm_error_propagates():
|
||||||
|
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
with pytest.raises(RuntimeError, match="API unavailable"):
|
||||||
|
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _send_insert_to_client
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_insert_to_client_happy_path():
|
||||||
|
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
sent_payloads: list[dict] = []
|
||||||
|
original_send = mgr.send_frame
|
||||||
|
|
||||||
|
async def _capture_send(uid: str, frame: dict) -> None:
|
||||||
|
sent_payloads.append(frame)
|
||||||
|
# Immediately resolve the pending call with a success result.
|
||||||
|
call_id = frame["id"]
|
||||||
|
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
|
||||||
|
|
||||||
|
mgr.send_frame = _capture_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(sent_payloads) == 1
|
||||||
|
payload = sent_payloads[0]
|
||||||
|
assert payload["action"] == "insert"
|
||||||
|
assert payload["table"] == "tasks"
|
||||||
|
assert payload["data"]["title"] == "Buy milk"
|
||||||
|
assert payload["data"]["isAiSuggested"] == 1
|
||||||
|
assert payload["data"]["isApproved"] == 0
|
||||||
|
assert result["row"]["title"] == "Buy milk"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_insert_to_client_timeout():
|
||||||
|
"""asyncio.TimeoutError is raised when Electron does not respond."""
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
async def _slow_send(uid: str, frame: dict) -> None:
|
||||||
|
# Never resolve the pending call.
|
||||||
|
pass
|
||||||
|
|
||||||
|
mgr.send_frame = _slow_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_local_agent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_device_offline():
|
||||||
|
"""run_local_agent marks run as error when device is offline."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = DeviceConnectionManager() # Empty — no device registered.
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("not connected" in e for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_happy_path():
|
||||||
|
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
# Build a fake agent_data frame (will be queued after send).
|
||||||
|
file_frame = {
|
||||||
|
"type": "agent_data",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
|
||||||
|
}
|
||||||
|
agent_complete_frame = None # sentinel
|
||||||
|
|
||||||
|
sent_frames: list[dict] = []
|
||||||
|
|
||||||
|
async def _mock_send(uid: str, frame: dict) -> None:
|
||||||
|
sent_frames.append(frame)
|
||||||
|
if frame.get("type") == "agent_run":
|
||||||
|
# Simulate Electron responding with file data then agent_complete.
|
||||||
|
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||||
|
await q.put(file_frame)
|
||||||
|
await q.put(agent_complete_frame)
|
||||||
|
elif frame.get("type") == "tool_call":
|
||||||
|
# Resolve the pending insert immediately.
|
||||||
|
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
|
||||||
|
|
||||||
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "success"
|
||||||
|
assert kwargs["items_processed"] == 1
|
||||||
|
assert kwargs["items_created"] == 1
|
||||||
|
assert kwargs["errors"] == []
|
||||||
|
assert kwargs["update_config_last_run"] is True
|
||||||
|
|
||||||
|
# Verify agent_run frame was sent.
|
||||||
|
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||||
|
assert len(agent_run_frames) == 1
|
||||||
|
assert agent_run_frames[0]["agent_id"] == config.id
|
||||||
|
assert "paths" in agent_run_frames[0]["config"]
|
||||||
|
|
||||||
|
# Verify insert frame was sent with AI flags.
|
||||||
|
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
|
||||||
|
assert len(insert_frames) == 1
|
||||||
|
assert insert_frames[0]["data"]["isAiSuggested"] == 1
|
||||||
|
assert insert_frames[0]["data"]["isApproved"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_file_read_timeout():
|
||||||
|
"""run_local_agent marks run as partial/error when device stops sending files."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
async def _mock_send(uid: str, frame: dict) -> None:
|
||||||
|
# Don't put anything in the queue — simulate stalled device.
|
||||||
|
pass
|
||||||
|
|
||||||
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error" # No items created, so error (not partial).
|
||||||
|
assert any("timed out" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_llm_extraction_error():
|
||||||
|
"""LLM errors per-file are recorded; run continues for remaining files."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
file_frame = {
|
||||||
|
"type": "agent_data",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"files": [
|
||||||
|
{"path": "/file1.eml", "content": "Email one."},
|
||||||
|
{"path": "/file2.eml", "content": "Email two."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _mock_send(uid: str, frame: dict) -> None:
|
||||||
|
if frame.get("type") == "agent_run":
|
||||||
|
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||||
|
await q.put(file_frame)
|
||||||
|
await q.put(None) # agent_complete sentinel
|
||||||
|
|
||||||
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert kwargs["items_processed"] == 2 # Both files attempted.
|
||||||
|
assert kwargs["items_created"] == 0
|
||||||
|
assert len(kwargs["errors"]) == 2 # One error per file.
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_cloud_agent (stub)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_device_offline():
|
||||||
|
"""Cloud agent aborts immediately when no device is connected."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = DeviceConnectionManager() # empty — no devices registered
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_no_oauth_token():
|
||||||
|
"""Cloud agent errors when no OAuth token is stored."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = None
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_token_decrypt_failure():
|
||||||
|
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
valid_key = _Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_happy_path_gmail():
|
||||||
|
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
||||||
|
from app.integrations import EmailMessage, encrypt_token
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
|
||||||
|
fernet_key = _Fernet.generate_key().decode()
|
||||||
|
credentials = {
|
||||||
|
"token": "access_abc",
|
||||||
|
"refresh_token": "refresh_xyz",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "cid",
|
||||||
|
"client_secret": "csec",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.provider = "gmail"
|
||||||
|
config.prompt_template = "Extract tasks from this email."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as ms:
|
||||||
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||||
|
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
sample_email = EmailMessage(
|
||||||
|
id="msg001",
|
||||||
|
subject="Action required",
|
||||||
|
sender="boss@company.com",
|
||||||
|
body_text="Please fix the bug by Friday.",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as mock_int_settings, \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
||||||
|
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
||||||
|
patch("app.core.agent_runner.async_session"):
|
||||||
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
|
||||||
|
mock_gmail = AsyncMock()
|
||||||
|
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
||||||
|
mock_gmail.refreshed_credentials = None
|
||||||
|
|
||||||
|
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_gmail):
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_extract.assert_called_once()
|
||||||
|
mock_insert.assert_called_once()
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "success"
|
||||||
|
assert kwargs["items_processed"] == 1
|
||||||
|
assert kwargs["items_created"] == 1
|
||||||
|
assert kwargs["config_type"] == "cloud"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_provider_fetch_error():
|
||||||
|
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
||||||
|
credentials = {"token": "abc"}
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
||||||
|
config.prompt_template = "Extract tasks."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
||||||
|
mock_provider.refreshed_credentials = None
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||||
|
patch("app.core.agent_runner.async_session"):
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||||
|
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||||
|
from app.integrations import EmailMessage, encrypt_token
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
|
||||||
|
fernet_key = _Fernet.generate_key().decode()
|
||||||
|
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
||||||
|
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
||||||
|
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.prompt_template = "Extract tasks."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as ms:
|
||||||
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||||
|
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
||||||
|
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
||||||
|
|
||||||
|
# Track DB writes via mock async_session.
|
||||||
|
mock_cfg_row = MagicMock()
|
||||||
|
mock_cfg_row.oauth_token_encrypted = None
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||||
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
||||||
|
cfg_result = MagicMock()
|
||||||
|
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
||||||
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||||
|
mock_db.commit = AsyncMock()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
||||||
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||||
|
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
||||||
|
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
||||||
|
patch("app.integrations.settings") as mock_int_settings:
|
||||||
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
# The new encrypted token should have been written to the config row.
|
||||||
|
mock_encrypt.assert_called_once_with(fresh_credentials)
|
||||||
|
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finalize_run_updates_cloud_config_last_run_at():
|
||||||
|
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
||||||
|
from app.core.agent_runner import _finalize_run
|
||||||
|
|
||||||
|
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
||||||
|
run_log.id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.last_run_at = None
|
||||||
|
|
||||||
|
cfg_result = MagicMock()
|
||||||
|
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||||
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_db.merge = AsyncMock(return_value=run_log)
|
||||||
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||||
|
mock_db.commit = AsyncMock()
|
||||||
|
|
||||||
|
config_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="success",
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config_id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
# CloudAgentConfig.last_run_at should have been set.
|
||||||
|
assert mock_cfg.last_run_at is not None
|
||||||
|
mock_db.commit.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# trigger_pending_runs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_pending_runs_no_overdue():
|
||||||
|
"""If no agents are overdue trigger_pending_runs does nothing."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
config = _make_local_config()
|
||||||
|
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
||||||
|
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_pending_runs_device_id_filter():
|
||||||
|
"""Local agents are only triggered for the matching device_id."""
|
||||||
|
# The DB query already filters by device_id, so we verify the SELECT
|
||||||
|
# includes the device_id filter by checking that a config bound to a
|
||||||
|
# different device is never dispatched.
|
||||||
|
#
|
||||||
|
# Since trigger_pending_runs queries with device_id == "dev-001",
|
||||||
|
# simulate the DB returning an empty list (as it would for a mismatch).
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
|
mgr = _make_manager(device_id="dev-001")
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_pending_runs_dispatches_overdue():
|
||||||
|
"""Overdue local agent triggers run_local_agent sequentially."""
|
||||||
|
config = _make_local_config() # last_run_at=None → always overdue
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
call_order: list[str] = []
|
||||||
|
|
||||||
|
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
||||||
|
call_order.append("run_local")
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
||||||
|
# First call: query configs. Subsequent calls: create run_log.
|
||||||
|
mock_query_ctx = AsyncMock()
|
||||||
|
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
||||||
|
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_query_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
|
||||||
|
run_log_obj = AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=config.id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=_FREE_UID,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
mock_insert_ctx = AsyncMock()
|
||||||
|
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
||||||
|
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_insert_ctx.add = MagicMock()
|
||||||
|
mock_insert_ctx.commit = AsyncMock()
|
||||||
|
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
||||||
|
|
||||||
|
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
||||||
|
|
||||||
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
|
assert call_order == ["run_local"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration: POST /agents/{id}/run
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
"""Route all get_session calls to the test SQLite session."""
|
||||||
|
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_run_unknown_agent(client):
|
||||||
|
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/v1/agents/{uuid.uuid4()}/run",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||||
|
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
||||||
|
# Create the local agent config in the DB.
|
||||||
|
config = LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=TEST_USER_IDS["power"],
|
||||||
|
device_id="dev-001",
|
||||||
|
name="My Agent",
|
||||||
|
directory_paths=["/home/user/docs"],
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks.",
|
||||||
|
file_extensions=[".txt"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
db_session.add(config)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
dispatched: list = []
|
||||||
|
|
||||||
|
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||||
|
dispatched.append((user_id, cfg.id))
|
||||||
|
|
||||||
|
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||||
|
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
||||||
|
patch("asyncio.create_task") as mock_create_task:
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/v1/agents/{config.id}/run",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 202
|
||||||
|
data = resp.json()
|
||||||
|
assert data["agent_id"] == config.id
|
||||||
|
assert data["status"] == "running"
|
||||||
|
assert data["agent_type"] == "local"
|
||||||
|
|
||||||
|
# Verify create_task was called (dispatching background run).
|
||||||
|
mock_create_task.assert_called_once()
|
||||||
243
tests/test_agent_setup.py
Normal file
243
tests/test_agent_setup.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""Tests for the Chatbot Journey endpoints.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
1. Start journey for local agent → session_id + first question, done=False
|
||||||
|
2. Start journey for cloud agent → contextual email-focused question
|
||||||
|
3. Start journey with existing agent_id → session seeded, first question returned
|
||||||
|
4. Start journey with non-existent agent_id → still succeeds (graceful fallback)
|
||||||
|
5. Message: continue conversation → done=False, follow-up question returned
|
||||||
|
6. Message: LLM wraps up → done=True + prompt_template extracted correctly
|
||||||
|
7. Message with max-turns nudge → no crash, returns response
|
||||||
|
8. Invalid session_id → 404
|
||||||
|
9. Expired session → 404
|
||||||
|
10. Session ownership: user B cannot access user A's session
|
||||||
|
11. No JWT on /start → 401
|
||||||
|
12. No JWT on /message → 401
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.routes.agent_setup import (
|
||||||
|
_SESSION_TTL_SECONDS,
|
||||||
|
_TEMPLATE_END,
|
||||||
|
_TEMPLATE_START,
|
||||||
|
_extract_template,
|
||||||
|
_sessions,
|
||||||
|
)
|
||||||
|
from app.models import LocalAgentConfig
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _start(client: TestClient, agent_type: str = "local", agent_id: str | None = None, tier: str = "power") -> dict:
|
||||||
|
body: dict = {"agent_type": agent_type}
|
||||||
|
if agent_id:
|
||||||
|
body["agent_id"] = agent_id
|
||||||
|
resp = client.post("/api/v1/agents/journey/start", json=body, headers=auth_header(tier))
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def _message(client: TestClient, session_id: str, message: str, tier: str = "power") -> dict:
|
||||||
|
return client.post(
|
||||||
|
"/api/v1/agents/journey/message",
|
||||||
|
json={"session_id": session_id, "message": message},
|
||||||
|
headers=auth_header(tier),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: _extract_template ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_template_present():
|
||||||
|
text = f"Some preamble.\n{_TEMPLATE_START}\nExtract tasks from emails.\n{_TEMPLATE_END}\nTrailing text."
|
||||||
|
result = _extract_template(text)
|
||||||
|
assert result == "Extract tasks from emails."
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_template_absent():
|
||||||
|
assert _extract_template("No markers here.") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_template_empty_content():
|
||||||
|
text = f"{_TEMPLATE_START}\n{_TEMPLATE_END}"
|
||||||
|
assert _extract_template(text) is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Start journey ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_local(client: TestClient):
|
||||||
|
resp = _start(client, agent_type="local")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert "session_id" in body
|
||||||
|
assert body["done"] is False
|
||||||
|
assert body["prompt_template"] is None
|
||||||
|
assert len(body["message"]) > 0
|
||||||
|
# Local question should be about files/directories
|
||||||
|
assert any(w in body["message"].lower() for w in ("file", "director", "document", "monitor"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_cloud(client: TestClient):
|
||||||
|
resp = _start(client, agent_type="cloud")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["done"] is False
|
||||||
|
# Cloud question should mention emails or messages
|
||||||
|
assert any(w in body["message"].lower() for w in ("email", "message", "communication"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_with_agent_id(client: TestClient, db_session: AsyncSession):
|
||||||
|
"""When agent_id is provided, session should be created even if agent doesn't exist."""
|
||||||
|
fake_agent_id = str(uuid.uuid4())
|
||||||
|
resp = _start(client, agent_type="local", agent_id=fake_agent_id)
|
||||||
|
# Should succeed gracefully even if the agent_id doesn't exist
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["done"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_with_existing_agent(client: TestClient, db_session: AsyncSession):
|
||||||
|
"""When a real local agent is provided, session is seeded with its prompt_template."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
user_id = TEST_USER_IDS["power"]
|
||||||
|
agent = LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
name="Test Agent",
|
||||||
|
device_id="device-1",
|
||||||
|
directory_paths=["/home/user/emails"],
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks from .eml files.",
|
||||||
|
file_extensions=[".eml"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _seed():
|
||||||
|
db_session.add(agent)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(_seed())
|
||||||
|
|
||||||
|
resp = _start(client, agent_type="local", agent_id=agent.id)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["done"] is False
|
||||||
|
# The session should be stored
|
||||||
|
assert body["session_id"] in _sessions
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_requires_auth(client: TestClient):
|
||||||
|
resp = client.post("/api/v1/agents/journey/start", json={"agent_type": "local"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── Message ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_continues_conversation(client: TestClient):
|
||||||
|
"""A mid-journey reply (no template markers) returns done=False."""
|
||||||
|
follow_up = "That looks good. Can you tell me more about priority rules?"
|
||||||
|
|
||||||
|
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
||||||
|
start_resp = _start(client, agent_type="local")
|
||||||
|
assert start_resp.status_code == 200
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
msg_resp = _message(client, session_id, "I have .eml and .txt files")
|
||||||
|
assert msg_resp.status_code == 200
|
||||||
|
body = msg_resp.json()
|
||||||
|
assert body["done"] is False
|
||||||
|
assert body["prompt_template"] is None
|
||||||
|
assert body["message"] == follow_up
|
||||||
|
assert body["session_id"] == session_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_produces_template(client: TestClient):
|
||||||
|
"""When the LLM includes PROMPT_TEMPLATE markers, done=True and prompt_template is set."""
|
||||||
|
final_template = "Extract tasks from email. Subject → title. 'urgent' → high priority."
|
||||||
|
llm_response = (
|
||||||
|
"Great, I have all the information I need.\n"
|
||||||
|
f"{_TEMPLATE_START}\n{final_template}\n{_TEMPLATE_END}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=llm_response)):
|
||||||
|
start_resp = _start(client, agent_type="cloud")
|
||||||
|
assert start_resp.status_code == 200
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
msg_resp = _message(client, session_id, "Only invoices from clients")
|
||||||
|
assert msg_resp.status_code == 200
|
||||||
|
body = msg_resp.json()
|
||||||
|
assert body["done"] is True
|
||||||
|
assert body["prompt_template"] == final_template
|
||||||
|
# Session should be cleaned up
|
||||||
|
assert session_id not in _sessions
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_invalid_session(client: TestClient):
|
||||||
|
resp = _message(client, "nonexistent-session-id", "hello")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_wrong_owner(client: TestClient):
|
||||||
|
"""User B cannot access user A's session."""
|
||||||
|
start_resp = _start(client, agent_type="local", tier="power")
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
# user with "pro" tier (different user_id) tries to send a message
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/agents/journey/message",
|
||||||
|
json={"session_id": session_id, "message": "hello"},
|
||||||
|
headers=auth_header("pro"), # different user
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_expired_session(client: TestClient):
|
||||||
|
"""Expired sessions return 404."""
|
||||||
|
start_resp = _start(client, agent_type="local")
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
# Manually expire the session
|
||||||
|
_sessions[session_id].created_at = time.monotonic() - _SESSION_TTL_SECONDS - 1
|
||||||
|
|
||||||
|
resp = _message(client, session_id, "hello")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_requires_auth(client: TestClient):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/agents/journey/message",
|
||||||
|
json={"session_id": "any", "message": "hello"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_max_turns_nudge(client: TestClient):
|
||||||
|
"""After _MAX_TURNS user messages, a system nudge is appended but no crash occurs."""
|
||||||
|
from app.api.routes.agent_setup import _MAX_TURNS
|
||||||
|
|
||||||
|
follow_up = "Tell me more about priority rules."
|
||||||
|
|
||||||
|
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
||||||
|
start_resp = _start(client, agent_type="local")
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
for i in range(_MAX_TURNS):
|
||||||
|
resp = _message(client, session_id, f"Answer {i + 1}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
# While no template produced, session must still exist
|
||||||
|
if resp.json()["done"]:
|
||||||
|
break # LLM decided to wrap up early — also fine
|
||||||
@@ -1,620 +0,0 @@
|
|||||||
"""Unit tests for the four domain-specific chat agents with mocked LLM."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
from app.agents.checkpoint_agent import CheckpointAgent
|
|
||||||
from app.agents.note_agent import NoteAgent
|
|
||||||
from app.agents.project_agent import ProjectAgent
|
|
||||||
from app.agents.task_agent import TaskAgent
|
|
||||||
from app.core.agent_registry import registry
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm(response_text: str) -> MagicMock:
|
|
||||||
"""Return a mock LLM that responds with *response_text* (no tool calls)."""
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.content = response_text
|
|
||||||
msg.tool_calls = []
|
|
||||||
llm = MagicMock()
|
|
||||||
bound = MagicMock()
|
|
||||||
bound.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
llm.bind_tools = MagicMock(return_value=bound)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm_with_tool_call(
|
|
||||||
tool_name: str, tool_args: dict[str, Any], final_text: str
|
|
||||||
) -> MagicMock:
|
|
||||||
"""Mock LLM that fires one tool call then returns *final_text*."""
|
|
||||||
tool_msg = MagicMock()
|
|
||||||
tool_msg.content = ""
|
|
||||||
tool_msg.tool_calls = [{"id": "call_1", "name": tool_name, "args": tool_args}]
|
|
||||||
|
|
||||||
final_msg = MagicMock()
|
|
||||||
final_msg.content = final_text
|
|
||||||
final_msg.tool_calls = []
|
|
||||||
|
|
||||||
bound = MagicMock()
|
|
||||||
bound.ainvoke = AsyncMock(side_effect=[tool_msg, final_msg])
|
|
||||||
|
|
||||||
llm = MagicMock()
|
|
||||||
llm.bind_tools = MagicMock(return_value=bound)
|
|
||||||
llm.ainvoke = AsyncMock(return_value=final_msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
# ── Registration ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestAgentRegistration:
|
|
||||||
def test_all_agents_registered(self) -> None:
|
|
||||||
names = {a["name"] for a in registry.list_agents()}
|
|
||||||
assert {
|
|
||||||
"task_agent", "checkpoint_agent", "project_agent", "note_agent"
|
|
||||||
}.issubset(names)
|
|
||||||
|
|
||||||
def test_registry_returns_correct_types(self) -> None:
|
|
||||||
assert isinstance(registry.get("task_agent"), TaskAgent)
|
|
||||||
assert isinstance(registry.get("checkpoint_agent"), CheckpointAgent)
|
|
||||||
assert isinstance(registry.get("project_agent"), ProjectAgent)
|
|
||||||
assert isinstance(registry.get("note_agent"), NoteAgent)
|
|
||||||
|
|
||||||
def test_descriptions_present(self) -> None:
|
|
||||||
for agent_info in registry.list_agents():
|
|
||||||
assert agent_info["description"], f"Empty description: {agent_info['name']}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── TaskAgent ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestTaskAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert TaskAgent().get_name() == "task_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert TaskAgent().get_description() == "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(TaskAgent().get_tools()) == 8
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in TaskAgent().get_tools()}
|
|
||||||
assert names == {
|
|
||||||
"list_tasks",
|
|
||||||
"create_task",
|
|
||||||
"update_task",
|
|
||||||
"delete_task",
|
|
||||||
"list_tasks_due_today",
|
|
||||||
"list_task_comments",
|
|
||||||
"add_task_comment",
|
|
||||||
"delete_task_comment",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_returns_string(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Task created.")
|
|
||||||
result = await TaskAgent().handle("create a task", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Here are your tasks.")
|
|
||||||
result = await TaskAgent().handle("list my tasks", {})
|
|
||||||
assert result == "Here are your tasks."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_task_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_task",
|
|
||||||
{"title": "Buy groceries", "priority": "low"},
|
|
||||||
"Task 'Buy groceries' created.",
|
|
||||||
)
|
|
||||||
result = await TaskAgent().handle("add a grocery task", {})
|
|
||||||
assert result == "Task 'Buy groceries' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await TaskAgent().handle("help", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_rich_context(self) -> None:
|
|
||||||
context = {
|
|
||||||
"user_profile": {"id": "u1", "tier": "pro"},
|
|
||||||
"recent_tasks": [{"id": "t1", "title": "Old task"}],
|
|
||||||
}
|
|
||||||
with patch("app.agents.task_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Tasks listed.")
|
|
||||||
result = await TaskAgent().handle("show tasks", context)
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestTaskAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_defaults(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks
|
|
||||||
result = await list_tasks.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "tasks"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_with_status_filter(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks
|
|
||||||
result = await list_tasks.ainvoke({"status": "done"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["filters"]["status"] == "done"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_task_defaults(self) -> None:
|
|
||||||
from app.agents.task_agent import create_task
|
|
||||||
result = await create_task.ainvoke({"title": "Test task"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["table"] == "tasks"
|
|
||||||
assert data["data"]["title"] == "Test task"
|
|
||||||
assert data["data"]["status"] == "todo"
|
|
||||||
assert data["data"]["priority"] == "medium"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_task_with_all_fields(self) -> None:
|
|
||||||
from app.agents.task_agent import create_task
|
|
||||||
result = await create_task.ainvoke({
|
|
||||||
"title": "Deploy",
|
|
||||||
"priority": "high",
|
|
||||||
"status": "in_progress",
|
|
||||||
"project_id": "p1",
|
|
||||||
"is_ai_suggested": 1,
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["priority"] == "high"
|
|
||||||
assert data["data"]["status"] == "in_progress"
|
|
||||||
assert data["data"]["projectId"] == "p1"
|
|
||||||
assert data["data"]["isAiSuggested"] == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_task_with_status(self) -> None:
|
|
||||||
from app.agents.task_agent import update_task
|
|
||||||
result = await update_task.ainvoke({"task_id": "t1", "status": "done"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "update_record"
|
|
||||||
assert data["data"]["id"] == "t1"
|
|
||||||
assert data["data"]["updates"]["status"] == "done"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_task_empty_updates(self) -> None:
|
|
||||||
from app.agents.task_agent import update_task
|
|
||||||
result = await update_task.ainvoke({"task_id": "t1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_task(self) -> None:
|
|
||||||
from app.agents.task_agent import delete_task
|
|
||||||
result = await delete_task.ainvoke({"task_id": "t1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["table"] == "tasks"
|
|
||||||
assert data["data"]["id"] == "t1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_tasks_due_today(self) -> None:
|
|
||||||
from app.agents.task_agent import list_tasks_due_today
|
|
||||||
result = await list_tasks_due_today.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list_due_today"
|
|
||||||
assert data["table"] == "tasks"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_task_comments(self) -> None:
|
|
||||||
from app.agents.task_agent import list_task_comments
|
|
||||||
result = await list_task_comments.ainvoke({"task_id": "t1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "taskComments"
|
|
||||||
assert data["filters"]["taskId"] == "t1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_add_task_comment(self) -> None:
|
|
||||||
from app.agents.task_agent import add_task_comment
|
|
||||||
result = await add_task_comment.ainvoke({
|
|
||||||
"task_id": "t1",
|
|
||||||
"author": "Alice",
|
|
||||||
"content": "Looks good!",
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["table"] == "taskComments"
|
|
||||||
assert data["data"]["taskId"] == "t1"
|
|
||||||
assert data["data"]["author"] == "Alice"
|
|
||||||
assert data["data"]["content"] == "Looks good!"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_task_comment(self) -> None:
|
|
||||||
from app.agents.task_agent import delete_task_comment
|
|
||||||
result = await delete_task_comment.ainvoke({"comment_id": "c1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["table"] == "taskComments"
|
|
||||||
assert data["data"]["id"] == "c1"
|
|
||||||
|
|
||||||
|
|
||||||
# ── CheckpointAgent ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestCheckpointAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert CheckpointAgent().get_name() == "checkpoint_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert CheckpointAgent().get_description() == "Manages project checkpoints (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(CheckpointAgent().get_tools()) == 4
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in CheckpointAgent().get_tools()}
|
|
||||||
assert names == {"list_checkpoints", "create_checkpoint", "update_checkpoint", "delete_checkpoint"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("No checkpoints found.")
|
|
||||||
result = await CheckpointAgent().handle("list checkpoints", {})
|
|
||||||
assert result == "No checkpoints found."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_checkpoint",
|
|
||||||
{"project_id": "p1", "title": "MVP Launch", "date": 1700000000000},
|
|
||||||
"Checkpoint 'MVP Launch' created.",
|
|
||||||
)
|
|
||||||
result = await CheckpointAgent().handle("add MVP checkpoint", {})
|
|
||||||
assert result == "Checkpoint 'MVP Launch' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.checkpoint_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await CheckpointAgent().handle("show milestones", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestCheckpointAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_checkpoints_no_project(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import list_checkpoints
|
|
||||||
result = await list_checkpoints.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "checkpoints"
|
|
||||||
assert data["filters"]["projectId"] is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_checkpoints_with_project(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import list_checkpoints
|
|
||||||
result = await list_checkpoints.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["filters"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_checkpoint(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import create_checkpoint
|
|
||||||
result = await create_checkpoint.ainvoke({
|
|
||||||
"project_id": "p1",
|
|
||||||
"title": "Beta release",
|
|
||||||
"date": 1700000000000,
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["table"] == "checkpoints"
|
|
||||||
assert data["data"]["projectId"] == "p1"
|
|
||||||
assert data["data"]["title"] == "Beta release"
|
|
||||||
assert data["data"]["date"] == 1700000000000
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_checkpoint_ai_suggested(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import create_checkpoint
|
|
||||||
result = await create_checkpoint.ainvoke({
|
|
||||||
"project_id": "p1",
|
|
||||||
"title": "Review",
|
|
||||||
"date": 1700000000000,
|
|
||||||
"is_ai_suggested": 1,
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["isAiSuggested"] == 1
|
|
||||||
assert data["data"]["isApproved"] == 0
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_checkpoint_approve(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import update_checkpoint
|
|
||||||
result = await update_checkpoint.ainvoke({
|
|
||||||
"checkpoint_id": "c1",
|
|
||||||
"is_approved": 1,
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "update_record"
|
|
||||||
assert data["data"]["id"] == "c1"
|
|
||||||
assert data["data"]["updates"]["isApproved"] == 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_checkpoint_empty_updates(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import update_checkpoint
|
|
||||||
result = await update_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_checkpoint(self) -> None:
|
|
||||||
from app.agents.checkpoint_agent import delete_checkpoint
|
|
||||||
result = await delete_checkpoint.ainvoke({"checkpoint_id": "c1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["table"] == "checkpoints"
|
|
||||||
assert data["data"]["id"] == "c1"
|
|
||||||
|
|
||||||
|
|
||||||
# ── ProjectAgent ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert ProjectAgent().get_name() == "project_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert ProjectAgent().get_description() == "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(ProjectAgent().get_tools()) == 6
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in ProjectAgent().get_tools()}
|
|
||||||
assert names == {
|
|
||||||
"list_projects",
|
|
||||||
"list_all_projects",
|
|
||||||
"get_project",
|
|
||||||
"create_project",
|
|
||||||
"update_project",
|
|
||||||
"delete_project",
|
|
||||||
}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Project Alpha is active.")
|
|
||||||
result = await ProjectAgent().handle("show my projects", {})
|
|
||||||
assert result == "Project Alpha is active."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_project_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_project",
|
|
||||||
{"name": "Pippo"},
|
|
||||||
"Project 'Pippo' created.",
|
|
||||||
)
|
|
||||||
result = await ProjectAgent().handle("create project Pippo", {})
|
|
||||||
assert result == "Project 'Pippo' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.project_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await ProjectAgent().handle("archive old project", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProjectAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_projects_defaults(self) -> None:
|
|
||||||
from app.agents.project_agent import list_projects
|
|
||||||
result = await list_projects.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "projects"
|
|
||||||
assert data["filters"]["includeArchived"] is False
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_projects_include_archived(self) -> None:
|
|
||||||
from app.agents.project_agent import list_projects
|
|
||||||
result = await list_projects.ainvoke({"include_archived": 1})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["filters"]["includeArchived"] is True
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_all_projects(self) -> None:
|
|
||||||
from app.agents.project_agent import list_all_projects
|
|
||||||
result = await list_all_projects.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list_all"
|
|
||||||
assert data["table"] == "projects"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_project(self) -> None:
|
|
||||||
from app.agents.project_agent import get_project
|
|
||||||
result = await get_project.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "get"
|
|
||||||
assert data["table"] == "projects"
|
|
||||||
assert data["data"]["id"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_project_name_only(self) -> None:
|
|
||||||
from app.agents.project_agent import create_project
|
|
||||||
result = await create_project.ainvoke({"name": "Alpha"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["data"]["name"] == "Alpha"
|
|
||||||
assert data["data"]["clientId"] is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_project_with_client(self) -> None:
|
|
||||||
from app.agents.project_agent import create_project
|
|
||||||
result = await create_project.ainvoke({"name": "Beta", "client_id": "cl1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["clientId"] == "cl1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_project_archive(self) -> None:
|
|
||||||
from app.agents.project_agent import update_project
|
|
||||||
result = await update_project.ainvoke({"project_id": "p1", "status": "archived"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "update_record"
|
|
||||||
assert data["data"]["id"] == "p1"
|
|
||||||
assert data["data"]["updates"]["status"] == "archived"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_project_empty_updates(self) -> None:
|
|
||||||
from app.agents.project_agent import update_project
|
|
||||||
result = await update_project.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_project(self) -> None:
|
|
||||||
from app.agents.project_agent import delete_project
|
|
||||||
result = await delete_project.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["data"]["id"] == "p1"
|
|
||||||
|
|
||||||
|
|
||||||
# ── NoteAgent ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoteAgent:
|
|
||||||
def test_name(self) -> None:
|
|
||||||
assert NoteAgent().get_name() == "note_agent"
|
|
||||||
|
|
||||||
def test_description(self) -> None:
|
|
||||||
assert NoteAgent().get_description() == "Manages notes: list, get, create, update, delete"
|
|
||||||
|
|
||||||
def test_get_tools_count(self) -> None:
|
|
||||||
assert len(NoteAgent().get_tools()) == 5
|
|
||||||
|
|
||||||
def test_tool_names(self) -> None:
|
|
||||||
names = {t.name for t in NoteAgent().get_tools()}
|
|
||||||
assert names == {"list_notes", "get_note", "create_note", "update_note", "delete_note"}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_no_tool_calls(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Note created.")
|
|
||||||
result = await NoteAgent().handle("create a note", {})
|
|
||||||
assert result == "Note created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_with_create_note_tool_call(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm_with_tool_call(
|
|
||||||
"create_note",
|
|
||||||
{"title": "Daily log", "content": "# Today\nAll good."},
|
|
||||||
"Note 'Daily log' created.",
|
|
||||||
)
|
|
||||||
result = await NoteAgent().handle("log today's progress", {})
|
|
||||||
assert result == "Note 'Daily log' created."
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_handle_accepts_empty_context(self) -> None:
|
|
||||||
with patch("app.agents.note_agent.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("Done.")
|
|
||||||
result = await NoteAgent().handle("show notes", {})
|
|
||||||
assert isinstance(result, str)
|
|
||||||
|
|
||||||
|
|
||||||
class TestNoteAgentTools:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_notes_no_project(self) -> None:
|
|
||||||
from app.agents.note_agent import list_notes
|
|
||||||
result = await list_notes.ainvoke({})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "list"
|
|
||||||
assert data["table"] == "notes"
|
|
||||||
assert data["filters"]["projectId"] is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_list_notes_with_project(self) -> None:
|
|
||||||
from app.agents.note_agent import list_notes
|
|
||||||
result = await list_notes.ainvoke({"project_id": "p1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["filters"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_get_note(self) -> None:
|
|
||||||
from app.agents.note_agent import get_note
|
|
||||||
result = await get_note.ainvoke({"note_id": "n1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "get"
|
|
||||||
assert data["table"] == "notes"
|
|
||||||
assert data["data"]["id"] == "n1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_note_minimal(self) -> None:
|
|
||||||
from app.agents.note_agent import create_note
|
|
||||||
result = await create_note.ainvoke({
|
|
||||||
"title": "Daily log",
|
|
||||||
"content": "# Today\nAll good.",
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "create_record"
|
|
||||||
assert data["table"] == "notes"
|
|
||||||
assert data["data"]["title"] == "Daily log"
|
|
||||||
assert data["data"]["content"] == "# Today\nAll good."
|
|
||||||
assert data["data"]["projectId"] is None
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_create_note_with_project(self) -> None:
|
|
||||||
from app.agents.note_agent import create_note
|
|
||||||
result = await create_note.ainvoke({
|
|
||||||
"title": "Sprint notes",
|
|
||||||
"content": "## Sprint 1",
|
|
||||||
"project_id": "p1",
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["projectId"] == "p1"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_note_content_only(self) -> None:
|
|
||||||
from app.agents.note_agent import update_note
|
|
||||||
result = await update_note.ainvoke({
|
|
||||||
"note_id": "n1",
|
|
||||||
"content": "# Updated content",
|
|
||||||
})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "update_record"
|
|
||||||
assert data["data"]["id"] == "n1"
|
|
||||||
assert data["data"]["updates"]["content"] == "# Updated content"
|
|
||||||
assert "title" not in data["data"]["updates"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_update_note_empty_updates(self) -> None:
|
|
||||||
from app.agents.note_agent import update_note
|
|
||||||
result = await update_note.ainvoke({"note_id": "n1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["data"]["updates"] == {}
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_delete_note(self) -> None:
|
|
||||||
from app.agents.note_agent import delete_note
|
|
||||||
result = await delete_note.ainvoke({"note_id": "n1"})
|
|
||||||
data = json.loads(result)
|
|
||||||
assert data["action"] == "delete_record"
|
|
||||||
assert data["table"] == "notes"
|
|
||||||
assert data["data"]["id"] == "n1"
|
|
||||||
206
tests/test_auth.py
Normal file
206
tests/test_auth.py
Normal file
@@ -0,0 +1,206 @@
|
|||||||
|
"""Tests for auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
|
Exercises the full auth lifecycle through the FastAPI TestClient against the
|
||||||
|
in-memory SQLite test database seeded by ``conftest.py``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from tests.conftest import auth_header, TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRegister ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRegister:
|
||||||
|
"""POST /api/v1/auth/register"""
|
||||||
|
|
||||||
|
def test_register_success(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "new@example.com", "password": "Str0ngP@ss!"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert "refresh_token" in data
|
||||||
|
assert "expires_at" in data
|
||||||
|
# expires_at should be a future millisecond timestamp
|
||||||
|
assert data["expires_at"] > int(time.time() * 1000)
|
||||||
|
|
||||||
|
def test_register_returns_valid_jwt(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "jwt-check@example.com", "password": "P@ss1234"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
token = resp.json()["access_token"]
|
||||||
|
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
||||||
|
assert payload["email"] == "jwt-check@example.com"
|
||||||
|
assert payload["tier"] == "free"
|
||||||
|
assert "sub" in payload
|
||||||
|
|
||||||
|
def test_register_duplicate_email(self, client) -> None:
|
||||||
|
client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "dupe@example.com", "password": "Pass1234"},
|
||||||
|
)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "dupe@example.com", "password": "Pass5678"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|
||||||
|
def test_register_missing_password(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "no-pass@example.com"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
def test_register_missing_email(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"password": "OnlyPass"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 422
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestLogin ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestLogin:
|
||||||
|
"""POST /api/v1/auth/login"""
|
||||||
|
|
||||||
|
def _register(self, client, email="login@example.com", password="MyP@ss123"):
|
||||||
|
client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": email, "password": password},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_login_success(self, client) -> None:
|
||||||
|
self._register(client)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": "login@example.com", "password": "MyP@ss123"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert "refresh_token" in data
|
||||||
|
assert "expires_at" in data
|
||||||
|
|
||||||
|
def test_login_wrong_password(self, client) -> None:
|
||||||
|
self._register(client)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": "login@example.com", "password": "WrongPass!"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_login_unknown_email(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
json={"email": "ghost@example.com", "password": "Whatever"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestRefresh ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestRefresh:
|
||||||
|
"""POST /api/v1/auth/refresh"""
|
||||||
|
|
||||||
|
def _register_and_get_tokens(self, client, email="refresh@example.com"):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": email, "password": "RefPass123!"},
|
||||||
|
)
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
def test_refresh_returns_new_tokens(self, client) -> None:
|
||||||
|
tokens = self._register_and_get_tokens(client)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
json={"refresh_token": tokens["refresh_token"]},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data
|
||||||
|
assert "refresh_token" in data
|
||||||
|
# New refresh token should differ from old one (rotation)
|
||||||
|
assert data["refresh_token"] != tokens["refresh_token"]
|
||||||
|
|
||||||
|
def test_refresh_old_token_rejected(self, client) -> None:
|
||||||
|
"""After rotation, the original refresh token must be rejected."""
|
||||||
|
tokens = self._register_and_get_tokens(client, email="rotate@example.com")
|
||||||
|
old_rt = tokens["refresh_token"]
|
||||||
|
|
||||||
|
# First refresh succeeds and rotates the token
|
||||||
|
client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
|
||||||
|
|
||||||
|
# Second attempt with the old token must fail
|
||||||
|
resp = client.post("/api/v1/auth/refresh", json={"refresh_token": old_rt})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_refresh_bogus_token(self, client) -> None:
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/refresh",
|
||||||
|
json={"refresh_token": "not-a-real-token"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestMe ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestMe:
|
||||||
|
"""GET /api/v1/auth/me"""
|
||||||
|
|
||||||
|
def test_me_with_valid_jwt(self, client) -> None:
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert data["id"] == TEST_USER_IDS["power"]
|
||||||
|
assert data["email"] == "power@test.com"
|
||||||
|
assert data["tier"] == "power"
|
||||||
|
|
||||||
|
def test_me_returns_correct_tier(self, client) -> None:
|
||||||
|
"""Tier comes from the live subscription row, not the JWT claim."""
|
||||||
|
resp = client.get("/api/v1/auth/me", headers=auth_header("free"))
|
||||||
|
assert resp.json()["tier"] == "free"
|
||||||
|
|
||||||
|
def test_me_missing_token(self, client) -> None:
|
||||||
|
resp = client.get("/api/v1/auth/me")
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_me_expired_token(self, client) -> None:
|
||||||
|
"""A JWT with ``exp`` in the past must be rejected."""
|
||||||
|
payload = {
|
||||||
|
"sub": TEST_USER_IDS["power"],
|
||||||
|
"email": "power@test.com",
|
||||||
|
"tier": "power",
|
||||||
|
"exp": int(time.time()) - 3600, # 1 hour ago
|
||||||
|
"iat": int(time.time()) - 7200,
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
|
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_me_invalid_signature(self, client) -> None:
|
||||||
|
payload = {
|
||||||
|
"sub": TEST_USER_IDS["power"],
|
||||||
|
"email": "power@test.com",
|
||||||
|
"tier": "power",
|
||||||
|
"exp": int(time.time()) + 3600,
|
||||||
|
"iat": int(time.time()),
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||||
|
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||||
|
assert resp.status_code == 401
|
||||||
243
tests/test_backup.py
Normal file
243
tests/test_backup.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""Tests for backup routes: upload, download, history, delete.
|
||||||
|
|
||||||
|
Exercises the backup lifecycle through the FastAPI TestClient against the
|
||||||
|
in-memory SQLite test database and moto-mocked S3 bucket.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
|
||||||
|
|
||||||
|
from tests.conftest import auth_header, TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_BLOB = b"encrypted-backup-blob-opaque-bytes"
|
||||||
|
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||||
|
_VERSION = 1
|
||||||
|
_TIMESTAMP = 1700000000000 # arbitrary ms timestamp
|
||||||
|
|
||||||
|
|
||||||
|
def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]:
|
||||||
|
"""Return auth + backup metadata headers."""
|
||||||
|
headers = auth_header(tier)
|
||||||
|
headers["X-Backup-Version"] = str(overrides.get("version", _VERSION))
|
||||||
|
headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP))
|
||||||
|
headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM)
|
||||||
|
headers["Content-Type"] = "application/octet-stream"
|
||||||
|
return headers
|
||||||
|
|
||||||
|
|
||||||
|
def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821
|
||||||
|
"""Upload a backup blob and return the response."""
|
||||||
|
return client.put(
|
||||||
|
"/api/v1/backup",
|
||||||
|
content=overrides.pop("blob", _BLOB),
|
||||||
|
headers=_backup_headers(tier, **overrides),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestUploadBackup ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestUploadBackup:
|
||||||
|
"""PUT /api/v1/backup"""
|
||||||
|
|
||||||
|
def test_upload_success(self, client, s3_bucket) -> None:
|
||||||
|
resp = _upload(client, tier="power")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
def test_upload_creates_history_entry(self, client, s3_bucket) -> None:
|
||||||
|
_upload(client, tier="power")
|
||||||
|
history = client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header("power")
|
||||||
|
).json()
|
||||||
|
assert len(history) == 1
|
||||||
|
assert history[0]["version"] == _VERSION
|
||||||
|
assert history[0]["timestamp"] == _TIMESTAMP
|
||||||
|
assert history[0]["checksum"] == _CHECKSUM
|
||||||
|
|
||||||
|
def test_upload_bad_checksum(self, client, s3_bucket) -> None:
|
||||||
|
resp = _upload(client, tier="power", checksum="0" * 64)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_upload_free_tier_blocked(self, client, s3_bucket) -> None:
|
||||||
|
"""Free tier has backup_gb=0 → should return 402."""
|
||||||
|
resp = _upload(client, tier="free")
|
||||||
|
assert resp.status_code == 402
|
||||||
|
|
||||||
|
def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None:
|
||||||
|
"""Pro tier has backup_gb=5 → small blob succeeds."""
|
||||||
|
resp = _upload(client, tier="pro")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestDownloadBackup ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDownloadBackup:
|
||||||
|
"""GET /api/v1/backup"""
|
||||||
|
|
||||||
|
def test_download_latest(self, client, s3_bucket) -> None:
|
||||||
|
_upload(client, tier="power")
|
||||||
|
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == _BLOB
|
||||||
|
assert resp.headers["X-Checksum"] == _CHECKSUM
|
||||||
|
assert resp.headers["X-Backup-Version"] == str(_VERSION)
|
||||||
|
|
||||||
|
def test_download_no_backup_returns_404(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None:
|
||||||
|
"""When If-Modified-Since is after the backup timestamp → 304."""
|
||||||
|
_upload(client, tier="power", timestamp=1700000000000)
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/backup",
|
||||||
|
headers={
|
||||||
|
**auth_header("power"),
|
||||||
|
"If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 304
|
||||||
|
|
||||||
|
def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None:
|
||||||
|
"""When If-Modified-Since is before the backup timestamp → serve blob."""
|
||||||
|
_upload(client, tier="power", timestamp=1700000000000)
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/backup",
|
||||||
|
headers={
|
||||||
|
**auth_header("power"),
|
||||||
|
"If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == _BLOB
|
||||||
|
|
||||||
|
def test_download_multiple_returns_latest(self, client, s3_bucket) -> None:
|
||||||
|
"""When multiple backups exist, GET returns the one with the highest timestamp."""
|
||||||
|
_upload(client, tier="power", timestamp=1000)
|
||||||
|
blob2 = b"second-encrypted-backup"
|
||||||
|
checksum2 = hashlib.sha256(blob2).hexdigest()
|
||||||
|
_upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2)
|
||||||
|
resp = client.get("/api/v1/backup", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == blob2
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestBackupHistory ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestBackupHistory:
|
||||||
|
"""GET /api/v1/backup/history"""
|
||||||
|
|
||||||
|
def test_history_empty(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.get("/api/v1/backup/history", headers=auth_header("power"))
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
def test_history_returns_entries(self, client, s3_bucket) -> None:
|
||||||
|
_upload(client, tier="power", timestamp=1000)
|
||||||
|
_upload(client, tier="power", timestamp=2000)
|
||||||
|
history = client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header("power")
|
||||||
|
).json()
|
||||||
|
assert len(history) == 2
|
||||||
|
# Ordered by timestamp descending
|
||||||
|
assert history[0]["timestamp"] == 2000
|
||||||
|
assert history[1]["timestamp"] == 1000
|
||||||
|
|
||||||
|
def test_history_isolated_per_user(self, client, s3_bucket) -> None:
|
||||||
|
"""One user's backups should not appear in another user's history."""
|
||||||
|
_upload(client, tier="power")
|
||||||
|
resp = client.get("/api/v1/backup/history", headers=auth_header("team"))
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestDeleteBackup ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestDeleteBackup:
|
||||||
|
"""DELETE /api/v1/backup/{backup_id}"""
|
||||||
|
|
||||||
|
def _get_backup_id(self, client, tier="power") -> str:
|
||||||
|
"""Upload a backup and return its DB id from history."""
|
||||||
|
_upload(client, tier=tier)
|
||||||
|
client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header(tier)
|
||||||
|
).json()
|
||||||
|
# History returns BackupMetadata schema which doesn't have `id`.
|
||||||
|
# We need to look it up via a different means.
|
||||||
|
# Since there's only 1 backup, find via history length.
|
||||||
|
# Actually the schema doesn't return id — let's verify via re-download.
|
||||||
|
# We'll use a workaround: upload, then list history to confirm it exists,
|
||||||
|
# then try to delete — but we need the id...
|
||||||
|
# Let's check if history includes an id field.
|
||||||
|
# The schema is: version, timestamp, checksum, chunk_count — no id.
|
||||||
|
# We'll need to query the DB directly or use a known ID.
|
||||||
|
# For testing, we'll search history then use the DB.
|
||||||
|
return None # pragma: no cover — overridden below
|
||||||
|
|
||||||
|
def test_delete_success(self, client, s3_bucket, db_session) -> None:
|
||||||
|
_upload(client, tier="power")
|
||||||
|
|
||||||
|
# Discover the backup_id via direct DB query
|
||||||
|
import asyncio
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models import BackupMetadata
|
||||||
|
|
||||||
|
async def _get_id():
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(BackupMetadata.id).where(
|
||||||
|
BackupMetadata.user_id == TEST_USER_IDS["power"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
|
||||||
|
|
||||||
|
resp = client.delete(
|
||||||
|
f"/api/v1/backup/{backup_id}", headers=auth_header("power")
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
# History should now be empty
|
||||||
|
history = client.get(
|
||||||
|
"/api/v1/backup/history", headers=auth_header("power")
|
||||||
|
).json()
|
||||||
|
assert history == []
|
||||||
|
|
||||||
|
def test_delete_nonexistent(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.delete(
|
||||||
|
"/api/v1/backup/no-such-id", headers=auth_header("power")
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None:
|
||||||
|
"""Cannot delete another user's backup (ownership check returns 404)."""
|
||||||
|
_upload(client, tier="power")
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from sqlalchemy import select
|
||||||
|
from app.models import BackupMetadata
|
||||||
|
|
||||||
|
async def _get_id():
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(BackupMetadata.id).where(
|
||||||
|
BackupMetadata.user_id == TEST_USER_IDS["power"]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return result.scalar_one()
|
||||||
|
|
||||||
|
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
|
||||||
|
|
||||||
|
# team user tries to delete power user's backup → 404
|
||||||
|
resp = client.delete(
|
||||||
|
f"/api/v1/backup/{backup_id}", headers=auth_header("team")
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
362
tests/test_device_ws.py
Normal file
362
tests/test_device_ws.py
Normal file
@@ -0,0 +1,362 @@
|
|||||||
|
"""Tests for Step 3.3: DeviceConnectionManager and device WS endpoint.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
Unit tests — DeviceConnectionManager register/unregister/is_online/
|
||||||
|
get_ws/send_frame/pending-call round-trip/agent-data queue
|
||||||
|
Integration — /api/v1/ws/device endpoint via TestClient WebSocket:
|
||||||
|
auth rejection, happy-path connect, tool_result dispatch,
|
||||||
|
agent_data queue routing, agent_complete sentinel, disconnect
|
||||||
|
cleanup (AgentRunLog marked as error)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.core.device_manager import DeviceConnection, DeviceConnectionManager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import AgentRunLog
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header, make_jwt
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_FREE_UID = TEST_USER_IDS["free"]
|
||||||
|
_PRO_UID = TEST_USER_IDS["pro"]
|
||||||
|
|
||||||
|
|
||||||
|
def _device_hello(device_id: str = "dev-001", agent_ids: list[str] | None = None) -> str:
|
||||||
|
return json.dumps(
|
||||||
|
{"type": "device_hello", "device_id": device_id, "agent_ids": agent_ids or []}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DB override (shared across integration tests)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
"""Route all get_session calls to the test SQLite session."""
|
||||||
|
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# DeviceConnectionManager unit tests
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def manager() -> DeviceConnectionManager:
|
||||||
|
"""Fresh manager instance for each test."""
|
||||||
|
return DeviceConnectionManager()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def mock_ws() -> MagicMock:
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
return ws
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_register_and_is_online(manager, mock_ws):
|
||||||
|
assert not manager.is_online("user1")
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
assert manager.is_online("user1")
|
||||||
|
assert manager.is_online("user1", "dev-A")
|
||||||
|
assert not manager.is_online("user1", "dev-B")
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_get_ws_returns_none_when_offline(manager):
|
||||||
|
assert manager.get_ws("no-such-user") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_unregister(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
assert manager.is_online("user1")
|
||||||
|
manager.unregister("user1")
|
||||||
|
assert not manager.is_online("user1")
|
||||||
|
assert manager.get_ws("user1") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_unregister_unknown_is_noop(manager):
|
||||||
|
# Must not raise.
|
||||||
|
manager.unregister("ghost")
|
||||||
|
|
||||||
|
|
||||||
|
def test_manager_replace_connection_cancels_old_futures(manager):
|
||||||
|
ws_a = MagicMock()
|
||||||
|
ws_a.send_text = AsyncMock()
|
||||||
|
ws_b = MagicMock()
|
||||||
|
ws_b.send_text = AsyncMock()
|
||||||
|
|
||||||
|
# Create event loop context for Future.
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
try:
|
||||||
|
async def _run():
|
||||||
|
manager.register("user1", "dev-A", ws_a)
|
||||||
|
fut = manager.create_pending_call("user1", "call-1")
|
||||||
|
# Replace connection — old future should be cancelled.
|
||||||
|
manager.register("user1", "dev-B", ws_b)
|
||||||
|
assert fut.cancelled()
|
||||||
|
|
||||||
|
loop.run_until_complete(_run())
|
||||||
|
finally:
|
||||||
|
loop.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_send_frame(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
await manager.send_frame("user1", {"type": "ping"})
|
||||||
|
mock_ws.send_text.assert_called_once_with(json.dumps({"type": "ping"}))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_send_frame_raises_when_offline(manager):
|
||||||
|
with pytest.raises(RuntimeError, match="not connected"):
|
||||||
|
await manager.send_frame("ghost", {"type": "ping"})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_pending_call_round_trip(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
fut = manager.create_pending_call("user1", "call-42")
|
||||||
|
result = {"type": "tool_result", "id": "call-42", "rows": [{"id": "row1"}]}
|
||||||
|
manager.resolve_pending_call("user1", "call-42", result)
|
||||||
|
assert fut.done()
|
||||||
|
assert await fut == result
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_resolve_unknown_call_is_noop(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
# Should not raise.
|
||||||
|
manager.resolve_pending_call("user1", "no-such-call", {})
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_unregister_cancels_pending_calls(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
fut = manager.create_pending_call("user1", "call-1")
|
||||||
|
manager.unregister("user1")
|
||||||
|
assert fut.cancelled()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_agent_data_queue(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
q = manager.get_agent_data_queue("user1", "run-xyz")
|
||||||
|
# Put a frame and get it back.
|
||||||
|
frame = {"type": "agent_data", "run_id": "run-xyz", "files": []}
|
||||||
|
await q.put(frame)
|
||||||
|
assert await q.get() == frame
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_agent_data_queue_creates_once(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
q1 = manager.get_agent_data_queue("user1", "run-1")
|
||||||
|
q2 = manager.get_agent_data_queue("user1", "run-1")
|
||||||
|
assert q1 is q2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_agent_data_queue_raises_when_offline(manager):
|
||||||
|
with pytest.raises(RuntimeError, match="not connected"):
|
||||||
|
manager.get_agent_data_queue("ghost", "run-1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_manager_cleanup_agent_data_queue(manager, mock_ws):
|
||||||
|
manager.register("user1", "dev-A", mock_ws)
|
||||||
|
manager.get_agent_data_queue("user1", "run-1")
|
||||||
|
manager.cleanup_agent_data_queue("user1", "run-1")
|
||||||
|
# After cleanup a new queue is created (not the same object).
|
||||||
|
q_new = manager.get_agent_data_queue("user1", "run-1")
|
||||||
|
assert q_new is not None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration tests — /api/v1/ws/device endpoint
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_ws_device_rejects_without_token(client):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
# TestClient will raise or close when the server rejects.
|
||||||
|
with client.websocket_connect("/api/v1/ws/device") as ws:
|
||||||
|
ws.receive_text()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_rejects_invalid_token(client):
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
|
||||||
|
ws.receive_text()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_happy_path(client):
|
||||||
|
"""Connect, send device_hello, receive ping, then close."""
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
|
||||||
|
# Patch the heartbeat sleep so the test doesn't block 30 s.
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.01):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(_device_hello("dev-001"))
|
||||||
|
# Next message from server should be a heartbeat ping (interval=0.01s).
|
||||||
|
msg = ws.receive_text()
|
||||||
|
data = json.loads(msg)
|
||||||
|
assert data["type"] == "ping"
|
||||||
|
# Close gracefully.
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_invalid_first_frame_closes(client):
|
||||||
|
"""Non-device_hello first frame should close the connection."""
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({"type": "chat_request", "message": "hi"}))
|
||||||
|
ws.receive_text() # server should close after bad frame
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_tool_result_dispatched(client):
|
||||||
|
"""tool_result frame is routed to the DeviceConnectionManager."""
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
user_id = TEST_USER_IDS["free"]
|
||||||
|
|
||||||
|
from app.core.device_manager import device_manager as dm
|
||||||
|
|
||||||
|
captured: list[dict] = []
|
||||||
|
|
||||||
|
original_resolve = dm.resolve_pending_call
|
||||||
|
|
||||||
|
def _spy(uid, call_id, result):
|
||||||
|
captured.append({"uid": uid, "call_id": call_id, "result": result})
|
||||||
|
original_resolve(uid, call_id, result)
|
||||||
|
|
||||||
|
with patch.object(dm, "resolve_pending_call", side_effect=_spy):
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(_device_hello("dev-001"))
|
||||||
|
# Send a tool_result frame.
|
||||||
|
ws.send_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "tool_result",
|
||||||
|
"id": "call-123",
|
||||||
|
"rows": [{"id": "task-1", "title": "Buy milk"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
assert any(c["call_id"] == "call-123" for c in captured)
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_agent_data_enqueued(client):
|
||||||
|
"""agent_data frame is placed in the per-run queue by the message loop."""
|
||||||
|
from app.core.device_manager import device_manager as dm
|
||||||
|
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
user_id = TEST_USER_IDS["free"]
|
||||||
|
|
||||||
|
# Capture the queue object the message loop accesses.
|
||||||
|
captured_queue: list[asyncio.Queue] = []
|
||||||
|
original_get_queue = dm.get_agent_data_queue
|
||||||
|
|
||||||
|
def _spy_get_queue(uid, run_id):
|
||||||
|
q = original_get_queue(uid, run_id)
|
||||||
|
if not captured_queue:
|
||||||
|
captured_queue.append(q)
|
||||||
|
return q
|
||||||
|
|
||||||
|
with patch.object(dm, "get_agent_data_queue", side_effect=_spy_get_queue):
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(_device_hello("dev-001"))
|
||||||
|
ws.send_text(
|
||||||
|
json.dumps(
|
||||||
|
{
|
||||||
|
"type": "agent_data",
|
||||||
|
"run_id": "run-XYZ",
|
||||||
|
"files": [{"path": "/tmp/file.txt", "content": "hello"}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
)
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
# The queue should have received exactly one frame.
|
||||||
|
assert captured_queue, "queue was never accessed"
|
||||||
|
assert not captured_queue[0].empty()
|
||||||
|
|
||||||
|
|
||||||
|
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
|
||||||
|
"""On disconnect, _mark_runs_disconnected is called with the correct user_id."""
|
||||||
|
from app.api.routes import device_ws as _dws
|
||||||
|
|
||||||
|
token = make_jwt(tier="free")
|
||||||
|
user_id = TEST_USER_IDS["free"]
|
||||||
|
|
||||||
|
cleanup_calls: list[str] = []
|
||||||
|
|
||||||
|
async def _fake_cleanup(uid: str) -> None:
|
||||||
|
cleanup_calls.append(uid)
|
||||||
|
|
||||||
|
with patch.object(_dws, "_mark_runs_disconnected", side_effect=_fake_cleanup):
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(_device_hello("dev-001"))
|
||||||
|
ws.close()
|
||||||
|
|
||||||
|
assert user_id in cleanup_calls
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_mark_runs_disconnected_updates_db(db_session):
|
||||||
|
"""_mark_runs_disconnected marks in-progress runs as error in the DB."""
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.api.routes.device_ws import _mark_runs_disconnected
|
||||||
|
from tests.conftest import _TestSessionLocal
|
||||||
|
|
||||||
|
user_id = TEST_USER_IDS["free"]
|
||||||
|
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=str(uuid.uuid4()),
|
||||||
|
agent_type="local",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
db_session.add(run_log)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
# Route the function to the same test-DB session factory.
|
||||||
|
with patch("app.api.routes.device_ws.async_session", _TestSessionLocal):
|
||||||
|
await _mark_runs_disconnected(user_id)
|
||||||
|
|
||||||
|
# Verify through the same session factory.
|
||||||
|
async with _TestSessionLocal() as s:
|
||||||
|
result = await s.execute(
|
||||||
|
select(AgentRunLog).where(AgentRunLog.id == run_log.id)
|
||||||
|
)
|
||||||
|
updated = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
assert updated is not None
|
||||||
|
assert updated.status == "error"
|
||||||
|
assert updated.errors and "device disconnected" in updated.errors
|
||||||
@@ -1,286 +0,0 @@
|
|||||||
"""Tests for execution_plan: PromptTemplateRegistry, ExecutionPlanBuilder, PlanCache."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.execution_plan import (
|
|
||||||
ExecutionPlanBuilder,
|
|
||||||
PlanCache,
|
|
||||||
PromptTemplateRegistry,
|
|
||||||
plan_cache,
|
|
||||||
template_registry,
|
|
||||||
)
|
|
||||||
from app.schemas import ExecutionPlan
|
|
||||||
|
|
||||||
|
|
||||||
# ── PromptTemplateRegistry ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestPromptTemplateRegistry:
|
|
||||||
def test_register_and_get(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_foo", "You are a foo agent.")
|
|
||||||
assert reg.get("tpl_foo") == "You are a foo agent."
|
|
||||||
|
|
||||||
def test_get_unknown_raises_key_error(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
with pytest.raises(KeyError, match="tpl_missing"):
|
|
||||||
reg.get("tpl_missing")
|
|
||||||
|
|
||||||
def test_has_returns_true_for_registered(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_x", "prompt text")
|
|
||||||
assert reg.has("tpl_x") is True
|
|
||||||
|
|
||||||
def test_has_returns_false_for_unregistered(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
assert reg.has("tpl_missing") is False
|
|
||||||
|
|
||||||
def test_list_ids_returns_all_registered_ids(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_a", "a")
|
|
||||||
reg.register("tpl_b", "b")
|
|
||||||
assert set(reg.list_ids()) == {"tpl_a", "tpl_b"}
|
|
||||||
|
|
||||||
def test_list_ids_does_not_return_prompt_text(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_secret", "top secret prompt")
|
|
||||||
ids = reg.list_ids()
|
|
||||||
assert "top secret prompt" not in ids
|
|
||||||
|
|
||||||
def test_overwrite_existing_template(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
reg.register("tpl_x", "v1")
|
|
||||||
reg.register("tpl_x", "v2")
|
|
||||||
assert reg.get("tpl_x") == "v2"
|
|
||||||
|
|
||||||
def test_empty_registry_has_no_ids(self) -> None:
|
|
||||||
reg = PromptTemplateRegistry()
|
|
||||||
assert reg.list_ids() == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── ExecutionPlanBuilder ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestExecutionPlanBuilder:
|
|
||||||
def test_builds_empty_plan(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").build()
|
|
||||||
assert plan.agent == "task_agent"
|
|
||||||
assert plan.steps == []
|
|
||||||
|
|
||||||
def test_add_step_basic(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("create_task", {"priority": "high"})
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert len(plan.steps) == 1
|
|
||||||
assert plan.steps[0].action == "create_task"
|
|
||||||
assert plan.steps[0].variables == {"priority": "high"}
|
|
||||||
assert plan.steps[0].prompt_template is None
|
|
||||||
assert plan.steps[0].data_from_step is None
|
|
||||||
|
|
||||||
def test_add_step_no_params(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").add_step("fetch").build()
|
|
||||||
assert plan.steps[0].variables is None
|
|
||||||
|
|
||||||
def test_add_llm_step(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_llm_step("tpl_task_default", {"message": "hi"})
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[0].action == "llm"
|
|
||||||
assert plan.steps[0].prompt_template == "tpl_task_default"
|
|
||||||
assert plan.steps[0].variables == {"message": "hi"}
|
|
||||||
|
|
||||||
def test_add_llm_step_no_variables(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").add_llm_step("tpl_x").build()
|
|
||||||
assert plan.steps[0].variables is None
|
|
||||||
|
|
||||||
def test_add_data_step(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("fetch_data")
|
|
||||||
.add_data_step("transform", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[1].action == "transform"
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_fluent_chaining_returns_builder(self) -> None:
|
|
||||||
builder = ExecutionPlanBuilder("analytics_agent")
|
|
||||||
result = builder.add_step("a")
|
|
||||||
assert result is builder
|
|
||||||
|
|
||||||
def test_fluent_chain_multiple_steps(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("analytics_agent")
|
|
||||||
.add_llm_step("tpl_analytics_default")
|
|
||||||
.add_step("format_output")
|
|
||||||
.add_data_step("store", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert len(plan.steps) == 3
|
|
||||||
|
|
||||||
def test_build_validates_data_from_step_out_of_range(self) -> None:
|
|
||||||
with pytest.raises(ValueError, match="data_from_step"):
|
|
||||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=5).build()
|
|
||||||
|
|
||||||
def test_build_validates_data_from_step_self_reference(self) -> None:
|
|
||||||
"""data_from_step=0 on the first step (index 0) is invalid."""
|
|
||||||
with pytest.raises(ValueError, match="data_from_step"):
|
|
||||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=0).build()
|
|
||||||
|
|
||||||
def test_build_validates_data_from_step_negative(self) -> None:
|
|
||||||
with pytest.raises(ValueError, match="data_from_step"):
|
|
||||||
ExecutionPlanBuilder("task_agent").add_data_step("bad", data_from_step=-1).build()
|
|
||||||
|
|
||||||
def test_valid_data_from_step_at_index_two(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("step0")
|
|
||||||
.add_step("step1")
|
|
||||||
.add_data_step("step2", data_from_step=1)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[2].data_from_step == 1
|
|
||||||
|
|
||||||
def test_data_from_step_zero_valid_at_index_one(self) -> None:
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_step("step0")
|
|
||||||
.add_data_step("step1", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_build_returns_new_plan_each_call(self) -> None:
|
|
||||||
builder = ExecutionPlanBuilder("task_agent").add_step("do_thing")
|
|
||||||
plan1 = builder.build()
|
|
||||||
plan2 = builder.build()
|
|
||||||
assert plan1 is not plan2
|
|
||||||
assert plan1.steps == plan2.steps
|
|
||||||
|
|
||||||
def test_plan_is_execution_plan_instance(self) -> None:
|
|
||||||
plan = ExecutionPlanBuilder("task_agent").build()
|
|
||||||
assert isinstance(plan, ExecutionPlan)
|
|
||||||
|
|
||||||
|
|
||||||
# ── PlanCache ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestPlanCache:
|
|
||||||
def _plan(self, agent: str = "a") -> ExecutionPlan:
|
|
||||||
return ExecutionPlanBuilder(agent).build()
|
|
||||||
|
|
||||||
def test_cache_and_get(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
plan = self._plan()
|
|
||||||
cache.cache_plan("key1", plan)
|
|
||||||
assert cache.get_plan("key1") is plan
|
|
||||||
|
|
||||||
def test_get_missing_returns_none(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
assert cache.get_plan("nonexistent") is None
|
|
||||||
|
|
||||||
def test_get_all_playbooks_empty(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
assert cache.get_all_playbooks() == []
|
|
||||||
|
|
||||||
def test_get_all_playbooks_returns_all_stored(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
p1, p2 = self._plan("a"), self._plan("b")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k2", p2)
|
|
||||||
playbooks = cache.get_all_playbooks()
|
|
||||||
assert len(playbooks) == 2
|
|
||||||
assert p1 in playbooks
|
|
||||||
assert p2 in playbooks
|
|
||||||
|
|
||||||
def test_lru_evicts_oldest_entry(self) -> None:
|
|
||||||
cache = PlanCache(maxsize=2)
|
|
||||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k2", p2)
|
|
||||||
cache.cache_plan("k3", p3) # k1 should be evicted
|
|
||||||
assert cache.get_plan("k1") is None
|
|
||||||
assert cache.get_plan("k2") is p2
|
|
||||||
assert cache.get_plan("k3") is p3
|
|
||||||
|
|
||||||
def test_lru_access_updates_recency(self) -> None:
|
|
||||||
cache = PlanCache(maxsize=2)
|
|
||||||
p1, p2, p3 = self._plan("a"), self._plan("b"), self._plan("c")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k2", p2)
|
|
||||||
cache.get_plan("k1") # k1 is now most-recently used
|
|
||||||
cache.cache_plan("k3", p3) # k2 should be evicted (LRU)
|
|
||||||
assert cache.get_plan("k1") is p1
|
|
||||||
assert cache.get_plan("k2") is None
|
|
||||||
assert cache.get_plan("k3") is p3
|
|
||||||
|
|
||||||
def test_overwrite_existing_key(self) -> None:
|
|
||||||
cache = PlanCache()
|
|
||||||
p1, p2 = self._plan("a"), self._plan("b")
|
|
||||||
cache.cache_plan("same_key", p1)
|
|
||||||
cache.cache_plan("same_key", p2)
|
|
||||||
assert cache.get_plan("same_key") is p2
|
|
||||||
assert len(cache.get_all_playbooks()) == 1
|
|
||||||
|
|
||||||
def test_overwrite_does_not_consume_capacity(self) -> None:
|
|
||||||
cache = PlanCache(maxsize=2)
|
|
||||||
p1, p2 = self._plan("a"), self._plan("b")
|
|
||||||
cache.cache_plan("k1", p1)
|
|
||||||
cache.cache_plan("k1", p2) # overwrite, not a new slot
|
|
||||||
cache.cache_plan("k2", p1) # should fit without eviction
|
|
||||||
assert cache.get_plan("k1") is p2
|
|
||||||
assert cache.get_plan("k2") is p1
|
|
||||||
|
|
||||||
|
|
||||||
# ── Module-level singletons ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestModuleSingletons:
|
|
||||||
def test_template_registry_has_all_agent_defaults(self) -> None:
|
|
||||||
for agent in ("task_agent", "checkpoint_agent", "project_agent", "note_agent"):
|
|
||||||
assert template_registry.has(f"tpl_{agent}_default"), (
|
|
||||||
f"Missing template: tpl_{agent}_default"
|
|
||||||
)
|
|
||||||
|
|
||||||
def test_template_registry_has_operation_templates(self) -> None:
|
|
||||||
assert template_registry.has("tpl_task_extract_from_project")
|
|
||||||
assert template_registry.has("tpl_note_weekly_summary")
|
|
||||||
|
|
||||||
def test_template_registry_get_returns_non_empty_string(self) -> None:
|
|
||||||
text = template_registry.get("tpl_task_agent_default")
|
|
||||||
assert isinstance(text, str)
|
|
||||||
assert len(text) > 0
|
|
||||||
|
|
||||||
def test_plan_cache_has_prebuilt_playbooks(self) -> None:
|
|
||||||
assert len(plan_cache.get_all_playbooks()) >= 2
|
|
||||||
|
|
||||||
def test_playbook_create_tasks_from_project(self) -> None:
|
|
||||||
plan = plan_cache.get_plan("create_tasks_from_project")
|
|
||||||
assert plan is not None
|
|
||||||
assert plan.agent == "project_agent"
|
|
||||||
assert len(plan.steps) == 2
|
|
||||||
assert plan.steps[0].prompt_template == "tpl_task_extract_from_project"
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_playbook_generate_weekly_note(self) -> None:
|
|
||||||
plan = plan_cache.get_plan("generate_weekly_note")
|
|
||||||
assert plan is not None
|
|
||||||
assert plan.agent == "note_agent"
|
|
||||||
assert len(plan.steps) == 2
|
|
||||||
assert plan.steps[0].prompt_template == "tpl_note_weekly_summary"
|
|
||||||
assert plan.steps[1].data_from_step == 0
|
|
||||||
|
|
||||||
def test_playbook_steps_have_no_raw_prompt_text(self) -> None:
|
|
||||||
"""Plans must not embed prompt text — only template IDs."""
|
|
||||||
for plan in plan_cache.get_all_playbooks():
|
|
||||||
for step in plan.steps:
|
|
||||||
if step.prompt_template is not None:
|
|
||||||
assert step.prompt_template.startswith("tpl_"), (
|
|
||||||
f"prompt_template looks like raw text: {step.prompt_template!r}"
|
|
||||||
)
|
|
||||||
729
tests/test_integrations.py
Normal file
729
tests/test_integrations.py
Normal file
@@ -0,0 +1,729 @@
|
|||||||
|
"""Tests for Step 3.6: cloud provider integration clients.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
Unit \u2014 app/integrations/__init__.py:
|
||||||
|
- encrypt_token / decrypt_token round-trip
|
||||||
|
- decrypt_token raises ValueError on invalid ciphertext
|
||||||
|
- encrypt_token raises ValueError on empty/non-dict input
|
||||||
|
- _get_fernet raises RuntimeError when OAUTH_ENCRYPTION_KEY not set
|
||||||
|
- get_provider returns GmailClient for 'gmail'
|
||||||
|
- get_provider returns MSGraphClient for 'outlook' and 'teams'
|
||||||
|
- get_provider raises ValueError for unknown provider
|
||||||
|
|
||||||
|
Unit \u2014 app/integrations/gmail.py:
|
||||||
|
- _build_gmail_query with no filter returns empty string
|
||||||
|
- _build_gmail_query with labels builds label: expr
|
||||||
|
- _build_gmail_query with senders builds from: expr
|
||||||
|
- _build_gmail_query with date_range builds after:/before: exprs
|
||||||
|
- _build_gmail_query since overrides date_range.from when more recent
|
||||||
|
- _build_gmail_query date_range.from overrides since when more recent
|
||||||
|
- _parse_body extracts text/plain part
|
||||||
|
- _parse_body extracts text/html part (stripped)
|
||||||
|
- _parse_body recurses into multipart, prefers text/plain
|
||||||
|
- GmailClient.fetch_messages: happy path with mocked service
|
||||||
|
- GmailClient.fetch_messages: no messages returns empty list
|
||||||
|
- GmailClient.fetch_messages: HTTP error on messages.list raises RuntimeError
|
||||||
|
- GmailClient.refreshed_credentials: None when token unchanged
|
||||||
|
- GmailClient.refreshed_credentials: returns dict when token changes
|
||||||
|
|
||||||
|
Unit \u2014 app/integrations/ms_graph.py:
|
||||||
|
- _build_email_filter with no filter returns empty string
|
||||||
|
- _build_email_filter with senders builds OData from clause
|
||||||
|
- _build_email_filter with since builds receivedDateTime ge clause
|
||||||
|
- MSGraphClient.fetch_emails: happy path with mocked httpx
|
||||||
|
- MSGraphClient.fetch_emails: 401 triggers token refresh and retries
|
||||||
|
- MSGraphClient.fetch_messages: happy path with mocked httpx
|
||||||
|
- MSGraphClient.fetch_messages: 403 from getAllMessages degrades gracefully
|
||||||
|
- MSGraphClient.refreshed_credentials: None when token unchanged
|
||||||
|
- MSGraphClient._refresh_access_token: MSAL error raises RuntimeError
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.integrations import (
|
||||||
|
ChatMessage,
|
||||||
|
EmailMessage,
|
||||||
|
decrypt_token,
|
||||||
|
encrypt_token,
|
||||||
|
get_provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# Helpers
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
_FERNET_KEY = "eW91LXNob3VsZC1ub3QtdXNlLXRoaXMta2V5LWluLXByb2Q="
|
||||||
|
# ^ 32-char URL-safe base64 (generated for tests only; not a real Fernet key length,
|
||||||
|
# so we generate a proper one below)
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet as _Fernet # noqa: E402
|
||||||
|
|
||||||
|
_VALID_KEY = _Fernet.generate_key().decode("utf-8")
|
||||||
|
|
||||||
|
_TOKEN_DICT = {
|
||||||
|
"token": "access_abc",
|
||||||
|
"refresh_token": "refresh_xyz",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "client_id_123",
|
||||||
|
"client_secret": "client_secret_456",
|
||||||
|
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||||
|
}
|
||||||
|
|
||||||
|
_MS_TOKEN_DICT = {
|
||||||
|
"access_token": "ms_access_abc",
|
||||||
|
"refresh_token": "ms_refresh_xyz",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "Mail.Read offline_access",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# encrypt_token / decrypt_token
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestTokenEncryption:
|
||||||
|
"""encrypt_token / decrypt_token round-trip tests."""
|
||||||
|
|
||||||
|
def test_round_trip(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
encrypted = encrypt_token(_TOKEN_DICT)
|
||||||
|
assert isinstance(encrypted, str)
|
||||||
|
assert encrypted != json.dumps(_TOKEN_DICT) # must be ciphertext, not plaintext
|
||||||
|
recovered = decrypt_token(encrypted)
|
||||||
|
assert recovered == _TOKEN_DICT
|
||||||
|
|
||||||
|
def test_decrypt_invalid_ciphertext_raises_value_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||||
|
decrypt_token("this-is-not-valid-fernet-ciphertext")
|
||||||
|
|
||||||
|
def test_decrypt_wrong_key_raises_value_error(self):
|
||||||
|
"""Decrypting with a different key must fail with ValueError."""
|
||||||
|
other_key = _Fernet.generate_key().decode("utf-8")
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
encrypted = encrypt_token(_TOKEN_DICT)
|
||||||
|
with patch("app.integrations.settings") as mock_settings2:
|
||||||
|
mock_settings2.OAUTH_ENCRYPTION_KEY = other_key
|
||||||
|
with pytest.raises(ValueError, match="Failed to decrypt"):
|
||||||
|
decrypt_token(encrypted)
|
||||||
|
|
||||||
|
def test_encrypt_empty_dict_raises_value_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
with pytest.raises(ValueError, match="non-empty dict"):
|
||||||
|
encrypt_token({})
|
||||||
|
|
||||||
|
def test_encrypt_non_dict_raises_value_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = _VALID_KEY
|
||||||
|
with pytest.raises(ValueError, match="non-empty dict"):
|
||||||
|
encrypt_token("not-a-dict") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
def test_missing_key_raises_runtime_error(self):
|
||||||
|
with patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = ""
|
||||||
|
with pytest.raises(RuntimeError, match="OAUTH_ENCRYPTION_KEY"):
|
||||||
|
encrypt_token(_TOKEN_DICT)
|
||||||
|
|
||||||
|
def test_email_message_as_text(self):
|
||||||
|
msg = EmailMessage(
|
||||||
|
id="m1",
|
||||||
|
subject="Hello",
|
||||||
|
sender="alice@example.com",
|
||||||
|
body_text="Test body",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
text = msg.as_text
|
||||||
|
assert "From: alice@example.com" in text
|
||||||
|
assert "Subject: Hello" in text
|
||||||
|
assert "Test body" in text
|
||||||
|
|
||||||
|
def test_chat_message_as_text(self):
|
||||||
|
msg = ChatMessage(
|
||||||
|
id="c1",
|
||||||
|
content="Buy milk",
|
||||||
|
sender="bob",
|
||||||
|
channel="general",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
text = msg.as_text
|
||||||
|
assert "From: bob" in text
|
||||||
|
assert "channel: general" in text
|
||||||
|
assert "Buy milk" in text
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# get_provider factory
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestGetProvider:
|
||||||
|
def test_gmail_returns_gmail_client(self):
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
|
||||||
|
client = get_provider("gmail", _TOKEN_DICT)
|
||||||
|
assert isinstance(client, GmailClient)
|
||||||
|
|
||||||
|
def test_outlook_returns_ms_graph_client(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
client = get_provider("outlook", _MS_TOKEN_DICT)
|
||||||
|
assert isinstance(client, MSGraphClient)
|
||||||
|
|
||||||
|
def test_teams_returns_ms_graph_client(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
client = get_provider("teams", _MS_TOKEN_DICT)
|
||||||
|
assert isinstance(client, MSGraphClient)
|
||||||
|
|
||||||
|
def test_unknown_provider_raises_value_error(self):
|
||||||
|
with pytest.raises(ValueError, match="Unknown cloud provider"):
|
||||||
|
get_provider("slack", {})
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# Gmail client \u2014 query builder
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildGmailQuery:
|
||||||
|
"""Unit tests for gmail._build_gmail_query."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
from app.integrations.gmail import _build_gmail_query
|
||||||
|
self._fn = _build_gmail_query
|
||||||
|
|
||||||
|
def test_empty_returns_empty_string(self):
|
||||||
|
assert self._fn(None, None) == ""
|
||||||
|
|
||||||
|
def test_single_label(self):
|
||||||
|
q = self._fn({"labels": ["INBOX"]}, None)
|
||||||
|
assert "label:INBOX" in q
|
||||||
|
|
||||||
|
def test_multiple_labels_joined_with_or(self):
|
||||||
|
q = self._fn({"labels": ["INBOX", "work"]}, None)
|
||||||
|
assert "label:INBOX OR label:work" in q
|
||||||
|
|
||||||
|
def test_senders(self):
|
||||||
|
q = self._fn({"senders": ["alice@example.com"]}, None)
|
||||||
|
assert "from:alice@example.com" in q
|
||||||
|
|
||||||
|
def test_date_range_from(self):
|
||||||
|
q = self._fn({"date_range": {"from": "2025-01-15"}}, None)
|
||||||
|
assert "after:2025/01/15" in q
|
||||||
|
|
||||||
|
def test_date_range_to(self):
|
||||||
|
q = self._fn({"date_range": {"to": "2025-03-01"}}, None)
|
||||||
|
assert "before:2025/03/01" in q
|
||||||
|
|
||||||
|
def test_since_overrides_earlier_date_range_from(self):
|
||||||
|
"""since=Feb is more recent than date_range.from=Jan, so after: should be Feb."""
|
||||||
|
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
||||||
|
q = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
||||||
|
assert "after:2025/02/01" in q
|
||||||
|
assert "after:2025/01/01" not in q
|
||||||
|
|
||||||
|
def test_date_range_from_overrides_earlier_since(self):
|
||||||
|
"""date_range.from=Feb is more recent than since=Jan, so after: should be Feb."""
|
||||||
|
since = datetime(2025, 1, 1, tzinfo=timezone.utc)
|
||||||
|
q = self._fn({"date_range": {"from": "2025-02-01"}}, since)
|
||||||
|
assert "after:2025/02/01" in q
|
||||||
|
|
||||||
|
def test_invalid_date_ignored(self):
|
||||||
|
"""An invalid date string in filter_config must not raise, just be skipped."""
|
||||||
|
q = self._fn({"date_range": {"from": "not-a-date"}}, None)
|
||||||
|
assert "after:" not in q
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# Gmail client \u2014 body parsing
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestParseBody:
|
||||||
|
"""Unit tests for gmail._parse_body."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
from app.integrations.gmail import _parse_body
|
||||||
|
self._fn = _parse_body
|
||||||
|
|
||||||
|
def _encode(self, text: str) -> str:
|
||||||
|
import base64
|
||||||
|
return base64.urlsafe_b64encode(text.encode()).decode()
|
||||||
|
|
||||||
|
def test_text_plain_extracted(self):
|
||||||
|
payload = {
|
||||||
|
"mimeType": "text/plain",
|
||||||
|
"body": {"data": self._encode("Hello world")},
|
||||||
|
}
|
||||||
|
assert self._fn(payload) == "Hello world"
|
||||||
|
|
||||||
|
def test_text_html_stripped(self):
|
||||||
|
payload = {
|
||||||
|
"mimeType": "text/html",
|
||||||
|
"body": {"data": self._encode("<p>Hello <b>world</b></p>")},
|
||||||
|
}
|
||||||
|
result = self._fn(payload)
|
||||||
|
assert "Hello" in result
|
||||||
|
assert "<p>" not in result
|
||||||
|
|
||||||
|
def test_multipart_prefers_plain_over_html(self):
|
||||||
|
plain_data = self._encode("Plain text")
|
||||||
|
html_data = self._encode("<p>HTML text</p>")
|
||||||
|
payload = {
|
||||||
|
"mimeType": "multipart/alternative",
|
||||||
|
"body": {},
|
||||||
|
"parts": [
|
||||||
|
{"mimeType": "text/html", "body": {"data": html_data}},
|
||||||
|
{"mimeType": "text/plain", "body": {"data": plain_data}},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
result = self._fn(payload)
|
||||||
|
assert result == "Plain text"
|
||||||
|
|
||||||
|
def test_empty_payload_returns_empty_string(self):
|
||||||
|
assert self._fn({}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# GmailClient.fetch_messages
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
def _make_gmail_message(
|
||||||
|
msg_id: str = "msg001",
|
||||||
|
subject: str = "Test email",
|
||||||
|
sender: str = "alice@example.com",
|
||||||
|
body_text: str = "Hello world",
|
||||||
|
date: str = "Mon, 01 Jan 2025 10:00:00 +0000",
|
||||||
|
) -> dict:
|
||||||
|
"""Build a minimal Gmail API message response dict."""
|
||||||
|
import base64
|
||||||
|
body_data = base64.urlsafe_b64encode(body_text.encode()).decode()
|
||||||
|
return {
|
||||||
|
"id": msg_id,
|
||||||
|
"labelIds": ["INBOX"],
|
||||||
|
"payload": {
|
||||||
|
"mimeType": "text/plain",
|
||||||
|
"headers": [
|
||||||
|
{"name": "Subject", "value": subject},
|
||||||
|
{"name": "From", "value": sender},
|
||||||
|
{"name": "Date", "value": date},
|
||||||
|
],
|
||||||
|
"body": {"data": body_data},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestGmailClientFetchMessages:
|
||||||
|
"""GmailClient.fetch_messages tests with mocked Google API."""
|
||||||
|
|
||||||
|
def _make_client(self) -> "GmailClient":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(_TOKEN_DICT)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_returns_email_messages(self):
|
||||||
|
client = self._make_client()
|
||||||
|
msg = _make_gmail_message()
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_users = mock_service.users.return_value
|
||||||
|
mock_messages = mock_users.messages.return_value
|
||||||
|
mock_messages.list.return_value.execute.return_value = {
|
||||||
|
"messages": [{"id": "msg001"}]
|
||||||
|
}
|
||||||
|
mock_messages.get.return_value.execute.return_value = msg
|
||||||
|
|
||||||
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||||
|
# Simulate to_thread running the sync function and returning results.
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
mock_thread.side_effect = fake_to_thread
|
||||||
|
|
||||||
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||||
|
patch("google.auth.transport.requests.Request"), \
|
||||||
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].subject == "Test email"
|
||||||
|
assert results[0].sender == "alice@example.com"
|
||||||
|
assert results[0].body_text == "Hello world"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_no_messages_returns_empty_list(self):
|
||||||
|
client = self._make_client()
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_users = mock_service.users.return_value
|
||||||
|
mock_messages = mock_users.messages.return_value
|
||||||
|
mock_messages.list.return_value.execute.return_value = {"messages": []}
|
||||||
|
|
||||||
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
mock_thread.side_effect = fake_to_thread
|
||||||
|
|
||||||
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||||
|
patch("google.auth.transport.requests.Request"), \
|
||||||
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_http_error_raises_runtime_error(self):
|
||||||
|
import googleapiclient.errors
|
||||||
|
client = self._make_client()
|
||||||
|
|
||||||
|
mock_service = MagicMock()
|
||||||
|
mock_users = mock_service.users.return_value
|
||||||
|
mock_messages = mock_users.messages.return_value
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.status = 403
|
||||||
|
mock_resp.reason = "Forbidden"
|
||||||
|
mock_messages.list.return_value.execute.side_effect = (
|
||||||
|
googleapiclient.errors.HttpError(mock_resp, b"Forbidden")
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.integrations.gmail.asyncio.to_thread") as mock_thread:
|
||||||
|
async def fake_to_thread(fn, *args, **kwargs):
|
||||||
|
return fn(*args, **kwargs)
|
||||||
|
mock_thread.side_effect = fake_to_thread
|
||||||
|
|
||||||
|
with patch("googleapiclient.discovery.build", return_value=mock_service), \
|
||||||
|
patch("google.auth.transport.requests.Request"), \
|
||||||
|
patch.object(type(client._credentials), "expired", new_callable=PropertyMock, return_value=False):
|
||||||
|
with pytest.raises(RuntimeError, match="Gmail messages.list failed"):
|
||||||
|
await client.fetch_messages()
|
||||||
|
|
||||||
|
def test_refreshed_credentials_none_when_unchanged(self):
|
||||||
|
client = self._make_client()
|
||||||
|
# Token unchanged — should return None.
|
||||||
|
assert client.refreshed_credentials is None
|
||||||
|
|
||||||
|
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
||||||
|
client = self._make_client()
|
||||||
|
# Simulate a token refresh by changing the access token on the credentials object.
|
||||||
|
client._credentials.token = "new_access_token_xyz"
|
||||||
|
refreshed = client.refreshed_credentials
|
||||||
|
assert refreshed is not None
|
||||||
|
assert refreshed["token"] == "new_access_token_xyz"
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# MS Graph client \u2014 email filter builder
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
class TestBuildEmailFilter:
|
||||||
|
"""Unit tests for ms_graph._build_email_filter."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
from app.integrations.ms_graph import _build_email_filter
|
||||||
|
self._fn = _build_email_filter
|
||||||
|
|
||||||
|
def test_empty_returns_empty_string(self):
|
||||||
|
assert self._fn(None, None) == ""
|
||||||
|
|
||||||
|
def test_single_sender(self):
|
||||||
|
result = self._fn({"senders": ["alice@example.com"]}, None)
|
||||||
|
assert "from/emailAddress/address eq 'alice@example.com'" in result
|
||||||
|
|
||||||
|
def test_multiple_senders_joined_with_or(self):
|
||||||
|
result = self._fn({"senders": ["a@x.com", "b@x.com"]}, None)
|
||||||
|
assert " or " in result
|
||||||
|
assert "a@x.com" in result
|
||||||
|
assert "b@x.com" in result
|
||||||
|
|
||||||
|
def test_since_adds_received_date_ge_clause(self):
|
||||||
|
since = datetime(2025, 3, 1, tzinfo=timezone.utc)
|
||||||
|
result = self._fn(None, since)
|
||||||
|
assert "receivedDateTime ge 2025-03-01T00:00:00Z" in result
|
||||||
|
|
||||||
|
def test_date_range_to_adds_received_date_le_clause(self):
|
||||||
|
result = self._fn({"date_range": {"to": "2025-06-30"}}, None)
|
||||||
|
assert "receivedDateTime le" in result
|
||||||
|
|
||||||
|
def test_since_overrides_earlier_date_range_from(self):
|
||||||
|
since = datetime(2025, 2, 1, tzinfo=timezone.utc)
|
||||||
|
result = self._fn({"date_range": {"from": "2025-01-01"}}, since)
|
||||||
|
assert "2025-02-01T00:00:00Z" in result
|
||||||
|
assert "2025-01-01" not in result
|
||||||
|
|
||||||
|
def test_invalid_date_ignored(self):
|
||||||
|
result = self._fn({"date_range": {"from": "bad-date"}}, None)
|
||||||
|
assert "receivedDateTime" not in result
|
||||||
|
|
||||||
|
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
# MSGraphClient.fetch_emails
|
||||||
|
# \u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500\u2500
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_email(
|
||||||
|
msg_id: str = "email001",
|
||||||
|
subject: str = "Meeting tomorrow",
|
||||||
|
sender_address: str = "boss@company.com",
|
||||||
|
body_content: str = "Please prepare the report.",
|
||||||
|
received: str = "2025-06-01T10:00:00Z",
|
||||||
|
) -> dict:
|
||||||
|
"""Build a minimal MS Graph message item dict."""
|
||||||
|
return {
|
||||||
|
"id": msg_id,
|
||||||
|
"subject": subject,
|
||||||
|
"from": {"emailAddress": {"address": sender_address}},
|
||||||
|
"receivedDateTime": received,
|
||||||
|
"body": {"contentType": "text", "content": body_content},
|
||||||
|
"bodyPreview": body_content[:100],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _make_graph_teams_message(
|
||||||
|
msg_id: str = "teams001",
|
||||||
|
content: str = "Stand-up at 9am",
|
||||||
|
sender: str = "alice",
|
||||||
|
channel_id: str = "chan001",
|
||||||
|
created: str = "2025-06-01T08:00:00Z",
|
||||||
|
) -> dict:
|
||||||
|
return {
|
||||||
|
"id": msg_id,
|
||||||
|
"body": {"contentType": "text", "content": content},
|
||||||
|
"from": {"user": {"displayName": sender}},
|
||||||
|
"channelIdentity": {"channelId": channel_id},
|
||||||
|
"createdDateTime": created,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphClientFetchEmails:
|
||||||
|
"""MSGraphClient.fetch_emails tests with mocked httpx."""
|
||||||
|
|
||||||
|
def _make_client(self) -> "MSGraphClient":
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_returns_email_messages(self):
|
||||||
|
client = self._make_client()
|
||||||
|
graph_email = _make_graph_email()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": [graph_email]}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_emails()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].subject == "Meeting tomorrow"
|
||||||
|
assert results[0].sender == "boss@company.com"
|
||||||
|
assert results[0].body_text == "Please prepare the report."
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_pagination_stops_at_max_emails(self):
|
||||||
|
"""No nextLink in first page \u2014 only one batch returned."""
|
||||||
|
client = self._make_client()
|
||||||
|
emails_batch = [_make_graph_email(msg_id=str(i)) for i in range(3)]
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": emails_batch} # no @odata.nextLink
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_emails()
|
||||||
|
|
||||||
|
assert len(results) == 3
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_401_triggers_token_refresh_and_retries(self):
|
||||||
|
"""On first 401, token refresh is attempted and the request retried."""
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
client = MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
graph_email = _make_graph_email()
|
||||||
|
|
||||||
|
response_401 = MagicMock()
|
||||||
|
response_401.status_code = 401
|
||||||
|
|
||||||
|
response_200 = MagicMock()
|
||||||
|
response_200.status_code = 200
|
||||||
|
response_200.json.return_value = {"value": [graph_email]}
|
||||||
|
response_200.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
call_count = 0
|
||||||
|
|
||||||
|
async def fake_get(url, params=None, headers=None):
|
||||||
|
nonlocal call_count
|
||||||
|
call_count += 1
|
||||||
|
if call_count == 1:
|
||||||
|
return response_401
|
||||||
|
return response_200
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls, \
|
||||||
|
patch.object(client, "_refresh_access_token", new_callable=AsyncMock) as mock_refresh:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = fake_get
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_emails()
|
||||||
|
|
||||||
|
mock_refresh.assert_called_once()
|
||||||
|
assert len(results) == 1
|
||||||
|
|
||||||
|
def test_refreshed_credentials_none_when_token_unchanged(self):
|
||||||
|
client = self._make_client()
|
||||||
|
assert client.refreshed_credentials is None
|
||||||
|
|
||||||
|
def test_refreshed_credentials_returns_dict_when_token_changes(self):
|
||||||
|
client = self._make_client()
|
||||||
|
client._access_token = "new_token_abc"
|
||||||
|
assert client.refreshed_credentials is not None
|
||||||
|
assert client.refreshed_credentials["access_token"] == "new_token_abc"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphClientFetchMessages:
|
||||||
|
"""MSGraphClient.fetch_messages (Teams) tests."""
|
||||||
|
|
||||||
|
def _make_client(self) -> "MSGraphClient":
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(_MS_TOKEN_DICT)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_happy_path_returns_chat_messages(self):
|
||||||
|
client = self._make_client()
|
||||||
|
teams_msg = _make_graph_teams_message()
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": [teams_msg]}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].content == "Stand-up at 9am"
|
||||||
|
assert results[0].sender == "alice"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_403_degrades_gracefully(self):
|
||||||
|
"""getAllMessages returning 403 (license issue) returns empty list, no exception."""
|
||||||
|
import httpx as _httpx
|
||||||
|
|
||||||
|
client = self._make_client()
|
||||||
|
|
||||||
|
error_response = MagicMock()
|
||||||
|
error_response.status_code = 403
|
||||||
|
http_error = _httpx.HTTPStatusError(
|
||||||
|
"Forbidden", request=MagicMock(), response=error_response
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(side_effect=http_error)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_messages()
|
||||||
|
|
||||||
|
assert results == []
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_channel_filter_applied(self):
|
||||||
|
"""Messages from non-matching channels are filtered out."""
|
||||||
|
client = self._make_client()
|
||||||
|
matching = _make_graph_teams_message(channel_id="dev-channel", content="Deploy today")
|
||||||
|
non_matching = _make_graph_teams_message(msg_id="t2", channel_id="random", content="Lunch?")
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.json.return_value = {"value": [matching, non_matching]}
|
||||||
|
mock_response.raise_for_status = MagicMock()
|
||||||
|
|
||||||
|
with patch("app.integrations.ms_graph.httpx.AsyncClient") as mock_client_cls:
|
||||||
|
mock_http = AsyncMock()
|
||||||
|
mock_http.get = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client_cls.return_value.__aenter__ = AsyncMock(return_value=mock_http)
|
||||||
|
mock_client_cls.return_value.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
|
||||||
|
results = await client.fetch_messages(
|
||||||
|
filter_config={"channels": ["dev-channel"]}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(results) == 1
|
||||||
|
assert results[0].content == "Deploy today"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMSGraphClientRefreshToken:
|
||||||
|
"""MSGraphClient._refresh_access_token with mocked MSAL."""
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_msal_error_raises_runtime_error(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_test"})
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.acquire_token_by_refresh_token.return_value = {
|
||||||
|
"error": "invalid_grant",
|
||||||
|
"error_description": "Refresh token expired",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
||||||
|
patch("app.integrations.ms_graph.settings") as mock_settings:
|
||||||
|
mock_settings.MS_CLIENT_ID = "client_id"
|
||||||
|
mock_settings.MS_CLIENT_SECRET = "secret"
|
||||||
|
mock_settings.MS_TENANT_ID = "common"
|
||||||
|
with pytest.raises(RuntimeError, match="MS Graph token refresh failed"):
|
||||||
|
await client._refresh_access_token()
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_successful_refresh_updates_access_token(self):
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
client = MSGraphClient({**_MS_TOKEN_DICT, "refresh_token": "rt_old"})
|
||||||
|
|
||||||
|
mock_app = MagicMock()
|
||||||
|
mock_app.acquire_token_by_refresh_token.return_value = {
|
||||||
|
"access_token": "new_access_token",
|
||||||
|
"refresh_token": "new_refresh_token",
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch("msal.ConfidentialClientApplication", return_value=mock_app), \
|
||||||
|
patch("app.integrations.ms_graph.settings") as mock_settings:
|
||||||
|
mock_settings.MS_CLIENT_ID = "client_id"
|
||||||
|
mock_settings.MS_CLIENT_SECRET = "secret"
|
||||||
|
mock_settings.MS_TENANT_ID = "common"
|
||||||
|
await client._refresh_access_token()
|
||||||
|
|
||||||
|
assert client._access_token == "new_access_token"
|
||||||
|
assert client._refresh_token == "new_refresh_token"
|
||||||
284
tests/test_memory_middleware.py
Normal file
284
tests/test_memory_middleware.py
Normal file
@@ -0,0 +1,284 @@
|
|||||||
|
"""Tests for Step 7 — MemoryMiddleware.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
1. enrich_context returns core prefs + associative + episodic + proactive
|
||||||
|
2. store_episode creates an encrypted row decryptable with the user's key
|
||||||
|
3. update_core upserts correctly
|
||||||
|
4. User with no encryption_key returns empty context (no crash)
|
||||||
|
5. End-to-end: home_request WS frame results in an episodic row being stored
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import (
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||||
|
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
_FERNET_KEY = Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── DB override ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixtures ──────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def user_with_key(db_session):
|
||||||
|
"""Set encryption_key on the seeded power user."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = _FERNET_KEY
|
||||||
|
await db_session.commit()
|
||||||
|
return user
|
||||||
|
|
||||||
|
|
||||||
|
def _fernet():
|
||||||
|
return Fernet(_FERNET_KEY.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def _enc(plaintext: str) -> str:
|
||||||
|
return _fernet().encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _dec(ciphertext: str) -> str:
|
||||||
|
return _fernet().decrypt(ciphertext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── enrich_context ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_core_memory(db_session, user_with_key):
|
||||||
|
# Seed a core memory row
|
||||||
|
db_session.add(MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
key="timezone",
|
||||||
|
value_encrypted=_enc("UTC"),
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "What are my tasks?")
|
||||||
|
|
||||||
|
assert "core_memory" in ctx
|
||||||
|
assert ctx["core_memory"]["timezone"] == "UTC"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_episodic_memory(db_session, user_with_key):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
db_session.add(MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=_enc("User asked about Q1 tasks"),
|
||||||
|
session_id=session_id,
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
||||||
|
|
||||||
|
assert "episodic_memory" in ctx
|
||||||
|
assert any("Q1 tasks" in s for s in ctx["episodic_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
|
||||||
|
# Add one pattern above threshold and one below
|
||||||
|
db_session.add(MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=_enc("User prefers short summaries"),
|
||||||
|
confidence=0.9,
|
||||||
|
source="inferred",
|
||||||
|
))
|
||||||
|
db_session.add(MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=_enc("User likes dark mode"),
|
||||||
|
confidence=0.1,
|
||||||
|
source="inferred",
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "any message")
|
||||||
|
|
||||||
|
assert "proactive_hints" in ctx
|
||||||
|
hints = ctx["proactive_hints"]
|
||||||
|
assert any("short summaries" in h for h in hints)
|
||||||
|
assert not any("dark mode" in h for h in hints)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_returns_associative_memory(db_session, user_with_key):
|
||||||
|
db_session.add(MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
content_encrypted=_enc("Related memory about meetings"),
|
||||||
|
embedding=None,
|
||||||
|
entity_type="note",
|
||||||
|
))
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "meetings")
|
||||||
|
|
||||||
|
assert "associative_memory" in ctx
|
||||||
|
assert any("meetings" in m for m in ctx["associative_memory"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_enrich_context_empty_for_user_without_key(db_session):
|
||||||
|
"""User with no encryption_key → empty context, no crash."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = None
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
ctx = await middleware.enrich_context(USER_ID, "hello")
|
||||||
|
assert ctx == {}
|
||||||
|
|
||||||
|
|
||||||
|
# ── store_episode ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_episode_creates_encrypted_row(db_session, user_with_key):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.store_episode(USER_ID, session_id, "hello", "world")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
plaintext = _dec(row.summary_encrypted)
|
||||||
|
assert "hello" in plaintext
|
||||||
|
assert "world" in plaintext
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_store_episode_decryptable(db_session, user_with_key):
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.store_episode(USER_ID, session_id, "msg", "resp")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
# Decrypt using the same key — must not raise
|
||||||
|
decrypted = _dec(row.summary_encrypted)
|
||||||
|
assert len(decrypted) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── update_core ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_core_insert(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.update_core(USER_ID, "lang", "en")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
||||||
|
)
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert _dec(row.value_encrypted) == "en"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_update_core_upsert(db_session, user_with_key):
|
||||||
|
middleware = MemoryMiddleware(db_session)
|
||||||
|
await middleware.update_core(USER_ID, "lang", "en")
|
||||||
|
await middleware.update_core(USER_ID, "lang", "fr")
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID, MemoryCore.key == "lang")
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert _dec(rows[0].value_encrypted) == "fr"
|
||||||
|
|
||||||
|
|
||||||
|
# ── End-to-end WS: memory middleware is called during home_request ────────────
|
||||||
|
|
||||||
|
def test_home_request_calls_memory_middleware(client):
|
||||||
|
"""home_request triggers enrich_context before and store_episode after the LLM."""
|
||||||
|
enrich_calls: list[tuple] = []
|
||||||
|
store_calls: list[tuple] = []
|
||||||
|
|
||||||
|
class _MockMiddleware:
|
||||||
|
def __init__(self, db):
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def enrich_context(self, user_id, message):
|
||||||
|
enrich_calls.append((user_id, message))
|
||||||
|
return {"core_memory": {"tz": "UTC"}}
|
||||||
|
|
||||||
|
async def store_episode(self, user_id, session_id, message, response):
|
||||||
|
store_calls.append((user_id, session_id, message, response))
|
||||||
|
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async def _mock_stream(user_id, message, context, db_session_factory=None):
|
||||||
|
# Verify memory context was injected
|
||||||
|
assert context.get("core_memory") == {"tz": "UTC"}
|
||||||
|
yield ("token", "Done")
|
||||||
|
yield ("mutations", [])
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.api.routes.device_ws.MemoryMiddleware", _MockMiddleware),
|
||||||
|
patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_stream),
|
||||||
|
):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-mem", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": "r-mem",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Show tasks",
|
||||||
|
}))
|
||||||
|
for _ in range(20):
|
||||||
|
raw = ws.receive_text()
|
||||||
|
frame = json.loads(raw)
|
||||||
|
if frame.get("type") == "stream_end":
|
||||||
|
break
|
||||||
|
|
||||||
|
assert len(enrich_calls) == 1
|
||||||
|
assert enrich_calls[0] == (USER_ID, "Show tasks")
|
||||||
|
assert len(store_calls) == 1
|
||||||
|
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
||||||
|
assert stored_session_id == session_id
|
||||||
|
assert stored_message == "Show tasks"
|
||||||
205
tests/test_memory_models.py
Normal file
205
tests/test_memory_models.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Tests for Step 6 — memory ORM models and User.encryption_key.
|
||||||
|
|
||||||
|
Uses the SQLite in-memory test DB (from conftest). The pgvector embedding
|
||||||
|
column is stored as JSON in tests (SQLite-compatible).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.models import MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, User
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _fernet_key() -> str:
|
||||||
|
return Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _encrypt(key: str, plaintext: str) -> str:
|
||||||
|
return Fernet(key.encode()).encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _decrypt(key: str, ciphertext: str) -> str:
|
||||||
|
return Fernet(key.encode()).decrypt(ciphertext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
# ── User.encryption_key ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_encryption_key_column_exists(db_session):
|
||||||
|
"""User model has encryption_key column and it can be set."""
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
# Column exists (may be None for seeded users)
|
||||||
|
assert hasattr(user, "encryption_key")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_user_encryption_key_can_be_set(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
result = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.encryption_key = key
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result2 = await db_session.execute(select(User).where(User.id == USER_ID))
|
||||||
|
user2 = result2.scalar_one()
|
||||||
|
assert user2.encryption_key == key
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryCore ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_core_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
encrypted_val = _encrypt(key, "UTC")
|
||||||
|
|
||||||
|
row = MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
key="timezone",
|
||||||
|
value_encrypted=encrypted_val,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == USER_ID)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.key == "timezone"
|
||||||
|
assert _decrypt(key, fetched.value_encrypted) == "UTC"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_core_cascade_delete(db_session):
|
||||||
|
"""Deleting a user cascades to memory_core."""
|
||||||
|
row = MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
key="lang",
|
||||||
|
value_encrypted="enc",
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
user = (await db_session.execute(select(User).where(User.id == USER_ID))).scalar_one()
|
||||||
|
await db_session.delete(user)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
remaining = (
|
||||||
|
await db_session.execute(select(MemoryCore).where(MemoryCore.user_id == USER_ID))
|
||||||
|
).scalars().all()
|
||||||
|
assert remaining == []
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryAssociative ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_associative_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
content = _encrypt(key, "User prefers morning meetings")
|
||||||
|
embedding = [0.1] * 1536 # fake embedding
|
||||||
|
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
content_encrypted=content,
|
||||||
|
embedding=embedding,
|
||||||
|
entity_type="preference",
|
||||||
|
entity_id=None,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryAssociative).where(MemoryAssociative.user_id == USER_ID)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.entity_type == "preference"
|
||||||
|
assert _decrypt(key, fetched.content_encrypted) == "User prefers morning meetings"
|
||||||
|
assert len(fetched.embedding) == 1536
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryEpisodic ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_episodic_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
summary = _encrypt(key, "User asked about Q1 tasks")
|
||||||
|
|
||||||
|
row = MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
summary_encrypted=summary,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryEpisodic).where(MemoryEpisodic.session_id == session_id)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert _decrypt(key, fetched.summary_encrypted) == "User asked about Q1 tasks"
|
||||||
|
assert isinstance(fetched.created_at, datetime)
|
||||||
|
|
||||||
|
|
||||||
|
# ── MemoryProactive ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_memory_proactive_create_and_read(db_session):
|
||||||
|
key = _fernet_key()
|
||||||
|
pattern = _encrypt(key, "User always assigns tasks to self")
|
||||||
|
|
||||||
|
row = MemoryProactive(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=USER_ID,
|
||||||
|
pattern_encrypted=pattern,
|
||||||
|
confidence=0.85,
|
||||||
|
source="inferred",
|
||||||
|
)
|
||||||
|
db_session.add(row)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
result = await db_session.execute(
|
||||||
|
select(MemoryProactive).where(MemoryProactive.user_id == USER_ID)
|
||||||
|
)
|
||||||
|
fetched = result.scalar_one()
|
||||||
|
assert fetched.confidence == pytest.approx(0.85)
|
||||||
|
assert fetched.source == "inferred"
|
||||||
|
assert _decrypt(key, fetched.pattern_encrypted) == "User always assigns tasks to self"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Auth registration generates encryption_key ───────────────────────────────
|
||||||
|
|
||||||
|
def test_register_sets_encryption_key(client):
|
||||||
|
"""POST /api/v1/auth/register creates a user with a valid Fernet key."""
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": "newuser@test.com", "password": "testpassword123"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
|
||||||
|
# Fetch the newly created user via the access token
|
||||||
|
token = resp.json()["access_token"]
|
||||||
|
me_resp = client.get(
|
||||||
|
"/api/v1/auth/me",
|
||||||
|
headers={"Authorization": f"Bearer {token}"},
|
||||||
|
)
|
||||||
|
assert me_resp.status_code == 200
|
||||||
|
# We can't see encryption_key in the API response (not in UserProfile),
|
||||||
|
# but we verify registration didn't crash — key generation is implicit.
|
||||||
@@ -18,13 +18,29 @@ from fastapi.testclient import TestClient
|
|||||||
from jose import jwt
|
from jose import jwt
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.db import get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.schemas import ChatResponse
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Autouse: redirect all DB access to the in-memory SQLite test engine.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
"""Route all get_session calls to the test SQLite session."""
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
_CHAT_BODY = {
|
_CHAT_BODY = {
|
||||||
"message": "hello",
|
"message": "hello",
|
||||||
"context": {
|
"context": {
|
||||||
@@ -33,7 +49,6 @@ _CHAT_BODY = {
|
|||||||
"recent_tasks": [],
|
"recent_tasks": [],
|
||||||
"conversation_history": [],
|
"conversation_history": [],
|
||||||
},
|
},
|
||||||
"execution_mode": "direct",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -74,14 +89,15 @@ class TestAuthMiddleware:
|
|||||||
"""Tests exercised via GET /api/v1/auth/me."""
|
"""Tests exercised via GET /api/v1/auth/me."""
|
||||||
|
|
||||||
def test_valid_token_returns_profile(self) -> None:
|
def test_valid_token_returns_profile(self) -> None:
|
||||||
uid = str(uuid.uuid4())
|
# Use the seeded pro user so the subscription lookup returns 'pro'.
|
||||||
token = _make_jwt(user_id=uid, email="alice@example.com", tier="pro")
|
uid = TEST_USER_IDS["pro"]
|
||||||
|
token = _make_jwt(user_id=uid, email="pro@test.com", tier="pro")
|
||||||
with TestClient(app) as client:
|
with TestClient(app) as client:
|
||||||
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
resp = client.get("/api/v1/auth/me", headers=_auth_header(token))
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["id"] == uid
|
assert data["id"] == uid
|
||||||
assert data["email"] == "alice@example.com"
|
assert data["email"] == "pro@test.com"
|
||||||
assert data["tier"] == "pro"
|
assert data["tier"] == "pro"
|
||||||
|
|
||||||
def test_missing_token_returns_401(self) -> None:
|
def test_missing_token_returns_401(self) -> None:
|
||||||
@@ -222,7 +238,7 @@ class TestRateLimitMiddleware:
|
|||||||
|
|
||||||
|
|
||||||
class TestSanitizerMiddleware:
|
class TestSanitizerMiddleware:
|
||||||
"""Mock ``orchestrate`` to inject controlled strings into chat responses."""
|
"""Mock ``run_home`` to inject controlled strings into chat responses."""
|
||||||
|
|
||||||
_CHAT_PATH = "/api/v1/chat"
|
_CHAT_PATH = "/api/v1/chat"
|
||||||
|
|
||||||
@@ -230,11 +246,10 @@ class TestSanitizerMiddleware:
|
|||||||
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
return _make_jwt(user_id=str(uuid.uuid4()), tier="pro")
|
||||||
|
|
||||||
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
def _post_chat(self, client: TestClient, response_text: str) -> dict:
|
||||||
mock_response = ChatResponse(response=response_text, actions=[])
|
|
||||||
with patch(
|
with patch(
|
||||||
"app.api.routes.chat.orchestrate",
|
"app.api.routes.chat.run_home",
|
||||||
new_callable=AsyncMock,
|
new_callable=AsyncMock,
|
||||||
return_value=mock_response,
|
return_value=response_text,
|
||||||
):
|
):
|
||||||
resp = client.post(
|
resp = client.post(
|
||||||
self._CHAT_PATH,
|
self._CHAT_PATH,
|
||||||
|
|||||||
@@ -1,348 +0,0 @@
|
|||||||
"""Integration tests for the orchestrator module."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from app.core.agent_registry import AgentRegistry, ChatAgent
|
|
||||||
from app.core.orchestrator import (
|
|
||||||
classify_intent,
|
|
||||||
orchestrate,
|
|
||||||
orchestrate_stream,
|
|
||||||
route_pipeline,
|
|
||||||
route_single,
|
|
||||||
)
|
|
||||||
from app.schemas import ChatContext, ChatRequest, ChatResponse, ExecutionPlan
|
|
||||||
|
|
||||||
|
|
||||||
# ── Stub agents ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _TaskAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "task_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages tasks: create, update, list, suggest"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"task: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
class _CalendarAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "calendar_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Calendar management: events, conflicts, scheduling"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
return f"calendar: {query}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _mock_llm(response_text: str) -> MagicMock:
|
|
||||||
"""Return a mock LLM that always produces *response_text*."""
|
|
||||||
msg = MagicMock()
|
|
||||||
msg.content = response_text
|
|
||||||
llm = MagicMock()
|
|
||||||
llm.ainvoke = AsyncMock(return_value=msg)
|
|
||||||
return llm
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixtures ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
|
||||||
def _fresh_registry():
|
|
||||||
"""Reset the AgentRegistry singleton between tests."""
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
yield
|
|
||||||
AgentRegistry._instance = None
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture()
|
|
||||||
def reg() -> AgentRegistry:
|
|
||||||
r = AgentRegistry()
|
|
||||||
r.register(_TaskAgent)
|
|
||||||
r.register(_CalendarAgent)
|
|
||||||
return r
|
|
||||||
|
|
||||||
|
|
||||||
# ── classify_intent ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestClassifyIntent:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_routes_to_known_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
result = await classify_intent("add a task", {}, reg)
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_routes_to_calendar_agent(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
|
||||||
result = await classify_intent("schedule a meeting", {}, reg)
|
|
||||||
assert result == "calendar_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_falls_back_on_unknown_name(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("nonexistent_agent")
|
|
||||||
result = await classify_intent("do something", {}, reg)
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_empty_registry_returns_fallback_without_llm_call(self) -> None:
|
|
||||||
empty_reg = AgentRegistry()
|
|
||||||
# No LLM should be instantiated — early return path
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
result = await classify_intent("anything", {}, empty_reg)
|
|
||||||
mock_cls.assert_not_called()
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_whitespace_stripped_from_response(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm(" task_agent \n")
|
|
||||||
result = await classify_intent("create task", {}, reg)
|
|
||||||
assert result == "task_agent"
|
|
||||||
|
|
||||||
|
|
||||||
# ── route_single ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestRouteSingle:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
|
||||||
result = await route_single("task_agent", "create a task", {}, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_response_contains_agent_output(self, reg: AgentRegistry) -> None:
|
|
||||||
result = await route_single("task_agent", "create a task", {}, reg)
|
|
||||||
assert result.response == "task: create a task"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_unknown_agent_raises_key_error(self, reg: AgentRegistry) -> None:
|
|
||||||
with pytest.raises(KeyError):
|
|
||||||
await route_single("nonexistent", "hello", {}, reg)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_actions_default_empty(self, reg: AgentRegistry) -> None:
|
|
||||||
result = await route_single("task_agent", "hi", {}, reg)
|
|
||||||
assert result.actions == []
|
|
||||||
|
|
||||||
|
|
||||||
# ── route_pipeline ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestRoutePipeline:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_returns_chat_response(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("synthesized result")
|
|
||||||
result = await route_pipeline(
|
|
||||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
|
||||||
)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_response_is_synthesis_output(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("synthesized result")
|
|
||||||
result = await route_pipeline(
|
|
||||||
["task_agent", "calendar_agent"], "plan my week", {}, reg
|
|
||||||
)
|
|
||||||
assert result.response == "synthesized result"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_passes_previous_results_to_subsequent_agents(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
"""Each agent after the first should receive prior outputs in context."""
|
|
||||||
received_contexts: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
class _CapturingAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "capture"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "captures context for testing"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return []
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
received_contexts.append(dict(context))
|
|
||||||
return "captured"
|
|
||||||
|
|
||||||
reg.register(_CapturingAgent)
|
|
||||||
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("done")
|
|
||||||
await route_pipeline(["task_agent", "capture"], "hi", {}, reg)
|
|
||||||
|
|
||||||
# The second agent (capture) must have received previous results
|
|
||||||
assert len(received_contexts) == 1
|
|
||||||
assert "previous_results" in received_contexts[0]
|
|
||||||
assert received_contexts[0]["previous_results"] == ["task: hi"]
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_single_agent_pipeline(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("single result")
|
|
||||||
result = await route_pipeline(["task_agent"], "one agent", {}, reg)
|
|
||||||
assert result.response == "single result"
|
|
||||||
|
|
||||||
|
|
||||||
# ── orchestrate ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestOrchestrate:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_direct_mode_returns_chat_response(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_direct_mode_response_content(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
assert result.response == "task: add a task"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_returns_execution_plan(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="plan my tasks", execution_mode="plan")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_agent_matches_classified(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("calendar_agent")
|
|
||||||
request = ChatRequest(
|
|
||||||
message="schedule something", execution_mode="plan"
|
|
||||||
)
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
assert result.agent == "calendar_agent"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_has_steps(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
assert len(result.steps) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_plan_mode_template_id_contains_agent_name(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="plan tasks", execution_mode="plan")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ExecutionPlan)
|
|
||||||
assert result.steps[0].prompt_template is not None
|
|
||||||
assert "task_agent" in result.steps[0].prompt_template
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_default_execution_mode_is_direct(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
# execution_mode defaults to "direct"
|
|
||||||
request = ChatRequest(message="help me")
|
|
||||||
result = await orchestrate(request, reg)
|
|
||||||
assert isinstance(result, ChatResponse)
|
|
||||||
|
|
||||||
|
|
||||||
# ── orchestrate_stream ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class TestOrchestrateStream:
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_yields_at_least_one_chunk(self, reg: AgentRegistry) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
assert len(chunks) >= 1
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_last_chunk_is_final_json_frame(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="add a task", execution_mode="direct")
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
|
|
||||||
last = json.loads(chunks[-1])
|
|
||||||
assert last["done"] is True
|
|
||||||
assert "response" in last
|
|
||||||
assert "actions" in last
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_final_frame_response_matches_agent_output(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(message="create a task", execution_mode="direct")
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
|
|
||||||
final = json.loads(chunks[-1])
|
|
||||||
assert final["response"] == "task: create a task"
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_text_chunks_before_final_frame(
|
|
||||||
self, reg: AgentRegistry
|
|
||||||
) -> None:
|
|
||||||
with patch("app.core.orchestrator.ChatOpenAI") as mock_cls:
|
|
||||||
mock_cls.return_value = _mock_llm("task_agent")
|
|
||||||
request = ChatRequest(
|
|
||||||
message="x" * 200, execution_mode="direct"
|
|
||||||
) # long enough to produce multiple chunks
|
|
||||||
chunks = [chunk async for chunk in orchestrate_stream(request, reg)]
|
|
||||||
|
|
||||||
# All but the last chunk should be plain text (not valid final JSON)
|
|
||||||
non_final = chunks[:-1]
|
|
||||||
for chunk in non_final:
|
|
||||||
try:
|
|
||||||
parsed = json.loads(chunk)
|
|
||||||
assert parsed.get("done") is not True
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass # plain text chunk — expected
|
|
||||||
214
tests/test_output_formatter.py
Normal file
214
tests/test_output_formatter.py
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
"""Tests for app.core.output_formatter — HomeFormatter and FloatingFormatter."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _stream(*events: tuple[str, object]):
|
||||||
|
"""Async generator that yields (event_type, data) tuples."""
|
||||||
|
for event in events:
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def collect(formatter, event_stream):
|
||||||
|
frames = []
|
||||||
|
async for frame in formatter.format(event_stream):
|
||||||
|
frames.append(frame)
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
# ── HomeFormatter ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_plain_text():
|
||||||
|
req_id = "req-1"
|
||||||
|
events = [
|
||||||
|
("token", "Hello world"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert frames[0].request_id == req_id
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert any("Hello world" in f.chunk for f in text_frames)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_entity_tags_passed_through():
|
||||||
|
"""Entity tags are streamed as-is — the frontend parses them."""
|
||||||
|
req_id = "req-2"
|
||||||
|
events = [
|
||||||
|
("token", "Here is your project:\n<project>[abc-123]</project>\nAll good."),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert "<project>[abc-123]</project>" in text
|
||||||
|
assert "Here is your project:" in text
|
||||||
|
assert "All good." in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_multiple_tags_passed_through():
|
||||||
|
req_id = "req-3"
|
||||||
|
events = [
|
||||||
|
("token", "<project>[p1]</project>\nText\n<task>[t1,t2]</task>"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert "<project>[p1]</project>" in text
|
||||||
|
assert "<task>[t1,t2]</task>" in text
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_tool_end_ignored():
|
||||||
|
"""tool_end events are silently ignored by HomeFormatter."""
|
||||||
|
req_id = "req-4"
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "task_agent", "result": "3 tasks"}),
|
||||||
|
("token", "No tags here."),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
text = "".join(f.chunk for f in frames if isinstance(f, WsStreamText))
|
||||||
|
assert text == "No tags here."
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_mutations_in_stream_end():
|
||||||
|
req_id = "req-5"
|
||||||
|
muts = [{"action": "insert", "table": "tasks", "data": {"id": "t1"}}]
|
||||||
|
events = [
|
||||||
|
("token", "Done"),
|
||||||
|
("mutations", muts),
|
||||||
|
]
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
end_frame = frames[-1]
|
||||||
|
assert isinstance(end_frame, WsStreamEnd)
|
||||||
|
assert len(end_frame.mutations) == 1
|
||||||
|
assert end_frame.mutations[0]["action"] == "insert"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_home_formatter_frame_order():
|
||||||
|
"""stream_start is first, stream_end is last."""
|
||||||
|
req_id = "req-6"
|
||||||
|
formatter = HomeFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(("token", "Hi"), ("mutations", [])))
|
||||||
|
assert isinstance(frames[0], WsStreamStart)
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
# ── FloatingFormatter ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_domain_from_tool_end():
|
||||||
|
req_id = "pop-1"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "task_agent", "result": "ok"}),
|
||||||
|
("token", "Hello"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "tasks"
|
||||||
|
assert frames[0].request_id == req_id
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_text_only():
|
||||||
|
req_id = "pop-2"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "timeline_agent", "result": "done"}),
|
||||||
|
("token", "Summary"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "timelines"
|
||||||
|
text_frames = [f for f in frames if isinstance(f, WsStreamText)]
|
||||||
|
assert len(text_frames) == 1
|
||||||
|
assert text_frames[0].chunk == "Summary"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_no_entity_tags():
|
||||||
|
"""FloatingFormatter never emits entity tag blocks."""
|
||||||
|
req_id = "pop-3"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "note_agent", "result": "data"}),
|
||||||
|
("token", "some text"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
# Only expected frame types
|
||||||
|
for f in frames:
|
||||||
|
assert isinstance(f, (WsFloatingDomain, WsStreamStart, WsStreamText, WsStreamEnd))
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_end_frame():
|
||||||
|
req_id = "pop-4"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [
|
||||||
|
("tool_end", {"name": "project_agent", "result": "ok"}),
|
||||||
|
("token", "Done"),
|
||||||
|
("mutations", []),
|
||||||
|
]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
assert isinstance(frames[-1], WsStreamEnd)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_default_domain_on_early_token():
|
||||||
|
"""When the first event is a token (no tool_end yet), default to 'tasks'."""
|
||||||
|
req_id = "pop-5"
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
events = [("token", "hi"), ("mutations", [])]
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
assert isinstance(frames[0], WsFloatingDomain)
|
||||||
|
assert frames[0].domain == "tasks"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_floating_formatter_mutations_in_stream_end():
|
||||||
|
req_id = "pop-6"
|
||||||
|
muts = [{"action": "update", "table": "tasks", "data": {"id": "t2"}}]
|
||||||
|
events = [
|
||||||
|
("token", "Updated"),
|
||||||
|
("mutations", muts),
|
||||||
|
]
|
||||||
|
formatter = FloatingFormatter(request_id=req_id)
|
||||||
|
frames = await collect(formatter, _stream(*events))
|
||||||
|
|
||||||
|
end_frame = frames[-1]
|
||||||
|
assert isinstance(end_frame, WsStreamEnd)
|
||||||
|
assert len(end_frame.mutations) == 1
|
||||||
@@ -1,52 +1,32 @@
|
|||||||
"""Tests for Step 10: Plugin Marketplace.
|
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
|
||||||
|
|
||||||
Covers:
|
Covers:
|
||||||
- PluginRegistry: catalog management, filtering, sorting, install counts
|
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
|
||||||
- ReviewQueue: pending queue, review decisions, manifest security checklist
|
- ReviewQueue: pending queue, review decisions, manifest security checklist
|
||||||
- RevenueShare: install event recording, earnings aggregation
|
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
|
||||||
- Route integration: tier gate, list/get/install/uninstall via TestClient
|
- Route integration: tier gate, list/get/install/uninstall via TestClient
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
from sqlalchemy import select
|
||||||
from fastapi.testclient import TestClient
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
from jose import jwt
|
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.main import app
|
|
||||||
from app.marketplace.plugin_registry import PluginRegistry
|
from app.marketplace.plugin_registry import PluginRegistry
|
||||||
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
|
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
|
||||||
from app.marketplace.revenue_share import RevenueShare
|
from app.marketplace.revenue_share import RevenueShare
|
||||||
|
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
|
||||||
from app.schemas import PluginManifest
|
from app.schemas import PluginManifest
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
def _make_jwt(tier: str = "power", user_id: str | None = None) -> str:
|
|
||||||
uid = user_id or str(uuid.uuid4())
|
|
||||||
now = int(time.time())
|
|
||||||
payload = {
|
|
||||||
"sub": uid,
|
|
||||||
"email": f"{uid[:8]}@example.com",
|
|
||||||
"tier": tier,
|
|
||||||
"exp": now + 3600,
|
|
||||||
"iat": now,
|
|
||||||
}
|
|
||||||
return jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
|
||||||
|
|
||||||
|
|
||||||
def _auth(tier: str = "power") -> dict[str, str]:
|
|
||||||
return {"Authorization": f"Bearer {_make_jwt(tier)}"}
|
|
||||||
|
|
||||||
|
|
||||||
def _fresh_manifest(
|
def _fresh_manifest(
|
||||||
plugin_id: str | None = None,
|
plugin_id: str | None = None,
|
||||||
category: str = "productivity",
|
category: str = "productivity",
|
||||||
@@ -67,118 +47,150 @@ def _fresh_manifest(
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# PluginRegistry
|
# PluginRegistry (DB-backed)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestPluginRegistry:
|
class TestPluginRegistry:
|
||||||
"""Each test uses a fresh PluginRegistry instance to avoid catalog pollution."""
|
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def reg(self) -> PluginRegistry:
|
def reg(self) -> PluginRegistry:
|
||||||
return PluginRegistry()
|
return PluginRegistry()
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_seed_plugins_are_approved(self, reg: PluginRegistry) -> None:
|
async def test_seed_plugins_are_listed(
|
||||||
result = await reg.list_plugins()
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
result = await reg.list_plugins(db_session)
|
||||||
assert result.total == 3
|
assert result.total == 3
|
||||||
assert all(p.id.startswith("plugin-") for p in result.plugins)
|
assert all(p.id.startswith("plugin-") for p in result.plugins)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_approved_only(self, reg: PluginRegistry) -> None:
|
async def test_list_approved_only(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
manifest = _fresh_manifest()
|
manifest = _fresh_manifest()
|
||||||
await reg.submit_plugin(manifest, "plugins/key.zip")
|
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
|
||||||
result = await reg.list_plugins()
|
result = await reg.list_plugins(db_session)
|
||||||
ids = [p.id for p in result.plugins]
|
ids = [p.id for p in result.plugins]
|
||||||
assert manifest.id not in ids # still pending
|
assert manifest.id not in ids # still pending
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_filter_by_category(self, reg: PluginRegistry) -> None:
|
async def test_list_filter_by_category(
|
||||||
result = await reg.list_plugins(category="communication")
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
result = await reg.list_plugins(db_session, category="communication")
|
||||||
assert result.total == 1
|
assert result.total == 1
|
||||||
assert result.plugins[0].id == "plugin-slack-notify"
|
assert result.plugins[0].id == "plugin-slack-notify"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_filter_by_query(self, reg: PluginRegistry) -> None:
|
async def test_list_filter_by_query(
|
||||||
result = await reg.list_plugins(query="time")
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
result = await reg.list_plugins(db_session, query="time tracker")
|
||||||
assert result.total == 1
|
assert result.total == 1
|
||||||
assert result.plugins[0].id == "plugin-time-tracker"
|
assert result.plugins[0].id == "plugin-time-tracker"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_list_sort_by_installs(self, reg: PluginRegistry) -> None:
|
async def test_list_sort_by_installs(
|
||||||
await reg.record_install("plugin-slack-notify")
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
await reg.record_install("plugin-slack-notify")
|
) -> None:
|
||||||
result = await reg.list_plugins(sort="installs")
|
await reg.record_install(db_session, "plugin-slack-notify")
|
||||||
|
await reg.record_install(db_session, "plugin-slack-notify")
|
||||||
|
result = await reg.list_plugins(db_session, sort="installs")
|
||||||
assert result.plugins[0].id == "plugin-slack-notify"
|
assert result.plugins[0].id == "plugin-slack-notify"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_plugin_found(self, reg: PluginRegistry) -> None:
|
async def test_get_plugin_found(
|
||||||
entry = await reg.get_plugin("plugin-github-sync")
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
|
) -> None:
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
assert entry is not None
|
assert entry is not None
|
||||||
assert entry["manifest"].id == "plugin-github-sync"
|
assert entry["manifest"].id == "plugin-github-sync"
|
||||||
assert "install_count" in entry
|
assert "install_count" in entry
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_plugin_not_found(self, reg: PluginRegistry) -> None:
|
async def test_get_plugin_not_found(
|
||||||
entry = await reg.get_plugin("no-such-plugin")
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
entry = await reg.get_plugin(db_session, "no-such-plugin")
|
||||||
assert entry is None
|
assert entry is None
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_submit_sets_pending(self, reg: PluginRegistry) -> None:
|
async def test_submit_sets_pending(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
manifest = _fresh_manifest()
|
manifest = _fresh_manifest()
|
||||||
plugin_id = await reg.submit_plugin(manifest, "key.zip")
|
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
assert plugin_id == manifest.id
|
assert plugin_id == manifest.id
|
||||||
assert reg._catalog[plugin_id]["status"] == "pending_review"
|
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "pending_review"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_approve_makes_visible(self, reg: PluginRegistry) -> None:
|
async def test_approve_makes_visible(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
manifest = _fresh_manifest()
|
manifest = _fresh_manifest()
|
||||||
await reg.submit_plugin(manifest, "key.zip")
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
await reg.approve_plugin(manifest.id)
|
await reg.approve_plugin(db_session, manifest.id)
|
||||||
result = await reg.list_plugins()
|
result = await reg.list_plugins(db_session)
|
||||||
assert manifest.id in [p.id for p in result.plugins]
|
assert manifest.id in [p.id for p in result.plugins]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_reject_stores_reason(self, reg: PluginRegistry) -> None:
|
async def test_reject_stores_reason(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
manifest = _fresh_manifest()
|
manifest = _fresh_manifest()
|
||||||
await reg.submit_plugin(manifest, "key.zip")
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
await reg.reject_plugin(manifest.id, reason="Unsafe permissions")
|
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
|
||||||
assert reg._catalog[manifest.id]["status"] == "rejected"
|
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||||
assert reg._catalog[manifest.id]["rejection_reason"] == "Unsafe permissions"
|
row = result.scalar_one()
|
||||||
result = await reg.list_plugins()
|
assert row.status == "rejected"
|
||||||
assert manifest.id not in [p.id for p in result.plugins]
|
assert row.rejection_reason == "Unsafe permissions"
|
||||||
|
listed = await reg.list_plugins(db_session)
|
||||||
|
assert manifest.id not in [p.id for p in listed.plugins]
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_approve_unknown_raises_key_error(self, reg: PluginRegistry) -> None:
|
async def test_approve_unknown_raises_key_error(
|
||||||
|
self, reg: PluginRegistry, db_session: AsyncSession
|
||||||
|
) -> None:
|
||||||
with pytest.raises(KeyError):
|
with pytest.raises(KeyError):
|
||||||
await reg.approve_plugin("ghost-plugin")
|
await reg.approve_plugin(db_session, "ghost-plugin")
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_record_install_increments_count(self, reg: PluginRegistry) -> None:
|
async def test_record_install_increments_count(
|
||||||
await reg.record_install("plugin-github-sync")
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
entry = await reg.get_plugin("plugin-github-sync")
|
) -> None:
|
||||||
|
await reg.record_install(db_session, "plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
assert entry is not None
|
assert entry is not None
|
||||||
assert entry["install_count"] == 1
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_record_uninstall_decrements_count(self, reg: PluginRegistry) -> None:
|
async def test_record_uninstall_decrements_count(
|
||||||
await reg.record_install("plugin-github-sync")
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
await reg.record_install("plugin-github-sync")
|
) -> None:
|
||||||
await reg.record_uninstall("plugin-github-sync")
|
await reg.record_install(db_session, "plugin-github-sync")
|
||||||
entry = await reg.get_plugin("plugin-github-sync")
|
await reg.record_install(db_session, "plugin-github-sync")
|
||||||
|
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
assert entry is not None
|
assert entry is not None
|
||||||
assert entry["install_count"] == 1
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_record_uninstall_floors_at_zero(self, reg: PluginRegistry) -> None:
|
async def test_record_uninstall_floors_at_zero(
|
||||||
await reg.record_uninstall("plugin-github-sync") # already 0
|
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
entry = await reg.get_plugin("plugin-github-sync")
|
) -> None:
|
||||||
|
await reg.record_uninstall(db_session, "plugin-github-sync")
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
assert entry is not None
|
assert entry is not None
|
||||||
assert entry["install_count"] == 0
|
assert entry["install_count"] == 0
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# ReviewQueue
|
# ReviewQueue (DB-backed)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
@@ -188,37 +200,47 @@ class TestReviewQueue:
|
|||||||
return PluginRegistry()
|
return PluginRegistry()
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def queue(self, reg: PluginRegistry) -> ReviewQueue:
|
def queue(self) -> ReviewQueue:
|
||||||
# Patch the 'registry' name as bound inside plugin_review.py
|
return ReviewQueue()
|
||||||
with patch("app.marketplace.plugin_review.registry", reg):
|
|
||||||
yield ReviewQueue()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_pending_returns_submitted_plugins(
|
async def test_get_pending_returns_submitted_plugins(
|
||||||
self, reg: PluginRegistry, queue: ReviewQueue
|
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||||
) -> None:
|
) -> None:
|
||||||
manifest = _fresh_manifest()
|
manifest = _fresh_manifest()
|
||||||
await reg.submit_plugin(manifest, "key.zip")
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
pending = await queue.get_pending()
|
pending = await queue.get_pending(db_session)
|
||||||
assert any(p["plugin_id"] == manifest.id for p in pending)
|
assert any(p["plugin_id"] == manifest.id for p in pending)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_submit_review_approved(
|
async def test_submit_review_approved(
|
||||||
self, reg: PluginRegistry, queue: ReviewQueue
|
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||||
) -> None:
|
) -> None:
|
||||||
manifest = _fresh_manifest()
|
manifest = _fresh_manifest()
|
||||||
await reg.submit_plugin(manifest, "key.zip")
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
await queue.submit_review(manifest.id, "reviewer-1", "approved", "Looks good")
|
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
|
||||||
assert reg._catalog[manifest.id]["status"] == "approved"
|
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "approved"
|
||||||
|
# Check review row was persisted
|
||||||
|
review_result = await db_session.execute(
|
||||||
|
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
|
||||||
|
)
|
||||||
|
review = review_result.scalar_one()
|
||||||
|
assert review.decision == "approved"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_submit_review_rejected(
|
async def test_submit_review_rejected(
|
||||||
self, reg: PluginRegistry, queue: ReviewQueue
|
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
|
||||||
) -> None:
|
) -> None:
|
||||||
manifest = _fresh_manifest()
|
manifest = _fresh_manifest()
|
||||||
await reg.submit_plugin(manifest, "key.zip")
|
await reg.submit_plugin(db_session, manifest, "key.zip")
|
||||||
await queue.submit_review(manifest.id, "reviewer-1", "rejected", "Bad permissions")
|
await queue.submit_review(
|
||||||
assert reg._catalog[manifest.id]["status"] == "rejected"
|
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
|
||||||
|
)
|
||||||
|
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
|
||||||
|
row = result.scalar_one()
|
||||||
|
assert row.status == "rejected"
|
||||||
|
|
||||||
def test_validate_manifest_ok(self) -> None:
|
def test_validate_manifest_ok(self) -> None:
|
||||||
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
|
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
|
||||||
@@ -241,65 +263,66 @@ class TestReviewQueue:
|
|||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# RevenueShare
|
# RevenueShare (DB-backed)
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
class TestRevenueShare:
|
class TestRevenueShare:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def reg(self) -> PluginRegistry:
|
def rs(self) -> RevenueShare:
|
||||||
return PluginRegistry()
|
return RevenueShare()
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def rs(self, reg: PluginRegistry) -> RevenueShare:
|
|
||||||
# Patch the 'registry' name as bound inside revenue_share.py
|
|
||||||
with patch("app.marketplace.revenue_share.registry", reg):
|
|
||||||
yield RevenueShare()
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_record_install_free_plugin(
|
async def test_record_install_free_plugin(
|
||||||
self, reg: PluginRegistry, rs: RevenueShare
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
) -> None:
|
) -> None:
|
||||||
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
|
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||||
assert len(rs._events) == 1
|
result = await db_session.execute(
|
||||||
assert rs._events[0]["developer_share_cents"] == 0
|
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
|
||||||
|
)
|
||||||
|
event = result.scalar_one()
|
||||||
|
assert event.developer_share_cents == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_record_install_paid_plugin_no_stripe(
|
async def test_record_install_paid_plugin_no_stripe(
|
||||||
self, reg: PluginRegistry, rs: RevenueShare
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
) -> None:
|
) -> None:
|
||||||
# No STRIPE_SECRET_KEY configured in test env — should not crash
|
await rs.record_install(
|
||||||
await rs.record_install("plugin-slack-notify", "user-2", amount_cents=499)
|
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
|
||||||
assert len(rs._events) == 1
|
)
|
||||||
assert rs._events[0]["amount_cents"] == 499
|
result = await db_session.execute(
|
||||||
assert rs._events[0]["developer_share_cents"] == int(499 * 0.70)
|
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
|
||||||
|
)
|
||||||
|
event = result.scalar_one()
|
||||||
|
assert event.amount_cents == 499
|
||||||
|
assert event.developer_share_cents == int(499 * 0.70)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_record_install_increments_registry_count(
|
async def test_record_install_increments_registry_count(
|
||||||
self, reg: PluginRegistry, rs: RevenueShare
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
) -> None:
|
) -> None:
|
||||||
await rs.record_install("plugin-github-sync", "user-1", amount_cents=0)
|
reg = PluginRegistry()
|
||||||
entry = await reg.get_plugin("plugin-github-sync")
|
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
|
||||||
|
entry = await reg.get_plugin(db_session, "plugin-github-sync")
|
||||||
assert entry is not None
|
assert entry is not None
|
||||||
assert entry["install_count"] == 1
|
assert entry["install_count"] == 1
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_earnings_empty(
|
async def test_get_earnings_empty(
|
||||||
self, reg: PluginRegistry, rs: RevenueShare
|
self, rs: RevenueShare, db_session: AsyncSession
|
||||||
) -> None:
|
) -> None:
|
||||||
result = await rs.get_earnings("unknown-dev")
|
result = await rs.get_earnings(db_session, "unknown-dev")
|
||||||
assert result["total_installs"] == 0
|
assert result["total_installs"] == 0
|
||||||
assert result["total_revenue_cents"] == 0
|
assert result["total_revenue_cents"] == 0
|
||||||
assert result["developer_share_cents"] == 0
|
assert result["developer_share_cents"] == 0
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_get_earnings_aggregates(
|
async def test_get_earnings_aggregates(
|
||||||
self, reg: PluginRegistry, rs: RevenueShare
|
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
|
||||||
) -> None:
|
) -> None:
|
||||||
# "Adiuva" is the author of the seeded plugins
|
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
|
||||||
await rs.record_install("plugin-slack-notify", "u1", amount_cents=499)
|
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
|
||||||
await rs.record_install("plugin-slack-notify", "u2", amount_cents=499)
|
result = await rs.get_earnings(db_session, "Adiuva")
|
||||||
result = await rs.get_earnings("Adiuva")
|
|
||||||
assert result["total_installs"] == 2
|
assert result["total_installs"] == 2
|
||||||
assert result["total_revenue_cents"] == 998
|
assert result["total_revenue_cents"] == 998
|
||||||
assert result["developer_share_cents"] == int(499 * 0.70) * 2
|
assert result["developer_share_cents"] == int(499 * 0.70) * 2
|
||||||
@@ -311,77 +334,67 @@ class TestRevenueShare:
|
|||||||
|
|
||||||
|
|
||||||
class TestPluginRoutes:
|
class TestPluginRoutes:
|
||||||
def test_list_plugins_requires_power_tier(self) -> None:
|
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
|
||||||
resp = client.get("/api/v1/plugins", headers=_auth("free"))
|
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
|
|
||||||
def test_list_plugins_pro_tier_blocked(self) -> None:
|
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
|
||||||
resp = client.get("/api/v1/plugins", headers=_auth("pro"))
|
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
|
|
||||||
def test_list_plugins_power_tier_ok(self) -> None:
|
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
|
||||||
resp = client.get("/api/v1/plugins", headers=_auth("power"))
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert "plugins" in data
|
assert "plugins" in data
|
||||||
assert data["total"] >= 3
|
assert data["total"] == 3
|
||||||
|
|
||||||
def test_list_plugins_team_tier_ok(self) -> None:
|
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
|
||||||
resp = client.get("/api/v1/plugins", headers=_auth("team"))
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
|
|
||||||
def test_get_plugin_found(self) -> None:
|
def test_get_plugin_found(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
|
||||||
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=_auth())
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["plugin"]["id"] == "plugin-github-sync"
|
assert data["plugin"]["id"] == "plugin-github-sync"
|
||||||
assert "install_count" in data
|
assert "install_count" in data
|
||||||
|
|
||||||
def test_get_plugin_not_found(self) -> None:
|
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
|
||||||
resp = client.get("/api/v1/plugins/no-such-plugin", headers=_auth())
|
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
def test_install_plugin_free(self) -> None:
|
def test_install_plugin_free(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.post(
|
||||||
resp = client.post(
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
"/api/v1/plugins/plugin-github-sync/install",
|
json={"plugin_id": "plugin-github-sync"},
|
||||||
json={"plugin_id": "plugin-github-sync"},
|
headers=auth_header(),
|
||||||
headers=_auth(),
|
)
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
data = resp.json()
|
data = resp.json()
|
||||||
assert data["ok"] is True
|
assert data["ok"] is True
|
||||||
assert "download_url" in data
|
assert "download_url" in data
|
||||||
|
|
||||||
def test_install_plugin_not_found(self) -> None:
|
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.post(
|
||||||
resp = client.post(
|
"/api/v1/plugins/ghost/install",
|
||||||
"/api/v1/plugins/ghost/install",
|
json={"plugin_id": "ghost"},
|
||||||
json={"plugin_id": "ghost"},
|
headers=auth_header(),
|
||||||
headers=_auth(),
|
)
|
||||||
)
|
|
||||||
assert resp.status_code == 404
|
assert resp.status_code == 404
|
||||||
|
|
||||||
def test_uninstall_plugin_ok(self) -> None:
|
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.delete(
|
||||||
resp = client.delete(
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
"/api/v1/plugins/plugin-github-sync/install",
|
headers=auth_header(),
|
||||||
headers=_auth(),
|
)
|
||||||
)
|
|
||||||
assert resp.status_code == 200
|
assert resp.status_code == 200
|
||||||
assert resp.json()["ok"] is True
|
assert resp.json()["ok"] is True
|
||||||
|
|
||||||
def test_install_requires_power_tier(self) -> None:
|
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
|
||||||
with TestClient(app) as client:
|
resp = client.post(
|
||||||
resp = client.post(
|
"/api/v1/plugins/plugin-github-sync/install",
|
||||||
"/api/v1/plugins/plugin-github-sync/install",
|
json={"plugin_id": "plugin-github-sync"},
|
||||||
json={"plugin_id": "plugin-github-sync"},
|
headers=auth_header("free"),
|
||||||
headers=_auth("free"),
|
)
|
||||||
)
|
|
||||||
assert resp.status_code == 403
|
assert resp.status_code == 403
|
||||||
|
|||||||
230
tests/test_schemas_v3.py
Normal file
230
tests/test_schemas_v3.py
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
"""Tests for v3 WebSocket frame protocol schemas."""
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from pydantic import ValidationError
|
||||||
|
|
||||||
|
from app.schemas import (
|
||||||
|
WsFrameType,
|
||||||
|
WsHomeRequest,
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsFloatingRequest,
|
||||||
|
WsFloatingScope,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsFrameType ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_v3_frame_types_exist():
|
||||||
|
v3_types = [
|
||||||
|
"home_request",
|
||||||
|
"floating_request",
|
||||||
|
"stream_start",
|
||||||
|
"stream_text",
|
||||||
|
"stream_end",
|
||||||
|
"floating_domain",
|
||||||
|
"data_request",
|
||||||
|
"data_response",
|
||||||
|
"mutation",
|
||||||
|
]
|
||||||
|
for name in v3_types:
|
||||||
|
assert hasattr(WsFrameType, name), f"WsFrameType missing: {name}"
|
||||||
|
assert WsFrameType[name].value == name
|
||||||
|
|
||||||
|
|
||||||
|
def test_v2_frame_types_still_exist():
|
||||||
|
"""Backward compat: v2 types must remain."""
|
||||||
|
v2_types = [
|
||||||
|
"chat_request",
|
||||||
|
"text_chunk",
|
||||||
|
"tool_call",
|
||||||
|
"tool_result",
|
||||||
|
"final",
|
||||||
|
"ping",
|
||||||
|
"agent_run",
|
||||||
|
"agent_data",
|
||||||
|
"agent_complete",
|
||||||
|
"device_hello",
|
||||||
|
]
|
||||||
|
for name in v2_types:
|
||||||
|
assert hasattr(WsFrameType, name), f"v2 WsFrameType missing: {name}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsHomeRequest ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_defaults():
|
||||||
|
frame = WsHomeRequest(message="Hello")
|
||||||
|
assert frame.type == WsFrameType.home_request
|
||||||
|
assert frame.message == "Hello"
|
||||||
|
assert frame.conversation_history == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_with_history():
|
||||||
|
history = [{"role": "user", "content": "Hi"}, {"role": "assistant", "content": "Hello!"}]
|
||||||
|
frame = WsHomeRequest(message="Follow up", conversation_history=history)
|
||||||
|
assert frame.conversation_history == history
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_serializes():
|
||||||
|
frame = WsHomeRequest(message="Test")
|
||||||
|
data = frame.model_dump()
|
||||||
|
assert data["type"] == "home_request"
|
||||||
|
assert data["message"] == "Test"
|
||||||
|
assert data["conversation_history"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_deserializes():
|
||||||
|
raw = {"type": "home_request", "message": "Hi there"}
|
||||||
|
frame = WsHomeRequest.model_validate(raw)
|
||||||
|
assert frame.message == "Hi there"
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_requires_message():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsHomeRequest.model_validate({"type": "home_request"})
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsFloatingRequest ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_basic():
|
||||||
|
frame = WsFloatingRequest(
|
||||||
|
message="Summarise",
|
||||||
|
scope=WsFloatingScope(type="task", id="task-123"),
|
||||||
|
)
|
||||||
|
assert frame.type == WsFrameType.floating_request
|
||||||
|
assert frame.scope.type == "task"
|
||||||
|
assert frame.scope.id == "task-123"
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_scope_without_id():
|
||||||
|
frame = WsFloatingRequest(
|
||||||
|
message="Show all",
|
||||||
|
scope=WsFloatingScope(type="project"),
|
||||||
|
)
|
||||||
|
assert frame.scope.id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_serializes():
|
||||||
|
frame = WsFloatingRequest(
|
||||||
|
message="Test",
|
||||||
|
scope=WsFloatingScope(type="note", id="n-1"),
|
||||||
|
)
|
||||||
|
data = frame.model_dump()
|
||||||
|
assert data["type"] == "floating_request"
|
||||||
|
assert data["scope"]["type"] == "note"
|
||||||
|
assert data["scope"]["id"] == "n-1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_invalid_scope_type():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsFloatingRequest(
|
||||||
|
message="X",
|
||||||
|
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_requires_scope():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamStart ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_start():
|
||||||
|
frame = WsStreamStart(request_id="req-abc")
|
||||||
|
assert frame.type == WsFrameType.stream_start
|
||||||
|
assert frame.request_id == "req-abc"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_start_serializes():
|
||||||
|
data = WsStreamStart(request_id="r1").model_dump()
|
||||||
|
assert data == {"type": "stream_start", "request_id": "r1"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_start_deserializes():
|
||||||
|
frame = WsStreamStart.model_validate({"type": "stream_start", "request_id": "r1"})
|
||||||
|
assert frame.request_id == "r1"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamText ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_text():
|
||||||
|
frame = WsStreamText(request_id="r1", chunk="Hello ")
|
||||||
|
assert frame.type == WsFrameType.stream_text
|
||||||
|
assert frame.chunk == "Hello "
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_text_serializes():
|
||||||
|
data = WsStreamText(request_id="r1", chunk="word").model_dump()
|
||||||
|
assert data == {"type": "stream_text", "request_id": "r1", "chunk": "word"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_text_deserializes():
|
||||||
|
raw = {"type": "stream_text", "request_id": "r2", "chunk": "test"}
|
||||||
|
frame = WsStreamText.model_validate(raw)
|
||||||
|
assert frame.chunk == "test"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsStreamEnd ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_defaults():
|
||||||
|
frame = WsStreamEnd(request_id="r1")
|
||||||
|
assert frame.type == WsFrameType.stream_end
|
||||||
|
assert frame.mutations == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_with_mutations():
|
||||||
|
mutations = [{"action": "create", "table": "tasks", "data": {"title": "New task"}}]
|
||||||
|
frame = WsStreamEnd(request_id="r1", mutations=mutations)
|
||||||
|
assert len(frame.mutations) == 1
|
||||||
|
assert frame.mutations[0]["action"] == "create"
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_serializes():
|
||||||
|
data = WsStreamEnd(request_id="r2").model_dump()
|
||||||
|
assert data == {"type": "stream_end", "request_id": "r2", "mutations": []}
|
||||||
|
|
||||||
|
|
||||||
|
def test_stream_end_deserializes():
|
||||||
|
raw = {"type": "stream_end", "request_id": "r3", "mutations": []}
|
||||||
|
frame = WsStreamEnd.model_validate(raw)
|
||||||
|
assert frame.request_id == "r3"
|
||||||
|
|
||||||
|
|
||||||
|
# ── WsFloatingDomain ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_tasks():
|
||||||
|
frame = WsFloatingDomain(request_id="r1", domain="tasks")
|
||||||
|
assert frame.type == WsFrameType.floating_domain
|
||||||
|
assert frame.domain == "tasks"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("domain", ["tasks", "timelines", "notes", "projects"])
|
||||||
|
def test_floating_domain_valid_domains(domain: str):
|
||||||
|
frame = WsFloatingDomain(request_id="r1", domain=domain) # type: ignore[arg-type]
|
||||||
|
assert frame.domain == domain
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_invalid():
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
WsFloatingDomain(request_id="r1", domain="invalid") # type: ignore[arg-type]
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_serializes():
|
||||||
|
d = WsFloatingDomain(request_id="r1", domain="notes").model_dump()
|
||||||
|
assert d == {"type": "floating_domain", "request_id": "r1", "domain": "notes"}
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_domain_deserializes():
|
||||||
|
raw = {"type": "floating_domain", "request_id": "r1", "domain": "projects"}
|
||||||
|
frame = WsFloatingDomain.model_validate(raw)
|
||||||
|
assert frame.domain == "projects"
|
||||||
@@ -1,48 +1,30 @@
|
|||||||
"""Tests for the storage layer: encryption, BlobStore, and VectorStore."""
|
"""Tests for the storage layer: encryption, BlobStore, VectorStore, and storage routes."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import hashlib
|
import hashlib
|
||||||
import os
|
|
||||||
from unittest.mock import MagicMock, patch
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
import boto3
|
import boto3
|
||||||
import pytest
|
import pytest
|
||||||
from botocore.exceptions import ClientError
|
from botocore.exceptions import ClientError
|
||||||
from moto import mock_aws
|
|
||||||
|
|
||||||
from app.storage.encryption import reject_if_tampered, verify_checksum
|
from app.storage.encryption import reject_if_tampered, verify_checksum
|
||||||
from app.storage.blob_store import BlobStore
|
from app.storage.blob_store import BlobStore
|
||||||
from app.storage.vector_store import VectorStore, _blob_to_vector
|
from app.storage.vector_store import VectorStore, _blob_to_vector
|
||||||
from app.schemas import VectorItem, VectorSearchResult
|
from app.schemas import VectorItem, VectorSearchResult
|
||||||
|
from tests.conftest import auth_header, S3_TEST_BUCKET
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ───────────────────────────────────────────────────────────
|
# ── Helpers ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_BLOB = b"encrypted-payload-opaque-to-server"
|
_BLOB = b"encrypted-payload-opaque-to-server"
|
||||||
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
|
||||||
_BUCKET = "test-bucket"
|
_BUCKET = S3_TEST_BUCKET
|
||||||
_REGION = "us-east-1"
|
_REGION = "us-east-1"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
|
||||||
def s3_bucket():
|
|
||||||
"""Create a mocked S3 bucket and expose its name."""
|
|
||||||
with mock_aws():
|
|
||||||
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
|
||||||
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
|
||||||
os.environ.setdefault("AWS_DEFAULT_REGION", _REGION)
|
|
||||||
client = boto3.client("s3", region_name=_REGION)
|
|
||||||
client.create_bucket(Bucket=_BUCKET)
|
|
||||||
with patch("app.storage.blob_store.settings") as mock_settings:
|
|
||||||
mock_settings.S3_BUCKET = _BUCKET
|
|
||||||
mock_settings.S3_REGION = _REGION
|
|
||||||
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
|
||||||
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
|
||||||
yield _BUCKET
|
|
||||||
|
|
||||||
|
|
||||||
def _pinecone_mock():
|
def _pinecone_mock():
|
||||||
"""Return a mock Pinecone index with realistic return shapes."""
|
"""Return a mock Pinecone index with realistic return shapes."""
|
||||||
mock_index = MagicMock()
|
mock_index = MagicMock()
|
||||||
@@ -383,3 +365,198 @@ class TestVectorStoreQdrant:
|
|||||||
await store.delete("u1", ["v1"])
|
await store.delete("u1", ["v1"])
|
||||||
call_kwargs = mock_client.delete.call_args[1]
|
call_kwargs = mock_client.delete.call_args[1]
|
||||||
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
assert call_kwargs["collection_name"] == "adiuva_vectors"
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestStorageRoutes (integration) ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestStorageRoutes:
|
||||||
|
"""Integration tests for POST/GET/PUT/DELETE /api/v1/storage/records.
|
||||||
|
|
||||||
|
Pydantic v2 converts JSON string → bytes via ``str.encode('utf-8')``.
|
||||||
|
So "hello" in JSON becomes ``b"hello"`` on the server. We use plain
|
||||||
|
ASCII strings as blob values and compute checksums accordingly.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_BLOB_STR = "encrypted-payload-opaque-to-server"
|
||||||
|
_BLOB_BYTES = _BLOB_STR.encode()
|
||||||
|
_BLOB_CHECKSUM = hashlib.sha256(_BLOB_BYTES).hexdigest()
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _create_payload(cls, blob_str: str | None = None) -> dict:
|
||||||
|
blob_str = blob_str or cls._BLOB_STR
|
||||||
|
checksum = hashlib.sha256(blob_str.encode()).hexdigest()
|
||||||
|
return {
|
||||||
|
"table": "tasks",
|
||||||
|
"blob": blob_str,
|
||||||
|
"checksum": checksum,
|
||||||
|
}
|
||||||
|
|
||||||
|
def _create_record(self, client, tier="power", blob_str=None):
|
||||||
|
payload = self._create_payload(blob_str)
|
||||||
|
return client.post(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
json=payload,
|
||||||
|
headers=auth_header(tier),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Create ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_create_record(self, client, s3_bucket) -> None:
|
||||||
|
resp = self._create_record(client)
|
||||||
|
assert resp.status_code == 201
|
||||||
|
data = resp.json()
|
||||||
|
assert "id" in data
|
||||||
|
assert "created_at" in data
|
||||||
|
|
||||||
|
def test_create_record_bad_checksum(self, client, s3_bucket) -> None:
|
||||||
|
payload = {
|
||||||
|
"table": "tasks",
|
||||||
|
"blob": self._BLOB_STR,
|
||||||
|
"checksum": "0" * 64,
|
||||||
|
}
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
json=payload,
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
def test_create_record_free_tier_blocked(self, client, s3_bucket) -> None:
|
||||||
|
"""Free tier has cloud_storage_gb=0 → 402."""
|
||||||
|
resp = self._create_record(client, tier="free")
|
||||||
|
assert resp.status_code == 402
|
||||||
|
|
||||||
|
def test_create_record_pro_tier_allowed(self, client, s3_bucket) -> None:
|
||||||
|
"""Pro tier has cloud_storage_gb=5 → succeeds for small blob."""
|
||||||
|
resp = self._create_record(client, tier="pro")
|
||||||
|
assert resp.status_code == 201
|
||||||
|
|
||||||
|
# ── List ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_list_records(self, client, s3_bucket) -> None:
|
||||||
|
self._create_record(client)
|
||||||
|
self._create_record(client, blob_str="second-blob")
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 2
|
||||||
|
# Each entry has metadata, no blob bytes
|
||||||
|
for item in data:
|
||||||
|
assert "id" in item
|
||||||
|
assert "table" in item
|
||||||
|
assert "checksum" in item
|
||||||
|
assert "blob" not in item
|
||||||
|
|
||||||
|
def test_list_records_filter_by_table(self, client, s3_bucket) -> None:
|
||||||
|
self._create_record(client)
|
||||||
|
# Create in a different table
|
||||||
|
note_blob = "note-blob"
|
||||||
|
payload = {
|
||||||
|
"table": "notes",
|
||||||
|
"blob": note_blob,
|
||||||
|
"checksum": hashlib.sha256(note_blob.encode()).hexdigest(),
|
||||||
|
}
|
||||||
|
client.post(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
json=payload,
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records?table=notes",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert len(data) == 1
|
||||||
|
assert data[0]["table"] == "notes"
|
||||||
|
|
||||||
|
def test_list_records_isolated_per_user(self, client, s3_bucket) -> None:
|
||||||
|
"""One user's records should not appear in another user's list."""
|
||||||
|
self._create_record(client, tier="power")
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records",
|
||||||
|
headers=auth_header("team"),
|
||||||
|
)
|
||||||
|
assert resp.json() == []
|
||||||
|
|
||||||
|
# ── Download ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_download_record(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
resp = client.get(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.content == self._BLOB_BYTES
|
||||||
|
assert resp.headers["X-Checksum"] == self._BLOB_CHECKSUM
|
||||||
|
|
||||||
|
def test_download_record_not_found(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.get(
|
||||||
|
"/api/v1/storage/records/nonexistent-id",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
# ── Update ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_update_record(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
new_blob_str = "updated-encrypted-payload"
|
||||||
|
new_checksum = hashlib.sha256(new_blob_str.encode()).hexdigest()
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
json={"blob": new_blob_str, "checksum": new_checksum},
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
# Verify download returns the updated blob
|
||||||
|
dl = client.get(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert dl.content == new_blob_str.encode()
|
||||||
|
|
||||||
|
def test_update_record_bad_checksum(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
resp = client.put(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
json={"blob": "some-data", "checksum": "0" * 64},
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 400
|
||||||
|
|
||||||
|
# ── Delete ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_delete_record(self, client, s3_bucket) -> None:
|
||||||
|
create_resp = self._create_record(client)
|
||||||
|
record_id = create_resp.json()["id"]
|
||||||
|
resp = client.delete(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
assert resp.json() == {"ok": True}
|
||||||
|
|
||||||
|
# Subsequent GET should return 404
|
||||||
|
dl = client.get(
|
||||||
|
f"/api/v1/storage/records/{record_id}",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert dl.status_code == 404
|
||||||
|
|
||||||
|
def test_delete_record_not_found(self, client, s3_bucket) -> None:
|
||||||
|
resp = client.delete(
|
||||||
|
"/api/v1/storage/records/nonexistent",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|||||||
158
tests/test_ws_unified.py
Normal file
158
tests/test_ws_unified.py
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
"""Integration tests for the unified WebSocket handler (Step 5).
|
||||||
|
|
||||||
|
Tests the device WS endpoint with home_request and floating_request frames,
|
||||||
|
verifying that the correct v3 frame sequence is returned.
|
||||||
|
|
||||||
|
LLM calls are mocked to avoid network dependency.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||||
|
|
||||||
|
USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
|
||||||
|
# ── helpers ───────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
|
||||||
|
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
|
||||||
|
frames = []
|
||||||
|
for _ in range(max_frames):
|
||||||
|
raw = ws.receive_text()
|
||||||
|
frame = json.loads(raw)
|
||||||
|
frames.append(frame)
|
||||||
|
if frame.get("type") == WsFrameType.stream_end:
|
||||||
|
break
|
||||||
|
return frames
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_home_stream(user_id, message, context, db_session_factory=None):
|
||||||
|
yield "token", "Here are your tasks:\n<task>[t1,t2]</task>"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
|
||||||
|
async def _mock_floating_stream(user_id, message, context, scope=None, db_session_factory=None):
|
||||||
|
yield "tool_end", {"name": "task_agent", "result": "ok"}
|
||||||
|
yield "token", "Here is a summary"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
|
||||||
|
# ── tests ─────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_home_request_produces_stream_frames(client):
|
||||||
|
"""home_request → stream_start, stream_text+, stream_end."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": "r1",
|
||||||
|
"message": "List my tasks",
|
||||||
|
"conversation_history": [],
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
types = [f["type"] for f in frames]
|
||||||
|
assert WsFrameType.stream_start in types
|
||||||
|
assert WsFrameType.stream_end in types
|
||||||
|
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
|
|
||||||
|
def test_floating_request_produces_domain_frame(client):
|
||||||
|
"""floating_request → floating_domain first, then stream_text*, stream_end."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "floating_request",
|
||||||
|
"request_id": "p1",
|
||||||
|
"message": "Summarize this task",
|
||||||
|
"scope": {"type": "task", "id": "task-123"},
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
types = [f["type"] for f in frames]
|
||||||
|
assert WsFrameType.floating_domain in types
|
||||||
|
assert WsFrameType.stream_end in types
|
||||||
|
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
|
||||||
|
|
||||||
|
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
|
||||||
|
assert domain_frame["domain"] == "tasks"
|
||||||
|
assert domain_frame["request_id"] == "p1"
|
||||||
|
|
||||||
|
|
||||||
|
def test_home_request_request_id_propagated(client):
|
||||||
|
"""request_id in home_request is echoed in all response frames."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
req_id = "my-unique-req-id"
|
||||||
|
|
||||||
|
async def _stream(user_id, message, context, db_session_factory=None):
|
||||||
|
yield "token", "ok"
|
||||||
|
yield "mutations", []
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "home_request",
|
||||||
|
"request_id": req_id,
|
||||||
|
"message": "hello",
|
||||||
|
}))
|
||||||
|
frames = _recv_until_end(ws)
|
||||||
|
|
||||||
|
for f in frames:
|
||||||
|
if "request_id" in f:
|
||||||
|
assert f["request_id"] == req_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_tool_result_dispatch_silent_on_unknown_id(client):
|
||||||
|
"""tool_result for unknown call_id is silently ignored — no crash."""
|
||||||
|
token = make_jwt("power", user_id=USER_ID)
|
||||||
|
|
||||||
|
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
|
||||||
|
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "device_hello", "device_id": "dev-4", "agent_ids": []
|
||||||
|
}))
|
||||||
|
ws.send_text(json.dumps({
|
||||||
|
"type": "tool_result", "id": "no-such-id", "ok": True
|
||||||
|
}))
|
||||||
|
# If connection is still alive, we'll get the heartbeat ping
|
||||||
|
msg = json.loads(ws.receive_text())
|
||||||
|
assert msg["type"] == "ping"
|
||||||
|
|
||||||
|
|
||||||
|
def test_invalid_jwt_rejected(client):
|
||||||
|
"""Connection with bad token is closed before or after accept."""
|
||||||
|
with pytest.raises(Exception):
|
||||||
|
with client.websocket_connect("/api/v1/ws/device?token=badtoken") as ws:
|
||||||
|
ws.receive_text()
|
||||||
Reference in New Issue
Block a user