Compare commits
92 Commits
8f7bc25611
...
feature/mi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b7d302ef2 | ||
|
|
7f6ea29525 | ||
|
|
48036397f1 | ||
|
|
57b5648915 | ||
|
|
7e4374c69b | ||
|
|
fe0dd038ee | ||
|
|
d3f7099d93 | ||
|
|
63fa119543 | ||
|
|
d856dfd28c | ||
|
|
ccba54ac24 | ||
|
|
55500cc818 | ||
|
|
75a826c9d8 | ||
|
|
971f1dd84f | ||
|
|
333bba6fdd | ||
|
|
229e20d073 | ||
|
|
0b491b3643 | ||
|
|
0d5fa3e569 | ||
|
|
aff68a9051 | ||
|
|
5e9ef2809e | ||
|
|
90018af311 | ||
|
|
1e2e395676 | ||
|
|
59d3a53980 | ||
|
|
9feeaa79c8 | ||
|
|
aa219a4d08 | ||
|
|
552b8eb305 | ||
|
|
0d93b3960d | ||
|
|
f07580574b | ||
|
|
1a8bf11f90 | ||
|
|
e7cdce8287 | ||
|
|
58bc6efd4b | ||
|
|
6c450805cb | ||
|
|
f340d0fa3e | ||
|
|
edc53cb6eb | ||
|
|
725cece5c1 | ||
|
|
297e20ce8d | ||
|
|
5a03bd1cfb | ||
|
|
87b7a1c6c9 | ||
|
|
826f64d6bb | ||
| 5faa6b1d7c | |||
| 02a9684cd6 | |||
| fae9efee0d | |||
| 30b062dd4a | |||
| 2a0331d7ce | |||
| 13fd8677c1 | |||
| 9bd629cb59 | |||
| 9c97702daa | |||
| a1e364c9c0 | |||
| 5b55f1292a | |||
| 5bc9ea6cd6 | |||
| f7404b6f66 | |||
| d667e43c73 | |||
| fe085a7951 | |||
| 2de67213f8 | |||
| f6ed383b3a | |||
| 9332e29e53 | |||
| 618076193a | |||
| 34f01234c9 | |||
| 0bd46937d3 | |||
| e6b5bc2e7d | |||
| c90ed58078 | |||
| 76c8f2bdad | |||
| 393b3befd6 | |||
| 2c08275934 | |||
| 7cb384fa63 | |||
| 7efaeba283 | |||
| b61ded8458 | |||
| ac71d99f9a | |||
| 3b3b3baf25 | |||
| 45415bb9ee | |||
| a775a2da18 | |||
| 24772f2b67 | |||
| fd1396a710 | |||
| 914f70bd85 | |||
| 608d6c784f | |||
| 19ad5be97f | |||
| 1dfd088e18 | |||
| c6e1e4e7fd | |||
| cc603aba06 | |||
| 6d9a16e513 | |||
| 27c087d5d8 | |||
|
|
4d7fd519c5 | ||
| 06de7c7ab0 | |||
| e3c7547c75 | |||
| 314780d59a | |||
| 091787a6da | |||
| 7f278c6f63 | |||
| 8bfce9da00 | |||
| 480e7ac5bd | |||
| d0b303e745 | |||
| 5d485b3665 | |||
| 9787befd4a | |||
| 9119474e71 |
61
.env.example
61
.env.example
@@ -4,25 +4,64 @@ ENV=dev
|
|||||||
# ── Database ──────────────────────────────────────────────────────────────────
|
# ── Database ──────────────────────────────────────────────────────────────────
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
||||||
|
|
||||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
# ── Redis ─────────────────────────────────────────────────────────────────────
|
||||||
JWT_SECRET=replace-with-a-long-random-secret
|
REDIS_URL=redis://localhost:6379/0
|
||||||
JWT_ALGORITHM=HS256
|
|
||||||
|
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
|
||||||
|
# Generate keypair:
|
||||||
|
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||||
|
# openssl rsa -in private.pem -pubout -out public.pem
|
||||||
|
# Paste PEM content with literal \n for newlines.
|
||||||
|
#
|
||||||
|
# Private key — ONLY used by the Auth Service (JWT signing).
|
||||||
|
JWT_PRIVATE_KEY=
|
||||||
|
# Public key — used by all services / Traefik ForwardAuth (JWT verification).
|
||||||
|
JWT_PUBLIC_KEY=
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
# ── OpenAI ────────────────────────────────────────────────────────────────────
|
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||||
OPENAI_API_KEY=sk-...
|
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||||
|
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||||
|
OPENAI_API_KEY=
|
||||||
|
ANTHROPIC_API_KEY=
|
||||||
|
GOOGLE_API_KEY=
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
|
||||||
# ── Stripe ────────────────────────────────────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=sk_test_...
|
STRIPE_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=whsec_...
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
||||||
S3_BUCKET=adiuva-backups
|
S3_BUCKET=adiuva
|
||||||
S3_REGION=us-east-1
|
S3_REGION=us-east-1
|
||||||
AWS_ACCESS_KEY_ID=AKIA...
|
S3_ENDPOINT_URL=
|
||||||
AWS_SECRET_ACCESS_KEY=...
|
AWS_ACCESS_KEY_ID=
|
||||||
|
AWS_SECRET_ACCESS_KEY=
|
||||||
|
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
|
||||||
|
# ── Vector Store ──────────────────────────────────────────────────────────────
|
||||||
|
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
|
||||||
|
PINECONE_API_KEY=
|
||||||
|
PINECONE_INDEX=adiuva
|
||||||
|
QDRANT_URL=
|
||||||
|
QDRANT_API_KEY=
|
||||||
|
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
||||||
|
|
||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||||
|
|
||||||
|
# ── Langfuse (observability) ─────────────────────────────────────────────────
|
||||||
|
LANGFUSE_SECRET_KEY=sk-lf-...
|
||||||
|
LANGFUSE_PUBLIC_KEY=pk-lf-...
|
||||||
|
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL
|
||||||
|
|
||||||
|
# ── Cloudflare (Traefik ACME DNS-01 challenge) ───────────────────────────────
|
||||||
|
CF_DNS_API_TOKEN=
|
||||||
|
ACME_EMAIL=
|
||||||
|
|
||||||
|
# ── PostgreSQL (used by docker-compose) ──────────────────────────────────────
|
||||||
|
POSTGRES_USER=postgres
|
||||||
|
POSTGRES_PASSWORD=postgres
|
||||||
|
POSTGRES_DB=adiuva
|
||||||
@@ -1,21 +1,93 @@
|
|||||||
name: Deploy to Proxmox Docker
|
name: Test & Deploy API
|
||||||
run-name: Deploying ${{ gitea.sha }}
|
run-name: ${{ gitea.ref_name }} → Docker LXC
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
branches:
|
tags:
|
||||||
- main # O il nome del tuo branch principale
|
- 'v*'
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
Deploy:
|
# ── 1. Run tests in an isolated Python container ──────────────────
|
||||||
runs-on: ubuntu-latest # Questo dipende dalle label che hai dato al tuo act_runner
|
test:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
container:
|
||||||
|
image: python:3.12-slim
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: Deploying via SSH
|
- name: Install git
|
||||||
|
run: apt-get update && apt-get install -y --no-install-recommends git
|
||||||
|
|
||||||
|
- name: Checkout Code
|
||||||
|
run: |
|
||||||
|
git clone --depth 1 --branch "${GITHUB_REF_NAME}" \
|
||||||
|
"http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . || \
|
||||||
|
git clone --depth 1 "http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . && \
|
||||||
|
git checkout "${GITHUB_SHA}"
|
||||||
|
|
||||||
|
- name: Install Dependencies
|
||||||
|
run: pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
- name: Run Linter
|
||||||
|
run: ruff check app/ tests/
|
||||||
|
|
||||||
|
- name: Run Tests
|
||||||
|
run: pytest tests/ -v --tb=short
|
||||||
|
|
||||||
|
# ── 2. Deploy to Docker LXC via SSH ─────────────────────────────────
|
||||||
|
deploy:
|
||||||
|
needs: test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
if: gitea.event_name == 'push'
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Deploy via SSH
|
||||||
uses: appleboy/ssh-action@v1.0.0
|
uses: appleboy/ssh-action@v1.0.0
|
||||||
with:
|
with:
|
||||||
host: ${{ secrets.SSH_HOST }}
|
host: ${{ secrets.SSH_HOST }}
|
||||||
username: ${{ secrets.SSH_USER }}
|
username: ${{ secrets.SSH_USER }}
|
||||||
key: ${{ secrets.SSH_KEY }}
|
key: ${{ secrets.SSH_KEY }}
|
||||||
script: |
|
script: |
|
||||||
cd /opt/adiuva-api
|
set -e
|
||||||
git pull origin main
|
DEPLOY_DIR="/opt/adiuva-api"
|
||||||
|
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||||
|
TAG="${{ gitea.ref_name }}"
|
||||||
|
|
||||||
|
# ── Pull latest code ──
|
||||||
|
cd /tmp && rm -rf adiuva-api-deploy
|
||||||
|
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
|
||||||
|
|
||||||
|
# ── Sync source (preserve .env) ──
|
||||||
|
cp -rf /tmp/adiuva-api-deploy/app/ \
|
||||||
|
/tmp/adiuva-api-deploy/alembic/ \
|
||||||
|
/tmp/adiuva-api-deploy/alembic.ini \
|
||||||
|
/tmp/adiuva-api-deploy/Dockerfile \
|
||||||
|
/tmp/adiuva-api-deploy/docker-compose.yml \
|
||||||
|
/tmp/adiuva-api-deploy/requirements.txt \
|
||||||
|
"$DEPLOY_DIR/"
|
||||||
|
rm -rf /tmp/adiuva-api-deploy
|
||||||
|
|
||||||
|
# ── Verify .env ──
|
||||||
|
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||||
|
echo "❌ $DEPLOY_DIR/.env not found. Create it before deploying."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
# ── Build & restart ──
|
||||||
|
cd "$DEPLOY_DIR"
|
||||||
|
docker compose down --remove-orphans || true
|
||||||
docker compose up -d --build
|
docker compose up -d --build
|
||||||
|
|
||||||
|
# ── Migrations ──
|
||||||
|
docker compose exec -T app alembic upgrade head
|
||||||
|
|
||||||
|
# ── Health check ──
|
||||||
|
echo "Waiting for app..."
|
||||||
|
sleep 5
|
||||||
|
HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/api/v1/health)
|
||||||
|
if [ "$HTTP_CODE" -eq 200 ]; then
|
||||||
|
echo "✅ API is healthy (HTTP ${HTTP_CODE})"
|
||||||
|
else
|
||||||
|
echo "❌ Health check failed (HTTP ${HTTP_CODE})"
|
||||||
|
docker compose logs app --tail=50
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
64
.github/workflows/ci.yml
vendored
Normal file
64
.github/workflows/ci.yml
vendored
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
name: CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [main]
|
||||||
|
pull_request:
|
||||||
|
branches: [main]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
name: Lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Install ruff
|
||||||
|
run: pip install ruff>=0.8.0
|
||||||
|
|
||||||
|
- name: Ruff check
|
||||||
|
run: ruff check .
|
||||||
|
|
||||||
|
- name: Ruff format check
|
||||||
|
run: ruff format --check .
|
||||||
|
|
||||||
|
test:
|
||||||
|
name: Test
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: lint
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.12"
|
||||||
|
|
||||||
|
- name: Cache pip
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: ~/.cache/pip
|
||||||
|
key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }}
|
||||||
|
restore-keys: ${{ runner.os }}-pip-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: pip install -r requirements.txt
|
||||||
|
|
||||||
|
- name: Run tests
|
||||||
|
run: pytest -v --tb=short
|
||||||
|
|
||||||
|
docker:
|
||||||
|
name: Docker Build
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
needs: test
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Build image
|
||||||
|
run: docker build -t adiuva-api:ci .
|
||||||
|
|
||||||
|
- name: Verify gunicorn installed
|
||||||
|
run: docker run --rm adiuva-api:ci gunicorn --version
|
||||||
7
.gitignore
vendored
7
.gitignore
vendored
@@ -13,6 +13,9 @@ env/
|
|||||||
# Environment variables
|
# Environment variables
|
||||||
.env
|
.env
|
||||||
|
|
||||||
|
# Cryptographic keys
|
||||||
|
*.pem
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
@@ -31,3 +34,7 @@ Thumbs.db
|
|||||||
|
|
||||||
# Claude Code
|
# Claude Code
|
||||||
.claude/
|
.claude/
|
||||||
|
logs/
|
||||||
|
|
||||||
|
# Eval private test data
|
||||||
|
services/batch-agent/eval/fixtures/private_data/
|
||||||
|
|||||||
530
BACKEND_PLAN.md
530
BACKEND_PLAN.md
@@ -1,530 +0,0 @@
|
|||||||
# Backend Plan — Adiuva Cloud API
|
|
||||||
|
|
||||||
> **Separate repository.** This document defines the FastAPI backend that the Electron app communicates with.
|
|
||||||
>
|
|
||||||
> The backend owns: orchestration logic, chat agent intelligence, prompt IP, auth, billing, E2E backup blob storage, cloud storage (encrypted blobs), cloud vector store, and plugin marketplace.
|
|
||||||
> The backend NEVER persists user data in plaintext. Cloud storage blobs are E2E encrypted before upload — the backend only verifies integrity, never decrypts.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── app/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── main.py # FastAPI entry + CORS + lifespan + router includes
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── agent_registry.py # Base classes + singleton registry
|
|
||||||
│ │ ├── orchestrator.py # LLM-based intent router
|
|
||||||
│ │ ├── execution_plan.py # Plan builder + cache
|
|
||||||
│ │ └── plugin_loader.py # Dynamic agent loading
|
|
||||||
│ ├── agents/ # Chat agents (proprietary logic + prompts)
|
|
||||||
│ │ ├── __init__.py # Auto-registers all agents
|
|
||||||
│ │ ├── task_agent.py
|
|
||||||
│ │ ├── calendar_agent.py
|
|
||||||
│ │ ├── email_agent.py
|
|
||||||
│ │ └── analytics_agent.py
|
|
||||||
│ ├── api/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── routes/
|
|
||||||
│ │ │ ├── __init__.py
|
|
||||||
│ │ │ ├── chat.py # POST /chat + WS /chat/stream
|
|
||||||
│ │ │ ├── plans.py # GET /plans/playbook
|
|
||||||
│ │ │ ├── storage.py # CRUD cloud storage (E2E encrypted blobs)
|
|
||||||
│ │ │ ├── vectors.py # Upsert/search cloud vector store
|
|
||||||
│ │ │ ├── backup.py # PUT/GET /backup
|
|
||||||
│ │ │ ├── plugins.py # Plugin marketplace
|
|
||||||
│ │ │ ├── auth.py # Register/login/refresh
|
|
||||||
│ │ │ └── billing.py # Checkout/webhook/subscription
|
|
||||||
│ │ └── middleware/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── auth.py # JWT validation
|
|
||||||
│ │ ├── rate_limit.py # Tier-aware rate limiting
|
|
||||||
│ │ └── sanitizer.py # Strip prompt metadata from responses
|
|
||||||
│ ├── storage/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── blob_store.py # S3 for E2E encrypted blobs
|
|
||||||
│ │ ├── vector_store.py # Cloud vector store (Pinecone/Qdrant)
|
|
||||||
│ │ └── encryption.py # Integrity verification only — NO decryption
|
|
||||||
│ ├── marketplace/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── plugin_registry.py # Plugin catalog (metadata, versions, ratings)
|
|
||||||
│ │ ├── plugin_review.py # Review queue + approval workflow
|
|
||||||
│ │ └── revenue_share.py # 70/30 split tracking with Stripe Connect
|
|
||||||
│ ├── billing/
|
|
||||||
│ │ ├── __init__.py
|
|
||||||
│ │ ├── stripe_service.py # Stripe checkout + webhooks
|
|
||||||
│ │ └── tier_manager.py # Feature matrix per tier
|
|
||||||
│ └── config/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ └── settings.py # Pydantic BaseSettings (env-based)
|
|
||||||
├── tests/
|
|
||||||
│ ├── __init__.py
|
|
||||||
│ ├── conftest.py # Fixtures: test client, mock agents, mock LLM
|
|
||||||
│ ├── test_orchestrator.py
|
|
||||||
│ ├── test_agents.py
|
|
||||||
│ ├── test_auth.py
|
|
||||||
│ ├── test_backup.py
|
|
||||||
│ ├── test_storage.py
|
|
||||||
│ └── test_plugins.py
|
|
||||||
├── alembic/ # DB migrations (auth/billing/marketplace tables only)
|
|
||||||
│ ├── alembic.ini
|
|
||||||
│ └── versions/
|
|
||||||
├── requirements.txt
|
|
||||||
├── Dockerfile
|
|
||||||
├── docker-compose.yml # App + PostgreSQL + Redis (dev)
|
|
||||||
├── .env.example
|
|
||||||
└── README.md
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Step-by-Step Implementation
|
|
||||||
|
|
||||||
### Step 1 — Project scaffolding ✅
|
|
||||||
- [x] Initialize repo with the directory structure above
|
|
||||||
- [x] Write `requirements.txt`:
|
|
||||||
```
|
|
||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
langchain>=0.3.0
|
|
||||||
langchain-openai>=0.3.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
stripe>=11.0.0
|
|
||||||
boto3>=1.35.0
|
|
||||||
slowapi>=0.1.9
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
alembic>=1.14.0
|
|
||||||
bcrypt>=4.2.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
httpx>=0.28.0
|
|
||||||
websockets>=14.0
|
|
||||||
pytest>=8.0.0
|
|
||||||
pytest-asyncio>=0.24.0
|
|
||||||
```
|
|
||||||
- [x] Write `app/main.py`: FastAPI app with CORS (allow `app://`, `http://localhost:*`), lifespan (init DB pool, init agent registry), include all routers under `/api/v1`
|
|
||||||
- [x] Write `app/config/settings.py`: `Settings(BaseSettings)` with fields: `DATABASE_URL`, `JWT_SECRET`, `JWT_ALGORITHM` (default HS256), `STRIPE_SECRET_KEY`, `STRIPE_WEBHOOK_SECRET`, `S3_BUCKET`, `S3_REGION`, `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `OPENAI_API_KEY`, `CORS_ORIGINS`, `ENV` (dev/prod), `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] Write `Dockerfile`: Python 3.12 slim, multi-stage (builder + runtime), non-root user
|
|
||||||
- [x] Write `docker-compose.yml`: app, postgres:16, optional redis
|
|
||||||
- [x] Write `.env.example`
|
|
||||||
- **Outcome:** Runnable FastAPI skeleton (returns 404 on all routes).
|
|
||||||
|
|
||||||
### Step 2 — Pydantic schemas (API contracts) ✅
|
|
||||||
- [x] Create `app/schemas.py` (mirrors `src/shared/api-types.ts` from Electron repo):
|
|
||||||
- `ChatRequest`: `message: str`, `context: ChatContext`, `execution_mode: Literal['direct', 'plan']`
|
|
||||||
- `ChatContext`: `user_profile: dict`, `relevant_documents: list[str]`, `recent_tasks: list[dict]`, `conversation_history: list[dict]`
|
|
||||||
- `ChatResponse`: `response: str`, `actions: list[PlanAction]`
|
|
||||||
- `PlanAction`: `type: Literal['create_record', 'update_record', 'delete_record', 'index_document', 'send_notification', 'call_agent']`, `table: str | None`, `data: dict | None`, `agent: str | None`
|
|
||||||
- `ExecutionPlan`: `agent: str`, `steps: list[PlanStep]`
|
|
||||||
- `PlanStep`: `action: str`, `prompt_template: str | None`, `variables: dict | None`, `data_from_step: int | None`
|
|
||||||
- `BackupMetadata`: `version: int`, `timestamp: int`, `checksum: str`, `chunk_count: int`
|
|
||||||
- `BillingTier`: `Literal['free', 'pro', 'power', 'team']`
|
|
||||||
- `AuthTokens`: `access_token: str`, `refresh_token: str`, `expires_at: int`
|
|
||||||
- `UserProfile`: `id: str`, `email: str`, `tier: BillingTier`
|
|
||||||
- `StorageRecord`: `id: str`, `user_id: str`, `table: str`, `blob: bytes`, `checksum: str`, `created_at: int`, `updated_at: int` — blob is always E2E encrypted by client
|
|
||||||
- `StorageRecordCreate`: `table: str`, `blob: bytes`, `checksum: str`
|
|
||||||
- `StorageRecordUpdate`: `blob: bytes`, `checksum: str`
|
|
||||||
- `VectorUpsertRequest`: `vectors: list[VectorItem]`
|
|
||||||
- `VectorItem`: `id: str`, `blob: bytes`, `checksum: str` — vector + metadata encrypted by client
|
|
||||||
- `VectorSearchRequest`: `query_blob: bytes`, `top_k: int = 10`
|
|
||||||
- `VectorSearchResponse`: `results: list[VectorSearchResult]`
|
|
||||||
- `VectorSearchResult`: `id: str`, `score: float`, `blob: bytes`
|
|
||||||
- `PluginManifest`: `id: str`, `name: str`, `description: str`, `version: str`, `author: str`, `permissions: list[str]`, `category: str`, `price_cents: int = 0`
|
|
||||||
- `PluginListResponse`: `plugins: list[PluginManifest]`, `total: int`, `page: int`
|
|
||||||
- `PluginInstallRequest`: `plugin_id: str`
|
|
||||||
- **Outcome:** All request/response models defined and validated.
|
|
||||||
|
|
||||||
### Step 3 — Agent Registry + base classes ✅
|
|
||||||
- [x] `app/core/agent_registry.py`:
|
|
||||||
- `BaseAgent(ABC)`:
|
|
||||||
- `user_id: str`, `shared_memory: dict`, `vector_store_context: list[str]`, `skills: list[str]`
|
|
||||||
- Abstract `get_name() -> str`, `get_description() -> str`
|
|
||||||
- `ChatAgent(BaseAgent)`:
|
|
||||||
- Abstract `async handle(query: str, context: dict) -> str`
|
|
||||||
- Abstract `get_tools() -> list` (LangChain tool definitions)
|
|
||||||
- Concrete `_tool_loop(llm, messages, tools, max_iter=5) -> str` — shared tool-calling loop
|
|
||||||
- `AgentRegistry` (singleton):
|
|
||||||
- `_agents: dict[str, ChatAgent]`
|
|
||||||
- `register(agent_class)` — decorator pattern
|
|
||||||
- `get(name) -> ChatAgent`
|
|
||||||
- `list_agents() -> list[dict]` — returns `[{name, description}]` for orchestrator prompt
|
|
||||||
- `async call_agent(name, query, context) -> str` — for inter-agent calls
|
|
||||||
- [x] Unit tests: register, get, list, call_agent with mock
|
|
||||||
- **Outcome:** Pluggable agent framework.
|
|
||||||
|
|
||||||
### Step 4 — Orchestrator ✅
|
|
||||||
- [x] `app/core/orchestrator.py`:
|
|
||||||
- `async classify_intent(message, context, registry) -> str`:
|
|
||||||
- System prompt: "You are an intent classifier. Given the user message and context, decide which agent to route to. Available agents: {registry.list_agents()}. Respond with just the agent name."
|
|
||||||
- Uses gpt-4o-mini via LangChain for low latency
|
|
||||||
- Falls back to `task_agent` if no clear match
|
|
||||||
- `async route_single(agent_name, message, context) -> ChatResponse`:
|
|
||||||
- Instantiates agent from registry
|
|
||||||
- Calls `agent.handle(message, context)`
|
|
||||||
- Returns response + any actions the agent produced
|
|
||||||
- `async route_pipeline(agent_names, message, context) -> ChatResponse`:
|
|
||||||
- Executes agents in sequence
|
|
||||||
- Each agent receives `{...context, previous_results: [...]}`
|
|
||||||
- Final synthesis via LLM: "Summarize these agent results into a coherent response"
|
|
||||||
- `async orchestrate(request: ChatRequest) -> ChatResponse | ExecutionPlan`:
|
|
||||||
- Main entry point
|
|
||||||
- Context is transparent to orchestrator — data may originate from local or cloud storage on the client side
|
|
||||||
- Classifies intent
|
|
||||||
- If `execution_mode == 'direct'`: route + return response
|
|
||||||
- If `execution_mode == 'plan'`: route + return execution plan with template IDs
|
|
||||||
- `async orchestrate_stream(request: ChatRequest) -> AsyncGenerator[str, None]`:
|
|
||||||
- Same as orchestrate but yields tokens for WebSocket streaming
|
|
||||||
- [x] Integration tests with mocked LLM and mocked agents
|
|
||||||
- **Outcome:** Intelligent routing with single-agent and pipeline modes.
|
|
||||||
|
|
||||||
### Step 5 — Execution Plan generator ✅
|
|
||||||
- [x] `app/core/execution_plan.py`:
|
|
||||||
- `PromptTemplateRegistry`: dict of `template_id -> prompt_text`. Templates are server-side only — client receives IDs.
|
|
||||||
- `ExecutionPlanBuilder`:
|
|
||||||
- `add_step(action, params) -> self`
|
|
||||||
- `add_llm_step(template_id, variables) -> self`
|
|
||||||
- `add_data_step(action, data_from_step) -> self`
|
|
||||||
- `build() -> ExecutionPlan` — validates step references
|
|
||||||
- `PlanCache`:
|
|
||||||
- In-memory LRU (maxsize=1000)
|
|
||||||
- `cache_plan(key, plan)`, `get_plan(key)`, `get_all_playbooks() -> list[ExecutionPlan]`
|
|
||||||
- Playbooks are pre-built plans for common operations (e.g., "create task from email", "generate weekly report")
|
|
||||||
- **Outcome:** Plans are cacheable as playbooks. Prompt IP never leaves the server.
|
|
||||||
|
|
||||||
### Step 6 — Chat Agents ✅
|
|
||||||
- [x] `app/agents/task_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages tasks and comments: list, create, update, delete, due-today, comments"
|
|
||||||
- Tools (8): `list_tasks(project_id, status, search, order_by)`, `create_task(title, description, status, priority, assignees, due_date, project_id, is_ai_suggested, is_approved)`, `update_task(task_id, ...)`, `delete_task(task_id)`, `list_tasks_due_today()`, `list_task_comments(task_id)`, `add_task_comment(task_id, author, content)`, `delete_task_comment(comment_id)`
|
|
||||||
- status: `todo|in_progress|done`; priority: `high|medium|low`; assignees: JSON-encoded string; due_date: ms timestamp
|
|
||||||
- Accepts flexible context; sentinel `-1` for optional integer update fields
|
|
||||||
- [x] `app/agents/checkpoint_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages project checkpoints (milestones): list, create, update, delete"
|
|
||||||
- Tools (4): `list_checkpoints(project_id)`, `create_checkpoint(project_id, title, date, is_ai_suggested, is_approved)`, `update_checkpoint(checkpoint_id, ...)`, `delete_checkpoint(checkpoint_id)`
|
|
||||||
- `project_id` is required for create; date is a ms timestamp; supports AI-suggestion + approval workflow
|
|
||||||
- [x] `app/agents/project_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages projects: list, get, create, update, archive, delete"
|
|
||||||
- Tools (6): `list_projects(client_id, include_archived)`, `list_all_projects()`, `get_project(project_id)`, `create_project(name, client_id)`, `update_project(project_id, ...)`, `delete_project(project_id)`
|
|
||||||
- status: `active|archived`; prefers archive over deletion (docstring guard on delete)
|
|
||||||
- [x] `app/agents/note_agent.py` — `@registry.register`:
|
|
||||||
- Description: "Manages notes: list, get, create, update, delete"
|
|
||||||
- Tools (5): `list_notes(project_id)`, `get_note(note_id)`, `create_note(title, content, project_id)`, `update_note(note_id, ...)`, `delete_note(note_id)`
|
|
||||||
- content is Markdown; `get_note` should be called before update to preserve existing content
|
|
||||||
- [x] `app/agents/__init__.py`: imports all four agent modules to trigger `@registry.register` decorators
|
|
||||||
- [x] Unit tests per agent with mocked LLM (registration, names, tool counts, handle(), direct tool invocation)
|
|
||||||
- **Outcome:** Four domain-specific agents matching the UI data model (Tasks, Checkpoints, Projects, Notes), all registered and tested.
|
|
||||||
|
|
||||||
### Step 7 — Storage Layer ✅
|
|
||||||
- [x] `app/storage/blob_store.py`:
|
|
||||||
- `BlobStore`: `async upload`, `async download`, `async delete` (idempotent), `async list_keys`
|
|
||||||
- Keys: `{user_id}/{table}/{record_id}` — backend never inspects blob content
|
|
||||||
- boto3 S3 with SSE-S3 at-rest encryption; client checksum stored in S3 object metadata
|
|
||||||
- [x] `app/storage/vector_store.py`:
|
|
||||||
- `VectorStore`: `async upsert`, `async search`, `async delete`
|
|
||||||
- Pinecone (default, `namespace=user_id`) or Qdrant (`user_id` payload filter) — runtime-configurable
|
|
||||||
- 32-dim SHA-256-derived float vector; blob stored as base64 in metadata/payload
|
|
||||||
- ANN on encrypted data: known accuracy trade-off, documented
|
|
||||||
- [x] `app/storage/encryption.py`:
|
|
||||||
- `verify_checksum(blob, checksum) -> bool` — SHA-256 + `hmac.compare_digest` (constant-time)
|
|
||||||
- `reject_if_tampered(blob, checksum)` — raises `HTTP 400` on mismatch
|
|
||||||
- Backend NEVER holds decryption keys
|
|
||||||
- [x] `app/schemas.py`: added `StorageRecord*`, `VectorItem`, `VectorUpsertRequest`, `VectorSearch*`, `Plugin*` schemas
|
|
||||||
- [x] `app/config/settings.py`: added `PINECONE_API_KEY`, `PINECONE_INDEX`, `QDRANT_URL`, `QDRANT_API_KEY`
|
|
||||||
- [x] `requirements.txt`: added `moto[s3]`, `pinecone`, `qdrant-client`
|
|
||||||
- [x] 37 unit tests covering encryption, BlobStore (moto), VectorStore Pinecone, VectorStore Qdrant
|
|
||||||
- **Outcome:** Cloud storage layer that handles E2E encrypted blobs without ever accessing plaintext.
|
|
||||||
|
|
||||||
### Step 8 — API Routes ✅
|
|
||||||
|
|
||||||
#### 8a — Chat endpoint
|
|
||||||
- [x] `app/api/routes/chat.py`:
|
|
||||||
- `POST /api/v1/chat`:
|
|
||||||
- Request: `ChatRequest`
|
|
||||||
- Calls `orchestrate(request)` or `orchestrate()` + `build_plan()`
|
|
||||||
- Response: `ChatResponse` or `ExecutionPlan`
|
|
||||||
- `WebSocket /api/v1/chat/stream`:
|
|
||||||
- Client sends `ChatRequest` as first JSON frame
|
|
||||||
- Server yields token strings via `orchestrate_stream()`
|
|
||||||
- Final frame: JSON `ChatResponse` with `{"done": true, "response": "...", "actions": [...]}`
|
|
||||||
- Heartbeat ping every 30s to keep connection alive
|
|
||||||
|
|
||||||
#### 8b — Plans endpoint
|
|
||||||
- [x] `app/api/routes/plans.py`:
|
|
||||||
- `GET /api/v1/plans/playbook`: Returns all playbooks available for the user's tier
|
|
||||||
- `GET /api/v1/plans/playbook/{plan_id}`: Returns a specific plan
|
|
||||||
|
|
||||||
#### 8c — Storage endpoint (cloud records)
|
|
||||||
- [x] `app/api/routes/storage.py`:
|
|
||||||
- `POST /api/v1/storage/records`: Create encrypted record
|
|
||||||
- Request: `StorageRecordCreate`
|
|
||||||
- Verifies checksum, stores blob in S3, inserts metadata row in PostgreSQL
|
|
||||||
- Response: `{id: str, created_at: int}`
|
|
||||||
- `GET /api/v1/storage/records`: List record metadata (no blobs)
|
|
||||||
- Query params: `table: str`, `page: int`, `limit: int`
|
|
||||||
- Response: `list[{id, table, checksum, created_at, updated_at}]`
|
|
||||||
- `GET /api/v1/storage/records/{id}`: Download encrypted blob
|
|
||||||
- Response: blob bytes + `X-Checksum` header
|
|
||||||
- `PUT /api/v1/storage/records/{id}`: Update encrypted blob
|
|
||||||
- Request: `StorageRecordUpdate`
|
|
||||||
- `DELETE /api/v1/storage/records/{id}`: Delete record + S3 blob
|
|
||||||
- All routes enforce tier cloud_storage_gb quota via `TierManager.check_quota(user_id)`
|
|
||||||
|
|
||||||
#### 8d — Vectors endpoint (cloud vector store)
|
|
||||||
- [x] `app/api/routes/vectors.py`:
|
|
||||||
- `POST /api/v1/storage/vectors/upsert`:
|
|
||||||
- Request: `VectorUpsertRequest`
|
|
||||||
- Verifies checksums, delegates to `VectorStore.upsert()`
|
|
||||||
- Response: `{upserted: int}`
|
|
||||||
- `POST /api/v1/storage/vectors/search`:
|
|
||||||
- Request: `VectorSearchRequest`
|
|
||||||
- Delegates to `VectorStore.search()`
|
|
||||||
- Response: `VectorSearchResponse`
|
|
||||||
- `DELETE /api/v1/storage/vectors`:
|
|
||||||
- Request: `{ids: list[str]}`
|
|
||||||
|
|
||||||
#### 8e — Backup endpoint
|
|
||||||
- [x] `app/api/routes/backup.py`:
|
|
||||||
- `PUT /api/v1/backup`: Accepts binary blob + metadata headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Stores in S3 keyed by `{user_id}/{timestamp}`. Enforces tier limits:
|
|
||||||
- Free: 0 (no backup)
|
|
||||||
- Pro: 5 GB
|
|
||||||
- Power: 25 GB
|
|
||||||
- Team: unlimited
|
|
||||||
- `GET /api/v1/backup`: Returns latest blob for authenticated user. Supports `If-Modified-Since`.
|
|
||||||
- `GET /api/v1/backup/history`: Returns list of `BackupMetadata` (no blobs).
|
|
||||||
- `DELETE /api/v1/backup/{backup_id}`: Delete specific backup.
|
|
||||||
|
|
||||||
#### 8f — Plugins endpoint
|
|
||||||
- [x] `app/api/routes/plugins.py`:
|
|
||||||
- `GET /api/v1/plugins`:
|
|
||||||
- Query params: `category: str | None`, `q: str | None`, `page: int`, `sort: Literal['rating', 'installs', 'newest']`
|
|
||||||
- Response: `PluginListResponse`
|
|
||||||
- Available from Power tier and above
|
|
||||||
- `GET /api/v1/plugins/{id}`:
|
|
||||||
- Response: `PluginManifest` + ratings + install count
|
|
||||||
- `POST /api/v1/plugins/{id}/install`:
|
|
||||||
- Request: `PluginInstallRequest`
|
|
||||||
- Records installation for the user (billing tracking, analytics)
|
|
||||||
- If plugin is paid: triggers Stripe Connect charge + revenue split (70% developer, 30% platform)
|
|
||||||
- Response: `{ok: true, download_url: str}` — signed S3 URL for plugin package
|
|
||||||
- `DELETE /api/v1/plugins/{id}/install`:
|
|
||||||
- Unregisters installation
|
|
||||||
|
|
||||||
#### 8g — Auth endpoint
|
|
||||||
- [x] `app/api/routes/auth.py`:
|
|
||||||
- `POST /api/v1/auth/register`: `{email, password}` → bcrypt hash → insert user → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/login`: Validate credentials → return `AuthTokens`
|
|
||||||
- `POST /api/v1/auth/refresh`: Rotate refresh token → return new `AuthTokens`
|
|
||||||
- `GET /api/v1/auth/me`: Return `UserProfile` for current JWT
|
|
||||||
|
|
||||||
#### 8h — Billing endpoint
|
|
||||||
- [x] `app/api/routes/billing.py`:
|
|
||||||
- `POST /api/v1/billing/checkout`: Creates Stripe checkout session → returns URL
|
|
||||||
- `POST /api/v1/billing/webhook`: Handles Stripe webhooks (subscription lifecycle)
|
|
||||||
- `GET /api/v1/billing/subscription`: Returns current subscription info
|
|
||||||
- `DELETE /api/v1/billing/subscription`: Cancels subscription
|
|
||||||
|
|
||||||
- **Outcome:** Complete REST + WebSocket API covering orchestration, storage, vectors, backup, marketplace.
|
|
||||||
|
|
||||||
### Step 9 — Middleware
|
|
||||||
|
|
||||||
#### 9a — Auth middleware
|
|
||||||
- [x] `app/api/middleware/auth.py`:
|
|
||||||
- FastAPI dependency: `get_current_user(token: str = Depends(oauth2_scheme)) -> UserProfile`
|
|
||||||
- Validates JWT signature, expiry, extracts `user_id` and `tier`
|
|
||||||
- Raises `401` on invalid/expired token
|
|
||||||
- Exempt routes: `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
|
||||||
|
|
||||||
#### 9b — Rate limiter
|
|
||||||
- [x] `app/api/middleware/rate_limit.py`:
|
|
||||||
- Uses `slowapi` with `Limiter(key_func=get_user_id_from_jwt)`
|
|
||||||
- Tier-based limits:
|
|
||||||
- Free: 20 req/min
|
|
||||||
- Pro: 60 req/min
|
|
||||||
- Power: 120 req/min
|
|
||||||
- Team: 200 req/seat/min
|
|
||||||
- Custom 429 response with `Retry-After` header
|
|
||||||
|
|
||||||
#### 9c — Sanitizer
|
|
||||||
- [x] `app/api/middleware/sanitizer.py`:
|
|
||||||
- Response middleware that scans response bodies
|
|
||||||
- Strips: system prompt fragments, agent internal reasoning, tool schemas, routing metadata
|
|
||||||
- Pattern-based detection + exact match against known prompt fingerprints
|
|
||||||
- Logs sanitization events for monitoring
|
|
||||||
|
|
||||||
- **Outcome:** Secure, rate-limited API with prompt IP protection.
|
|
||||||
|
|
||||||
### Step 10 — Plugin Marketplace ✅
|
|
||||||
- [x] `app/marketplace/plugin_registry.py`:
|
|
||||||
- `PluginRegistry`:
|
|
||||||
- `async list_plugins(category, query, page, sort) -> PluginListResponse`
|
|
||||||
- `async get_plugin(plugin_id) -> PluginManifest | None`
|
|
||||||
- `async submit_plugin(manifest: PluginManifest, package_s3_key: str) -> str` — returns plugin_id, sets status = 'pending_review'
|
|
||||||
- `async approve_plugin(plugin_id) -> None` — admin only, sets status = 'approved'
|
|
||||||
- `async reject_plugin(plugin_id, reason: str) -> None`
|
|
||||||
- [x] `app/marketplace/plugin_review.py`:
|
|
||||||
- `ReviewQueue`:
|
|
||||||
- `async get_pending() -> list[dict]`
|
|
||||||
- `async submit_review(plugin_id, reviewer_id, decision, notes) -> None`
|
|
||||||
- Security checklist enforced before approval: manifest schema valid, permissions are from allowed set, no binary blobs in manifest
|
|
||||||
- [x] `app/marketplace/revenue_share.py`:
|
|
||||||
- `RevenueShare`:
|
|
||||||
- `async record_install(plugin_id, user_id, amount_cents) -> None`
|
|
||||||
- `async payout_developer(plugin_id, period) -> None` — Stripe Connect transfer: 70% to developer
|
|
||||||
- `async get_earnings(developer_id, period) -> dict`
|
|
||||||
- **Outcome:** Plugin marketplace with catalog, review workflow, and revenue split.
|
|
||||||
|
|
||||||
### Step 11 — Billing & Tier management
|
|
||||||
- [ ] `app/billing/stripe_service.py`:
|
|
||||||
- `create_checkout_session(user_id, tier) -> str`
|
|
||||||
- `handle_webhook(payload, sig_header) -> None`: processes `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed`
|
|
||||||
- `get_subscription(user_id) -> dict | None`
|
|
||||||
- `cancel_subscription(user_id) -> None`
|
|
||||||
- [ ] `app/billing/tier_manager.py`:
|
|
||||||
- `TierManager`:
|
|
||||||
- Feature matrix:
|
|
||||||
```python
|
|
||||||
FEATURES = {
|
|
||||||
'free': {
|
|
||||||
'agents': 3,
|
|
||||||
'batch_active': 2,
|
|
||||||
'cloud_storage_gb': 0,
|
|
||||||
'backup_gb': 0,
|
|
||||||
'providers': 1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'pro': {
|
|
||||||
'agents': -1, # unlimited
|
|
||||||
'batch_active': 10,
|
|
||||||
'cloud_storage_gb': 5,
|
|
||||||
'backup_gb': 5,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': False,
|
|
||||||
'plugin_marketplace': False,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'power': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1, # unlimited
|
|
||||||
'cloud_storage_gb': 25,
|
|
||||||
'backup_gb': 25,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': False,
|
|
||||||
},
|
|
||||||
'team': {
|
|
||||||
'agents': -1,
|
|
||||||
'batch_active': -1,
|
|
||||||
'cloud_storage_gb': -1,
|
|
||||||
'backup_gb': -1,
|
|
||||||
'providers': -1,
|
|
||||||
'batch_builder': True,
|
|
||||||
'plugin_marketplace': True,
|
|
||||||
'sso': True,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
```
|
|
||||||
- `get_tier(user_id) -> BillingTier`
|
|
||||||
- `check_feature(user_id, feature) -> bool`
|
|
||||||
- `get_rate_limit(tier) -> int`
|
|
||||||
- `check_quota(user_id) -> bool` — checks cloud_storage_gb current usage vs limit
|
|
||||||
- **Outcome:** Stripe integration with tier-based feature gating matching Free/Pro(15€)/Power(29€)/Team(49€/seat).
|
|
||||||
|
|
||||||
### Step 12 — Database (auth/billing/marketplace only)
|
|
||||||
- [ ] PostgreSQL schema via Alembic:
|
|
||||||
- `users`: `id UUID PK`, `email UNIQUE`, `password_hash`, `tier` (default 'free'), `stripe_customer_id`, `created_at`, `updated_at`
|
|
||||||
- `refresh_tokens`: `id UUID PK`, `user_id FK`, `token_hash`, `expires_at`, `created_at`
|
|
||||||
- `subscriptions`: `id UUID PK`, `user_id FK`, `stripe_subscription_id`, `tier`, `status`, `current_period_end`, `created_at`
|
|
||||||
- `backup_metadata`: `id UUID PK`, `user_id FK`, `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes`, `created_at`
|
|
||||||
- `storage_records`: `id UUID PK`, `user_id FK`, `table_name VARCHAR`, `s3_key`, `checksum`, `size_bytes`, `created_at`, `updated_at` — metadata only, no plaintext
|
|
||||||
- `plugins`: `id UUID PK`, `name`, `description`, `version`, `author_id FK`, `category`, `status` (pending_review/approved/rejected), `price_cents`, `s3_package_key`, `install_count`, `avg_rating`, `created_at`
|
|
||||||
- `plugin_installations`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `installed_at`
|
|
||||||
- `plugin_reviews`: `id UUID PK`, `plugin_id FK`, `reviewer_id FK`, `decision`, `notes`, `reviewed_at`
|
|
||||||
- `revenue_events`: `id UUID PK`, `plugin_id FK`, `user_id FK`, `amount_cents`, `developer_share_cents`, `stripe_transfer_id`, `created_at`
|
|
||||||
- [ ] Initial Alembic migration
|
|
||||||
- [ ] SQLAlchemy models in `app/models.py`
|
|
||||||
- **Outcome:** Auth, billing, storage metadata, and marketplace persistence. Zero user data in plaintext.
|
|
||||||
|
|
||||||
### Step 13 — Testing & deployment
|
|
||||||
- [ ] `tests/conftest.py`: TestClient fixture, mock LLM fixture (`AsyncMock` returning canned responses), mock agent fixture, test DB (SQLite in-memory for speed), mock S3 (moto), mock Pinecone
|
|
||||||
- [ ] `tests/test_orchestrator.py`: classify_intent routing, single agent, pipeline, plan mode
|
|
||||||
- [ ] `tests/test_agents.py`: each agent with mocked tools
|
|
||||||
- [ ] `tests/test_auth.py`: register → login → access protected → refresh → expired token
|
|
||||||
- [ ] `tests/test_backup.py`: upload → download → history → delete, tier limit enforcement
|
|
||||||
- [ ] `tests/test_storage.py`: create record → list → download → update → delete, checksum rejection, quota enforcement
|
|
||||||
- [ ] `tests/test_plugins.py`: list plugins, install, uninstall, revenue event creation, tier gate (free user blocked)
|
|
||||||
- [ ] `Dockerfile` optimized for production (gunicorn + uvicorn workers)
|
|
||||||
- [ ] GitHub Actions CI: lint (ruff), test (pytest), build Docker image
|
|
||||||
- **Outcome:** Fully tested, deployable backend.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API Contract Summary
|
|
||||||
|
|
||||||
| Method | Endpoint | Auth | Request | Response |
|
|
||||||
|--------|----------|------|---------|----------|
|
|
||||||
| POST | `/api/v1/auth/register` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/login` | No | `{email, password}` | `AuthTokens` |
|
|
||||||
| POST | `/api/v1/auth/refresh` | No | `{refresh_token}` | `AuthTokens` |
|
|
||||||
| GET | `/api/v1/auth/me` | JWT | — | `UserProfile` |
|
|
||||||
| POST | `/api/v1/chat` | JWT | `ChatRequest` | `ChatResponse \| ExecutionPlan` |
|
|
||||||
| WS | `/api/v1/chat/stream` | JWT | `ChatRequest` (first frame) | Token stream + final JSON |
|
|
||||||
| GET | `/api/v1/plans/playbook` | JWT | — | `ExecutionPlan[]` |
|
|
||||||
| GET | `/api/v1/plans/playbook/:id` | JWT | — | `ExecutionPlan` |
|
|
||||||
| POST | `/api/v1/storage/records` | JWT | `StorageRecordCreate` | `{id, created_at}` |
|
|
||||||
| GET | `/api/v1/storage/records` | JWT | `?table&page&limit` | `RecordMeta[]` |
|
|
||||||
| GET | `/api/v1/storage/records/:id` | JWT | — | Binary blob |
|
|
||||||
| PUT | `/api/v1/storage/records/:id` | JWT | `StorageRecordUpdate` | `{ok: true}` |
|
|
||||||
| DELETE | `/api/v1/storage/records/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/upsert` | JWT | `VectorUpsertRequest` | `{upserted: int}` |
|
|
||||||
| POST | `/api/v1/storage/vectors/search` | JWT | `VectorSearchRequest` | `VectorSearchResponse` |
|
|
||||||
| DELETE | `/api/v1/storage/vectors` | JWT | `{ids: list[str]}` | `{ok: true}` |
|
|
||||||
| PUT | `/api/v1/backup` | JWT | Binary blob + headers | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/backup` | JWT | — | Binary blob |
|
|
||||||
| GET | `/api/v1/backup/history` | JWT | — | `BackupMetadata[]` |
|
|
||||||
| DELETE | `/api/v1/backup/:id` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/plugins` | JWT | `?category&q&page&sort` | `PluginListResponse` |
|
|
||||||
| GET | `/api/v1/plugins/:id` | JWT | — | `PluginManifest` + stats |
|
|
||||||
| POST | `/api/v1/plugins/:id/install` | JWT | `PluginInstallRequest` | `{ok, download_url}` |
|
|
||||||
| DELETE | `/api/v1/plugins/:id/install` | JWT | — | `{ok: true}` |
|
|
||||||
| POST | `/api/v1/billing/checkout` | JWT | `{tier}` | `{checkout_url}` |
|
|
||||||
| POST | `/api/v1/billing/webhook` | Stripe sig | Stripe event | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/billing/subscription` | JWT | — | Subscription info |
|
|
||||||
| DELETE | `/api/v1/billing/subscription` | JWT | — | `{ok: true}` |
|
|
||||||
| GET | `/api/v1/health` | No | — | `{status, version}` |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Stack
|
|
||||||
|
|
||||||
| Layer | Technology |
|
|
||||||
|-------|-----------|
|
|
||||||
| Framework | FastAPI + Uvicorn |
|
|
||||||
| LLM | LangChain + langchain-openai |
|
|
||||||
| Auth | PyJWT + bcrypt + OAuth2 |
|
|
||||||
| Billing | stripe-python + Stripe Connect |
|
|
||||||
| Blob storage | boto3 (S3) |
|
|
||||||
| Vector store | Pinecone or Qdrant (configurable) |
|
|
||||||
| Database | PostgreSQL + SQLAlchemy + Alembic |
|
|
||||||
| Rate limiting | slowapi |
|
|
||||||
| Testing | pytest + pytest-asyncio + httpx + moto (S3 mock) |
|
|
||||||
| Deployment | Docker → fly.io / Railway / AWS ECS |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Development Rules
|
|
||||||
|
|
||||||
1. **NEVER persist user data in plaintext.** The DB stores only auth, billing, storage metadata, and marketplace data. User context arrives in requests and is discarded. Cloud blobs are E2E encrypted client-side — backend only stores opaque bytes.
|
|
||||||
2. **NEVER expose prompts.** System prompts are composed server-side from fragments. Responses are sanitized before sending. In plan mode, `prompt_template` fields are reference IDs only.
|
|
||||||
3. **NEVER decrypt user blobs.** `app/storage/encryption.py` only verifies checksums. No decryption key ever reaches the backend.
|
|
||||||
4. **Stateless request handling.** No server-side session state. All context comes from the client + JWT.
|
|
||||||
5. **Type hints everywhere.** All functions have full type annotations.
|
|
||||||
6. **Test every agent.** Each chat agent has unit tests with mocked LLM responses.
|
|
||||||
7. **Structured logging.** JSON logs with request ID correlation.
|
|
||||||
8. **Tier gates are enforced server-side.** Never trust client-reported tier. Always fetch from DB via `TierManager.get_tier(user_id)`.
|
|
||||||
9. **One step at a time.** Implement one numbered step per session. When the step is fully done, mark all its checkboxes as `[x]` in this file and commit with message `step N complete: <outcome line>`.
|
|
||||||
793
README.md
Normal file
793
README.md
Normal file
@@ -0,0 +1,793 @@
|
|||||||
|
# Adiuva Cloud API
|
||||||
|
|
||||||
|
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
|
||||||
|
|
||||||
|
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [Architecture](#architecture)
|
||||||
|
- [Key Features](#key-features)
|
||||||
|
- [Tech Stack](#tech-stack)
|
||||||
|
- [Getting Started](#getting-started)
|
||||||
|
- [Docker Deployment](#docker-deployment)
|
||||||
|
- [Environment Variables](#environment-variables)
|
||||||
|
- [API Reference](#api-reference)
|
||||||
|
- [Data Model](#data-model)
|
||||||
|
- [AI Agent System](#ai-agent-system)
|
||||||
|
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
||||||
|
- [Middleware](#middleware)
|
||||||
|
- [Storage Layer](#storage-layer)
|
||||||
|
- [Billing & Tiers](#billing--tiers)
|
||||||
|
- [Plugin Marketplace](#plugin-marketplace)
|
||||||
|
- [Testing](#testing)
|
||||||
|
- [Project Structure](#project-structure)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
|
||||||
|
|
||||||
|
### Design Principles
|
||||||
|
|
||||||
|
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
|
||||||
|
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
||||||
|
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
|
||||||
|
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
||||||
|
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
|
||||||
|
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
|
||||||
|
│ Desktop App │────▶│ │
|
||||||
|
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
|
||||||
|
└──────────────┘ │ │
|
||||||
|
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
||||||
|
│ │ Auth Routes │ │ Chat Routes │ │
|
||||||
|
│ │ Billing Routes │ │ ↓ │ │
|
||||||
|
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
||||||
|
│ │ Backup Routes │ │ ↓ classify intent │ │
|
||||||
|
│ │ Plugin Routes │ │ Agent Registry │ │
|
||||||
|
│ │ Vector Routes │ │ ↓ │ │
|
||||||
|
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
|
||||||
|
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
|
||||||
|
│ │ (GPT-4o + LangChain) │ │
|
||||||
|
│ └────────────────────────────┘ │
|
||||||
|
└────────────────────────────────────────────────────────┘
|
||||||
|
│ │ │
|
||||||
|
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
|
||||||
|
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
|
||||||
|
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
|
||||||
|
│ Billing, │ │ backups) │ │ (Vectors) │
|
||||||
|
│ Metadata) │ └───────────────┘ └────────────────┘
|
||||||
|
└────────────┘
|
||||||
|
│
|
||||||
|
┌────────▼───┐
|
||||||
|
│ Stripe │
|
||||||
|
│ (Billing, │
|
||||||
|
│ Connect) │
|
||||||
|
└────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
||||||
|
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
||||||
|
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
||||||
|
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
||||||
|
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
||||||
|
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
|
||||||
|
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
|
||||||
|
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
||||||
|
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
||||||
|
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
||||||
|
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
||||||
|
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
|
||||||
|
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
||||||
|
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
|
||||||
|
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
|
||||||
|
| Package | Version | Purpose |
|
||||||
|
|---|---|---|
|
||||||
|
| `fastapi` | ≥ 0.115.0 | Web framework |
|
||||||
|
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
|
||||||
|
| `gunicorn` | ≥ 22.0.0 | Production process manager |
|
||||||
|
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
|
||||||
|
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
|
||||||
|
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
|
||||||
|
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
|
||||||
|
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
||||||
|
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
||||||
|
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
||||||
|
| `boto3` | ≥ 1.35.0 | AWS S3 client |
|
||||||
|
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
||||||
|
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
||||||
|
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
||||||
|
| `alembic` | ≥ 1.14.0 | Database migration management |
|
||||||
|
| `bcrypt` | ≥ 4.2.0 | Password hashing |
|
||||||
|
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
|
||||||
|
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
||||||
|
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
||||||
|
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
||||||
|
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
|
||||||
|
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
|
||||||
|
| `pytest` | ≥ 8.0.0 | Test framework |
|
||||||
|
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
||||||
|
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
||||||
|
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
|
||||||
|
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.12+
|
||||||
|
- PostgreSQL 16+
|
||||||
|
- An OpenAI API key (for LLM features)
|
||||||
|
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
||||||
|
- AWS credentials (optional — needed for S3 storage in production)
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone the repository
|
||||||
|
git clone <repo-url> && cd adiuva-api
|
||||||
|
|
||||||
|
# Create a virtual environment
|
||||||
|
python -m venv .venv && source .venv/bin/activate
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Configure environment
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start PostgreSQL (or use the Docker Compose database)
|
||||||
|
docker compose up db -d
|
||||||
|
|
||||||
|
# Run migrations
|
||||||
|
alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run the Development Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Docker Deployment
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts two services:
|
||||||
|
|
||||||
|
- **app** — FastAPI server on port `8000`
|
||||||
|
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
||||||
|
|
||||||
|
The compose file also includes optional services for fully local deployments:
|
||||||
|
|
||||||
|
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
|
||||||
|
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
|
||||||
|
|
||||||
|
### Dockerfile Details
|
||||||
|
|
||||||
|
The Dockerfile uses a multi-stage build:
|
||||||
|
|
||||||
|
1. **Builder stage** — Installs Python dependencies into a virtual environment.
|
||||||
|
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
|
||||||
|
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Production command (run by the container)
|
||||||
|
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Homelab / Self-Hosted Deployment
|
||||||
|
|
||||||
|
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
|
||||||
|
|
||||||
|
### 1. Start all services
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
|
||||||
|
|
||||||
|
### 2. Create the MinIO bucket
|
||||||
|
|
||||||
|
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
|
||||||
|
docker compose exec minio mc mb local/adiuva
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Configure your `.env`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Database (uses the compose PostgreSQL)
|
||||||
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
|
||||||
|
# S3 → MinIO
|
||||||
|
S3_BUCKET=adiuva
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
AWS_ACCESS_KEY_ID=minioadmin
|
||||||
|
AWS_SECRET_ACCESS_KEY=minioadmin
|
||||||
|
|
||||||
|
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
|
||||||
|
QDRANT_URL=http://qdrant:6333
|
||||||
|
QDRANT_API_KEY=
|
||||||
|
PINECONE_API_KEY=
|
||||||
|
|
||||||
|
# Billing — leave empty to stub (no Stripe needed)
|
||||||
|
STRIPE_SECRET_KEY=
|
||||||
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
|
# LLM — the only external service
|
||||||
|
OPENAI_API_KEY=sk-...
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# Auth
|
||||||
|
JWT_SECRET=your-secret-here
|
||||||
|
ENV=dev
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Run migrations
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose exec app alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### What runs where
|
||||||
|
|
||||||
|
| Service | Runs on | Port | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| FastAPI app | Docker | 8000 | API server |
|
||||||
|
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
|
||||||
|
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
|
||||||
|
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
|
||||||
|
| Stripe | — | — | Stubbed when keys are empty |
|
||||||
|
| OpenAI / LLM | Cloud | — | Only external dependency |
|
||||||
|
|
||||||
|
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
|
||||||
|
|
||||||
|
| Variable | Type | Default | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
|
||||||
|
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
|
||||||
|
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
|
||||||
|
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
||||||
|
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
||||||
|
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
||||||
|
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
|
||||||
|
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
|
||||||
|
| `S3_REGION` | `str` | `us-east-1` | AWS region |
|
||||||
|
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
|
||||||
|
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
|
||||||
|
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
|
||||||
|
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
|
||||||
|
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
|
||||||
|
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
|
||||||
|
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
|
||||||
|
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
|
||||||
|
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
||||||
|
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
||||||
|
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
||||||
|
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
|
||||||
|
|
||||||
|
### Health
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
|
||||||
|
|
||||||
|
### Auth
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
|
||||||
|
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
|
||||||
|
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
|
||||||
|
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
|
||||||
|
|
||||||
|
### Chat
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
||||||
|
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
||||||
|
|
||||||
|
### Plans
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
||||||
|
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
||||||
|
|
||||||
|
### Storage (Cloud Records)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
|
||||||
|
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
|
||||||
|
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
|
||||||
|
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
|
||||||
|
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
|
||||||
|
|
||||||
|
### Vectors (Cloud Vector Store)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
|
||||||
|
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
|
||||||
|
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
|
||||||
|
|
||||||
|
### Backup
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
|
||||||
|
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
|
||||||
|
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
|
||||||
|
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
|
||||||
|
|
||||||
|
### Plugins (Marketplace)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
|
||||||
|
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
|
||||||
|
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
|
||||||
|
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
|
||||||
|
|
||||||
|
### Billing
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
|
||||||
|
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
|
||||||
|
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
|
||||||
|
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Model
|
||||||
|
|
||||||
|
9 tables managed by Alembic migrations. Source: `app/models.py`
|
||||||
|
|
||||||
|
### Tables
|
||||||
|
|
||||||
|
| Table | Primary Key | Key Columns | Purpose |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
||||||
|
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
||||||
|
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
||||||
|
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
|
||||||
|
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
|
||||||
|
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
|
||||||
|
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
|
||||||
|
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
|
||||||
|
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
|
||||||
|
|
||||||
|
### Enum Types
|
||||||
|
|
||||||
|
| Enum | Values |
|
||||||
|
|---|---|
|
||||||
|
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
||||||
|
| `plugin_status` | `pending_review`, `approved`, `rejected` |
|
||||||
|
| `review_decision` | `approved`, `rejected` |
|
||||||
|
|
||||||
|
### Migrations
|
||||||
|
|
||||||
|
| Version | Description |
|
||||||
|
|---|---|
|
||||||
|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
|
||||||
|
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## AI Agent System
|
||||||
|
|
||||||
|
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
|
||||||
|
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
||||||
|
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
||||||
|
|
||||||
|
### Registered Agents
|
||||||
|
|
||||||
|
| Agent | Registry Name | Tools | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
||||||
|
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
||||||
|
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
|
||||||
|
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
||||||
|
|
||||||
|
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
||||||
|
|
||||||
|
### Switching LLM Providers
|
||||||
|
|
||||||
|
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# OpenAI (default)
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# Anthropic
|
||||||
|
LLM_MODEL=anthropic/claude-3.5-sonnet
|
||||||
|
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
|
||||||
|
|
||||||
|
# Google Gemini
|
||||||
|
LLM_MODEL=gemini/gemini-pro
|
||||||
|
LLM_ROUTER_MODEL=gemini/gemini-flash
|
||||||
|
|
||||||
|
# Local Ollama
|
||||||
|
LLM_MODEL=ollama/llama3
|
||||||
|
LLM_ROUTER_MODEL=ollama/llama3
|
||||||
|
|
||||||
|
# AWS Bedrock
|
||||||
|
LLM_MODEL=bedrock/anthropic.claude-v2
|
||||||
|
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Orchestration & Execution Plans
|
||||||
|
|
||||||
|
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
|
||||||
|
|
||||||
|
### Orchestrator
|
||||||
|
|
||||||
|
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
|
||||||
|
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
|
||||||
|
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
|
||||||
|
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
|
||||||
|
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
|
||||||
|
|
||||||
|
### Execution Plans
|
||||||
|
|
||||||
|
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
|
||||||
|
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
|
||||||
|
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
|
||||||
|
|
||||||
|
### Built-in Templates (6)
|
||||||
|
|
||||||
|
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
||||||
|
|
||||||
|
### Built-in Playbooks (2)
|
||||||
|
|
||||||
|
| Playbook | Description |
|
||||||
|
|---|---|
|
||||||
|
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
|
||||||
|
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Middleware
|
||||||
|
|
||||||
|
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
|
||||||
|
|
||||||
|
### JWT Authentication
|
||||||
|
|
||||||
|
Source: `app/api/middleware/auth.py`
|
||||||
|
|
||||||
|
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
|
||||||
|
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
|
||||||
|
- Falls back to `free` when no subscription row exists.
|
||||||
|
- Raises `401 Unauthorized` on invalid or expired tokens.
|
||||||
|
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||||
|
|
||||||
|
### Tier-Based Rate Limiter
|
||||||
|
|
||||||
|
Source: `app/api/middleware/rate_limit.py`
|
||||||
|
|
||||||
|
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
|
||||||
|
- Per-user 60-second window sized by subscription tier:
|
||||||
|
|
||||||
|
| Tier | Requests / Minute |
|
||||||
|
|---|---|
|
||||||
|
| Free | 20 |
|
||||||
|
| Pro | 60 |
|
||||||
|
| Power | 120 |
|
||||||
|
| Team | 200 |
|
||||||
|
|
||||||
|
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
|
||||||
|
- **Exempt paths:** register, login, webhook, health
|
||||||
|
|
||||||
|
### Response Sanitizer
|
||||||
|
|
||||||
|
Source: `app/api/middleware/sanitizer.py`
|
||||||
|
|
||||||
|
- Runs only on `/api/v1/chat` endpoints.
|
||||||
|
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
||||||
|
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
||||||
|
- Logs sanitization events as `WARNING`.
|
||||||
|
- Binary responses (storage, backup) are never touched.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Storage Layer
|
||||||
|
|
||||||
|
### Blob Store
|
||||||
|
|
||||||
|
Source: `app/storage/blob_store.py`
|
||||||
|
|
||||||
|
- S3-backed storage for E2E encrypted blobs.
|
||||||
|
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
|
||||||
|
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
|
||||||
|
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
|
||||||
|
- The backend **never inspects or decrypts blob content**.
|
||||||
|
|
||||||
|
### Vector Store
|
||||||
|
|
||||||
|
Source: `app/storage/vector_store.py`
|
||||||
|
|
||||||
|
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
|
||||||
|
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
|
||||||
|
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
|
||||||
|
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
|
||||||
|
- Methods: `upsert()`, `search()`, `delete()`
|
||||||
|
|
||||||
|
### Encryption Utilities
|
||||||
|
|
||||||
|
Source: `app/storage/encryption.py`
|
||||||
|
|
||||||
|
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
|
||||||
|
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
|
||||||
|
- **No decryption key ever reaches the backend.**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Billing & Tiers
|
||||||
|
|
||||||
|
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
||||||
|
|
||||||
|
### Feature Matrix
|
||||||
|
|
||||||
|
| Feature | Free | Pro | Power | Team |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
||||||
|
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
||||||
|
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
|
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
|
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
||||||
|
| Batch Builder | — | — | ✓ | ✓ |
|
||||||
|
| Plugin Marketplace | — | — | ✓ | ✓ |
|
||||||
|
| SSO | — | — | — | ✓ |
|
||||||
|
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
||||||
|
|
||||||
|
### Stripe Integration
|
||||||
|
|
||||||
|
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
|
||||||
|
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
|
||||||
|
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
|
||||||
|
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
|
||||||
|
|
||||||
|
### Tier Manager
|
||||||
|
|
||||||
|
- `get_tier(user_id)` — Returns the user's current billing tier.
|
||||||
|
- `check_feature(tier, feature)` — Boolean feature gate check.
|
||||||
|
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
||||||
|
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Plugin Marketplace
|
||||||
|
|
||||||
|
Source: `app/marketplace/`
|
||||||
|
|
||||||
|
### Plugin Registry
|
||||||
|
|
||||||
|
- PostgreSQL-backed catalog of submitted and approved plugins.
|
||||||
|
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
|
||||||
|
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
|
||||||
|
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
|
||||||
|
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
|
||||||
|
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
|
||||||
|
|
||||||
|
### Review Queue
|
||||||
|
|
||||||
|
- Automated security checklist before human review:
|
||||||
|
- Plugin ID must match `^[a-z0-9-]+$`
|
||||||
|
- Permissions must be from the allowed set only
|
||||||
|
- No binary blobs in the manifest
|
||||||
|
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
|
||||||
|
- `get_pending(db)` — Lists plugins awaiting review.
|
||||||
|
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
||||||
|
|
||||||
|
### Revenue Sharing
|
||||||
|
|
||||||
|
- **70% developer / 30% platform** split on all paid plugin sales.
|
||||||
|
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
|
||||||
|
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
|
||||||
|
- Gracefully stubs transfers when Stripe is not configured.
|
||||||
|
|
||||||
|
### Seed Plugins
|
||||||
|
|
||||||
|
| Plugin | Category | Price |
|
||||||
|
|---|---|---|
|
||||||
|
| GitHub Sync | Productivity | Free |
|
||||||
|
| Slack Notifier | Communication | €4.99 |
|
||||||
|
| Time Tracker | Productivity | €9.99 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
pytest
|
||||||
|
|
||||||
|
# Run a specific test file
|
||||||
|
pytest tests/test_auth.py
|
||||||
|
|
||||||
|
# Run with verbose output
|
||||||
|
pytest -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Infrastructure
|
||||||
|
|
||||||
|
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
||||||
|
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
|
||||||
|
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
||||||
|
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
||||||
|
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
|
||||||
|
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
||||||
|
- **No external dependencies** — all tests run fully offline.
|
||||||
|
|
||||||
|
### Test Coverage
|
||||||
|
|
||||||
|
| File | Coverage |
|
||||||
|
|---|---|
|
||||||
|
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
||||||
|
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
|
||||||
|
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
|
||||||
|
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
|
||||||
|
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
|
||||||
|
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
|
||||||
|
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
|
||||||
|
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
|
||||||
|
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
adiuva-api/
|
||||||
|
├── alembic.ini # Alembic configuration
|
||||||
|
├── BACKEND_PLAN.md # Architecture & design decisions
|
||||||
|
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
||||||
|
├── Dockerfile # Multi-stage production build
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
│
|
||||||
|
├── alembic/ # Database migrations
|
||||||
|
│ ├── env.py # Alembic environment config
|
||||||
|
│ ├── script.py.mako # Migration template
|
||||||
|
│ └── versions/
|
||||||
|
│ ├── 001_initial_schema.py # Tables, indexes, FKs
|
||||||
|
│ └── 002_seed_plugins.py # Seed marketplace plugins
|
||||||
|
│
|
||||||
|
├── app/ # Application source
|
||||||
|
│ ├── main.py # FastAPI app factory, middleware, routes
|
||||||
|
│ ├── db.py # Async SQLAlchemy engine & session
|
||||||
|
│ ├── models.py # SQLAlchemy ORM models (9 tables)
|
||||||
|
│ ├── schemas.py # Pydantic request/response schemas
|
||||||
|
│ │
|
||||||
|
│ ├── config/
|
||||||
|
│ │ └── settings.py # Pydantic Settings (env vars)
|
||||||
|
│ │
|
||||||
|
│ ├── agents/ # LLM-powered domain agents
|
||||||
|
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
||||||
|
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
||||||
|
│ │ ├── timeline_agent.py # Milestones (4 tools)
|
||||||
|
│ │ └── note_agent.py # Markdown notes (5 tools)
|
||||||
|
│ │
|
||||||
|
│ ├── core/ # Orchestration engine
|
||||||
|
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
||||||
|
│ │ ├── llm.py # LiteLLM factory (get_llm)
|
||||||
|
│ │ ├── orchestrator.py # Intent classification & routing
|
||||||
|
│ │ └── execution_plan.py # Plan builder, templates, cache
|
||||||
|
│ │
|
||||||
|
│ ├── api/ # HTTP layer
|
||||||
|
│ │ ├── deps.py # Shared FastAPI dependencies
|
||||||
|
│ │ ├── middleware/
|
||||||
|
│ │ │ ├── auth.py # JWT validation, live tier lookup
|
||||||
|
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
||||||
|
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
||||||
|
│ │ └── routes/
|
||||||
|
│ │ ├── auth.py # Register, login, refresh, me
|
||||||
|
│ │ ├── chat.py # Chat + WebSocket streaming
|
||||||
|
│ │ ├── plans.py # Execution plan playbooks
|
||||||
|
│ │ ├── storage.py # E2E encrypted record CRUD
|
||||||
|
│ │ ├── vectors.py # Vector upsert, search, delete
|
||||||
|
│ │ ├── backup.py # Encrypted backup management
|
||||||
|
│ │ ├── plugins.py # Marketplace browse & install
|
||||||
|
│ │ └── billing.py # Stripe checkout & webhooks
|
||||||
|
│ │
|
||||||
|
│ ├── storage/ # Storage backends
|
||||||
|
│ │ ├── blob_store.py # S3 blob storage
|
||||||
|
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
|
||||||
|
│ │ └── encryption.py # Checksum verification utilities
|
||||||
|
│ │
|
||||||
|
│ ├── billing/ # Subscription management
|
||||||
|
│ │ ├── stripe_service.py # Stripe API integration
|
||||||
|
│ │ └── tier_manager.py # Feature matrix & quota enforcement
|
||||||
|
│ │
|
||||||
|
│ └── marketplace/ # Plugin ecosystem
|
||||||
|
│ ├── plugin_registry.py # Catalog CRUD & search
|
||||||
|
│ ├── plugin_review.py # Security checklist & review queue
|
||||||
|
│ └── revenue_share.py # 70/30 split & Stripe Connect
|
||||||
|
│
|
||||||
|
└── tests/ # Test suite
|
||||||
|
├── conftest.py # Fixtures: DB, S3, auth, seeds
|
||||||
|
├── test_auth.py
|
||||||
|
├── test_orchestrator.py
|
||||||
|
├── test_agents.py
|
||||||
|
├── test_storage.py
|
||||||
|
├── test_backup.py
|
||||||
|
├── test_plugins.py
|
||||||
|
├── test_agent_registry.py
|
||||||
|
├── test_execution_plan.py
|
||||||
|
└── test_middleware.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
*To be determined.*
|
||||||
47
alembic.ini
Normal file
47
alembic.ini
Normal file
@@ -0,0 +1,47 @@
|
|||||||
|
# Alembic configuration file.
|
||||||
|
# The async app uses postgresql+asyncpg:// at runtime.
|
||||||
|
# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var).
|
||||||
|
|
||||||
|
[alembic]
|
||||||
|
script_location = alembic
|
||||||
|
prepend_sys_path = .
|
||||||
|
version_path_separator = os
|
||||||
|
|
||||||
|
# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder.
|
||||||
|
sqlalchemy.url = driver://user:pass@localhost/dbname
|
||||||
|
|
||||||
|
[post_write_hooks]
|
||||||
|
|
||||||
|
[loggers]
|
||||||
|
keys = root,sqlalchemy,alembic
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys = console
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys = generic
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level = WARN
|
||||||
|
handlers = console
|
||||||
|
qualname =
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level = WARN
|
||||||
|
handlers =
|
||||||
|
qualname = sqlalchemy.engine
|
||||||
|
|
||||||
|
[logger_alembic]
|
||||||
|
level = INFO
|
||||||
|
handlers =
|
||||||
|
qualname = alembic
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class = StreamHandler
|
||||||
|
args = (sys.stderr,)
|
||||||
|
level = NOTSET
|
||||||
|
formatter = generic
|
||||||
|
|
||||||
|
[formatter_generic]
|
||||||
|
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||||
|
datefmt = %H:%M:%S
|
||||||
93
alembic/env.py
Normal file
93
alembic/env.py
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
"""Alembic migration environment — async-compatible.
|
||||||
|
|
||||||
|
At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is
|
||||||
|
synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL
|
||||||
|
env var by replacing the driver prefix.
|
||||||
|
|
||||||
|
Run migrations with:
|
||||||
|
alembic upgrade head
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from logging.config import fileConfig
|
||||||
|
|
||||||
|
from alembic import context
|
||||||
|
from sqlalchemy import engine_from_config, pool
|
||||||
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
|
# Alembic Config object (gives access to alembic.ini values).
|
||||||
|
config = context.config
|
||||||
|
|
||||||
|
# Set up Python logging from alembic.ini.
|
||||||
|
if config.config_file_name is not None:
|
||||||
|
fileConfig(config.config_file_name)
|
||||||
|
|
||||||
|
# Import the Base so that Alembic can detect model changes for --autogenerate.
|
||||||
|
from app.models import Base # noqa: E402
|
||||||
|
|
||||||
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _sync_url(async_url: str) -> str:
|
||||||
|
"""Convert an asyncpg URL to a psycopg2 URL for Alembic CLI."""
|
||||||
|
return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_url() -> str:
|
||||||
|
db_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not db_url:
|
||||||
|
# Fall back to settings if env var not set directly.
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
db_url = settings.DATABASE_URL
|
||||||
|
return _sync_url(db_url)
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_offline() -> None:
|
||||||
|
"""Emit SQL without a live DB connection."""
|
||||||
|
url = _get_url()
|
||||||
|
context.configure(
|
||||||
|
url=url,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
literal_binds=True,
|
||||||
|
dialect_opts={"paramstyle": "named"},
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
def do_run_migrations(connection): # type: ignore[no-untyped-def]
|
||||||
|
context.configure(
|
||||||
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
compare_type=True,
|
||||||
|
)
|
||||||
|
with context.begin_transaction():
|
||||||
|
context.run_migrations()
|
||||||
|
|
||||||
|
|
||||||
|
async def run_migrations_online_async() -> None:
|
||||||
|
"""Run migrations against a live DB using the async engine."""
|
||||||
|
async_url = os.environ.get("DATABASE_URL", "")
|
||||||
|
if not async_url:
|
||||||
|
from app.config.settings import settings # noqa: PLC0415
|
||||||
|
async_url = settings.DATABASE_URL
|
||||||
|
|
||||||
|
connectable = create_async_engine(async_url, poolclass=pool.NullPool)
|
||||||
|
async with connectable.connect() as connection:
|
||||||
|
await connection.run_sync(do_run_migrations)
|
||||||
|
await connectable.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def run_migrations_online() -> None:
|
||||||
|
asyncio.run(run_migrations_online_async())
|
||||||
|
|
||||||
|
|
||||||
|
if context.is_offline_mode():
|
||||||
|
run_migrations_offline()
|
||||||
|
else:
|
||||||
|
run_migrations_online()
|
||||||
28
alembic/script.py.mako
Normal file
28
alembic/script.py.mako
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
"""${message}
|
||||||
|
|
||||||
|
Revision ID: ${up_revision}
|
||||||
|
Revises: ${down_revision | comma,n}
|
||||||
|
Create Date: ${create_date}
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
${imports if imports else ""}
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = ${repr(up_revision)}
|
||||||
|
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||||
|
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
${upgrades if upgrades else "pass"}
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
${downgrades if downgrades else "pass"}
|
||||||
209
alembic/versions/001_initial_schema.py
Normal file
209
alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,209 @@
|
|||||||
|
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||||
|
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||||
|
|
||||||
|
Revision ID: 001
|
||||||
|
Revises:
|
||||||
|
Create Date: 2026-03-02
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "001"
|
||||||
|
down_revision: Union[str, None] = None
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enum types — idempotent creation via exception handling ───────────
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE billing_tier AS ENUM ('free', 'pro', 'power', 'team');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
|
# ── users ─────────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"users",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("email", sa.String(255), nullable=False),
|
||||||
|
sa.Column("password_hash", sa.String(255), nullable=False),
|
||||||
|
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||||
|
sa.Column("stripe_customer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.UniqueConstraint("email"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_users_email", "users", ["email"])
|
||||||
|
|
||||||
|
# ── refresh_tokens ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"refresh_tokens",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("token_hash", sa.String(64), nullable=False),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("token_hash"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"])
|
||||||
|
op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"])
|
||||||
|
|
||||||
|
# ── subscriptions ─────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"subscriptions",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("stripe_subscription_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"),
|
||||||
|
sa.Column("status", sa.String(50), nullable=False, server_default="free"),
|
||||||
|
sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("user_id"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||||
|
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||||
|
|
||||||
|
# ── storage_records ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"storage_records",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("table_name", sa.String(100), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||||
|
|
||||||
|
# ── backup_metadata ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"backup_metadata",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer, nullable=False),
|
||||||
|
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugins ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugins",
|
||||||
|
sa.Column("id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||||
|
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||||
|
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||||
|
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
||||||
|
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||||
|
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||||
|
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||||
|
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── plugin_installations ──────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_installations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||||
|
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_reviews",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
||||||
|
sa.Column("notes", sa.Text, nullable=True),
|
||||||
|
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||||
|
|
||||||
|
# ── revenue_events ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"revenue_events",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||||
|
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("revenue_events")
|
||||||
|
op.drop_table("plugin_reviews")
|
||||||
|
op.drop_table("plugin_installations")
|
||||||
|
op.drop_table("plugins")
|
||||||
|
op.drop_table("backup_metadata")
|
||||||
|
op.drop_table("storage_records")
|
||||||
|
op.drop_table("subscriptions")
|
||||||
|
op.drop_table("refresh_tokens")
|
||||||
|
op.drop_table("users")
|
||||||
|
|
||||||
|
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||||
|
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||||
|
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||||
92
alembic/versions/002_seed_plugins.py
Normal file
92
alembic/versions/002_seed_plugins.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2026-03-03
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "002"
|
||||||
|
down_revision: Union[str, None] = "001"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
{
|
||||||
|
"id": "plugin-github-sync",
|
||||||
|
"name": "GitHub Sync",
|
||||||
|
"description": "Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 0,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-slack-notify",
|
||||||
|
"name": "Slack Notifier",
|
||||||
|
"description": "Post task and timeline updates to Slack channels.",
|
||||||
|
"version": "1.2.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "communication",
|
||||||
|
"price_cents": 499,
|
||||||
|
"permissions": json.dumps(["read:tasks", "read:timelines"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-time-tracker",
|
||||||
|
"name": "Time Tracker",
|
||||||
|
"description": "Track time spent on tasks with automatic reporting.",
|
||||||
|
"version": "0.9.1",
|
||||||
|
"author_name": "Third Party",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 999,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
plugins = sa.table(
|
||||||
|
"plugins",
|
||||||
|
sa.column("id", sa.String),
|
||||||
|
sa.column("name", sa.String),
|
||||||
|
sa.column("description", sa.Text),
|
||||||
|
sa.column("version", sa.String),
|
||||||
|
sa.column("author_name", sa.String),
|
||||||
|
sa.column("category", sa.String),
|
||||||
|
sa.column("price_cents", sa.Integer),
|
||||||
|
sa.column("permissions", sa.Text),
|
||||||
|
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
|
||||||
|
sa.column("s3_package_key", sa.String),
|
||||||
|
sa.column("install_count", sa.Integer),
|
||||||
|
sa.column("avg_rating", sa.Float),
|
||||||
|
)
|
||||||
|
op.bulk_insert(plugins, _SEED_PLUGINS)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"DELETE FROM plugins WHERE id IN ("
|
||||||
|
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
|
||||||
|
")"
|
||||||
|
)
|
||||||
127
alembic/versions/003_agent_tables.py
Normal file
127
alembic/versions/003_agent_tables.py
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
"""Add agent config and run log tables: local_agent_configs, cloud_agent_configs, agent_run_logs.
|
||||||
|
|
||||||
|
Revision ID: 003
|
||||||
|
Revises: 002
|
||||||
|
Create Date: 2026-03-05
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "003"
|
||||||
|
down_revision: Union[str, None] = "002"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enum types — idempotent creation ──────────────────────────────────
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
|
# ── local_agent_configs ───────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
# ── cloud_agent_configs ───────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
# ── agent_run_logs ─────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"agent_run_logs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
# Plain string — not a FK because it references either local_agent_configs or
|
||||||
|
# cloud_agent_configs depending on agent_type.
|
||||||
|
sa.Column("agent_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"agent_type",
|
||||||
|
postgresql.ENUM("local", "cloud", name="agent_type", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"status",
|
||||||
|
postgresql.ENUM("running", "success", "error", "partial", name="agent_run_status", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
server_default="running",
|
||||||
|
),
|
||||||
|
sa.Column("items_processed", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("items_created", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("errors", sa.JSON, nullable=True),
|
||||||
|
sa.Column("started_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_agent_run_logs_user_id", "agent_run_logs", ["user_id"])
|
||||||
|
op.create_index("ix_agent_run_logs_agent_id", "agent_run_logs", ["agent_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("agent_run_logs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
|
|
||||||
|
op.execute("DROP TYPE IF EXISTS cloud_provider;")
|
||||||
|
op.execute("DROP TYPE IF EXISTS agent_run_status;")
|
||||||
|
op.execute("DROP TYPE IF EXISTS agent_type;")
|
||||||
144
alembic/versions/004_add_memory_tables.py
Normal file
144
alembic/versions/004_add_memory_tables.py
Normal file
@@ -0,0 +1,144 @@
|
|||||||
|
"""Add memory tables and user encryption_key column.
|
||||||
|
|
||||||
|
Memory tables:
|
||||||
|
memory_core — per-user key/value preferences (encrypted)
|
||||||
|
memory_associative — semantic memory with pgvector embedding (encrypted)
|
||||||
|
memory_episodic — session summaries (encrypted)
|
||||||
|
memory_proactive — behavioral patterns (encrypted)
|
||||||
|
|
||||||
|
Also adds encryption_key column to users table.
|
||||||
|
|
||||||
|
Revision ID: 004
|
||||||
|
Revises: 003
|
||||||
|
Create Date: 2026-03-08
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "004"
|
||||||
|
down_revision: Union[str, None] = "003"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── Enable pgvector extension (idempotent) ────────────────────────────────
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
|
||||||
|
# ── Add encryption_key to users ───────────────────────────────────────────
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("encryption_key", sa.String(64), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── memory_core ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_core",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("key", sa.String(255), nullable=False),
|
||||||
|
sa.Column("value_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_core_user_id", "memory_core", ["user_id"])
|
||||||
|
|
||||||
|
# ── memory_associative ────────────────────────────────────────────────────
|
||||||
|
# The embedding column uses pgvector's vector(1536) type.
|
||||||
|
op.create_table(
|
||||||
|
"memory_associative",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("content_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("entity_type", sa.String(100), nullable=True),
|
||||||
|
sa.Column("entity_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"updated_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
# Add the pgvector column separately (not supported by generic sa types)
|
||||||
|
op.execute(
|
||||||
|
"ALTER TABLE memory_associative ADD COLUMN embedding vector(1536);"
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_associative_user_id", "memory_associative", ["user_id"])
|
||||||
|
# IVFFlat index for approximate nearest-neighbour search
|
||||||
|
op.execute(
|
||||||
|
"CREATE INDEX ix_memory_associative_embedding "
|
||||||
|
"ON memory_associative USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);"
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── memory_episodic ───────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_episodic",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("summary_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("session_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_episodic_user_id", "memory_episodic", ["user_id"])
|
||||||
|
op.create_index("ix_memory_episodic_session_id", "memory_episodic", ["session_id"])
|
||||||
|
|
||||||
|
# ── memory_proactive ──────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"memory_proactive",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column(
|
||||||
|
"user_id",
|
||||||
|
postgresql.UUID(as_uuid=False),
|
||||||
|
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("pattern_encrypted", sa.Text, nullable=False),
|
||||||
|
sa.Column("confidence", sa.Float, nullable=False, server_default="0.5"),
|
||||||
|
sa.Column("source", sa.String(50), nullable=False, server_default="inferred"),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.func.now(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
op.create_index("ix_memory_proactive_user_id", "memory_proactive", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_table("memory_proactive")
|
||||||
|
op.drop_table("memory_episodic")
|
||||||
|
op.drop_index("ix_memory_associative_embedding", "memory_associative")
|
||||||
|
op.drop_table("memory_associative")
|
||||||
|
op.drop_table("memory_core")
|
||||||
|
op.drop_column("users", "encryption_key")
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
"""add name and surname to users table
|
||||||
|
|
||||||
|
Revision ID: 818478c251dc
|
||||||
|
Revises: 004
|
||||||
|
Create Date: 2026-03-10 15:10:42.811947
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = '818478c251dc'
|
||||||
|
down_revision: Union[str, None] = '004'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column('users', sa.Column('name', sa.String(length=100), nullable=True))
|
||||||
|
op.add_column('users', sa.Column('surname', sa.String(length=100), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column('users', 'surname')
|
||||||
|
op.drop_column('users', 'name')
|
||||||
@@ -0,0 +1,92 @@
|
|||||||
|
"""Deprecate backend agent config tables.
|
||||||
|
|
||||||
|
The Electron client is now the source of truth for agent configuration
|
||||||
|
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
||||||
|
billing checks and trigger/run logs only.
|
||||||
|
|
||||||
|
Revision ID: 9a1f2d0b6c7e
|
||||||
|
Revises: 818478c251dc
|
||||||
|
Create Date: 2026-03-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
revision: str = "9a1f2d0b6c7e"
|
||||||
|
down_revision: Union[str, None] = "818478c251dc"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = sa.inspect(bind)
|
||||||
|
existing = set(inspector.get_table_names())
|
||||||
|
|
||||||
|
if "cloud_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
|
||||||
|
if "local_agent_configs" in existing:
|
||||||
|
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
"""Import all agent modules to trigger @registry.register decorators."""
|
|
||||||
|
|
||||||
from app.agents import checkpoint_agent, note_agent, project_agent, task_agent
|
|
||||||
|
|
||||||
__all__ = ["checkpoint_agent", "note_agent", "project_agent", "task_agent"]
|
|
||||||
@@ -1,122 +0,0 @@
|
|||||||
"""Checkpoint agent — project milestone management (list, create, update, delete)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a project checkpoint assistant. Checkpoints are milestone dates that\n"
|
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a checkpoint, 0 otherwise\n"
|
|
||||||
" - is_approved: 0 until the user explicitly confirms; then 1\n"
|
|
||||||
" - For update_checkpoint, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all checkpoints across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_checkpoints(project_id: str = "") -> str:
|
|
||||||
"""List checkpoints. Provide project_id to scope to a specific project."""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "list",
|
|
||||||
"table": "checkpoints",
|
|
||||||
"filters": {"projectId": project_id or None},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def create_checkpoint(
|
|
||||||
project_id: str,
|
|
||||||
title: str,
|
|
||||||
date: int,
|
|
||||||
is_ai_suggested: int = 0,
|
|
||||||
is_approved: int = 0,
|
|
||||||
) -> str:
|
|
||||||
"""Create a project checkpoint (milestone).
|
|
||||||
project_id: REQUIRED UUID of the parent project
|
|
||||||
title: descriptive name for the milestone
|
|
||||||
date: Unix timestamp in milliseconds
|
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
|
||||||
is_approved: 0 until the user confirms
|
|
||||||
"""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "create_record",
|
|
||||||
"table": "checkpoints",
|
|
||||||
"data": {
|
|
||||||
"projectId": project_id,
|
|
||||||
"title": title,
|
|
||||||
"date": date,
|
|
||||||
"isAiSuggested": is_ai_suggested,
|
|
||||||
"isApproved": is_approved,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def update_checkpoint(
|
|
||||||
checkpoint_id: str,
|
|
||||||
title: str = "",
|
|
||||||
date: int = -1,
|
|
||||||
is_approved: int = -1,
|
|
||||||
) -> str:
|
|
||||||
"""Update a checkpoint. Only pass fields that should change.
|
|
||||||
checkpoint_id: UUID of the checkpoint (required)
|
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
|
||||||
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
|
||||||
"""
|
|
||||||
updates: dict[str, Any] = {}
|
|
||||||
if title:
|
|
||||||
updates["title"] = title
|
|
||||||
if date != -1:
|
|
||||||
updates["date"] = date
|
|
||||||
if is_approved != -1:
|
|
||||||
updates["isApproved"] = is_approved
|
|
||||||
return json.dumps({
|
|
||||||
"action": "update_record",
|
|
||||||
"table": "checkpoints",
|
|
||||||
"data": {"id": checkpoint_id, "updates": updates},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def delete_checkpoint(checkpoint_id: str) -> str:
|
|
||||||
"""Delete a checkpoint permanently by its UUID."""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "delete_record",
|
|
||||||
"table": "checkpoints",
|
|
||||||
"data": {"id": checkpoint_id},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class CheckpointAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "checkpoint_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages project checkpoints (milestones): list, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_checkpoints, create_checkpoint, update_checkpoint, delete_checkpoint]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
@@ -1,123 +0,0 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, delete)."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.agent_registry import ChatAgent, registry
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT = (
|
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_notes(project_id: str = "") -> str:
|
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "list",
|
|
||||||
"table": "notes",
|
|
||||||
"filters": {"projectId": project_id or None},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_note(note_id: str) -> str:
|
|
||||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "get",
|
|
||||||
"table": "notes",
|
|
||||||
"data": {"id": note_id},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def create_note(
|
|
||||||
title: str,
|
|
||||||
content: str,
|
|
||||||
project_id: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""Create a new note.
|
|
||||||
title: note heading (required)
|
|
||||||
content: Markdown body text (required)
|
|
||||||
project_id: optional UUID linking this note to a project
|
|
||||||
"""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "create_record",
|
|
||||||
"table": "notes",
|
|
||||||
"data": {
|
|
||||||
"title": title,
|
|
||||||
"content": content,
|
|
||||||
"projectId": project_id or None,
|
|
||||||
},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def update_note(
|
|
||||||
note_id: str,
|
|
||||||
title: str = "",
|
|
||||||
content: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""Update an existing note. Only pass fields that should change.
|
|
||||||
note_id: UUID of the note (required)
|
|
||||||
If you need to preserve existing content, call get_note first.
|
|
||||||
"""
|
|
||||||
updates: dict[str, Any] = {}
|
|
||||||
if title:
|
|
||||||
updates["title"] = title
|
|
||||||
if content:
|
|
||||||
updates["content"] = content
|
|
||||||
return json.dumps({
|
|
||||||
"action": "update_record",
|
|
||||||
"table": "notes",
|
|
||||||
"data": {"id": note_id, "updates": updates},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def delete_note(note_id: str) -> str:
|
|
||||||
"""Delete a note permanently by its UUID."""
|
|
||||||
return json.dumps({
|
|
||||||
"action": "delete_record",
|
|
||||||
"table": "notes",
|
|
||||||
"data": {"id": note_id},
|
|
||||||
})
|
|
||||||
|
|
||||||
|
|
||||||
@registry.register
|
|
||||||
class NoteAgent(ChatAgent):
|
|
||||||
def get_name(self) -> str:
|
|
||||||
return "note_agent"
|
|
||||||
|
|
||||||
def get_description(self) -> str:
|
|
||||||
return "Manages notes: list, get, create, update, delete"
|
|
||||||
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
return [list_notes, get_note, create_note, update_note, delete_note]
|
|
||||||
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
llm = ChatOpenAI(model="gpt-4o", temperature=0, api_key=settings.OPENAI_API_KEY)
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(
|
|
||||||
content=f"User query: {query}\nContext: {json.dumps(context)[:1000]}"
|
|
||||||
),
|
|
||||||
]
|
|
||||||
return await self._tool_loop(llm, messages, self.get_tools())
|
|
||||||
@@ -1,14 +0,0 @@
|
|||||||
"""Shared FastAPI dependencies.
|
|
||||||
|
|
||||||
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
|
|
||||||
(the canonical location per Step 9). This module re-exports them so that all
|
|
||||||
existing route imports (``from app.api.deps import get_current_user``) continue
|
|
||||||
to work without modification.
|
|
||||||
|
|
||||||
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
|
|
||||||
instead of reading it from the JWT payload.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
|
|
||||||
|
|
||||||
__all__ = ["get_current_user", "oauth2_scheme"]
|
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
"""API middleware package.
|
|
||||||
|
|
||||||
Exports the three middleware components introduced in Step 9:
|
|
||||||
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
|
|
||||||
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
|
|
||||||
- Sanitizer: ``SanitizerMiddleware``
|
|
||||||
"""
|
|
||||||
|
|
||||||
from app.api.middleware.auth import get_current_user, oauth2_scheme
|
|
||||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
|
|
||||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
|
||||||
|
|
||||||
__all__ = [
|
|
||||||
"get_current_user",
|
|
||||||
"oauth2_scheme",
|
|
||||||
"TierRateLimitMiddleware",
|
|
||||||
"limiter",
|
|
||||||
"SanitizerMiddleware",
|
|
||||||
]
|
|
||||||
@@ -1,51 +0,0 @@
|
|||||||
"""Auth middleware — JWT validation dependency.
|
|
||||||
|
|
||||||
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
|
||||||
It decodes the Bearer JWT, validates signature and expiry, and returns a
|
|
||||||
``UserProfile`` carrying ``id``, ``email``, and ``tier``.
|
|
||||||
|
|
||||||
Exempt routes (no JWT required):
|
|
||||||
- POST /api/v1/auth/register
|
|
||||||
- POST /api/v1/auth/login
|
|
||||||
- POST /api/v1/billing/webhook
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.schemas import UserProfile
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
|
||||||
token: str = Depends(oauth2_scheme),
|
|
||||||
) -> UserProfile:
|
|
||||||
"""Validate a Bearer JWT and return the authenticated user.
|
|
||||||
|
|
||||||
Raises HTTP 401 on any invalid or expired token.
|
|
||||||
The tier embedded in the JWT is used for feature-gating until Step 12
|
|
||||||
adds a live DB lookup.
|
|
||||||
"""
|
|
||||||
credentials_exc = HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Could not validate credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
||||||
)
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
email: str | None = payload.get("email")
|
|
||||||
tier: str = payload.get("tier", "free")
|
|
||||||
if not user_id or not email:
|
|
||||||
raise credentials_exc
|
|
||||||
except JWTError:
|
|
||||||
raise credentials_exc
|
|
||||||
|
|
||||||
return UserProfile(id=user_id, email=email, tier=tier) # type: ignore[arg-type]
|
|
||||||
@@ -1,129 +0,0 @@
|
|||||||
"""Tier-aware rate limiting middleware.
|
|
||||||
|
|
||||||
Uses a per-user sliding-window counter (in-process, no Redis required).
|
|
||||||
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
|
||||||
|
|
||||||
Limits (requests per minute):
|
|
||||||
- free: 20
|
|
||||||
- pro: 60
|
|
||||||
- power: 120
|
|
||||||
- team: 200
|
|
||||||
|
|
||||||
Exempt paths bypass the limiter entirely:
|
|
||||||
- POST /api/v1/auth/register
|
|
||||||
- POST /api/v1/auth/login
|
|
||||||
- POST /api/v1/billing/webhook
|
|
||||||
- GET /api/v1/health
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import time
|
|
||||||
from collections import defaultdict
|
|
||||||
|
|
||||||
from fastapi import Request, Response
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
from slowapi import Limiter
|
|
||||||
from slowapi.util import get_remote_address
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.types import ASGIApp
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
|
|
||||||
_TIER_LIMITS: dict[str, int] = {
|
|
||||||
"free": 20,
|
|
||||||
"pro": 60,
|
|
||||||
"power": 120,
|
|
||||||
"team": 200,
|
|
||||||
}
|
|
||||||
|
|
||||||
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"/api/v1/auth/register",
|
|
||||||
"/api/v1/auth/login",
|
|
||||||
"/api/v1/billing/webhook",
|
|
||||||
"/api/v1/health",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_user_id_from_jwt(request: Request) -> str:
|
|
||||||
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
|
||||||
auth = request.headers.get("Authorization", "")
|
|
||||||
token = auth.removeprefix("Bearer ").strip()
|
|
||||||
if not token:
|
|
||||||
return get_remote_address(request)
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
||||||
)
|
|
||||||
return payload.get("sub") or get_remote_address(request)
|
|
||||||
except JWTError:
|
|
||||||
return get_remote_address(request)
|
|
||||||
|
|
||||||
|
|
||||||
# Exported Limiter instance — available for optional route-level decoration.
|
|
||||||
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
|
||||||
|
|
||||||
|
|
||||||
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
|
||||||
|
|
||||||
Each authenticated user gets their own 60-second window sized by tier.
|
|
||||||
Unauthenticated requests pass through (the auth dependency will reject them
|
|
||||||
with 401 before the route handler runs).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp) -> None:
|
|
||||||
super().__init__(app)
|
|
||||||
# user_id → list of request timestamps (float, seconds since epoch)
|
|
||||||
self._window: dict[str, list[float]] = defaultdict(list)
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
|
||||||
if request.url.path in _EXEMPT_PATHS:
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
|
||||||
auth = request.headers.get("Authorization", "")
|
|
||||||
token = auth.removeprefix("Bearer ").strip()
|
|
||||||
if not token:
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
|
||||||
)
|
|
||||||
user_id: str = payload.get("sub") or get_remote_address(request)
|
|
||||||
tier: str = payload.get("tier", "free")
|
|
||||||
except JWTError:
|
|
||||||
return await call_next(request)
|
|
||||||
|
|
||||||
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
|
||||||
now = time.monotonic()
|
|
||||||
window_start = now - 60.0
|
|
||||||
|
|
||||||
# Slide the window: discard timestamps older than 60 seconds.
|
|
||||||
timestamps = [t for t in self._window[user_id] if t > window_start]
|
|
||||||
|
|
||||||
if len(timestamps) >= limit:
|
|
||||||
retry_after = max(1, int(60 - (now - min(timestamps))))
|
|
||||||
return Response(
|
|
||||||
content=json.dumps(
|
|
||||||
{
|
|
||||||
"detail": (
|
|
||||||
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
|
||||||
f"Retry in {retry_after}s."
|
|
||||||
)
|
|
||||||
}
|
|
||||||
),
|
|
||||||
status_code=429,
|
|
||||||
headers={
|
|
||||||
"Retry-After": str(retry_after),
|
|
||||||
"Content-Type": "application/json",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
timestamps.append(now)
|
|
||||||
self._window[user_id] = timestamps
|
|
||||||
return await call_next(request)
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
"""Response sanitizer middleware.
|
|
||||||
|
|
||||||
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
|
||||||
that could reveal server-side prompt IP:
|
|
||||||
- System prompt openers ("You are a/an/the …")
|
|
||||||
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
|
||||||
- LangChain tool schema fragments (``"type": "function"``)
|
|
||||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
|
||||||
- Exact-match known prompt fingerprints
|
|
||||||
|
|
||||||
Binary responses (storage blobs, backup data) are never touched — the
|
|
||||||
middleware only activates for paths under /api/v1/chat.
|
|
||||||
|
|
||||||
Any sanitisation event is logged as a WARNING with the request path and the
|
|
||||||
names of the fields that were modified.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
|
|
||||||
from fastapi import Request, Response
|
|
||||||
from starlette.middleware.base import BaseHTTPMiddleware
|
|
||||||
from starlette.types import ASGIApp
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
# Detection patterns — order matters: fingerprints checked first (exact),
|
|
||||||
# then compiled regexes.
|
|
||||||
# ---------------------------------------------------------------------------
|
|
||||||
|
|
||||||
_FINGERPRINTS: tuple[str, ...] = (
|
|
||||||
"You are an intent classifier",
|
|
||||||
"Respond with just the agent name",
|
|
||||||
"Summarize these agent results",
|
|
||||||
"Available agents:",
|
|
||||||
"route to:",
|
|
||||||
)
|
|
||||||
|
|
||||||
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
|
||||||
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
|
||||||
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
|
||||||
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
|
||||||
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
|
||||||
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
|
||||||
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
|
||||||
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
|
||||||
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _sanitize_text(text: str) -> tuple[str, bool]:
|
|
||||||
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
|
||||||
|
|
||||||
Returns ``(cleaned_text, was_changed)``.
|
|
||||||
"""
|
|
||||||
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
|
||||||
for fp in _FINGERPRINTS:
|
|
||||||
if fp in text:
|
|
||||||
return "[REDACTED]", True
|
|
||||||
|
|
||||||
changed = False
|
|
||||||
for pattern in _PATTERNS:
|
|
||||||
new_text, n = pattern.subn("[REDACTED]", text)
|
|
||||||
if n:
|
|
||||||
text = new_text
|
|
||||||
changed = True
|
|
||||||
|
|
||||||
return text, changed
|
|
||||||
|
|
||||||
|
|
||||||
class SanitizerMiddleware(BaseHTTPMiddleware):
|
|
||||||
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
|
||||||
|
|
||||||
def __init__(self, app: ASGIApp) -> None:
|
|
||||||
super().__init__(app)
|
|
||||||
|
|
||||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
|
||||||
response: Response = await call_next(request)
|
|
||||||
|
|
||||||
# Only process chat endpoint responses.
|
|
||||||
if not request.url.path.startswith("/api/v1/chat"):
|
|
||||||
return response
|
|
||||||
|
|
||||||
# Read body — collect streaming chunks.
|
|
||||||
body_bytes = b""
|
|
||||||
async for chunk in response.body_iterator:
|
|
||||||
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
|
||||||
|
|
||||||
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
|
||||||
try:
|
|
||||||
body = json.loads(body_bytes.decode("utf-8"))
|
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
|
||||||
return Response(
|
|
||||||
content=body_bytes,
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers=dict(response.headers),
|
|
||||||
media_type=response.media_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not isinstance(body, dict):
|
|
||||||
return Response(
|
|
||||||
content=body_bytes,
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers=dict(response.headers),
|
|
||||||
media_type=response.media_type,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Walk top-level string fields and sanitise.
|
|
||||||
sanitised_fields: list[str] = []
|
|
||||||
for key, value in body.items():
|
|
||||||
if isinstance(value, str):
|
|
||||||
cleaned, changed = _sanitize_text(value)
|
|
||||||
if changed:
|
|
||||||
body[key] = cleaned
|
|
||||||
sanitised_fields.append(key)
|
|
||||||
|
|
||||||
if sanitised_fields:
|
|
||||||
logger.warning(
|
|
||||||
"Sanitizer redacted prompt fragments",
|
|
||||||
extra={
|
|
||||||
"path": request.url.path,
|
|
||||||
"fields": sanitised_fields,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
new_body = json.dumps(body).encode("utf-8")
|
|
||||||
headers = dict(response.headers)
|
|
||||||
headers["content-length"] = str(len(new_body))
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
content=new_body,
|
|
||||||
status_code=response.status_code,
|
|
||||||
headers=headers,
|
|
||||||
media_type="application/json",
|
|
||||||
)
|
|
||||||
@@ -1,118 +0,0 @@
|
|||||||
"""Auth routes: register, login, refresh, me.
|
|
||||||
|
|
||||||
Users and refresh tokens are kept in an in-memory dict until Step 12
|
|
||||||
migrates them to PostgreSQL.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from jose import jwt
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.schemas import AuthTokens, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
||||||
|
|
||||||
# ── In-memory stores (replaced by PostgreSQL in Step 12) ─────────────
|
|
||||||
_users: dict[str, dict[str, Any]] = {} # email → user record
|
|
||||||
_refresh_tokens: dict[str, str] = {} # plain token → user_id
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _hash_password(password: str) -> str:
|
|
||||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_password(password: str, hashed: str) -> bool:
|
|
||||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
|
||||||
|
|
||||||
|
|
||||||
def _make_tokens(user_id: str, email: str, tier: str) -> AuthTokens:
|
|
||||||
now = int(time.time())
|
|
||||||
access_exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
|
||||||
access_payload = {
|
|
||||||
"sub": user_id,
|
|
||||||
"email": email,
|
|
||||||
"tier": tier,
|
|
||||||
"exp": access_exp,
|
|
||||||
"iat": now,
|
|
||||||
}
|
|
||||||
access_token = jwt.encode(
|
|
||||||
access_payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM
|
|
||||||
)
|
|
||||||
refresh_token = str(uuid.uuid4())
|
|
||||||
_refresh_tokens[refresh_token] = user_id
|
|
||||||
return AuthTokens(
|
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=refresh_token,
|
|
||||||
expires_at=access_exp * 1000, # milliseconds for client
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _RegisterRequest(BaseModel):
|
|
||||||
email: str
|
|
||||||
password: str
|
|
||||||
|
|
||||||
|
|
||||||
class _LoginRequest(BaseModel):
|
|
||||||
email: str
|
|
||||||
password: str
|
|
||||||
|
|
||||||
|
|
||||||
class _RefreshRequest(BaseModel):
|
|
||||||
refresh_token: str
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def register(body: _RegisterRequest) -> AuthTokens:
|
|
||||||
"""Create a new account and return JWT tokens."""
|
|
||||||
if body.email in _users:
|
|
||||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
|
||||||
user_id = str(uuid.uuid4())
|
|
||||||
_users[body.email] = {
|
|
||||||
"id": user_id,
|
|
||||||
"email": body.email,
|
|
||||||
"password_hash": _hash_password(body.password),
|
|
||||||
"tier": "free",
|
|
||||||
}
|
|
||||||
return _make_tokens(user_id, body.email, "free")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=AuthTokens)
|
|
||||||
async def login(body: _LoginRequest) -> AuthTokens:
|
|
||||||
"""Validate credentials and return JWT tokens."""
|
|
||||||
user = _users.get(body.email)
|
|
||||||
if not user or not _verify_password(body.password, user["password_hash"]):
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
|
||||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=AuthTokens)
|
|
||||||
async def refresh(body: _RefreshRequest) -> AuthTokens:
|
|
||||||
"""Rotate a refresh token and return a new token pair."""
|
|
||||||
user_id = _refresh_tokens.pop(body.refresh_token, None)
|
|
||||||
if user_id is None:
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
|
||||||
user = next((u for u in _users.values() if u["id"] == user_id), None)
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
|
||||||
return _make_tokens(user["id"], user["email"], user["tier"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserProfile)
|
|
||||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
|
||||||
"""Return the profile for the authenticated user."""
|
|
||||||
return current_user
|
|
||||||
@@ -1,158 +0,0 @@
|
|||||||
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
|
||||||
|
|
||||||
Blobs are stored in S3 via BlobStore. Backup metadata is kept in an
|
|
||||||
in-memory dict until Step 12 migrates it to PostgreSQL (backup_metadata table).
|
|
||||||
|
|
||||||
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
|
||||||
treating "history" as a ``{backup_id}`` path parameter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
from email.utils import parsedate_to_datetime
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.schemas import BackupMetadata, UserProfile
|
|
||||||
from app.storage.blob_store import BlobStore
|
|
||||||
from app.storage.encryption import reject_if_tampered
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/backup", tags=["backup"])
|
|
||||||
|
|
||||||
_blob_store = BlobStore()
|
|
||||||
|
|
||||||
# In-memory backup metadata — replaced by PostgreSQL backup_metadata table in Step 12
|
|
||||||
_backups: dict[str, list[dict[str, Any]]] = {} # user_id → list of backup records
|
|
||||||
|
|
||||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
|
||||||
_TIER_BACKUP_LIMITS_GB: dict[str, int] = {
|
|
||||||
"free": 0,
|
|
||||||
"pro": 5,
|
|
||||||
"power": 25,
|
|
||||||
"team": -1, # unlimited
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _check_backup_quota(user_id: str, tier: str, size_bytes: int) -> None:
|
|
||||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
|
||||||
limit_gb = _TIER_BACKUP_LIMITS_GB.get(tier, 0)
|
|
||||||
if limit_gb == 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail="Backup is not available on the free tier",
|
|
||||||
)
|
|
||||||
if limit_gb == -1:
|
|
||||||
return # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024**3
|
|
||||||
used = sum(b["size_bytes"] for b in _backups.get(user_id, []))
|
|
||||||
if used + size_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("")
|
|
||||||
async def upload_backup(
|
|
||||||
request: Request,
|
|
||||||
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
|
||||||
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
|
||||||
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Upload an E2E-encrypted backup blob.
|
|
||||||
|
|
||||||
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
|
||||||
"""
|
|
||||||
blob = await request.body()
|
|
||||||
reject_if_tampered(blob, x_backup_checksum)
|
|
||||||
_check_backup_quota(current_user.id, current_user.tier, len(blob))
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
|
||||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
|
||||||
)
|
|
||||||
|
|
||||||
backup_record: dict[str, Any] = {
|
|
||||||
"id": str(x_backup_timestamp),
|
|
||||||
"s3_key": s3_key,
|
|
||||||
"version": x_backup_version,
|
|
||||||
"timestamp": x_backup_timestamp,
|
|
||||||
"checksum": x_backup_checksum,
|
|
||||||
"size_bytes": len(blob),
|
|
||||||
}
|
|
||||||
|
|
||||||
user_backups = _backups.setdefault(current_user.id, [])
|
|
||||||
user_backups.append(backup_record)
|
|
||||||
user_backups.sort(key=lambda b: b["timestamp"], reverse=True)
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/history", response_model=list[BackupMetadata])
|
|
||||||
async def backup_history(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> list[BackupMetadata]:
|
|
||||||
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
|
||||||
return [
|
|
||||||
BackupMetadata(
|
|
||||||
version=b["version"],
|
|
||||||
timestamp=b["timestamp"],
|
|
||||||
checksum=b["checksum"],
|
|
||||||
chunk_count=1, # single-chunk uploads for now — TODO(Step12): track real count
|
|
||||||
)
|
|
||||||
for b in _backups.get(current_user.id, [])
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
|
||||||
async def download_backup(
|
|
||||||
request: Request,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> Response:
|
|
||||||
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
|
||||||
user_backups = _backups.get(current_user.id, [])
|
|
||||||
if not user_backups:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
|
||||||
|
|
||||||
latest = user_backups[0]
|
|
||||||
|
|
||||||
ims_header = request.headers.get("If-Modified-Since")
|
|
||||||
if ims_header:
|
|
||||||
try:
|
|
||||||
ims_dt = parsedate_to_datetime(ims_header)
|
|
||||||
ims_ms = int(ims_dt.timestamp() * 1000)
|
|
||||||
if latest["timestamp"] <= ims_ms:
|
|
||||||
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
|
||||||
except Exception:
|
|
||||||
pass # malformed header — ignore and serve the blob
|
|
||||||
|
|
||||||
blob = await _blob_store.download(current_user.id, latest["s3_key"])
|
|
||||||
return Response(
|
|
||||||
content=blob,
|
|
||||||
media_type="application/octet-stream",
|
|
||||||
headers={
|
|
||||||
"X-Backup-Version": str(latest["version"]),
|
|
||||||
"X-Backup-Timestamp": str(latest["timestamp"]),
|
|
||||||
"X-Checksum": latest["checksum"],
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{backup_id}", response_model=dict)
|
|
||||||
async def delete_backup(
|
|
||||||
backup_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a specific backup by ID."""
|
|
||||||
user_backups = _backups.get(current_user.id, [])
|
|
||||||
target = next((b for b in user_backups if b["id"] == backup_id), None)
|
|
||||||
if target is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
|
||||||
|
|
||||||
await _blob_store.delete(current_user.id, target["s3_key"])
|
|
||||||
_backups[current_user.id] = [b for b in user_backups if b["id"] != backup_id]
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -1,184 +0,0 @@
|
|||||||
"""Billing routes: Stripe checkout, webhook, subscription management.
|
|
||||||
|
|
||||||
Subscription records are kept in-memory until Step 12 migrates them to
|
|
||||||
PostgreSQL (subscriptions table). Stripe calls are gracefully stubbed when
|
|
||||||
STRIPE_SECRET_KEY is not configured, allowing local development without keys.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import stripe as stripe_lib
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.schemas import BillingTier, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
|
||||||
|
|
||||||
# In-memory subscriptions — replaced by PostgreSQL subscriptions table in Step 12
|
|
||||||
_subscriptions: dict[str, dict[str, Any]] = {} # user_id → subscription record
|
|
||||||
|
|
||||||
_TIER_PRICE_IDS: dict[str, str] = {
|
|
||||||
"pro": "price_pro_monthly", # replace with real Stripe price IDs
|
|
||||||
"power": "price_power_monthly",
|
|
||||||
"team": "price_team_monthly",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _stripe_configured() -> bool:
|
|
||||||
return bool(settings.STRIPE_SECRET_KEY)
|
|
||||||
|
|
||||||
|
|
||||||
def _stripe() -> Any:
|
|
||||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
|
||||||
return stripe_lib
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _CheckoutRequest(BaseModel):
|
|
||||||
tier: BillingTier
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/checkout", response_model=dict)
|
|
||||||
async def create_checkout(
|
|
||||||
body: _CheckoutRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Create a Stripe checkout session for a tier upgrade.
|
|
||||||
|
|
||||||
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
|
||||||
"""
|
|
||||||
if body.tier == "free":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Cannot create a checkout session for the free tier",
|
|
||||||
)
|
|
||||||
|
|
||||||
if _stripe_configured():
|
|
||||||
price_id = _TIER_PRICE_IDS.get(body.tier)
|
|
||||||
if not price_id:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail=f"Unknown tier: {body.tier}",
|
|
||||||
)
|
|
||||||
s = _stripe()
|
|
||||||
session = s.checkout.Session.create(
|
|
||||||
payment_method_types=["card"],
|
|
||||||
mode="subscription",
|
|
||||||
line_items=[{"price": price_id, "quantity": 1}],
|
|
||||||
success_url=(
|
|
||||||
"https://app.adiuva.app/billing/success"
|
|
||||||
"?session_id={CHECKOUT_SESSION_ID}"
|
|
||||||
),
|
|
||||||
cancel_url="https://app.adiuva.app/billing/cancel",
|
|
||||||
metadata={"user_id": current_user.id, "tier": body.tier},
|
|
||||||
)
|
|
||||||
return {"checkout_url": session.url}
|
|
||||||
|
|
||||||
return {"checkout_url": "https://stripe.com/stub-checkout"}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/webhook", response_model=dict)
|
|
||||||
async def stripe_webhook(
|
|
||||||
request: Request,
|
|
||||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Handle Stripe webhook events.
|
|
||||||
|
|
||||||
No JWT auth — authenticated via Stripe signature verification instead.
|
|
||||||
Returns 200 immediately when Stripe is not configured (local dev).
|
|
||||||
"""
|
|
||||||
payload = await request.body()
|
|
||||||
|
|
||||||
if not _stripe_configured():
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
try:
|
|
||||||
s = _stripe()
|
|
||||||
event = s.Webhook.construct_event(
|
|
||||||
payload, stripe_signature, settings.STRIPE_WEBHOOK_SECRET
|
|
||||||
)
|
|
||||||
except stripe_lib.error.SignatureVerificationError:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Invalid Stripe signature",
|
|
||||||
)
|
|
||||||
|
|
||||||
event_type: str = event["type"]
|
|
||||||
data: dict[str, Any] = event["data"]["object"]
|
|
||||||
|
|
||||||
if event_type == "checkout.session.completed":
|
|
||||||
user_id = data.get("metadata", {}).get("user_id")
|
|
||||||
tier = data.get("metadata", {}).get("tier", "free")
|
|
||||||
sub_id = data.get("subscription")
|
|
||||||
if user_id:
|
|
||||||
_subscriptions[user_id] = {
|
|
||||||
"tier": tier,
|
|
||||||
"stripe_subscription_id": sub_id,
|
|
||||||
"status": "active",
|
|
||||||
"current_period_end": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
elif event_type == "customer.subscription.updated":
|
|
||||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, then update tier
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif event_type == "customer.subscription.deleted":
|
|
||||||
# TODO(Step12): look up user_id from stripe_customer_id in DB, set tier to free
|
|
||||||
pass
|
|
||||||
|
|
||||||
elif event_type == "invoice.payment_failed":
|
|
||||||
# TODO(Step12): flag subscription as past_due, notify user
|
|
||||||
pass
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/subscription", response_model=dict)
|
|
||||||
async def get_subscription(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Return the current subscription info for the authenticated user."""
|
|
||||||
sub = _subscriptions.get(current_user.id)
|
|
||||||
if sub is None:
|
|
||||||
return {
|
|
||||||
"tier": current_user.tier,
|
|
||||||
"status": "free",
|
|
||||||
"stripe_subscription_id": None,
|
|
||||||
"current_period_end": None,
|
|
||||||
}
|
|
||||||
return sub
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/subscription", response_model=dict)
|
|
||||||
async def cancel_subscription(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Cancel the active subscription."""
|
|
||||||
sub = _subscriptions.get(current_user.id)
|
|
||||||
if sub is None or not sub.get("stripe_subscription_id"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail="No active subscription found",
|
|
||||||
)
|
|
||||||
|
|
||||||
if _stripe_configured():
|
|
||||||
s = _stripe()
|
|
||||||
s.Subscription.cancel(sub["stripe_subscription_id"])
|
|
||||||
|
|
||||||
_subscriptions[current_user.id] = {
|
|
||||||
**sub,
|
|
||||||
"tier": "free",
|
|
||||||
"status": "canceled",
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -1,78 +0,0 @@
|
|||||||
"""Chat routes: POST /chat and WebSocket /chat/stream."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, WebSocket, WebSocketDisconnect
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.orchestrator import orchestrate, orchestrate_stream
|
|
||||||
from app.schemas import ChatRequest, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
|
||||||
async def chat(
|
|
||||||
body: ChatRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> JSONResponse:
|
|
||||||
"""Route a chat message through the orchestrator.
|
|
||||||
|
|
||||||
Returns ``ChatResponse`` for ``execution_mode='direct'``,
|
|
||||||
or ``ExecutionPlan`` for ``execution_mode='plan'``.
|
|
||||||
"""
|
|
||||||
result = await orchestrate(body)
|
|
||||||
return JSONResponse(content=result.model_dump())
|
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/stream")
|
|
||||||
async def chat_stream(websocket: WebSocket) -> None:
|
|
||||||
"""Streaming chat via WebSocket.
|
|
||||||
|
|
||||||
Auth: ``?token=<jwt>`` query param (Bearer not possible during WS handshake).
|
|
||||||
|
|
||||||
Protocol:
|
|
||||||
1. Client sends ``ChatRequest`` as the first JSON text frame.
|
|
||||||
2. Server streams response text chunks.
|
|
||||||
3. Final frame: JSON ``{"done": true, "response": "...", "actions": [...]}``.
|
|
||||||
4. Server pings every 30 s to keep the connection alive.
|
|
||||||
"""
|
|
||||||
# Authenticate before accepting the connection
|
|
||||||
token = websocket.query_params.get("token", "")
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM])
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
if not user_id:
|
|
||||||
raise JWTError("missing sub")
|
|
||||||
except JWTError:
|
|
||||||
await websocket.close(code=1008) # 1008 = Policy Violation
|
|
||||||
return
|
|
||||||
|
|
||||||
await websocket.accept()
|
|
||||||
|
|
||||||
try:
|
|
||||||
raw = await websocket.receive_text()
|
|
||||||
body = ChatRequest.model_validate_json(raw)
|
|
||||||
|
|
||||||
async def _heartbeat() -> None:
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
||||||
await websocket.send_text(json.dumps({"ping": True}))
|
|
||||||
|
|
||||||
heartbeat_task = asyncio.create_task(_heartbeat())
|
|
||||||
try:
|
|
||||||
async for chunk in orchestrate_stream(body):
|
|
||||||
await websocket.send_text(chunk)
|
|
||||||
finally:
|
|
||||||
heartbeat_task.cancel()
|
|
||||||
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
pass
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Plans routes: GET /plans/playbook and GET /plans/playbook/{plan_id}."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.core.execution_plan import plan_cache
|
|
||||||
from app.schemas import ExecutionPlan, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/plans", tags=["plans"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook", response_model=list[ExecutionPlan])
|
|
||||||
async def list_playbooks(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached execution plan playbooks for the authenticated user.
|
|
||||||
|
|
||||||
TODO(Step11): filter by tier — power+ plans gated behind batch_builder feature.
|
|
||||||
"""
|
|
||||||
return plan_cache.get_all_playbooks()
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/playbook/{plan_id}", response_model=ExecutionPlan)
|
|
||||||
async def get_playbook(
|
|
||||||
plan_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> ExecutionPlan:
|
|
||||||
"""Return a specific execution plan playbook by ID."""
|
|
||||||
plan = plan_cache.get_plan(plan_id)
|
|
||||||
if plan is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Plan not found: {plan_id}",
|
|
||||||
)
|
|
||||||
return plan
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
"""Plugins routes: browse and install plugins from the marketplace.
|
|
||||||
|
|
||||||
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes introduced
|
|
||||||
in Step 10. Step 12 will swap those services' in-memory stores for
|
|
||||||
PostgreSQL persistence.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
from app.marketplace.revenue_share import revenue_share
|
|
||||||
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier gate ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _require_plugin_tier(user: UserProfile) -> None:
|
|
||||||
"""Raise HTTP 403 for users below Power tier."""
|
|
||||||
if user.tier not in ("power", "team"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Plugin marketplace requires Power tier or above",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local detail schema ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _PluginDetail(BaseModel):
|
|
||||||
plugin: PluginManifest
|
|
||||||
install_count: int
|
|
||||||
ratings: list[Any] # Step 12 populates from plugin_reviews table
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("", response_model=PluginListResponse)
|
|
||||||
async def list_plugins(
|
|
||||||
category: str | None = Query(default=None),
|
|
||||||
q: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> PluginListResponse:
|
|
||||||
"""Browse the plugin marketplace. Requires Power tier or above."""
|
|
||||||
_require_plugin_tier(current_user)
|
|
||||||
return await registry.list_plugins(category=category, query=q, page=page, sort=sort)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
|
||||||
async def get_plugin(
|
|
||||||
plugin_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> _PluginDetail:
|
|
||||||
"""Get full plugin details including install count. Requires Power tier or above."""
|
|
||||||
_require_plugin_tier(current_user)
|
|
||||||
entry = await registry.get_plugin(plugin_id)
|
|
||||||
if entry is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
|
||||||
return _PluginDetail(
|
|
||||||
plugin=entry["manifest"],
|
|
||||||
install_count=entry["install_count"],
|
|
||||||
ratings=[], # Step 12 populates from plugin_reviews table
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{plugin_id}/install", response_model=dict)
|
|
||||||
async def install_plugin(
|
|
||||||
plugin_id: str,
|
|
||||||
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
|
||||||
|
|
||||||
Requires Power tier or above.
|
|
||||||
"""
|
|
||||||
_require_plugin_tier(current_user)
|
|
||||||
entry = await registry.get_plugin(plugin_id)
|
|
||||||
if entry is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
|
||||||
|
|
||||||
await revenue_share.record_install(
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
amount_cents=entry["manifest"].price_cents,
|
|
||||||
)
|
|
||||||
|
|
||||||
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
|
||||||
return {"ok": True, "download_url": download_url}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{plugin_id}/install", response_model=dict)
|
|
||||||
async def uninstall_plugin(
|
|
||||||
plugin_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Unregister a plugin installation."""
|
|
||||||
await registry.record_uninstall(plugin_id)
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -1,185 +0,0 @@
|
|||||||
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
|
||||||
|
|
||||||
Blobs are stored in S3 via BlobStore. Record metadata is kept in an
|
|
||||||
in-memory dict until Step 12 migrates it to PostgreSQL (storage_records table).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
|
||||||
from app.storage.blob_store import BlobStore
|
|
||||||
from app.storage.encryption import reject_if_tampered
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/storage", tags=["storage"])
|
|
||||||
|
|
||||||
_blob_store = BlobStore()
|
|
||||||
|
|
||||||
# In-memory record metadata — replaced by PostgreSQL storage_records table in Step 12
|
|
||||||
_records: dict[str, dict[str, Any]] = {}
|
|
||||||
|
|
||||||
# TODO(Step11/12): replace with TierManager.check_quota(user_id)
|
|
||||||
_TIER_STORAGE_LIMITS_GB: dict[str, int] = {
|
|
||||||
"free": 0,
|
|
||||||
"pro": 5,
|
|
||||||
"power": 25,
|
|
||||||
"team": -1, # unlimited
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local response schemas ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _CreateResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
created_at: int
|
|
||||||
|
|
||||||
|
|
||||||
class _RecordMeta(BaseModel):
|
|
||||||
id: str
|
|
||||||
table: str
|
|
||||||
checksum: str
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _check_quota(user_id: str, tier: str, additional_bytes: int) -> None:
|
|
||||||
"""Raise HTTP 402 if adding ``additional_bytes`` would exceed the tier limit."""
|
|
||||||
limit_gb = _TIER_STORAGE_LIMITS_GB.get(tier, 0)
|
|
||||||
if limit_gb == -1:
|
|
||||||
return # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024**3
|
|
||||||
used = sum(r["size_bytes"] for r in _records.values() if r["user_id"] == user_id)
|
|
||||||
if used + additional_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_record_for_user(record_id: str, user_id: str) -> dict[str, Any]:
|
|
||||||
"""Look up a record and verify ownership. Always returns 404 on mismatch
|
|
||||||
to prevent user enumeration attacks."""
|
|
||||||
record = _records.get(record_id)
|
|
||||||
if record is None or record["user_id"] != user_id:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_record(
|
|
||||||
body: StorageRecordCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> _CreateResponse:
|
|
||||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
|
||||||
reject_if_tampered(body.blob, body.checksum)
|
|
||||||
_check_quota(current_user.id, current_user.tier, len(body.blob))
|
|
||||||
|
|
||||||
record_id = str(uuid.uuid4())
|
|
||||||
now = int(time.time() * 1000)
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
|
||||||
current_user.id, body.table, record_id, body.blob, body.checksum
|
|
||||||
)
|
|
||||||
|
|
||||||
_records[record_id] = {
|
|
||||||
"id": record_id,
|
|
||||||
"user_id": current_user.id,
|
|
||||||
"table": body.table,
|
|
||||||
"s3_key": s3_key,
|
|
||||||
"checksum": body.checksum,
|
|
||||||
"size_bytes": len(body.blob),
|
|
||||||
"created_at": now,
|
|
||||||
"updated_at": now,
|
|
||||||
}
|
|
||||||
|
|
||||||
return _CreateResponse(id=record_id, created_at=now)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/records", response_model=list[_RecordMeta])
|
|
||||||
async def list_records(
|
|
||||||
table: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
limit: int = Query(default=50, ge=1, le=200),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> list[_RecordMeta]:
|
|
||||||
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
|
||||||
all_records = [
|
|
||||||
r for r in _records.values()
|
|
||||||
if r["user_id"] == current_user.id and (table is None or r["table"] == table)
|
|
||||||
]
|
|
||||||
start = (page - 1) * limit
|
|
||||||
page_records = all_records[start : start + limit]
|
|
||||||
return [
|
|
||||||
_RecordMeta(
|
|
||||||
id=r["id"],
|
|
||||||
table=r["table"],
|
|
||||||
checksum=r["checksum"],
|
|
||||||
created_at=r["created_at"],
|
|
||||||
updated_at=r["updated_at"],
|
|
||||||
)
|
|
||||||
for r in page_records
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/records/{record_id}")
|
|
||||||
async def download_record(
|
|
||||||
record_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> Response:
|
|
||||||
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
|
||||||
record = _get_record_for_user(record_id, current_user.id)
|
|
||||||
blob = await _blob_store.download(current_user.id, record["s3_key"])
|
|
||||||
return Response(
|
|
||||||
content=blob,
|
|
||||||
media_type="application/octet-stream",
|
|
||||||
headers={"X-Checksum": record["checksum"]},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/records/{record_id}", response_model=dict)
|
|
||||||
async def update_record(
|
|
||||||
record_id: str,
|
|
||||||
body: StorageRecordUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
|
||||||
record = _get_record_for_user(record_id, current_user.id)
|
|
||||||
reject_if_tampered(body.blob, body.checksum)
|
|
||||||
|
|
||||||
delta = len(body.blob) - record["size_bytes"]
|
|
||||||
if delta > 0:
|
|
||||||
_check_quota(current_user.id, current_user.tier, delta)
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
|
||||||
current_user.id, record["table"], record_id, body.blob, body.checksum
|
|
||||||
)
|
|
||||||
|
|
||||||
record["s3_key"] = s3_key
|
|
||||||
record["checksum"] = body.checksum
|
|
||||||
record["size_bytes"] = len(body.blob)
|
|
||||||
record["updated_at"] = int(time.time() * 1000)
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/records/{record_id}", response_model=dict)
|
|
||||||
async def delete_record(
|
|
||||||
record_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a record and its S3 blob."""
|
|
||||||
record = _get_record_for_user(record_id, current_user.id)
|
|
||||||
await _blob_store.delete(current_user.id, record["s3_key"])
|
|
||||||
del _records[record_id]
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
"""Vectors routes: upsert, search, and delete cloud vector store entries."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.schemas import (
|
|
||||||
UserProfile,
|
|
||||||
VectorSearchRequest,
|
|
||||||
VectorSearchResponse,
|
|
||||||
VectorUpsertRequest,
|
|
||||||
)
|
|
||||||
from app.storage.encryption import reject_if_tampered
|
|
||||||
from app.storage.vector_store import VectorStore
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/storage", tags=["vectors"])
|
|
||||||
|
|
||||||
_vector_store = VectorStore()
|
|
||||||
|
|
||||||
|
|
||||||
class _VectorDeleteRequest(BaseModel):
|
|
||||||
ids: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/vectors/upsert", response_model=dict)
|
|
||||||
async def upsert_vectors(
|
|
||||||
body: VectorUpsertRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, int]:
|
|
||||||
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
|
||||||
for item in body.vectors:
|
|
||||||
reject_if_tampered(item.blob, item.checksum)
|
|
||||||
await _vector_store.upsert(current_user.id, body.vectors)
|
|
||||||
return {"upserted": len(body.vectors)}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
|
||||||
async def search_vectors(
|
|
||||||
body: VectorSearchRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> VectorSearchResponse:
|
|
||||||
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
|
||||||
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
|
||||||
return VectorSearchResponse(results=results)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/vectors", response_model=dict)
|
|
||||||
async def delete_vectors(
|
|
||||||
body: _VectorDeleteRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete vectors by ID, scoped to the authenticated user."""
|
|
||||||
await _vector_store.delete(current_user.id, body.ids)
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
from typing import Literal
|
|
||||||
from pydantic_settings import BaseSettings
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
|
||||||
JWT_SECRET: str = "change-me-in-production"
|
|
||||||
JWT_ALGORITHM: str = "HS256"
|
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
|
||||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
|
||||||
|
|
||||||
STRIPE_SECRET_KEY: str = ""
|
|
||||||
STRIPE_WEBHOOK_SECRET: str = ""
|
|
||||||
|
|
||||||
S3_BUCKET: str = ""
|
|
||||||
S3_REGION: str = "us-east-1"
|
|
||||||
AWS_ACCESS_KEY_ID: str = ""
|
|
||||||
AWS_SECRET_ACCESS_KEY: str = ""
|
|
||||||
|
|
||||||
PINECONE_API_KEY: str = ""
|
|
||||||
PINECONE_INDEX: str = "adiuva"
|
|
||||||
QDRANT_URL: str = ""
|
|
||||||
QDRANT_API_KEY: str = ""
|
|
||||||
|
|
||||||
OPENAI_API_KEY: str = ""
|
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
env_file = ".env"
|
|
||||||
env_file_encoding = "utf-8"
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
@@ -1,137 +0,0 @@
|
|||||||
"""Agent Registry — base classes and singleton registry for chat agents."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
|
||||||
"""Common base for all agents."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
user_id: str = "",
|
|
||||||
shared_memory: dict[str, Any] | None = None,
|
|
||||||
vector_store_context: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.user_id = user_id
|
|
||||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
|
||||||
self.vector_store_context: list[str] = vector_store_context or []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_name(self) -> str: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_description(self) -> str: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def skills(self) -> list[str]:
|
|
||||||
"""Override in subclasses to advertise capabilities."""
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
class ChatAgent(BaseAgent):
|
|
||||||
"""Base class for LLM-powered chat agents."""
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
async def handle(self, query: str, context: dict[str, Any]) -> str:
|
|
||||||
"""Process a user query and return a text response."""
|
|
||||||
...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_tools(self) -> list[Any]:
|
|
||||||
"""Return LangChain tool definitions available to this agent."""
|
|
||||||
...
|
|
||||||
|
|
||||||
async def _tool_loop(
|
|
||||||
self,
|
|
||||||
llm: Any,
|
|
||||||
messages: list[Any],
|
|
||||||
tools: list[Any],
|
|
||||||
max_iter: int = 5,
|
|
||||||
) -> str:
|
|
||||||
"""Shared tool-calling loop.
|
|
||||||
|
|
||||||
Binds *tools* to *llm*, invokes iteratively until the model stops
|
|
||||||
requesting tool calls or *max_iter* is reached, and returns the
|
|
||||||
final text response.
|
|
||||||
"""
|
|
||||||
from langchain_core.messages import AIMessage, ToolMessage
|
|
||||||
|
|
||||||
llm_with_tools = llm.bind_tools(tools) if tools else llm
|
|
||||||
|
|
||||||
for _ in range(max_iter):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
return str(response.content)
|
|
||||||
|
|
||||||
# Execute each requested tool call
|
|
||||||
tool_map = {t.name: t for t in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_fn = tool_map.get(call["name"])
|
|
||||||
if tool_fn is None:
|
|
||||||
result = f"Unknown tool: {call['name']}"
|
|
||||||
else:
|
|
||||||
result = await tool_fn.ainvoke(call["args"])
|
|
||||||
messages.append(
|
|
||||||
ToolMessage(content=str(result), tool_call_id=call["id"])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Exhausted iterations — ask model for a final answer without tools
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
return str(response.content)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRegistry:
|
|
||||||
"""Singleton registry for ChatAgent subclasses."""
|
|
||||||
|
|
||||||
_instance: AgentRegistry | None = None
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._agents: dict[str, type[ChatAgent]] = {}
|
|
||||||
|
|
||||||
def __new__(cls) -> AgentRegistry:
|
|
||||||
if cls._instance is None:
|
|
||||||
cls._instance = super().__new__(cls)
|
|
||||||
cls._instance._agents = {}
|
|
||||||
return cls._instance
|
|
||||||
|
|
||||||
# ── public API ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def register(self, agent_class: type[ChatAgent]) -> type[ChatAgent]:
|
|
||||||
"""Class decorator — registers an agent by its name."""
|
|
||||||
instance = agent_class()
|
|
||||||
name = instance.get_name()
|
|
||||||
self._agents[name] = agent_class
|
|
||||||
return agent_class
|
|
||||||
|
|
||||||
def get(self, name: str) -> ChatAgent:
|
|
||||||
"""Return a fresh instance of the named agent."""
|
|
||||||
cls = self._agents.get(name)
|
|
||||||
if cls is None:
|
|
||||||
raise KeyError(f"Agent not found: {name}")
|
|
||||||
return cls()
|
|
||||||
|
|
||||||
def list_agents(self) -> list[dict[str, str]]:
|
|
||||||
"""Return ``[{name, description}]`` for the orchestrator prompt."""
|
|
||||||
result: list[dict[str, str]] = []
|
|
||||||
for cls in self._agents.values():
|
|
||||||
inst = cls()
|
|
||||||
result.append(
|
|
||||||
{"name": inst.get_name(), "description": inst.get_description()}
|
|
||||||
)
|
|
||||||
return result
|
|
||||||
|
|
||||||
async def call_agent(
|
|
||||||
self, name: str, query: str, context: dict[str, Any]
|
|
||||||
) -> str:
|
|
||||||
"""Instantiate the named agent and call its ``handle`` method."""
|
|
||||||
agent = self.get(name)
|
|
||||||
return await agent.handle(query, context)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
registry = AgentRegistry()
|
|
||||||
@@ -1,222 +0,0 @@
|
|||||||
"""Execution Plan generator — builder, template registry, and LRU plan cache."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections import OrderedDict
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.schemas import ExecutionPlan, PlanStep
|
|
||||||
|
|
||||||
|
|
||||||
# ── Prompt Template Registry ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PromptTemplateRegistry:
|
|
||||||
"""Server-side store mapping template IDs to prompt text.
|
|
||||||
|
|
||||||
Clients only ever receive template IDs (e.g. ``"tpl_task_agent_default"``).
|
|
||||||
The actual prompt text is resolved here on the server, keeping prompt IP
|
|
||||||
out of API responses.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._templates: dict[str, str] = {}
|
|
||||||
|
|
||||||
def register(self, template_id: str, prompt_text: str) -> None:
|
|
||||||
self._templates[template_id] = prompt_text
|
|
||||||
|
|
||||||
def get(self, template_id: str) -> str:
|
|
||||||
"""Resolve a template ID to its prompt text.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the template is not registered.
|
|
||||||
"""
|
|
||||||
text = self._templates.get(template_id)
|
|
||||||
if text is None:
|
|
||||||
raise KeyError(f"Template not found: {template_id!r}")
|
|
||||||
return text
|
|
||||||
|
|
||||||
def has(self, template_id: str) -> bool:
|
|
||||||
return template_id in self._templates
|
|
||||||
|
|
||||||
def list_ids(self) -> list[str]:
|
|
||||||
"""Return all registered template IDs (never the text)."""
|
|
||||||
return list(self._templates.keys())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plan Builder ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlanBuilder:
|
|
||||||
"""Fluent builder for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Example::
|
|
||||||
|
|
||||||
plan = (
|
|
||||||
ExecutionPlanBuilder("task_agent")
|
|
||||||
.add_llm_step("tpl_task_agent_default", {"message": user_msg})
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build()
|
|
||||||
)
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, agent: str) -> None:
|
|
||||||
self._agent = agent
|
|
||||||
self._steps: list[PlanStep] = []
|
|
||||||
|
|
||||||
# ── step adders ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def add_step(
|
|
||||||
self, action: str, params: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a generic action step with optional parameters."""
|
|
||||||
self._steps.append(PlanStep(action=action, variables=params))
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_llm_step(
|
|
||||||
self, template_id: str, variables: dict[str, Any] | None = None
|
|
||||||
) -> ExecutionPlanBuilder:
|
|
||||||
"""Append an LLM step referencing a server-side template by ID."""
|
|
||||||
self._steps.append(
|
|
||||||
PlanStep(action="llm", prompt_template=template_id, variables=variables)
|
|
||||||
)
|
|
||||||
return self
|
|
||||||
|
|
||||||
def add_data_step(self, action: str, data_from_step: int) -> ExecutionPlanBuilder:
|
|
||||||
"""Append a step whose input comes from the output of an earlier step."""
|
|
||||||
self._steps.append(PlanStep(action=action, data_from_step=data_from_step))
|
|
||||||
return self
|
|
||||||
|
|
||||||
# ── build ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def build(self) -> ExecutionPlan:
|
|
||||||
"""Validate step references and return the ``ExecutionPlan``.
|
|
||||||
|
|
||||||
Raises ``ValueError`` if any ``data_from_step`` references a
|
|
||||||
non-existent or future step index.
|
|
||||||
"""
|
|
||||||
for i, step in enumerate(self._steps):
|
|
||||||
if step.data_from_step is not None:
|
|
||||||
if not (0 <= step.data_from_step < i):
|
|
||||||
raise ValueError(
|
|
||||||
f"Step {i}: data_from_step={step.data_from_step} must "
|
|
||||||
f"reference a preceding step index in range 0..{i - 1}"
|
|
||||||
)
|
|
||||||
return ExecutionPlan(agent=self._agent, steps=list(self._steps))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plan Cache (LRU) ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class PlanCache:
|
|
||||||
"""In-memory LRU cache for ``ExecutionPlan`` objects.
|
|
||||||
|
|
||||||
Plans stored here are accessible as playbooks via ``get_all_playbooks()``.
|
|
||||||
The cache also serves as a runtime memoisation layer so that repeated
|
|
||||||
identical intent classifications can skip re-building the plan.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self, maxsize: int = 1000) -> None:
|
|
||||||
self._maxsize = maxsize
|
|
||||||
self._cache: OrderedDict[str, ExecutionPlan] = OrderedDict()
|
|
||||||
|
|
||||||
def cache_plan(self, key: str, plan: ExecutionPlan) -> None:
|
|
||||||
"""Store *plan* under *key*, evicting the LRU entry if at capacity."""
|
|
||||||
if key in self._cache:
|
|
||||||
del self._cache[key] # remove so re-insertion places it at the end
|
|
||||||
elif len(self._cache) >= self._maxsize:
|
|
||||||
self._cache.popitem(last=False) # evict least-recently-used
|
|
||||||
self._cache[key] = plan
|
|
||||||
|
|
||||||
def get_plan(self, key: str) -> ExecutionPlan | None:
|
|
||||||
"""Return the cached plan for *key*, or ``None`` if not present.
|
|
||||||
|
|
||||||
Accessing a plan marks it as most-recently used.
|
|
||||||
"""
|
|
||||||
if key not in self._cache:
|
|
||||||
return None
|
|
||||||
self._cache.move_to_end(key)
|
|
||||||
return self._cache[key]
|
|
||||||
|
|
||||||
def get_all_playbooks(self) -> list[ExecutionPlan]:
|
|
||||||
"""Return all cached plans (most-recently used last)."""
|
|
||||||
return list(self._cache.values())
|
|
||||||
|
|
||||||
|
|
||||||
# ── Module-level singletons ───────────────────────────────────────────
|
|
||||||
|
|
||||||
template_registry = PromptTemplateRegistry()
|
|
||||||
plan_cache = PlanCache()
|
|
||||||
|
|
||||||
|
|
||||||
def _register_builtin_templates() -> None:
|
|
||||||
"""Register the built-in server-side prompt templates.
|
|
||||||
|
|
||||||
These strings never leave the server. Clients only receive the IDs.
|
|
||||||
"""
|
|
||||||
_tpls: dict[str, str] = {
|
|
||||||
"tpl_task_agent_default": (
|
|
||||||
"You are a task management assistant. Help the user create, update, "
|
|
||||||
"list, and track tasks. Use correct status values (todo, in_progress, "
|
|
||||||
"done) and priority values (high, medium, low) from the workspace model."
|
|
||||||
),
|
|
||||||
"tpl_checkpoint_agent_default": (
|
|
||||||
"You are a project checkpoint assistant. Help the user create and manage "
|
|
||||||
"milestone checkpoints on their projects. Every checkpoint requires a "
|
|
||||||
"project_id and a date expressed as a Unix timestamp in milliseconds."
|
|
||||||
),
|
|
||||||
"tpl_project_agent_default": (
|
|
||||||
"You are a project management assistant. Help the user create, find, "
|
|
||||||
"update, and archive projects. Projects have a name, an optional client, "
|
|
||||||
"and a status of either active or archived."
|
|
||||||
),
|
|
||||||
"tpl_note_agent_default": (
|
|
||||||
"You are a note-taking assistant. Help the user create, retrieve, update, "
|
|
||||||
"and delete Markdown notes. Notes can optionally be linked to a project."
|
|
||||||
),
|
|
||||||
"tpl_task_extract_from_project": (
|
|
||||||
"Extract all actionable tasks from the provided project context. "
|
|
||||||
"Return a structured list of tasks, each with a title, inferred priority "
|
|
||||||
"(high, medium, or low), suggested status (todo), and a due_date in "
|
|
||||||
"milliseconds where a deadline can be inferred."
|
|
||||||
),
|
|
||||||
"tpl_note_weekly_summary": (
|
|
||||||
"Generate a weekly project summary note from the provided workspace data. "
|
|
||||||
"Include: tasks completed this week, tasks due soon, active projects, "
|
|
||||||
"and upcoming checkpoints. Format the output as clean Markdown."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
for tid, text in _tpls.items():
|
|
||||||
template_registry.register(tid, text)
|
|
||||||
|
|
||||||
|
|
||||||
def _load_playbooks() -> None:
|
|
||||||
"""Pre-build and cache the built-in playbooks."""
|
|
||||||
playbooks: list[tuple[str, ExecutionPlan]] = [
|
|
||||||
(
|
|
||||||
"create_tasks_from_project",
|
|
||||||
ExecutionPlanBuilder("project_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_task_extract_from_project",
|
|
||||||
{"source": "project_context"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"generate_weekly_note",
|
|
||||||
ExecutionPlanBuilder("note_agent")
|
|
||||||
.add_llm_step(
|
|
||||||
"tpl_note_weekly_summary",
|
|
||||||
{"period": "last_7_days"},
|
|
||||||
)
|
|
||||||
.add_data_step("create_record", data_from_step=0)
|
|
||||||
.build(),
|
|
||||||
),
|
|
||||||
]
|
|
||||||
for key, plan in playbooks:
|
|
||||||
plan_cache.cache_plan(key, plan)
|
|
||||||
|
|
||||||
|
|
||||||
# Initialise on module load
|
|
||||||
_register_builtin_templates()
|
|
||||||
_load_playbooks()
|
|
||||||
@@ -1,169 +0,0 @@
|
|||||||
"""Orchestrator — LLM-based intent router and agent pipeline."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, AsyncGenerator
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.agent_registry import AgentRegistry
|
|
||||||
from app.core.agent_registry import registry as _default_registry
|
|
||||||
from app.schemas import ChatRequest, ChatResponse, ExecutionPlan
|
|
||||||
|
|
||||||
_FALLBACK_AGENT = "task_agent"
|
|
||||||
|
|
||||||
_CLASSIFY_SYSTEM = (
|
|
||||||
"You are an intent classifier. Given the user message and context, decide "
|
|
||||||
"which agent to route to.\n"
|
|
||||||
"Available agents: {agents}\n"
|
|
||||||
"Respond with just the agent name, nothing else."
|
|
||||||
)
|
|
||||||
|
|
||||||
_SYNTHESIZE_HUMAN = (
|
|
||||||
"Combine the following agent results into one coherent response.\n\n"
|
|
||||||
"Agent results:\n{results}\n\n"
|
|
||||||
"Original message: {message}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_llm(model: str = "gpt-4o-mini") -> ChatOpenAI:
|
|
||||||
return ChatOpenAI(model=model, temperature=0, api_key=settings.OPENAI_API_KEY)
|
|
||||||
|
|
||||||
|
|
||||||
async def classify_intent(
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> str:
|
|
||||||
"""Use gpt-4o-mini to classify intent and return the matching agent name.
|
|
||||||
|
|
||||||
Falls back to ``task_agent`` when the registry is empty or the model
|
|
||||||
returns a name that is not registered.
|
|
||||||
"""
|
|
||||||
agents = reg.list_agents()
|
|
||||||
if not agents:
|
|
||||||
return _FALLBACK_AGENT
|
|
||||||
|
|
||||||
system = _CLASSIFY_SYSTEM.format(agents=json.dumps(agents))
|
|
||||||
# Truncate context to keep the classification prompt short
|
|
||||||
human = f"Message: {message}\nContext summary: {json.dumps(context)[:500]}"
|
|
||||||
|
|
||||||
llm = _make_llm()
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[SystemMessage(content=system), HumanMessage(content=human)]
|
|
||||||
)
|
|
||||||
|
|
||||||
agent_name = str(response.content).strip().lower()
|
|
||||||
known = {a["name"] for a in agents}
|
|
||||||
return agent_name if agent_name in known else _FALLBACK_AGENT
|
|
||||||
|
|
||||||
|
|
||||||
async def route_single(
|
|
||||||
agent_name: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Route to a single agent and wrap the result in a ``ChatResponse``."""
|
|
||||||
response_text = await reg.call_agent(agent_name, message, context)
|
|
||||||
return ChatResponse(response=response_text)
|
|
||||||
|
|
||||||
|
|
||||||
async def route_pipeline(
|
|
||||||
agent_names: list[str],
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
reg: AgentRegistry,
|
|
||||||
) -> ChatResponse:
|
|
||||||
"""Execute agents sequentially; each agent receives previous results in context.
|
|
||||||
|
|
||||||
A final LLM synthesis call merges all results into one coherent response.
|
|
||||||
"""
|
|
||||||
previous_results: list[str] = []
|
|
||||||
|
|
||||||
for agent_name in agent_names:
|
|
||||||
ctx = {**context, "previous_results": list(previous_results)}
|
|
||||||
result = await reg.call_agent(agent_name, message, ctx)
|
|
||||||
previous_results.append(result)
|
|
||||||
|
|
||||||
results_str = "\n\n".join(
|
|
||||||
f"[{name}]: {res}" for name, res in zip(agent_names, previous_results)
|
|
||||||
)
|
|
||||||
human = _SYNTHESIZE_HUMAN.format(results=results_str, message=message)
|
|
||||||
llm = _make_llm()
|
|
||||||
synthesis = await llm.ainvoke([HumanMessage(content=human)])
|
|
||||||
return ChatResponse(response=str(synthesis.content))
|
|
||||||
|
|
||||||
|
|
||||||
def _build_plan(agent_name: str, message: str) -> ExecutionPlan:
|
|
||||||
"""Build an ``ExecutionPlan`` for the resolved agent.
|
|
||||||
|
|
||||||
Uses ``ExecutionPlanBuilder`` with the server-side template registry.
|
|
||||||
If a default template exists for the agent, an LLM step is emitted;
|
|
||||||
otherwise a plain ``handle`` action step is used.
|
|
||||||
"""
|
|
||||||
from app.core.execution_plan import ExecutionPlanBuilder, template_registry
|
|
||||||
|
|
||||||
template_id = f"tpl_{agent_name}_default"
|
|
||||||
builder = ExecutionPlanBuilder(agent_name)
|
|
||||||
if template_registry.has(template_id):
|
|
||||||
builder.add_llm_step(template_id, {"message": message})
|
|
||||||
else:
|
|
||||||
builder.add_step("handle", {"message": message})
|
|
||||||
return builder.build()
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> ChatResponse | ExecutionPlan:
|
|
||||||
"""Main orchestration entry point.
|
|
||||||
|
|
||||||
* Classifies the user's intent to select an agent.
|
|
||||||
* ``execution_mode == 'direct'``: routes to the agent and returns a
|
|
||||||
``ChatResponse``.
|
|
||||||
* ``execution_mode == 'plan'``: returns an ``ExecutionPlan`` with the
|
|
||||||
resolved agent and a template-ID-only step (prompt IP stays server-side).
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
|
|
||||||
if request.execution_mode == "direct":
|
|
||||||
return await route_single(agent_name, request.message, context, reg)
|
|
||||||
|
|
||||||
# plan mode — return plan, do not execute
|
|
||||||
return _build_plan(agent_name, request.message)
|
|
||||||
|
|
||||||
|
|
||||||
async def orchestrate_stream(
|
|
||||||
request: ChatRequest,
|
|
||||||
reg: AgentRegistry | None = None,
|
|
||||||
) -> AsyncGenerator[str, None]:
|
|
||||||
"""Streaming orchestration — yields text chunks then a final JSON frame.
|
|
||||||
|
|
||||||
The final frame is a JSON object:
|
|
||||||
``{"done": true, "response": "...", "actions": []}``.
|
|
||||||
|
|
||||||
Agents do not yet support token-level streaming; the full response is
|
|
||||||
fetched first, then emitted in fixed-size chunks. Token-level streaming
|
|
||||||
will be wired in Step 6 when agents expose ``astream()``.
|
|
||||||
"""
|
|
||||||
if reg is None:
|
|
||||||
reg = _default_registry
|
|
||||||
|
|
||||||
context = request.context.model_dump()
|
|
||||||
agent_name = await classify_intent(request.message, context, reg)
|
|
||||||
response_text = await reg.call_agent(agent_name, request.message, context)
|
|
||||||
|
|
||||||
chunk_size = 50
|
|
||||||
for i in range(0, len(response_text), chunk_size):
|
|
||||||
yield response_text[i : i + chunk_size]
|
|
||||||
|
|
||||||
final = ChatResponse(response=response_text)
|
|
||||||
yield json.dumps({"done": True, **final.model_dump()})
|
|
||||||
62
app/main.py
62
app/main.py
@@ -1,62 +0,0 @@
|
|||||||
from contextlib import asynccontextmanager
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
|
||||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
|
||||||
from app.config.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
# Startup: initialise DB connection pool and agent registry
|
|
||||||
from app.core.agent_registry import registry # noqa: F401 — triggers module load
|
|
||||||
import app.agents # noqa: F401 — triggers @registry.register decorators
|
|
||||||
|
|
||||||
yield
|
|
||||||
|
|
||||||
# Shutdown: nothing to clean up for now
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
|
||||||
app = FastAPI(
|
|
||||||
title="Adiuva Cloud API",
|
|
||||||
version="0.1.0",
|
|
||||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
|
||||||
redoc_url=None,
|
|
||||||
lifespan=lifespan,
|
|
||||||
)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=settings.CORS_ORIGINS,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
|
|
||||||
# Request flow: TierRateLimit → Sanitizer → CORS → Router
|
|
||||||
# Response flow: Router → CORS → Sanitizer → TierRateLimit
|
|
||||||
app.add_middleware(SanitizerMiddleware)
|
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
|
||||||
|
|
||||||
from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors
|
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
|
||||||
app.include_router(plans.router, prefix="/api/v1")
|
|
||||||
app.include_router(storage.router, prefix="/api/v1")
|
|
||||||
app.include_router(vectors.router, prefix="/api/v1")
|
|
||||||
app.include_router(backup.router, prefix="/api/v1")
|
|
||||||
app.include_router(plugins.router, prefix="/api/v1")
|
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
|
||||||
async def health() -> dict:
|
|
||||||
return {"status": "ok", "version": app.version}
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
app = create_app()
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
"""Plugin marketplace package.
|
|
||||||
|
|
||||||
Three service classes introduced in Step 10:
|
|
||||||
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
|
||||||
- ``ReviewQueue`` — approval workflow + security checklist
|
|
||||||
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
|
||||||
"""
|
|
||||||
@@ -1,211 +0,0 @@
|
|||||||
"""Plugin catalog registry.
|
|
||||||
|
|
||||||
Maintains the authoritative list of plugins, their review status, and
|
|
||||||
aggregate install counts. Storage is in-memory until Step 12 migrates to
|
|
||||||
the ``plugins`` PostgreSQL table.
|
|
||||||
|
|
||||||
Module-level singleton::
|
|
||||||
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from app.schemas import PluginListResponse, PluginManifest
|
|
||||||
|
|
||||||
# ── Pre-seeded approved plugins (mirrors the Step 8 stub catalog) ─────
|
|
||||||
|
|
||||||
_SEED_PLUGINS: list[dict[str, Any]] = [
|
|
||||||
{
|
|
||||||
"manifest": PluginManifest(
|
|
||||||
id="plugin-github-sync",
|
|
||||||
name="GitHub Sync",
|
|
||||||
description="Sync tasks with GitHub Issues and pull requests.",
|
|
||||||
version="1.0.0",
|
|
||||||
author="Adiuva",
|
|
||||||
permissions=["read:tasks", "write:tasks"],
|
|
||||||
category="productivity",
|
|
||||||
price_cents=0,
|
|
||||||
),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
"rejection_reason": None,
|
|
||||||
"submitted_at": int(time.time()),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"manifest": PluginManifest(
|
|
||||||
id="plugin-slack-notify",
|
|
||||||
name="Slack Notifier",
|
|
||||||
description="Post task and checkpoint updates to Slack channels.",
|
|
||||||
version="1.2.0",
|
|
||||||
author="Adiuva",
|
|
||||||
permissions=["read:tasks", "read:checkpoints"],
|
|
||||||
category="communication",
|
|
||||||
price_cents=499,
|
|
||||||
),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
"rejection_reason": None,
|
|
||||||
"submitted_at": int(time.time()),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"manifest": PluginManifest(
|
|
||||||
id="plugin-time-tracker",
|
|
||||||
name="Time Tracker",
|
|
||||||
description="Track time spent on tasks with automatic reporting.",
|
|
||||||
version="0.9.1",
|
|
||||||
author="Third Party",
|
|
||||||
permissions=["read:tasks", "write:tasks"],
|
|
||||||
category="productivity",
|
|
||||||
price_cents=999,
|
|
||||||
),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
"rejection_reason": None,
|
|
||||||
"submitted_at": int(time.time()),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
_PAGE_SIZE = 20
|
|
||||||
|
|
||||||
|
|
||||||
class PluginRegistry:
|
|
||||||
"""In-process plugin catalog.
|
|
||||||
|
|
||||||
All mutating methods are ``async`` to make the future DB swap transparent
|
|
||||||
to callers.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
# plugin_id → entry dict (deep-copied so each instance is independent)
|
|
||||||
self._catalog: dict[str, dict[str, Any]] = {
|
|
||||||
e["manifest"].id: copy.deepcopy(e) for e in _SEED_PLUGINS
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Queries ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def list_plugins(
|
|
||||||
self,
|
|
||||||
category: str | None = None,
|
|
||||||
query: str | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
sort: Literal["rating", "installs", "newest"] = "newest",
|
|
||||||
) -> PluginListResponse:
|
|
||||||
"""Return a page of approved plugins, optionally filtered and sorted."""
|
|
||||||
entries = [e for e in self._catalog.values() if e["status"] == "approved"]
|
|
||||||
|
|
||||||
if category:
|
|
||||||
entries = [e for e in entries if e["manifest"].category == category]
|
|
||||||
|
|
||||||
if query:
|
|
||||||
q_lower = query.lower()
|
|
||||||
entries = [
|
|
||||||
e
|
|
||||||
for e in entries
|
|
||||||
if q_lower in e["manifest"].name.lower()
|
|
||||||
or q_lower in e["manifest"].description.lower()
|
|
||||||
]
|
|
||||||
|
|
||||||
if sort == "installs":
|
|
||||||
entries = sorted(entries, key=lambda e: e["install_count"], reverse=True)
|
|
||||||
elif sort == "rating":
|
|
||||||
entries = sorted(entries, key=lambda e: e["avg_rating"], reverse=True)
|
|
||||||
# "newest" = catalog insertion order (dict preserves insertion in Python 3.7+)
|
|
||||||
|
|
||||||
total = len(entries)
|
|
||||||
start = (page - 1) * _PAGE_SIZE
|
|
||||||
page_entries = entries[start : start + _PAGE_SIZE]
|
|
||||||
|
|
||||||
return PluginListResponse(
|
|
||||||
plugins=[e["manifest"] for e in page_entries],
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_plugin(self, plugin_id: str) -> dict[str, Any] | None:
|
|
||||||
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
|
||||||
entry = self._catalog.get(plugin_id)
|
|
||||||
if entry is None:
|
|
||||||
return None
|
|
||||||
return {
|
|
||||||
"manifest": entry["manifest"],
|
|
||||||
"status": entry["status"],
|
|
||||||
"install_count": entry["install_count"],
|
|
||||||
"avg_rating": entry["avg_rating"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Mutations ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def submit_plugin(
|
|
||||||
self,
|
|
||||||
manifest: PluginManifest,
|
|
||||||
package_s3_key: str,
|
|
||||||
) -> str:
|
|
||||||
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
|
||||||
|
|
||||||
Returns the plugin_id. If a plugin with the same id already exists
|
|
||||||
it is overwritten (re-submission after rejection).
|
|
||||||
"""
|
|
||||||
plugin_id = manifest.id or str(uuid.uuid4())
|
|
||||||
self._catalog[plugin_id] = {
|
|
||||||
"manifest": manifest,
|
|
||||||
"status": "pending_review",
|
|
||||||
"s3_package_key": package_s3_key,
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
"rejection_reason": None,
|
|
||||||
"submitted_at": int(time.time()),
|
|
||||||
}
|
|
||||||
return plugin_id
|
|
||||||
|
|
||||||
async def approve_plugin(self, plugin_id: str) -> None:
|
|
||||||
"""Set *plugin_id* status to ``'approved'``.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the plugin is not found.
|
|
||||||
"""
|
|
||||||
if plugin_id not in self._catalog:
|
|
||||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
|
||||||
self._catalog[plugin_id]["status"] = "approved"
|
|
||||||
self._catalog[plugin_id]["rejection_reason"] = None
|
|
||||||
|
|
||||||
async def reject_plugin(self, plugin_id: str, reason: str) -> None:
|
|
||||||
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the plugin is not found.
|
|
||||||
"""
|
|
||||||
if plugin_id not in self._catalog:
|
|
||||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
|
||||||
self._catalog[plugin_id]["status"] = "rejected"
|
|
||||||
self._catalog[plugin_id]["rejection_reason"] = reason
|
|
||||||
|
|
||||||
async def record_install(self, plugin_id: str) -> None:
|
|
||||||
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
|
||||||
if plugin_id in self._catalog:
|
|
||||||
self._catalog[plugin_id]["install_count"] += 1
|
|
||||||
|
|
||||||
async def record_uninstall(self, plugin_id: str) -> None:
|
|
||||||
"""Decrement the install count for *plugin_id*, floored at 0."""
|
|
||||||
if plugin_id in self._catalog:
|
|
||||||
current = self._catalog[plugin_id]["install_count"]
|
|
||||||
self._catalog[plugin_id]["install_count"] = max(0, current - 1)
|
|
||||||
|
|
||||||
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
|
||||||
|
|
||||||
def _get_pending_entries(self) -> list[dict[str, Any]]:
|
|
||||||
"""Return all entries with status='pending_review' (synchronous helper)."""
|
|
||||||
return [e for e in self._catalog.values() if e["status"] == "pending_review"]
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
registry = PluginRegistry()
|
|
||||||
@@ -1,127 +0,0 @@
|
|||||||
"""Plugin review workflow.
|
|
||||||
|
|
||||||
Manages the approval queue for newly submitted plugins and enforces a
|
|
||||||
security checklist before any plugin is made visible in the marketplace.
|
|
||||||
|
|
||||||
Module-level singleton::
|
|
||||||
|
|
||||||
from app.marketplace.plugin_review import review_queue
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
import time
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
from app.schemas import PluginManifest
|
|
||||||
|
|
||||||
# ── Security policy ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"read:tasks",
|
|
||||||
"write:tasks",
|
|
||||||
"read:projects",
|
|
||||||
"write:projects",
|
|
||||||
"read:notes",
|
|
||||||
"write:notes",
|
|
||||||
"read:checkpoints",
|
|
||||||
"write:checkpoints",
|
|
||||||
"read:calendar",
|
|
||||||
"write:calendar",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
|
||||||
|
|
||||||
|
|
||||||
def validate_manifest(manifest: PluginManifest) -> None:
|
|
||||||
"""Enforce the plugin security checklist.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
``ValueError`` on the first violation found. Callers should catch
|
|
||||||
this and return HTTP 422 / reject the submission.
|
|
||||||
|
|
||||||
Checks:
|
|
||||||
1. Plugin id matches ``^[a-z0-9-]+$``
|
|
||||||
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
|
||||||
3. No manifest field contains raw binary data
|
|
||||||
"""
|
|
||||||
if not _PLUGIN_ID_RE.match(manifest.id):
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid plugin id format: '{manifest.id}'. "
|
|
||||||
"Only lowercase letters, digits, and hyphens are allowed."
|
|
||||||
)
|
|
||||||
|
|
||||||
for perm in manifest.permissions:
|
|
||||||
if perm not in ALLOWED_PERMISSIONS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown permission: '{perm}'. "
|
|
||||||
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for field_name, value in manifest.model_dump().items():
|
|
||||||
if isinstance(value, (bytes, bytearray)):
|
|
||||||
raise ValueError(
|
|
||||||
f"Binary content is not allowed in manifest field '{field_name}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReviewQueue:
|
|
||||||
"""Approval queue for pending plugin submissions.
|
|
||||||
|
|
||||||
Delegates status changes to the shared ``PluginRegistry`` singleton so
|
|
||||||
there is a single source of truth for plugin state.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
# Completed reviews — Step 12 stores in plugin_reviews table
|
|
||||||
self._reviews: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
async def get_pending(self) -> list[dict[str, Any]]:
|
|
||||||
"""Return all plugins currently awaiting review.
|
|
||||||
|
|
||||||
Each item is ``{plugin_id, manifest, submitted_at}``.
|
|
||||||
"""
|
|
||||||
entries = registry._get_pending_entries()
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"plugin_id": e["manifest"].id,
|
|
||||||
"manifest": e["manifest"],
|
|
||||||
"submitted_at": e["submitted_at"],
|
|
||||||
}
|
|
||||||
for e in entries
|
|
||||||
]
|
|
||||||
|
|
||||||
async def submit_review(
|
|
||||||
self,
|
|
||||||
plugin_id: str,
|
|
||||||
reviewer_id: str,
|
|
||||||
decision: Literal["approved", "rejected"],
|
|
||||||
notes: str = "",
|
|
||||||
) -> None:
|
|
||||||
"""Record a review decision and update the plugin's status.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
``KeyError`` if *plugin_id* is not found in the registry.
|
|
||||||
"""
|
|
||||||
if decision == "approved":
|
|
||||||
await registry.approve_plugin(plugin_id)
|
|
||||||
else:
|
|
||||||
await registry.reject_plugin(plugin_id, reason=notes)
|
|
||||||
|
|
||||||
self._reviews.append(
|
|
||||||
{
|
|
||||||
"plugin_id": plugin_id,
|
|
||||||
"reviewer_id": reviewer_id,
|
|
||||||
"decision": decision,
|
|
||||||
"notes": notes,
|
|
||||||
"reviewed_at": int(time.time()),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
review_queue = ReviewQueue()
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
"""Revenue share tracking and Stripe Connect payouts.
|
|
||||||
|
|
||||||
Records every plugin installation as a revenue event and facilitates
|
|
||||||
70 % / 30 % payouts to developers via Stripe Connect. Storage is
|
|
||||||
in-memory until Step 12 migrates to the ``revenue_events`` table.
|
|
||||||
|
|
||||||
Module-level singleton::
|
|
||||||
|
|
||||||
from app.marketplace.revenue_share import revenue_share
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import stripe as stripe_lib
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ── Revenue split constants ───────────────────────────────────────────
|
|
||||||
|
|
||||||
DEVELOPER_SHARE: float = 0.70
|
|
||||||
PLATFORM_SHARE: float = 0.30
|
|
||||||
|
|
||||||
|
|
||||||
class RevenueShare:
|
|
||||||
"""Records installation revenue events and coordinates developer payouts.
|
|
||||||
|
|
||||||
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
|
||||||
is not configured, consistent with the rest of the billing layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
# Step 12 replaces with revenue_events DB table
|
|
||||||
self._events: list[dict[str, Any]] = []
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _stripe_configured() -> bool:
|
|
||||||
return bool(settings.STRIPE_SECRET_KEY)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _stripe() -> Any:
|
|
||||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
|
||||||
return stripe_lib
|
|
||||||
|
|
||||||
# ── Core operations ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def record_install(
|
|
||||||
self,
|
|
||||||
plugin_id: str,
|
|
||||||
user_id: str,
|
|
||||||
amount_cents: int,
|
|
||||||
) -> None:
|
|
||||||
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
|
||||||
|
|
||||||
For free plugins (``amount_cents == 0``) no payment is initiated but
|
|
||||||
the event is still recorded for analytics.
|
|
||||||
|
|
||||||
For paid plugins the developer receives 70 % via a Stripe Connect
|
|
||||||
destination charge. If Stripe is not configured or the charge fails
|
|
||||||
the installation still succeeds (the event is recorded and the install
|
|
||||||
count is incremented) — a warning is logged for monitoring.
|
|
||||||
"""
|
|
||||||
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
|
||||||
stripe_transfer_id: str | None = None
|
|
||||||
|
|
||||||
if amount_cents > 0 and self._stripe_configured():
|
|
||||||
plugin_entry = registry._catalog.get(plugin_id)
|
|
||||||
developer_stripe_account: str | None = None
|
|
||||||
if plugin_entry:
|
|
||||||
# Step 12: look up developer's Stripe account from DB
|
|
||||||
# For now, the author field is used as a placeholder key.
|
|
||||||
developer_stripe_account = None # no real account yet
|
|
||||||
|
|
||||||
if developer_stripe_account:
|
|
||||||
try:
|
|
||||||
s = self._stripe()
|
|
||||||
transfer = s.Transfer.create(
|
|
||||||
amount=developer_share_cents,
|
|
||||||
currency="eur",
|
|
||||||
destination=developer_stripe_account,
|
|
||||||
description=f"Revenue share for plugin {plugin_id}",
|
|
||||||
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
|
||||||
)
|
|
||||||
stripe_transfer_id = transfer["id"]
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"Stripe Connect transfer failed for plugin %s: %s",
|
|
||||||
plugin_id,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"No Stripe account on file for plugin %s developer; "
|
|
||||||
"skipping transfer.",
|
|
||||||
plugin_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
self._events.append(
|
|
||||||
{
|
|
||||||
"plugin_id": plugin_id,
|
|
||||||
"user_id": user_id,
|
|
||||||
"amount_cents": amount_cents,
|
|
||||||
"developer_share_cents": developer_share_cents,
|
|
||||||
"stripe_transfer_id": stripe_transfer_id,
|
|
||||||
"paid_at": None,
|
|
||||||
"created_at": int(time.time()),
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
await registry.record_install(plugin_id)
|
|
||||||
|
|
||||||
async def get_earnings(
|
|
||||||
self,
|
|
||||||
developer_id: str,
|
|
||||||
period: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Return aggregated earnings for *developer_id*.
|
|
||||||
|
|
||||||
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
|
||||||
|
|
||||||
Returns::
|
|
||||||
|
|
||||||
{
|
|
||||||
"developer_id": str,
|
|
||||||
"period": str | None,
|
|
||||||
"total_installs": int,
|
|
||||||
"total_revenue_cents": int,
|
|
||||||
"developer_share_cents": int,
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Find plugin ids belonging to this developer
|
|
||||||
developer_plugin_ids: set[str] = {
|
|
||||||
pid
|
|
||||||
for pid, entry in registry._catalog.items()
|
|
||||||
if entry["manifest"].author == developer_id
|
|
||||||
}
|
|
||||||
|
|
||||||
events = [e for e in self._events if e["plugin_id"] in developer_plugin_ids]
|
|
||||||
|
|
||||||
if period:
|
|
||||||
# Filter by YYYY-MM prefix of the created_at timestamp
|
|
||||||
events = [
|
|
||||||
e
|
|
||||||
for e in events
|
|
||||||
if time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
|
||||||
]
|
|
||||||
|
|
||||||
return {
|
|
||||||
"developer_id": developer_id,
|
|
||||||
"period": period,
|
|
||||||
"total_installs": len(events),
|
|
||||||
"total_revenue_cents": sum(e["amount_cents"] for e in events),
|
|
||||||
"developer_share_cents": sum(e["developer_share_cents"] for e in events),
|
|
||||||
}
|
|
||||||
|
|
||||||
async def payout_developer(self, plugin_id: str, period: str) -> None:
|
|
||||||
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
|
||||||
|
|
||||||
Marks processed events with ``paid_at`` timestamp.
|
|
||||||
Stubs gracefully when Stripe is not configured.
|
|
||||||
"""
|
|
||||||
unpaid = [
|
|
||||||
e
|
|
||||||
for e in self._events
|
|
||||||
if e["plugin_id"] == plugin_id
|
|
||||||
and e["paid_at"] is None
|
|
||||||
and time.strftime("%Y-%m", time.gmtime(e["created_at"])) == period
|
|
||||||
]
|
|
||||||
|
|
||||||
total_dev_share = sum(e["developer_share_cents"] for e in unpaid)
|
|
||||||
if total_dev_share <= 0 or not unpaid:
|
|
||||||
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._stripe_configured():
|
|
||||||
plugin_entry = registry._catalog.get(plugin_id)
|
|
||||||
developer_stripe_account: str | None = None # Step 12: fetch from DB
|
|
||||||
if plugin_entry and developer_stripe_account:
|
|
||||||
try:
|
|
||||||
s = self._stripe()
|
|
||||||
s.Transfer.create(
|
|
||||||
amount=total_dev_share,
|
|
||||||
currency="eur",
|
|
||||||
destination=developer_stripe_account,
|
|
||||||
description=f"Payout for plugin {plugin_id} period {period}",
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
paid_ts = int(time.time())
|
|
||||||
for event in unpaid:
|
|
||||||
event["paid_at"] = paid_ts
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
revenue_share = RevenueShare()
|
|
||||||
157
app/schemas.py
157
app/schemas.py
@@ -1,157 +0,0 @@
|
|||||||
"""Pydantic schemas — API request/response contracts.
|
|
||||||
|
|
||||||
Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
|
|
||||||
|
|
||||||
# ── Billing ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
BillingTier = Literal["free", "pro", "power", "team"]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Auth ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class AuthTokens(BaseModel):
|
|
||||||
access_token: str
|
|
||||||
refresh_token: str
|
|
||||||
expires_at: int
|
|
||||||
|
|
||||||
|
|
||||||
class UserProfile(BaseModel):
|
|
||||||
id: str
|
|
||||||
email: str
|
|
||||||
tier: BillingTier
|
|
||||||
|
|
||||||
|
|
||||||
# ── Chat ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class ChatContext(BaseModel):
|
|
||||||
user_profile: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
relevant_documents: list[str] = Field(default_factory=list)
|
|
||||||
recent_tasks: list[dict[str, Any]] = Field(default_factory=list)
|
|
||||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
class PlanAction(BaseModel):
|
|
||||||
type: Literal[
|
|
||||||
"create_record",
|
|
||||||
"update_record",
|
|
||||||
"delete_record",
|
|
||||||
"index_document",
|
|
||||||
"send_notification",
|
|
||||||
]
|
|
||||||
table: str | None = None
|
|
||||||
data: dict[str, Any] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ChatRequest(BaseModel):
|
|
||||||
message: str
|
|
||||||
context: ChatContext = Field(default_factory=ChatContext)
|
|
||||||
execution_mode: Literal["direct", "plan"] = "direct"
|
|
||||||
|
|
||||||
|
|
||||||
class ChatResponse(BaseModel):
|
|
||||||
response: str
|
|
||||||
actions: list[PlanAction] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Execution Plans ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class PlanStep(BaseModel):
|
|
||||||
action: str
|
|
||||||
prompt_template: str | None = None
|
|
||||||
variables: dict[str, Any] | None = None
|
|
||||||
data_from_step: int | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class ExecutionPlan(BaseModel):
|
|
||||||
agent: str
|
|
||||||
steps: list[PlanStep] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Backup ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class BackupMetadata(BaseModel):
|
|
||||||
version: int
|
|
||||||
timestamp: int
|
|
||||||
checksum: str
|
|
||||||
chunk_count: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
|
||||||
|
|
||||||
class StorageRecord(BaseModel):
|
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
table: str
|
|
||||||
blob: bytes
|
|
||||||
checksum: str
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
class StorageRecordCreate(BaseModel):
|
|
||||||
table: str
|
|
||||||
blob: bytes
|
|
||||||
checksum: str
|
|
||||||
|
|
||||||
|
|
||||||
class StorageRecordUpdate(BaseModel):
|
|
||||||
blob: bytes
|
|
||||||
checksum: str
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
|
||||||
|
|
||||||
class VectorItem(BaseModel):
|
|
||||||
id: str
|
|
||||||
blob: bytes # encrypted vector + metadata — backend never decrypts
|
|
||||||
checksum: str
|
|
||||||
|
|
||||||
|
|
||||||
class VectorUpsertRequest(BaseModel):
|
|
||||||
vectors: list[VectorItem]
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchRequest(BaseModel):
|
|
||||||
query_blob: bytes # encrypted query — backend never decrypts
|
|
||||||
top_k: int = 10
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchResult(BaseModel):
|
|
||||||
id: str
|
|
||||||
score: float
|
|
||||||
blob: bytes
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchResponse(BaseModel):
|
|
||||||
results: list[VectorSearchResult]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plugin Marketplace ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class PluginManifest(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
version: str
|
|
||||||
author: str
|
|
||||||
permissions: list[str]
|
|
||||||
category: str
|
|
||||||
price_cents: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class PluginListResponse(BaseModel):
|
|
||||||
plugins: list[PluginManifest]
|
|
||||||
total: int
|
|
||||||
page: int
|
|
||||||
|
|
||||||
|
|
||||||
class PluginInstallRequest(BaseModel):
|
|
||||||
plugin_id: str
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
|
||||||
@@ -1,105 +0,0 @@
|
|||||||
"""S3-backed store for E2E-encrypted blobs.
|
|
||||||
|
|
||||||
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
|
||||||
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import boto3
|
|
||||||
from botocore.exceptions import ClientError
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class BlobStore:
|
|
||||||
"""Thin wrapper around boto3 S3.
|
|
||||||
|
|
||||||
All blobs must be E2E encrypted by the client before upload.
|
|
||||||
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
|
||||||
but cannot decrypt the inner client-side payload.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _client(self) -> Any:
|
|
||||||
return boto3.client(
|
|
||||||
"s3",
|
|
||||||
region_name=settings.S3_REGION,
|
|
||||||
aws_access_key_id=settings.AWS_ACCESS_KEY_ID,
|
|
||||||
aws_secret_access_key=settings.AWS_SECRET_ACCESS_KEY,
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _key(user_id: str, table: str, record_id: str) -> str:
|
|
||||||
return f"{user_id}/{table}/{record_id}"
|
|
||||||
|
|
||||||
async def upload(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
table: str,
|
|
||||||
record_id: str,
|
|
||||||
blob: bytes,
|
|
||||||
checksum: str,
|
|
||||||
) -> str:
|
|
||||||
"""Store *blob* in S3 and return the S3 key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Owner of the blob (used as key prefix).
|
|
||||||
table: Logical table name (e.g. ``"tasks"``).
|
|
||||||
record_id: Record UUID.
|
|
||||||
blob: Raw bytes (pre-encrypted by client).
|
|
||||||
checksum: SHA-256 hex digest supplied by the client; stored as
|
|
||||||
object metadata for download-time verification.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The S3 key under which the blob was stored.
|
|
||||||
"""
|
|
||||||
key = self._key(user_id, table, record_id)
|
|
||||||
self._client().put_object(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Key=key,
|
|
||||||
Body=blob,
|
|
||||||
ServerSideEncryption="AES256", # SSE-S3 at rest
|
|
||||||
Metadata={"checksum": checksum},
|
|
||||||
)
|
|
||||||
return key
|
|
||||||
|
|
||||||
async def download(self, user_id: str, s3_key: str) -> bytes:
|
|
||||||
"""Retrieve the blob stored at *s3_key*.
|
|
||||||
|
|
||||||
*user_id* is retained in the signature so higher-level code can
|
|
||||||
enforce ownership without re-parsing the key.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
|
||||||
object does not exist.
|
|
||||||
"""
|
|
||||||
response = self._client().get_object(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Key=s3_key,
|
|
||||||
)
|
|
||||||
return response["Body"].read()
|
|
||||||
|
|
||||||
async def delete(self, user_id: str, s3_key: str) -> None:
|
|
||||||
"""Delete the object at *s3_key*.
|
|
||||||
|
|
||||||
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
|
||||||
not exist.
|
|
||||||
"""
|
|
||||||
self._client().delete_object(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Key=s3_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
|
||||||
"""Return all S3 keys for a given user + table combination.
|
|
||||||
|
|
||||||
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
|
||||||
"""
|
|
||||||
prefix = f"{user_id}/{table}/"
|
|
||||||
response = self._client().list_objects_v2(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Prefix=prefix,
|
|
||||||
)
|
|
||||||
return [obj["Key"] for obj in response.get("Contents", [])]
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
"""Integrity verification only — the backend NEVER decrypts user data."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
|
|
||||||
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
|
||||||
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
|
||||||
|
|
||||||
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
|
||||||
timing-based side-channel attacks.
|
|
||||||
"""
|
|
||||||
computed = hashlib.sha256(blob).hexdigest()
|
|
||||||
return hmac.compare_digest(computed, checksum)
|
|
||||||
|
|
||||||
|
|
||||||
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
|
||||||
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
|
||||||
|
|
||||||
Call this before storing or forwarding any client-provided blob.
|
|
||||||
The backend never holds decryption keys — this check only verifies
|
|
||||||
that the opaque bytes arrived intact.
|
|
||||||
"""
|
|
||||||
if not verify_checksum(blob, checksum):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Checksum mismatch: blob integrity check failed",
|
|
||||||
)
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
|
||||||
|
|
||||||
Vectors are pre-encrypted blobs from the client. The backend stores them
|
|
||||||
alongside a deterministic 32-dim float representation derived from the blob's
|
|
||||||
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
|
||||||
is a known trade-off documented in the backend plan.
|
|
||||||
|
|
||||||
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
|
||||||
``user_id`` payload field on a shared collection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pinecone import Pinecone
|
|
||||||
from qdrant_client import QdrantClient
|
|
||||||
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.schemas import VectorItem, VectorSearchResult
|
|
||||||
|
|
||||||
_QDRANT_COLLECTION = "adiuva_vectors"
|
|
||||||
|
|
||||||
|
|
||||||
def _blob_to_vector(blob: bytes) -> list[float]:
|
|
||||||
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
|
||||||
|
|
||||||
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
|
||||||
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
|
||||||
semantic meaning on encrypted data.
|
|
||||||
"""
|
|
||||||
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
|
||||||
|
|
||||||
|
|
||||||
class VectorStore:
|
|
||||||
"""Thin wrapper around Pinecone or Qdrant.
|
|
||||||
|
|
||||||
The backend to use is selected at runtime:
|
|
||||||
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
|
||||||
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _use_pinecone(self) -> bool:
|
|
||||||
return bool(settings.PINECONE_API_KEY)
|
|
||||||
|
|
||||||
# ── Pinecone helpers ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _pinecone_index(self) -> Any:
|
|
||||||
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
|
||||||
return pc.Index(settings.PINECONE_INDEX)
|
|
||||||
|
|
||||||
# ── Qdrant helpers ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _qdrant_client(self) -> Any:
|
|
||||||
return QdrantClient(
|
|
||||||
url=settings.QDRANT_URL,
|
|
||||||
api_key=settings.QDRANT_API_KEY or None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
|
||||||
"""Store encrypted vectors in the backend.
|
|
||||||
|
|
||||||
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
|
||||||
so it can be returned verbatim during search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Used as Pinecone namespace or Qdrant payload field.
|
|
||||||
vectors: List of encrypted vector items from the client.
|
|
||||||
"""
|
|
||||||
if self._use_pinecone():
|
|
||||||
await self._pinecone_upsert(user_id, vectors)
|
|
||||||
else:
|
|
||||||
await self._qdrant_upsert(user_id, vectors)
|
|
||||||
|
|
||||||
async def search(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
query_blob: bytes,
|
|
||||||
top_k: int,
|
|
||||||
) -> list[VectorSearchResult]:
|
|
||||||
"""Query the vector store and return encrypted result blobs.
|
|
||||||
|
|
||||||
The query vector is derived from *query_blob* using the same
|
|
||||||
deterministic mapping as upsert.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Scopes the search to this user's namespace.
|
|
||||||
query_blob: Encrypted query from the client.
|
|
||||||
top_k: Maximum number of results to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
|
||||||
"""
|
|
||||||
if self._use_pinecone():
|
|
||||||
return await self._pinecone_search(user_id, query_blob, top_k)
|
|
||||||
return await self._qdrant_search(user_id, query_blob, top_k)
|
|
||||||
|
|
||||||
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
|
||||||
"""Remove vectors by ID, scoped to *user_id*.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Namespace / payload filter to prevent cross-user deletion.
|
|
||||||
vector_ids: List of vector IDs to remove.
|
|
||||||
"""
|
|
||||||
if self._use_pinecone():
|
|
||||||
await self._pinecone_delete(user_id, vector_ids)
|
|
||||||
else:
|
|
||||||
await self._qdrant_delete(user_id, vector_ids)
|
|
||||||
|
|
||||||
# ── Pinecone implementation ───────────────────────────────────────
|
|
||||||
|
|
||||||
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
|
||||||
index = self._pinecone_index()
|
|
||||||
records = [
|
|
||||||
{
|
|
||||||
"id": v.id,
|
|
||||||
"values": _blob_to_vector(v.blob),
|
|
||||||
"metadata": {
|
|
||||||
"blob": base64.b64encode(v.blob).decode(),
|
|
||||||
"checksum": v.checksum,
|
|
||||||
"user_id": user_id,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for v in vectors
|
|
||||||
]
|
|
||||||
index.upsert(vectors=records, namespace=user_id)
|
|
||||||
|
|
||||||
async def _pinecone_search(
|
|
||||||
self, user_id: str, query_blob: bytes, top_k: int
|
|
||||||
) -> list[VectorSearchResult]:
|
|
||||||
index = self._pinecone_index()
|
|
||||||
query_vector = _blob_to_vector(query_blob)
|
|
||||||
response = index.query(
|
|
||||||
vector=query_vector,
|
|
||||||
top_k=top_k,
|
|
||||||
namespace=user_id,
|
|
||||||
include_metadata=True,
|
|
||||||
)
|
|
||||||
results: list[VectorSearchResult] = []
|
|
||||||
for match in response.get("matches", []):
|
|
||||||
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
|
||||||
results.append(
|
|
||||||
VectorSearchResult(
|
|
||||||
id=match["id"],
|
|
||||||
score=match["score"],
|
|
||||||
blob=blob_bytes,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
|
||||||
index = self._pinecone_index()
|
|
||||||
index.delete(ids=vector_ids, namespace=user_id)
|
|
||||||
|
|
||||||
# ── Qdrant implementation ─────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
|
||||||
client = self._qdrant_client()
|
|
||||||
points = [
|
|
||||||
PointStruct(
|
|
||||||
id=v.id,
|
|
||||||
vector=_blob_to_vector(v.blob),
|
|
||||||
payload={
|
|
||||||
"blob": base64.b64encode(v.blob).decode(),
|
|
||||||
"checksum": v.checksum,
|
|
||||||
"user_id": user_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for v in vectors
|
|
||||||
]
|
|
||||||
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
|
||||||
|
|
||||||
async def _qdrant_search(
|
|
||||||
self, user_id: str, query_blob: bytes, top_k: int
|
|
||||||
) -> list[VectorSearchResult]:
|
|
||||||
client = self._qdrant_client()
|
|
||||||
query_vector = _blob_to_vector(query_blob)
|
|
||||||
hits = client.search(
|
|
||||||
collection_name=_QDRANT_COLLECTION,
|
|
||||||
query_vector=query_vector,
|
|
||||||
query_filter=Filter(
|
|
||||||
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
|
||||||
),
|
|
||||||
limit=top_k,
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
VectorSearchResult(
|
|
||||||
id=str(hit.id),
|
|
||||||
score=hit.score,
|
|
||||||
blob=base64.b64decode(hit.payload["blob"]),
|
|
||||||
)
|
|
||||||
for hit in hits
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
|
||||||
client = self._qdrant_client()
|
|
||||||
client.delete(
|
|
||||||
collection_name=_QDRANT_COLLECTION,
|
|
||||||
points_selector=PointIdsList(points=vector_ids),
|
|
||||||
)
|
|
||||||
@@ -1,25 +1,34 @@
|
|||||||
version: "3.9"
|
# ── Adiuva Microservices ─────────────────────────────────────────────
|
||||||
|
# docker compose up --build
|
||||||
|
# docker compose up --build auth ws-gateway chat # subset
|
||||||
|
|
||||||
services:
|
services:
|
||||||
app:
|
|
||||||
build: .
|
# ═══════════════════════════════════════════════════════════════════
|
||||||
|
# Infrastructure
|
||||||
|
# ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
traefik:
|
||||||
|
image: traefik:v3.1
|
||||||
ports:
|
ports:
|
||||||
- "8000:8000"
|
- "80:80"
|
||||||
env_file:
|
- "443:443"
|
||||||
- .env
|
- "8080:8080" # dashboard (dev only)
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
CF_DNS_API_TOKEN: ${CF_DNS_API_TOKEN:-}
|
||||||
depends_on:
|
volumes:
|
||||||
db:
|
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||||
condition: service_healthy
|
- ./traefik/traefik.yml:/etc/traefik/traefik.yml:ro
|
||||||
|
- ./traefik/dynamic:/etc/traefik/dynamic:ro
|
||||||
|
- traefik_acme:/etc/traefik/acme
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
db:
|
db:
|
||||||
image: postgres:16-alpine
|
image: pgvector/pgvector:pg16
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
||||||
POSTGRES_PASSWORD: postgres
|
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-postgres}
|
||||||
POSTGRES_DB: adiuva
|
POSTGRES_DB: ${POSTGRES_DB:-adiuva}
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
@@ -29,10 +38,161 @@ services:
|
|||||||
retries: 5
|
retries: 5
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# Optional Redis for future rate-limit or caching needs
|
redis:
|
||||||
# redis:
|
image: redis:7-alpine
|
||||||
# image: redis:7-alpine
|
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Optional infrastructure (uncomment as needed) ────────────────
|
||||||
|
|
||||||
|
# minio:
|
||||||
|
# image: minio/minio:latest
|
||||||
|
# command: server /data --console-address ":9001"
|
||||||
|
# ports:
|
||||||
|
# - "9000:9000"
|
||||||
|
# - "9001:9001"
|
||||||
|
# environment:
|
||||||
|
# MINIO_ROOT_USER: minioadmin
|
||||||
|
# MINIO_ROOT_PASSWORD: minioadmin
|
||||||
|
# volumes:
|
||||||
|
# - minio_data:/data
|
||||||
|
# healthcheck:
|
||||||
|
# test: ["CMD", "mc", "ready", "local"]
|
||||||
|
# interval: 5s
|
||||||
|
# timeout: 5s
|
||||||
|
# retries: 5
|
||||||
# restart: unless-stopped
|
# restart: unless-stopped
|
||||||
|
|
||||||
|
# qdrant:
|
||||||
|
# image: qdrant/qdrant:latest
|
||||||
|
# ports:
|
||||||
|
# - "6333:6333"
|
||||||
|
# - "6334:6334"
|
||||||
|
# volumes:
|
||||||
|
# - qdrant_data:/qdrant/storage
|
||||||
|
# restart: unless-stopped
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════
|
||||||
|
# Migrations (run once, then exit)
|
||||||
|
# ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
migrate:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
command: ["python", "-m", "alembic", "upgrade", "head"]
|
||||||
|
env_file:
|
||||||
|
- path: .env
|
||||||
|
required: false
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
restart: "no"
|
||||||
|
|
||||||
|
# ═══════════════════════════════════════════════════════════════════
|
||||||
|
# Application Services
|
||||||
|
# ═══════════════════════════════════════════════════════════════════
|
||||||
|
|
||||||
|
auth:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: services/auth/Dockerfile
|
||||||
|
env_file:
|
||||||
|
- path: .env
|
||||||
|
required: false
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||||
|
REDIS_URL: redis://redis:6379/0
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
migrate:
|
||||||
|
condition: service_completed_successfully
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
ws-gateway:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: services/ws-gateway/Dockerfile
|
||||||
|
env_file:
|
||||||
|
- path: .env
|
||||||
|
required: false
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||||
|
REDIS_URL: redis://redis:6379/0
|
||||||
|
depends_on:
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
auth:
|
||||||
|
condition: service_started
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
chat:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: services/chat/Dockerfile
|
||||||
|
env_file:
|
||||||
|
- path: .env
|
||||||
|
required: false
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||||
|
REDIS_URL: redis://redis:6379/0
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
migrate:
|
||||||
|
condition: service_completed_successfully
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
batch-agent:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: services/batch-agent/Dockerfile
|
||||||
|
env_file:
|
||||||
|
- path: .env
|
||||||
|
required: false
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||||
|
REDIS_URL: redis://redis:6379/0
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
migrate:
|
||||||
|
condition: service_completed_successfully
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
billing:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: services/billing/Dockerfile
|
||||||
|
env_file:
|
||||||
|
- path: .env
|
||||||
|
required: false
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
migrate:
|
||||||
|
condition: service_completed_successfully
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
redis_data:
|
||||||
|
traefik_acme:
|
||||||
|
# minio_data:
|
||||||
|
# qdrant_data:
|
||||||
|
|||||||
941
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
941
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
@@ -0,0 +1,941 @@
|
|||||||
|
# Adiuva — Architettura Microservizi (MVP)
|
||||||
|
|
||||||
|
## Panoramica
|
||||||
|
|
||||||
|
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
|
||||||
|
|
||||||
|
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐
|
||||||
|
│ Cloudflare │
|
||||||
|
│ (DNS + CDN) │
|
||||||
|
└──────┬───────┘
|
||||||
|
│ HTTPS / WSS
|
||||||
|
┌──────▼───────┐
|
||||||
|
│ Traefik │
|
||||||
|
│ API Gateway │
|
||||||
|
│ (routing, │
|
||||||
|
│ TLS, rate │
|
||||||
|
│ limiting) │
|
||||||
|
└──────┬───────┘
|
||||||
|
│
|
||||||
|
┌──────────┬───────────┼───────────┐
|
||||||
|
│ │ │ │
|
||||||
|
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
|
||||||
|
│ Auth │ │ Chat │ │ Agent │ │Billing │
|
||||||
|
│ Service │ │Service│ │ Service │ │Service │
|
||||||
|
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
|
||||||
|
│ │ │ │
|
||||||
|
┌─────▼──────────▼──────────▼───────────▼────┐
|
||||||
|
│ Infrastruttura │
|
||||||
|
│ PostgreSQL │ Redis │ Qdrant │
|
||||||
|
└─────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 1. Suddivisione dei Servizi
|
||||||
|
|
||||||
|
### 1.1 Auth Service (`auth-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/auth/register` | POST |
|
||||||
|
| `/api/v1/auth/login` | POST |
|
||||||
|
| `/api/v1/auth/refresh` | POST |
|
||||||
|
| `/api/v1/auth/me` | GET / PUT |
|
||||||
|
|
||||||
|
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
|
||||||
|
|
||||||
|
**Modifica chiave — JWT con RS256**:
|
||||||
|
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
|
||||||
|
- L'Auth Service firma i JWT con la **chiave privata**.
|
||||||
|
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
|
||||||
|
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
|
||||||
|
|
||||||
|
```python
|
||||||
|
# auth-service/app/auth/jwt.py
|
||||||
|
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PRIVATE_KEY = ... # Da env/secret
|
||||||
|
PUBLIC_KEY = ... # Derivata o da env
|
||||||
|
|
||||||
|
def create_access_token(user_id: str, tier: str) -> str:
|
||||||
|
return jwt.encode(
|
||||||
|
{"sub": user_id, "tier": tier, "exp": ...},
|
||||||
|
PRIVATE_KEY,
|
||||||
|
algorithm="RS256",
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/auth.py (usato da tutti gli altri servizi)
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
|
||||||
|
|
||||||
|
def verify_token(token: str) -> dict:
|
||||||
|
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
||||||
|
```
|
||||||
|
|
||||||
|
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
|
||||||
|
|
||||||
|
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
|
||||||
|
|
||||||
|
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
|
||||||
|
|
||||||
|
| Endpoint | Tipo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
|
||||||
|
| `/api/v1/chat` | POST (REST fallback) |
|
||||||
|
|
||||||
|
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
|
||||||
|
|
||||||
|
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
|
||||||
|
|
||||||
|
**Scaling**: 2–N repliche. Sticky cookies per le WS + Redis per cross-instance.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.3 Agent Service (`agent-service`) ⭐ Batch
|
||||||
|
|
||||||
|
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
|
||||||
|
|
||||||
|
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
|
||||||
|
|
||||||
|
| Endpoint | Tipo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/agents/catalog` | GET |
|
||||||
|
| `/api/v1/agents/can-create` | POST |
|
||||||
|
| `/api/v1/agents/trigger` | POST |
|
||||||
|
| `/api/v1/agents/journey/start` | POST (o WS relay) |
|
||||||
|
| `/api/v1/agents/journey/message` | POST (o WS relay) |
|
||||||
|
|
||||||
|
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
|
||||||
|
|
||||||
|
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────┐ ┌──────────────┐ ┌──────────┐
|
||||||
|
│ Agent Service│ │ Redis │ │ Chat │
|
||||||
|
│ (batch run) │ │ │ │ Service │
|
||||||
|
│ │ │ │ │ (ha WS) │
|
||||||
|
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
|
||||||
|
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
|
||||||
|
│ from │ │ │ │ al │
|
||||||
|
│ device │ │ │ │ device│
|
||||||
|
│ │ │ │ │ via WS│
|
||||||
|
│ │ SUBSCRIBE │ │ PUBLISH │ │
|
||||||
|
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
|
||||||
|
│ risultato │ │ │ │ reply │
|
||||||
|
└──────────────┘ └──────────────┘ └──────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
**Scaling**: 1–N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
|
||||||
|
|
||||||
|
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.4 Billing Service (`billing-service`)
|
||||||
|
|
||||||
|
**Responsabilità**: Stripe checkout, webhook, subscription management.
|
||||||
|
|
||||||
|
| Endpoint originale | Metodo |
|
||||||
|
|---|---|
|
||||||
|
| `/api/v1/billing/checkout` | POST |
|
||||||
|
| `/api/v1/billing/webhook` | POST |
|
||||||
|
| `/api/v1/billing/subscription` | GET / DELETE |
|
||||||
|
|
||||||
|
**Database**: Tabelle `subscriptions` (schema `billing`).
|
||||||
|
|
||||||
|
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
|
||||||
|
|
||||||
|
**Scaling**: 1 replica sufficiente. Basso traffico.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
### 1.5 Servizi esclusi dall'MVP
|
||||||
|
|
||||||
|
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
|
||||||
|
|
||||||
|
| Servizio | Responsabilità | Note |
|
||||||
|
|---|---|---|
|
||||||
|
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
|
||||||
|
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 2. Tier Check — Dove e Come
|
||||||
|
|
||||||
|
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
|
||||||
|
|
||||||
|
### Strategia: Tier nel JWT
|
||||||
|
|
||||||
|
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
|
||||||
|
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"sub": "user_123",
|
||||||
|
"tier": "pro",
|
||||||
|
"exp": 1742515200,
|
||||||
|
"iat": 1742511600
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Ogni servizio:
|
||||||
|
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
|
||||||
|
2. Legge `payload["tier"]` — **zero chiamate extra**
|
||||||
|
3. Applica le sue regole di enforcement localmente
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/auth.py — dependency FastAPI condivisa
|
||||||
|
from fastapi import Depends, HTTPException, Request
|
||||||
|
from jose import jwt
|
||||||
|
|
||||||
|
PUBLIC_KEY = ...
|
||||||
|
|
||||||
|
class CurrentUser:
|
||||||
|
def __init__(self, user_id: str, tier: str):
|
||||||
|
self.user_id = user_id
|
||||||
|
self.tier = tier
|
||||||
|
|
||||||
|
async def get_current_user(request: Request) -> CurrentUser:
|
||||||
|
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
|
||||||
|
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
||||||
|
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
|
||||||
|
|
||||||
|
def require_tier(*allowed_tiers: str):
|
||||||
|
"""Dependency che blocca se il tier non è tra quelli ammessi."""
|
||||||
|
async def check(user: CurrentUser = Depends(get_current_user)):
|
||||||
|
if user.tier not in allowed_tiers:
|
||||||
|
raise HTTPException(403, "Tier insufficient")
|
||||||
|
return user
|
||||||
|
return check
|
||||||
|
```
|
||||||
|
|
||||||
|
### Cosa succede quando il tier cambia (upgrade/downgrade)?
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
|
||||||
|
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
|
||||||
|
│ │ │ Service │ (Redis pub/sub) │ Service │
|
||||||
|
└──────────┘ └──────────┘ └────┬─────┘
|
||||||
|
│
|
||||||
|
UPDATE users
|
||||||
|
SET tier = 'power'
|
||||||
|
│
|
||||||
|
Al prossimo /refresh
|
||||||
|
il JWT conterrà tier='power'
|
||||||
|
```
|
||||||
|
|
||||||
|
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 15–30 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
|
||||||
|
|
||||||
|
### Dove si applica in ciascun servizio
|
||||||
|
|
||||||
|
| Servizio | Enforcement |
|
||||||
|
|---|---|
|
||||||
|
| **Auth Service** | Nessuno (è lui che scrive il tier) |
|
||||||
|
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
|
||||||
|
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
|
||||||
|
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
|
||||||
|
|
||||||
|
### Rate-limit distribuito via Redis
|
||||||
|
|
||||||
|
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# shared/middleware/rate_limit.py
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
|
||||||
|
class DistributedRateLimiter:
|
||||||
|
def __init__(self, redis: aioredis.Redis):
|
||||||
|
self._redis = redis
|
||||||
|
|
||||||
|
async def check(self, user_id: str, tier: str, service: str) -> bool:
|
||||||
|
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
|
||||||
|
max_req = limits.get(tier, 20)
|
||||||
|
key = f"rate:{service}:{user_id}"
|
||||||
|
|
||||||
|
pipe = self._redis.pipeline()
|
||||||
|
pipe.incr(key)
|
||||||
|
pipe.expire(key, 60)
|
||||||
|
count, _ = await pipe.execute()
|
||||||
|
|
||||||
|
return count <= max_req
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
|
||||||
|
|
||||||
|
`DeviceConnectionManager` è un **singleton in-memory**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class DeviceConnectionManager:
|
||||||
|
def __init__(self):
|
||||||
|
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
|
||||||
|
```
|
||||||
|
|
||||||
|
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
|
||||||
|
|
||||||
|
### La soluzione: Redis Pub/Sub + Registry
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────────────────────────────────────────────────────────┐
|
||||||
|
│ Redis │
|
||||||
|
│ │
|
||||||
|
│ Hash: ws:connections │
|
||||||
|
│ user_123 → instance_A │
|
||||||
|
│ user_456 → instance_B │
|
||||||
|
│ │
|
||||||
|
│ Pub/Sub channels: │
|
||||||
|
│ tool_call:{user_id} → tool call payloads │
|
||||||
|
│ tool_result:{call_id} → tool result payloads │
|
||||||
|
│ stream:{user_id} → text_chunk streaming │
|
||||||
|
└──────────────────────────────────────────────────────────────┘
|
||||||
|
|
||||||
|
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
|
||||||
|
┌───────────────────────┐ ┌───────────────────────┐
|
||||||
|
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
|
||||||
|
│ tool_call:user_123│ │ → user_123 è su A │
|
||||||
|
│ │ │ │
|
||||||
|
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
|
||||||
|
│ da Redis channel │ │ tool_call:user_123 │
|
||||||
|
│ │ │ {id, action, ...} │
|
||||||
|
│ 3. Invia al device │ │ │
|
||||||
|
│ via WS │ │ 4. SUBSCRIBE │
|
||||||
|
│ │ │ tool_result:{id} │
|
||||||
|
│ 4. Device risponde │ │ │
|
||||||
|
│ tool_result │──────────│► 5. Riceve risultato │
|
||||||
|
│ │ │ │
|
||||||
|
│ 5. PUBLISH │ │ │
|
||||||
|
│ tool_result:{id} │ │ │
|
||||||
|
└───────────────────────┘ └───────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### Implementazione: `RedisDeviceManager`
|
||||||
|
|
||||||
|
```python
|
||||||
|
# chat-service/app/core/device_manager.py
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import redis.asyncio as aioredis
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LocalConnection:
|
||||||
|
ws: WebSocket
|
||||||
|
device_id: str
|
||||||
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class RedisDeviceManager:
|
||||||
|
"""Device manager backed by Redis for cross-instance communication."""
|
||||||
|
|
||||||
|
def __init__(self, redis_url: str = "redis://redis:6379"):
|
||||||
|
self._redis = aioredis.from_url(redis_url)
|
||||||
|
self._pubsub = self._redis.pubsub()
|
||||||
|
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
|
||||||
|
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
|
||||||
|
|
||||||
|
async def start(self):
|
||||||
|
"""Avvia il listener Redis per tool_call in arrivo."""
|
||||||
|
asyncio.create_task(self._listen_tool_calls())
|
||||||
|
|
||||||
|
# ── Registrazione ──
|
||||||
|
|
||||||
|
async def register(self, user_id: str, device_id: str, ws: WebSocket):
|
||||||
|
# Registra localmente
|
||||||
|
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
|
||||||
|
# Registra in Redis quale istanza ha la connessione
|
||||||
|
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
|
||||||
|
# Sottoscrivi ai tool_call per questo utente
|
||||||
|
await self._pubsub.subscribe(f"tool_call:{user_id}")
|
||||||
|
|
||||||
|
async def unregister(self, user_id: str):
|
||||||
|
conn = self._local.pop(user_id, None)
|
||||||
|
if conn:
|
||||||
|
for fut in conn.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
await self._redis.hdel("ws:connections", user_id)
|
||||||
|
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
|
||||||
|
|
||||||
|
# ── Presenza ──
|
||||||
|
|
||||||
|
async def is_online(self, user_id: str) -> bool:
|
||||||
|
return await self._redis.hexists("ws:connections", user_id)
|
||||||
|
|
||||||
|
# ── Tool-call round-trip (cross-instance) ──
|
||||||
|
|
||||||
|
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
|
||||||
|
"""
|
||||||
|
Invia un tool_call al device dell'utente.
|
||||||
|
Funziona sia che la WS sia locale che su un'altra istanza.
|
||||||
|
"""
|
||||||
|
call_id = payload["id"]
|
||||||
|
|
||||||
|
# Caso 1: connessione locale → invio diretto
|
||||||
|
if user_id in self._local:
|
||||||
|
conn = self._local[user_id]
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut: asyncio.Future[dict] = loop.create_future()
|
||||||
|
conn.pending_calls[call_id] = fut
|
||||||
|
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
|
||||||
|
return await asyncio.wait_for(fut, timeout=30.0)
|
||||||
|
|
||||||
|
# Caso 2: connessione remota → Redis pub/sub
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut = loop.create_future()
|
||||||
|
self._remote_futures[call_id] = fut
|
||||||
|
|
||||||
|
# Sottoscrivi al canale di risposta
|
||||||
|
result_channel = f"tool_result:{call_id}"
|
||||||
|
await self._pubsub.subscribe(result_channel)
|
||||||
|
|
||||||
|
# Pubblica il tool_call
|
||||||
|
await self._redis.publish(
|
||||||
|
f"tool_call:{user_id}",
|
||||||
|
json.dumps(payload),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
return await asyncio.wait_for(fut, timeout=30.0)
|
||||||
|
finally:
|
||||||
|
self._remote_futures.pop(call_id, None)
|
||||||
|
await self._pubsub.unsubscribe(result_channel)
|
||||||
|
|
||||||
|
# ── Risoluzione tool_result (da WS locale) ──
|
||||||
|
|
||||||
|
def resolve_local(self, user_id: str, call_id: str, result: dict):
|
||||||
|
conn = self._local.get(user_id)
|
||||||
|
if conn:
|
||||||
|
fut = conn.pending_calls.pop(call_id, None)
|
||||||
|
if fut and not fut.done():
|
||||||
|
fut.set_result(result)
|
||||||
|
|
||||||
|
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
|
||||||
|
"""Chiamato quando il device locale invia un tool_result."""
|
||||||
|
self.resolve_local(user_id, call_id, result)
|
||||||
|
# Pubblica anche su Redis per l'istanza remota che aspetta
|
||||||
|
await self._redis.publish(
|
||||||
|
f"tool_result:{call_id}",
|
||||||
|
json.dumps(result),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Listener Redis ──
|
||||||
|
|
||||||
|
async def _listen_tool_calls(self):
|
||||||
|
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
|
||||||
|
async for message in self._pubsub.listen():
|
||||||
|
if message["type"] != "message":
|
||||||
|
continue
|
||||||
|
channel = message["channel"]
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
|
||||||
|
data = json.loads(message["data"])
|
||||||
|
|
||||||
|
if channel.startswith("tool_call:"):
|
||||||
|
# Un'altra istanza vuole che inviamo un tool_call al nostro device
|
||||||
|
user_id = channel.split(":", 1)[1]
|
||||||
|
conn = self._local.get(user_id)
|
||||||
|
if conn:
|
||||||
|
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
|
||||||
|
|
||||||
|
elif channel.startswith("tool_result:"):
|
||||||
|
# Risposta a un tool_call che abbiamo inviato tramite Redis
|
||||||
|
call_id = channel.split(":", 1)[1]
|
||||||
|
fut = self._remote_futures.pop(call_id, None)
|
||||||
|
if fut and not fut.done():
|
||||||
|
fut.set_result(data)
|
||||||
|
|
||||||
|
# ── Stream cross-instance ──
|
||||||
|
|
||||||
|
async def publish_stream_chunk(self, user_id: str, chunk: dict):
|
||||||
|
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
|
||||||
|
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 4. Struttura Directory Proposta (MVP)
|
||||||
|
|
||||||
|
```
|
||||||
|
adiuva-api/
|
||||||
|
├── docker-compose.yml # Orchestrazione completa
|
||||||
|
├── docker-compose.dev.yml # Override per sviluppo locale
|
||||||
|
├── shared/ # Codice condiviso (montato come volume)
|
||||||
|
│ ├── auth.py # JWT verification (chiave pubblica)
|
||||||
|
│ ├── schemas.py # Pydantic schemas condivisi
|
||||||
|
│ ├── middleware/
|
||||||
|
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
|
||||||
|
│ │ └── sanitizer.py
|
||||||
|
│ └── models/
|
||||||
|
│ └── base.py # SQLAlchemy base condivisa
|
||||||
|
│
|
||||||
|
├── auth-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # users, refresh_tokens
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ └── auth.py
|
||||||
|
│ └── services/
|
||||||
|
│ ├── jwt_service.py # RS256 signing
|
||||||
|
│ └── user_service.py
|
||||||
|
│
|
||||||
|
├── chat-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # memory_*
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── device_ws.py # WS connection owner
|
||||||
|
│ │ └── chat.py # REST fallback
|
||||||
|
│ ├── core/
|
||||||
|
│ │ ├── device_manager.py # RedisDeviceManager
|
||||||
|
│ │ ├── deep_agent.py # Home + floating chat
|
||||||
|
│ │ ├── memory_middleware.py
|
||||||
|
│ │ ├── ws_context.py
|
||||||
|
│ │ ├── output_formatter.py
|
||||||
|
│ │ └── llm.py
|
||||||
|
│ └── agents/ # Tool definitions (used by deep_agent)
|
||||||
|
│ ├── task_agent.py
|
||||||
|
│ ├── project_agent.py
|
||||||
|
│ ├── note_agent.py
|
||||||
|
│ └── timeline_agent.py
|
||||||
|
│
|
||||||
|
├── agent-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ ├── agents.py # catalog, can-create, trigger
|
||||||
|
│ │ └── agent_setup.py # journey start/message
|
||||||
|
│ ├── core/
|
||||||
|
│ │ ├── agent_runner.py # Batch classify → process
|
||||||
|
│ │ ├── agent_registry.py
|
||||||
|
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
|
||||||
|
│ │ └── llm.py
|
||||||
|
│ └── agents/
|
||||||
|
│ ├── task_agent.py # Tool definitions (batch context)
|
||||||
|
│ ├── project_agent.py
|
||||||
|
│ ├── note_agent.py
|
||||||
|
│ ├── timeline_agent.py
|
||||||
|
│ └── filesystem_agent.py
|
||||||
|
│
|
||||||
|
├── billing-service/
|
||||||
|
│ ├── Dockerfile
|
||||||
|
│ ├── requirements.txt
|
||||||
|
│ └── app/
|
||||||
|
│ ├── main.py
|
||||||
|
│ ├── config.py
|
||||||
|
│ ├── db.py
|
||||||
|
│ ├── models.py # subscriptions
|
||||||
|
│ ├── routes/
|
||||||
|
│ │ └── billing.py
|
||||||
|
│ └── services/
|
||||||
|
│ ├── stripe_service.py
|
||||||
|
│ └── tier_manager.py
|
||||||
|
│
|
||||||
|
└── infra/
|
||||||
|
├── traefik/
|
||||||
|
│ └── traefik.yml
|
||||||
|
├── keys/
|
||||||
|
│ ├── jwt_private.pem # Solo auth-service
|
||||||
|
│ └── jwt_public.pem # Tutti i servizi
|
||||||
|
└── alembic/ # Migrazioni condivise o per-servizio
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 5. Docker Compose — Configurazione MVP
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# docker-compose.yml
|
||||||
|
|
||||||
|
services:
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# API Gateway
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
traefik:
|
||||||
|
image: traefik:v3.2
|
||||||
|
command:
|
||||||
|
- "--api.insecure=true"
|
||||||
|
- "--providers.docker=true"
|
||||||
|
- "--providers.docker.exposedbydefault=false"
|
||||||
|
- "--entrypoints.web.address=:80"
|
||||||
|
- "--entrypoints.websecure.address=:443"
|
||||||
|
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
||||||
|
ports:
|
||||||
|
- "80:80"
|
||||||
|
- "443:443"
|
||||||
|
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
|
||||||
|
volumes:
|
||||||
|
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||||
|
- ./infra/certs:/certs:ro
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Auth Service (2 repliche)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
auth-service:
|
||||||
|
build: ./auth-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
|
||||||
|
SERVICE_NAME: auth
|
||||||
|
secrets:
|
||||||
|
- jwt_private_key
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
|
||||||
|
- "traefik.http.services.auth.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Chat Service — Real-time WS + Chat (scalabile)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
chat-service:
|
||||||
|
build: ./chat-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: chat
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
# REST chat endpoint
|
||||||
|
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
|
||||||
|
- "traefik.http.services.chat.loadbalancer.server.port=8000"
|
||||||
|
# WebSocket route con sticky session
|
||||||
|
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
|
||||||
|
- "traefik.http.routers.ws.service=chat-ws"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
|
||||||
|
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Agent Service — Batch processing (scalabile indipendentemente)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
agent-service:
|
||||||
|
build: ./agent-service
|
||||||
|
deploy:
|
||||||
|
replicas: 2
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: agent
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
|
||||||
|
- "traefik.http.services.agents.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Billing Service (1 replica)
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
billing-service:
|
||||||
|
build: ./billing-service
|
||||||
|
deploy:
|
||||||
|
replicas: 1
|
||||||
|
env_file: .env
|
||||||
|
environment:
|
||||||
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
REDIS_URL: redis://redis:6379
|
||||||
|
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||||
|
SERVICE_NAME: billing
|
||||||
|
secrets:
|
||||||
|
- jwt_public_key
|
||||||
|
labels:
|
||||||
|
- "traefik.enable=true"
|
||||||
|
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
|
||||||
|
- "traefik.http.services.billing.loadbalancer.server.port=8000"
|
||||||
|
depends_on:
|
||||||
|
db:
|
||||||
|
condition: service_healthy
|
||||||
|
redis:
|
||||||
|
condition: service_healthy
|
||||||
|
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
# Infrastruttura
|
||||||
|
# ══════════════════════════════════════════════════════════
|
||||||
|
db:
|
||||||
|
image: pgvector/pgvector:pg16
|
||||||
|
environment:
|
||||||
|
POSTGRES_USER: postgres
|
||||||
|
POSTGRES_PASSWORD: postgres
|
||||||
|
POSTGRES_DB: adiuva
|
||||||
|
volumes:
|
||||||
|
- postgres_data:/var/lib/postgresql/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
redis:
|
||||||
|
image: redis:7-alpine
|
||||||
|
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||||
|
volumes:
|
||||||
|
- redis_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "redis-cli", "ping"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 3s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
qdrant:
|
||||||
|
image: qdrant/qdrant:latest
|
||||||
|
volumes:
|
||||||
|
- qdrant_data:/qdrant/storage
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
secrets:
|
||||||
|
jwt_private_key:
|
||||||
|
file: ./infra/keys/jwt_private.pem
|
||||||
|
jwt_public_key:
|
||||||
|
file: ./infra/keys/jwt_public.pem
|
||||||
|
|
||||||
|
volumes:
|
||||||
|
postgres_data:
|
||||||
|
redis_data:
|
||||||
|
qdrant_data:
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 6. Configurazione Cloudflare + VPS
|
||||||
|
|
||||||
|
### 6.1 DNS
|
||||||
|
|
||||||
|
```
|
||||||
|
api.tuodominio.com → A record → IP del VPS
|
||||||
|
→ Proxy: ON (orange cloud)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.2 Cloudflare Settings
|
||||||
|
|
||||||
|
| Setting | Valore | Motivo |
|
||||||
|
|---------|--------|--------|
|
||||||
|
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
|
||||||
|
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
|
||||||
|
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
|
||||||
|
| Under Attack Mode | Off (attivare se necessario) | |
|
||||||
|
|
||||||
|
### 6.3 TLS sul VPS
|
||||||
|
|
||||||
|
Due opzioni:
|
||||||
|
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
|
||||||
|
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# traefik.yml — con Cloudflare Origin Certificate
|
||||||
|
entryPoints:
|
||||||
|
websecure:
|
||||||
|
address: ":443"
|
||||||
|
|
||||||
|
tls:
|
||||||
|
certificates:
|
||||||
|
- certFile: /certs/origin.pem
|
||||||
|
keyFile: /certs/origin-key.pem
|
||||||
|
```
|
||||||
|
|
||||||
|
### 6.4 Rete VPS
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
|
||||||
|
# https://www.cloudflare.com/ips/
|
||||||
|
ufw default deny incoming
|
||||||
|
ufw allow from 173.245.48.0/20 to any port 443
|
||||||
|
ufw allow from 103.21.244.0/22 to any port 443
|
||||||
|
# ... (tutti gli IP range di Cloudflare)
|
||||||
|
ufw allow ssh
|
||||||
|
ufw enable
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 7. Comunicazione Inter-Servizio
|
||||||
|
|
||||||
|
### 7.1 Redis Pub/Sub — Event Bus
|
||||||
|
|
||||||
|
```
|
||||||
|
┌──────────┐ tier_changed:user_123 ┌──────────┐
|
||||||
|
│ Billing │ ────────────────────────► │ Auth │
|
||||||
|
│ Service │ │ Service │
|
||||||
|
└──────────┘ └──────────┘
|
||||||
|
|
||||||
|
┌──────────┐ tool_call:user_123 ┌──────────┐
|
||||||
|
│ Agent │ ────────────────────────► │ Chat │
|
||||||
|
│ Service │ │ Service │
|
||||||
|
│ (batch) │ ◄────────────────────────│ (ha WS) │
|
||||||
|
└──────────┘ tool_result:{call_id} └──────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 7.2 Health Checks e Service Discovery
|
||||||
|
|
||||||
|
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
|
||||||
|
- **Redis pub/sub** (tool-call cross-instance, tier events)
|
||||||
|
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
|
||||||
|
- **PostgreSQL** (dati persistenti condivisi)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 8. Piano di Migrazione Incrementale (MVP)
|
||||||
|
|
||||||
|
### Fase 1 — Preparazione (nel monolite attuale)
|
||||||
|
1. Aggiungere Redis al `docker-compose.yml` attuale
|
||||||
|
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
|
||||||
|
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
|
||||||
|
4. Estrarre `shared/` con auth verification, schemas, middleware
|
||||||
|
|
||||||
|
### Fase 2 — Auth Service (primo split)
|
||||||
|
1. Estrarre `auth.py` routes + models in `auth-service/`
|
||||||
|
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
|
||||||
|
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
|
||||||
|
4. Il monolite continua a servire tutto il resto
|
||||||
|
|
||||||
|
### Fase 3 — Billing Service
|
||||||
|
1. Estrarre billing routes, Stripe service, tier manager
|
||||||
|
2. Configurare Redis pub/sub per `tier_changed` events
|
||||||
|
3. Routare via Traefik
|
||||||
|
|
||||||
|
### Fase 4 — Split Chat + Agent (il più delicato)
|
||||||
|
1. Il monolite residuo contiene WS + chat + agents
|
||||||
|
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
|
||||||
|
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
|
||||||
|
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
|
||||||
|
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
|
||||||
|
|
||||||
|
### Fase 5 — Scaling test
|
||||||
|
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
|
||||||
|
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
|
||||||
|
3. Monitoring (Prometheus + Grafana) per ogni servizio
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 9. Monitoraggio e Logging
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Aggiungere al docker-compose.yml
|
||||||
|
|
||||||
|
prometheus:
|
||||||
|
image: prom/prometheus:latest
|
||||||
|
volumes:
|
||||||
|
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
grafana:
|
||||||
|
image: grafana/grafana:latest
|
||||||
|
ports:
|
||||||
|
- "3000:3000"
|
||||||
|
volumes:
|
||||||
|
- grafana_data:/var/lib/grafana
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
loki:
|
||||||
|
image: grafana/loki:latest
|
||||||
|
restart: unless-stopped
|
||||||
|
```
|
||||||
|
|
||||||
|
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 10. Sizing VPS Minimo Consigliato (MVP)
|
||||||
|
|
||||||
|
| Componente | CPU | RAM | Note |
|
||||||
|
|---|---|---|---|
|
||||||
|
| Traefik | 0.25 | 128MB | |
|
||||||
|
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
|
||||||
|
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
|
||||||
|
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
|
||||||
|
| Billing Service | 0.25 | 128MB | |
|
||||||
|
| PostgreSQL | 1.0 | 1GB | |
|
||||||
|
| Redis | 0.25 | 256MB | |
|
||||||
|
| Qdrant | 0.5 | 512MB | |
|
||||||
|
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
|
||||||
|
|
||||||
|
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Riepilogo Architettura MVP
|
||||||
|
|
||||||
|
| Servizio | Repliche | Proprietario di |
|
||||||
|
|---|---|---|
|
||||||
|
| **Traefik** | 1 | Routing, TLS, sticky sessions |
|
||||||
|
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
|
||||||
|
| **Chat Service** | 2–N | WebSocket, home/floating chat, streaming |
|
||||||
|
| **Agent Service** | 2–N | Batch processing, directory scan, agent setup |
|
||||||
|
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
|
||||||
|
|
||||||
|
| Decisione | Scelta | Motivazione |
|
||||||
|
|---|---|---|
|
||||||
|
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
|
||||||
|
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
|
||||||
|
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
|
||||||
|
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
|
||||||
|
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
|
||||||
|
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
|
||||||
|
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
|
||||||
|
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
|
||||||
|
| TLS | Cloudflare Origin Certificate | Zero maintenance |
|
||||||
|
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
|
||||||
|
| Storage / Plugin | Post-MVP | Non critici per il lancio |
|
||||||
56
logging.conf
Normal file
56
logging.conf
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
[loggers]
|
||||||
|
keys=root,uvicorn,uvicorn.error,uvicorn.access,sqlalchemy,watchfiles
|
||||||
|
|
||||||
|
[handlers]
|
||||||
|
keys=console,file
|
||||||
|
|
||||||
|
[formatters]
|
||||||
|
keys=default
|
||||||
|
|
||||||
|
[logger_root]
|
||||||
|
level=INFO
|
||||||
|
handlers=console,file
|
||||||
|
|
||||||
|
[logger_uvicorn]
|
||||||
|
level=INFO
|
||||||
|
handlers=
|
||||||
|
qualname=uvicorn
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_uvicorn.error]
|
||||||
|
level=INFO
|
||||||
|
handlers=
|
||||||
|
qualname=uvicorn.error
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_uvicorn.access]
|
||||||
|
level=INFO
|
||||||
|
handlers=
|
||||||
|
qualname=uvicorn.access
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_sqlalchemy]
|
||||||
|
level=WARNING
|
||||||
|
handlers=
|
||||||
|
qualname=sqlalchemy
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[logger_watchfiles]
|
||||||
|
level=WARNING
|
||||||
|
handlers=
|
||||||
|
qualname=watchfiles
|
||||||
|
propagate=1
|
||||||
|
|
||||||
|
[handler_console]
|
||||||
|
class=StreamHandler
|
||||||
|
formatter=default
|
||||||
|
args=(sys.stderr,)
|
||||||
|
|
||||||
|
[handler_file]
|
||||||
|
class=logging.handlers.RotatingFileHandler
|
||||||
|
formatter=default
|
||||||
|
args=('logs/app.log', 'a', 10485760, 5, 'utf-8')
|
||||||
|
|
||||||
|
[formatter_default]
|
||||||
|
format=%(asctime)s %(levelname)s %(name)s: %(message)s
|
||||||
|
datefmt=%Y-%m-%d %H:%M:%S
|
||||||
@@ -1,22 +0,0 @@
|
|||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
langchain>=0.3.0
|
|
||||||
langchain-openai>=0.3.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
pydantic-settings>=2.7.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
stripe>=11.0.0
|
|
||||||
boto3>=1.35.0
|
|
||||||
slowapi>=0.1.9
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
alembic>=1.14.0
|
|
||||||
bcrypt>=4.2.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
httpx>=0.28.0
|
|
||||||
websockets>=14.0
|
|
||||||
pytest>=8.0.0
|
|
||||||
pytest-asyncio>=0.24.0
|
|
||||||
moto[s3]>=5.0.0
|
|
||||||
pinecone>=5.0.0
|
|
||||||
qdrant-client>=1.7.0
|
|
||||||
19
services/auth/.env.example
Normal file
19
services/auth/.env.example
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# ── Auth Service ──────────────────────────────────────────────────────────────
|
||||||
|
# This file contains env vars specific to the Auth Service.
|
||||||
|
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
|
||||||
|
# or from docker-compose environment.
|
||||||
|
|
||||||
|
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
|
||||||
|
# Generate keypair:
|
||||||
|
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||||
|
# openssl rsa -in private.pem -pubout -out public.pem
|
||||||
|
#
|
||||||
|
# Paste PEM content with literal \n for newlines:
|
||||||
|
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
|
||||||
|
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
|
||||||
|
|
||||||
|
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
|
||||||
|
JWT_PRIVATE_KEY=
|
||||||
|
|
||||||
|
# PUBLIC KEY — used to VERIFY JWTs.
|
||||||
|
JWT_PUBLIC_KEY=
|
||||||
36
services/auth/Dockerfile
Normal file
36
services/auth/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
# Install shared + service deps in one layer
|
||||||
|
COPY services/auth/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Copy shared module (available to all services)
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Copy service source
|
||||||
|
COPY services/auth/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "2", \
|
||||||
|
"--timeout", "30"]
|
||||||
16
services/auth/README.md
Normal file
16
services/auth/README.md
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Auth Service
|
||||||
|
|
||||||
|
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
|
||||||
|
|
||||||
|
## Tables owned
|
||||||
|
- `users`
|
||||||
|
- `refresh_tokens`
|
||||||
|
- `subscriptions` (read; Billing Service writes)
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `POST /auth/register`
|
||||||
|
- `POST /auth/login`
|
||||||
|
- `POST /auth/refresh`
|
||||||
|
- `GET /auth/me`
|
||||||
|
- `PUT /auth/me`
|
||||||
|
- `GET /auth/verify` (ForwardAuth for Traefik)
|
||||||
34
services/auth/app/config.py
Normal file
34
services/auth/app/config.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""Auth Service — local configuration.
|
||||||
|
|
||||||
|
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
|
||||||
|
These are NOT in shared/config.py to prevent other services from accessing them.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from pydantic import field_validator
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class AuthSettings(BaseSettings):
|
||||||
|
# RS256 private key (PEM format). Used to SIGN JWTs.
|
||||||
|
# Only the Auth Service has this. Generate with:
|
||||||
|
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||||
|
# Then set the env var (newlines as \n):
|
||||||
|
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
|
||||||
|
JWT_PRIVATE_KEY: str = ""
|
||||||
|
|
||||||
|
# RS256 public key (PEM format). Used to VERIFY JWTs.
|
||||||
|
# Derived from the private key:
|
||||||
|
# openssl rsa -in private.pem -pubout -out public.pem
|
||||||
|
JWT_PUBLIC_KEY: str = ""
|
||||||
|
|
||||||
|
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _expand_pem_newlines(cls, v: str) -> str:
|
||||||
|
if isinstance(v, str) and r"\n" in v:
|
||||||
|
return v.replace(r"\n", "\n")
|
||||||
|
return v
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
auth_settings = AuthSettings()
|
||||||
69
services/auth/app/deps.py
Normal file
69
services/auth/app/deps.py
Normal file
@@ -0,0 +1,69 @@
|
|||||||
|
"""Auth dependencies — JWT validation for the Auth Service.
|
||||||
|
|
||||||
|
This is the canonical get_current_user used by protected endpoints
|
||||||
|
within the Auth Service itself (/me, /me PUT).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import Depends, HTTPException, status
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import get_session
|
||||||
|
from shared.models import Subscription, User
|
||||||
|
from shared.schemas import UserProfile
|
||||||
|
|
||||||
|
from app.config import auth_settings
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
|
|
||||||
|
async def get_current_user(
|
||||||
|
token: str = Depends(oauth2_scheme),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Validate a Bearer JWT and return the authenticated user.
|
||||||
|
|
||||||
|
The JWT is used for identity and expiry. Tier is fetched live from the
|
||||||
|
subscriptions table so upgrades/downgrades take effect immediately.
|
||||||
|
"""
|
||||||
|
credentials_exc = HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail="Could not validate credentials",
|
||||||
|
headers={"WWW-Authenticate": "Bearer"},
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
if not user_id or not email:
|
||||||
|
raise credentials_exc
|
||||||
|
except JWTError:
|
||||||
|
raise credentials_exc
|
||||||
|
|
||||||
|
# Live tier lookup
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
# Fetch name/surname
|
||||||
|
user_result = await db.execute(
|
||||||
|
select(User.name, User.surname).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user_id,
|
||||||
|
email=email,
|
||||||
|
name=user_row.name if user_row else None,
|
||||||
|
surname=user_row.surname if user_row else None,
|
||||||
|
tier=tier,
|
||||||
|
) # type: ignore[arg-type]
|
||||||
62
services/auth/app/main.py
Normal file
62
services/auth/app/main.py
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
|
||||||
|
|
||||||
|
Standalone FastAPI service extracted from the adiuva-api monolith.
|
||||||
|
Owns: users, refresh_tokens, subscriptions (read).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so "shared" is importable.
|
||||||
|
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
|
||||||
|
# In local dev, we need to add the repo root (two levels up from this file).
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
yield
|
||||||
|
from shared.db import engine
|
||||||
|
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="Adiuva Auth Service",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
from app.routes import router
|
||||||
|
from app.verify import router as verify_router
|
||||||
|
|
||||||
|
app.include_router(router, prefix="/api/v1")
|
||||||
|
app.include_router(verify_router, prefix="/api/v1")
|
||||||
|
|
||||||
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
async def health() -> dict:
|
||||||
|
return {"status": "ok", "service": "auth", "version": app.version}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
249
services/auth/app/routes.py
Normal file
249
services/auth/app/routes.py
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
|
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
import bcrypt
|
||||||
|
from cryptography.fernet import Fernet
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from jose import jwt
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import get_session
|
||||||
|
from shared.models import RefreshToken, Subscription, User
|
||||||
|
from shared.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
|
from app.config import auth_settings
|
||||||
|
from app.deps import get_current_user
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_password(password: str) -> str:
|
||||||
|
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _verify_password(password: str, hashed: str) -> bool:
|
||||||
|
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_token(plain_token: str) -> str:
|
||||||
|
"""SHA-256 of the plain refresh token string."""
|
||||||
|
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||||
|
"""Return (RS256-signed JWT, expires_at_ms)."""
|
||||||
|
now = int(time.time())
|
||||||
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||||
|
payload = {
|
||||||
|
"sub": user_id,
|
||||||
|
"email": email,
|
||||||
|
"tier": tier,
|
||||||
|
"exp": exp,
|
||||||
|
"iat": now,
|
||||||
|
}
|
||||||
|
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
|
||||||
|
return token, exp * 1000 # ms for client
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
|
||||||
|
"""Fetch authoritative tier from subscriptions table."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
return result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _RegisterRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class _LoginRequest(BaseModel):
|
||||||
|
email: str
|
||||||
|
password: str
|
||||||
|
|
||||||
|
|
||||||
|
class _RefreshRequest(BaseModel):
|
||||||
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateProfileRequest(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def register(
|
||||||
|
body: _RegisterRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Create a new account and return JWT tokens."""
|
||||||
|
existing = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
if existing.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||||
|
|
||||||
|
user = User(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
email=body.email,
|
||||||
|
name=body.name,
|
||||||
|
surname=body.surname,
|
||||||
|
password_hash=_hash_password(body.password),
|
||||||
|
tier="free",
|
||||||
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
|
)
|
||||||
|
db.add(user)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/login", response_model=AuthTokens)
|
||||||
|
async def login(
|
||||||
|
body: _LoginRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Validate credentials and return JWT tokens."""
|
||||||
|
result = await db.execute(select(User).where(User.email == body.email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not _verify_password(body.password, user.password_hash):
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||||
|
|
||||||
|
# Fetch live tier for the JWT claim
|
||||||
|
tier = await _get_live_tier(db, user.id)
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/refresh", response_model=AuthTokens)
|
||||||
|
async def refresh(
|
||||||
|
body: _RefreshRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Rotate a refresh token and return a new token pair."""
|
||||||
|
token_hash = _hash_token(body.refresh_token)
|
||||||
|
result = await db.execute(
|
||||||
|
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||||
|
)
|
||||||
|
rt = result.scalar_one_or_none()
|
||||||
|
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||||
|
|
||||||
|
await db.delete(rt)
|
||||||
|
|
||||||
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||||
|
user = user_result.scalar_one_or_none()
|
||||||
|
if user is None:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||||
|
|
||||||
|
# Fetch live tier for the new JWT
|
||||||
|
tier = await _get_live_tier(db, user.id)
|
||||||
|
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
|
new_rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=new_expires,
|
||||||
|
)
|
||||||
|
db.add(new_rt)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
||||||
|
return AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me", response_model=UserProfile)
|
||||||
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||||
|
"""Return the profile for the authenticated user."""
|
||||||
|
return current_user
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me", response_model=UserProfile)
|
||||||
|
async def update_profile(
|
||||||
|
body: _UpdateProfileRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update the authenticated user's name and surname."""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
if body.name is not None:
|
||||||
|
user.name = body.name
|
||||||
|
if body.surname is not None:
|
||||||
|
user.surname = body.surname
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(user)
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user.id,
|
||||||
|
email=user.email,
|
||||||
|
name=user.name,
|
||||||
|
surname=user.surname,
|
||||||
|
tier=current_user.tier,
|
||||||
|
)
|
||||||
66
services/auth/app/verify.py
Normal file
66
services/auth/app/verify.py
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
"""ForwardAuth verification endpoint for Traefik.
|
||||||
|
|
||||||
|
Traefik calls GET /api/v1/auth/verify on every request to a protected
|
||||||
|
service. This endpoint validates the JWT from the Authorization header
|
||||||
|
and returns identity headers that Traefik injects into downstream requests.
|
||||||
|
|
||||||
|
Downstream services NEVER validate JWTs themselves — they trust the
|
||||||
|
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Request, Response
|
||||||
|
from fastapi import status as http_status
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import Subscription
|
||||||
|
|
||||||
|
from app.config import auth_settings
|
||||||
|
|
||||||
|
router = APIRouter(tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/auth/verify")
|
||||||
|
async def verify(request: Request) -> Response:
|
||||||
|
"""Validate JWT and return identity headers for Traefik ForwardAuth.
|
||||||
|
|
||||||
|
Returns 200 with X-User-* headers on success, 401 on failure.
|
||||||
|
Traefik copies response headers to the downstream request.
|
||||||
|
"""
|
||||||
|
auth_header = request.headers.get("Authorization", "")
|
||||||
|
if not auth_header.startswith("Bearer "):
|
||||||
|
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
token = auth_header[7:] # strip "Bearer "
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
email: str | None = payload.get("email")
|
||||||
|
if not user_id or not email:
|
||||||
|
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
except JWTError:
|
||||||
|
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||||
|
|
||||||
|
# Live tier lookup from subscriptions table
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
status_code=http_status.HTTP_200_OK,
|
||||||
|
headers={
|
||||||
|
"X-User-Id": user_id,
|
||||||
|
"X-User-Email": email,
|
||||||
|
"X-User-Tier": tier,
|
||||||
|
},
|
||||||
|
)
|
||||||
11
services/auth/requirements.txt
Normal file
11
services/auth/requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
bcrypt>=4.2.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
36
services/batch-agent/Dockerfile
Normal file
36
services/batch-agent/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
# ── builder ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS builder
|
||||||
|
|
||||||
|
WORKDIR /build
|
||||||
|
|
||||||
|
COPY services/batch-agent/requirements.txt ./requirements.txt
|
||||||
|
RUN pip install --upgrade pip && \
|
||||||
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
|
# Shared module
|
||||||
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/batch-agent/app/ app/
|
||||||
|
|
||||||
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
|
USER appuser
|
||||||
|
|
||||||
|
EXPOSE 8000
|
||||||
|
|
||||||
|
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "2", \
|
||||||
|
"--timeout", "300"]
|
||||||
23
services/batch-agent/README.md
Normal file
23
services/batch-agent/README.md
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
# Batch Agent Service
|
||||||
|
|
||||||
|
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
|
||||||
|
|
||||||
|
## Tables owned
|
||||||
|
- `local_agent_configs`
|
||||||
|
- `cloud_agent_configs`
|
||||||
|
- `agent_run_logs`
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `GET /agents/catalog`
|
||||||
|
- `POST /agents/can-create`
|
||||||
|
- `POST /agents/trigger`
|
||||||
|
- `GET /agents/{id}/history`
|
||||||
|
|
||||||
|
## Redis channels
|
||||||
|
- Subscribe: `batch:request:{user_id}`
|
||||||
|
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
|
||||||
|
- BRPOP: `tool:result:{call_id}` (30s timeout)
|
||||||
|
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
|
||||||
|
|
||||||
|
## TODO
|
||||||
|
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.
|
||||||
910
services/batch-agent/app/agent_runner.py
Normal file
910
services/batch-agent/app/agent_runner.py
Normal file
@@ -0,0 +1,910 @@
|
|||||||
|
"""Agent run orchestrator — adapted for Batch Agent Service.
|
||||||
|
|
||||||
|
Key changes from monolith app/core/agent_runner.py:
|
||||||
|
- No DeviceConnectionManager — tool calls go through Redis ws_context.
|
||||||
|
- set_current_user / clear_current_user replace set_client_executor.
|
||||||
|
- run_local_agent accepts a serialized dict (from Redis / REST) instead
|
||||||
|
of SQLAlchemy model objects.
|
||||||
|
- _finalize_run writes to PostgreSQL via shared.db.async_session.
|
||||||
|
- Cloud agent import path changed to app.integrations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from shared.agents.note_agent import NOTE_TOOLS
|
||||||
|
from shared.agents.project_agent import PROJECT_TOOLS
|
||||||
|
from shared.agents.task_agent import TASK_TOOLS
|
||||||
|
from shared.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
|
from shared.llm import get_llm
|
||||||
|
from shared.ws_context import execute_on_client, set_current_user, clear_current_user
|
||||||
|
import app.tracing as tracing
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
from shared.redis import redis_client, ws_out_channel
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Concurrency guard ─────────────────────────────────────────────────────
|
||||||
|
_running_agents: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
|
def is_agent_running(agent_id: str) -> bool:
|
||||||
|
return agent_id in _running_agents
|
||||||
|
|
||||||
|
|
||||||
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||||
|
_TOOL_CALL_TIMEOUT: int = 30
|
||||||
|
_MAX_PROCESSING_STEPS: int = 12
|
||||||
|
_MAX_SCAN_DEPTH: int = 5
|
||||||
|
|
||||||
|
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||||
|
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||||
|
"tasks": TASK_TOOLS,
|
||||||
|
"notes": NOTE_TOOLS,
|
||||||
|
"timelines": TIMELINE_TOOLS,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
||||||
|
|
||||||
|
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
||||||
|
"tasks": (
|
||||||
|
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
||||||
|
"assigned to someone, or tracked with a due date or status."
|
||||||
|
),
|
||||||
|
"notes": (
|
||||||
|
"Documentation, meeting notes, summaries, reference material — "
|
||||||
|
"written content meant to be read and referenced rather than acted on."
|
||||||
|
),
|
||||||
|
"timelines": (
|
||||||
|
"Project milestones, deadlines, scheduled events — "
|
||||||
|
"specific dates that mark a point in the progress of a project."
|
||||||
|
),
|
||||||
|
"projects": (
|
||||||
|
"High-level project entities — only relevant if the file clearly introduces "
|
||||||
|
"a new project or updates the scope of an existing one."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
_STEP1_SYSTEM_PROMPT = """\
|
||||||
|
You are a file classifier for a freelance project management tool.
|
||||||
|
|
||||||
|
Your job is to match a file to an existing project and identify which data domains to extract.
|
||||||
|
|
||||||
|
## Project matching rules (STRICT — follow in order)
|
||||||
|
|
||||||
|
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
||||||
|
that overlaps with the existing projects listed below.
|
||||||
|
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
||||||
|
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
||||||
|
when the file has zero meaningful connection to any listed project.
|
||||||
|
4. When in doubt, pick the closest match from the list.
|
||||||
|
|
||||||
|
## Response format
|
||||||
|
|
||||||
|
Respond ONLY with a JSON object — no markdown, no explanation:
|
||||||
|
|
||||||
|
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
||||||
|
|
||||||
|
## Domain definitions (only consider domains in the allowed list)
|
||||||
|
|
||||||
|
{domain_definitions}
|
||||||
|
|
||||||
|
## Existing projects
|
||||||
|
|
||||||
|
{projects_list}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
_PROCESSING_SYSTEM_PROMPT = """\
|
||||||
|
You are a data extraction assistant for a freelance project management tool.
|
||||||
|
|
||||||
|
Your task: extract structured data from the file content and persist it using the available tools.
|
||||||
|
|
||||||
|
## Mandatory process — follow this order for EVERY item you extract
|
||||||
|
|
||||||
|
1. READ the existing records listed below for the relevant domain.
|
||||||
|
2. SEARCH for a match by title, topic, or semantic similarity.
|
||||||
|
3. If a match exists → call the update_* tool with the existing record's id.
|
||||||
|
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
||||||
|
|
||||||
|
NEVER call create_* without first checking the existing records.
|
||||||
|
NEVER duplicate a record that already exists under a different wording.
|
||||||
|
|
||||||
|
## Existing records (source of truth)
|
||||||
|
|
||||||
|
{existing_context}
|
||||||
|
|
||||||
|
## Context
|
||||||
|
|
||||||
|
Project: {project_context}
|
||||||
|
Domains to extract: {data_types}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Cloud processing prompt ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CLOUD_PROCESSING_PROMPT = """\
|
||||||
|
You are a data extraction and management assistant for a freelance project
|
||||||
|
management tool.
|
||||||
|
|
||||||
|
Available tools:
|
||||||
|
Filesystem : read_file_content, list_directory, get_file_metadata
|
||||||
|
Tasks : list_tasks, create_task, update_task, add_task_comment
|
||||||
|
Notes : list_notes, get_note, create_note, update_note
|
||||||
|
Timelines : list_timelines, create_timeline, update_timeline
|
||||||
|
Projects : list_all_projects, get_project, create_project, update_project
|
||||||
|
|
||||||
|
Your task:
|
||||||
|
1. Read the full content of each file below using read_file_content.
|
||||||
|
2. For each piece of information found, ALWAYS try to match and update an
|
||||||
|
existing record before creating a new one.
|
||||||
|
3. ONLY act on these entity types: {data_types}.
|
||||||
|
4. Do NOT invent data. Only extract what is clearly present in the files.
|
||||||
|
5. If a file contains no relevant data for the target entity types, skip it.
|
||||||
|
|
||||||
|
{project_context}
|
||||||
|
|
||||||
|
Files to process:
|
||||||
|
{file_list}
|
||||||
|
|
||||||
|
{custom_prompt_section}
|
||||||
|
|
||||||
|
After processing all files, respond with a brief summary of what you updated
|
||||||
|
and what you created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM tool-calling loop ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_agent_with_tools(
|
||||||
|
*,
|
||||||
|
system_prompt: str,
|
||||||
|
user_message: str,
|
||||||
|
tools: list[Any],
|
||||||
|
max_steps: int,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Run an LLM agent with tool-calling, returning the final text response."""
|
||||||
|
callbacks = [langfuse_handler] if langfuse_handler else None
|
||||||
|
llm = get_llm(callbacks=callbacks)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
messages: list[Any] = [
|
||||||
|
SystemMessage(content=system_prompt),
|
||||||
|
HumanMessage(content=user_message),
|
||||||
|
]
|
||||||
|
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(max_steps):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_id = str(call.get("id", ""))
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:200],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool list builder ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
||||||
|
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
||||||
|
for dt in data_types:
|
||||||
|
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
||||||
|
if dt_tools:
|
||||||
|
tools.extend(dt_tools)
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
# ── Code-based directory scanner ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _scan_directories(
|
||||||
|
paths: list[str],
|
||||||
|
extensions: list[str],
|
||||||
|
last_run_at: datetime | None,
|
||||||
|
) -> list[str]:
|
||||||
|
all_files: list[str] = []
|
||||||
|
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
|
||||||
|
|
||||||
|
async def _walk(path: str, depth: int) -> None:
|
||||||
|
if depth > _MAX_SCAN_DEPTH:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="list_directory", data={"path": path})
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
|
||||||
|
return
|
||||||
|
for entry in result.get("entries", []):
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
if not entry_path:
|
||||||
|
continue
|
||||||
|
if entry.get("type") == "directory":
|
||||||
|
await _walk(entry_path, depth + 1)
|
||||||
|
elif entry.get("type") == "file":
|
||||||
|
if ext_set:
|
||||||
|
dot_pos = entry_path.rfind(".")
|
||||||
|
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
|
||||||
|
if file_ext not in ext_set:
|
||||||
|
continue
|
||||||
|
all_files.append(entry_path)
|
||||||
|
|
||||||
|
for root in paths:
|
||||||
|
await _walk(root, depth=0)
|
||||||
|
|
||||||
|
if last_run_at is None:
|
||||||
|
return all_files
|
||||||
|
|
||||||
|
last_run_ms = int(last_run_at.timestamp() * 1000)
|
||||||
|
filtered: list[str] = []
|
||||||
|
for file_path in all_files:
|
||||||
|
try:
|
||||||
|
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
||||||
|
modified_at = meta.get("modifiedAt")
|
||||||
|
if modified_at is None:
|
||||||
|
filtered.append(file_path)
|
||||||
|
continue
|
||||||
|
if isinstance(modified_at, (int, float)):
|
||||||
|
mod_ms = int(modified_at)
|
||||||
|
else:
|
||||||
|
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
|
||||||
|
if mod_ms > last_run_ms:
|
||||||
|
filtered.append(file_path)
|
||||||
|
except Exception:
|
||||||
|
filtered.append(file_path)
|
||||||
|
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
# ── Code-based entity fetchers ────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_projects() -> list[dict]:
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(action="select", table="projects")
|
||||||
|
return result.get("rows", [])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to fetch projects: %s", exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
_DOMAIN_TABLE: dict[str, str] = {
|
||||||
|
"tasks": "tasks",
|
||||||
|
"notes": "notes",
|
||||||
|
"timelines": "timelines",
|
||||||
|
"projects": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
|
||||||
|
table = _DOMAIN_TABLE.get(domain)
|
||||||
|
if not table:
|
||||||
|
return []
|
||||||
|
filters: dict[str, Any] = {}
|
||||||
|
if project_id != "standalone" and domain != "projects":
|
||||||
|
filters["projectId"] = project_id
|
||||||
|
try:
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="select",
|
||||||
|
table=table,
|
||||||
|
filters=filters if filters else None,
|
||||||
|
)
|
||||||
|
return result.get("rows", [])
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
||||||
|
if not rows:
|
||||||
|
return f"No existing {domain}."
|
||||||
|
lines: list[str] = []
|
||||||
|
for r in rows:
|
||||||
|
if domain == "tasks":
|
||||||
|
desc = r.get("description") or ""
|
||||||
|
desc_part = f" — {desc[:120]}" if desc else ""
|
||||||
|
assignee = r.get("assignee") or r.get("assignees") or ""
|
||||||
|
due = r.get("dueDate") or r.get("due_date") or ""
|
||||||
|
meta = ", ".join(filter(None, [
|
||||||
|
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
|
||||||
|
f"assignee: {assignee}" if assignee else "",
|
||||||
|
f"due: {due}" if due else "",
|
||||||
|
]))
|
||||||
|
lines.append(
|
||||||
|
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
|
||||||
|
f" ({meta}, id: {r['id']})"
|
||||||
|
)
|
||||||
|
elif domain == "notes":
|
||||||
|
snippet = (r.get("content") or "")[:200].replace("\n", " ")
|
||||||
|
snippet_part = f"\n Preview: {snippet}" if snippet else ""
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
|
||||||
|
)
|
||||||
|
elif domain == "timelines":
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
|
||||||
|
)
|
||||||
|
elif domain == "projects":
|
||||||
|
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
|
||||||
|
summary_part = f" — {summary}" if summary else ""
|
||||||
|
lines.append(
|
||||||
|
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
|
||||||
|
f" (id: {r['id']})"
|
||||||
|
)
|
||||||
|
return f"Existing {domain}:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _classify_file(
|
||||||
|
file_path: str,
|
||||||
|
file_content: str,
|
||||||
|
projects: list[dict],
|
||||||
|
config_data_types: list[str],
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
custom_system_prompt: str | None = None,
|
||||||
|
) -> tuple[str, list[str], str | None]:
|
||||||
|
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
||||||
|
|
||||||
|
if not file_content.strip():
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
valid_project_ids = {p["id"] for p in projects}
|
||||||
|
|
||||||
|
def _fmt_project(p: dict) -> str:
|
||||||
|
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
||||||
|
summary_part = f" — {summary[:100]}" if summary else ""
|
||||||
|
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
||||||
|
|
||||||
|
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
||||||
|
|
||||||
|
domain_definitions = "\n".join(
|
||||||
|
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
||||||
|
for d in config_data_types
|
||||||
|
if d in _DOMAIN_DESCRIPTIONS
|
||||||
|
)
|
||||||
|
|
||||||
|
if custom_system_prompt:
|
||||||
|
# Fixture-provided prompt takes absolute priority
|
||||||
|
system = custom_system_prompt.format_map(
|
||||||
|
{"domain_definitions": domain_definitions, "projects_list": projects_list}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
system = tracing.compile_prompt(
|
||||||
|
"batch_file_classifier",
|
||||||
|
fallback=_STEP1_SYSTEM_PROMPT,
|
||||||
|
variables={
|
||||||
|
"domain_definitions": domain_definitions,
|
||||||
|
"projects_list": projects_list,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = get_llm(callbacks=[langfuse_handler] if langfuse_handler else None)
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=system),
|
||||||
|
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
||||||
|
])
|
||||||
|
raw = _as_text(response.content).strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
raw_project_id: str = str(parsed.get("project_id") or "new")
|
||||||
|
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
||||||
|
new_project_name: str | None = (
|
||||||
|
str(parsed["new_project_name"]).strip() or None
|
||||||
|
if project_id == "new" and parsed.get("new_project_name")
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
domains: list[str] = [
|
||||||
|
d for d in parsed.get("domains", [])
|
||||||
|
if d in config_data_types
|
||||||
|
]
|
||||||
|
if not domains:
|
||||||
|
domains = list(config_data_types)
|
||||||
|
return project_id, domains, new_project_name
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
||||||
|
)
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent runner (two-step per file) ────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_local_agent(user_id: str, trigger_data: dict[str, Any], *, langfuse_handler: Any | None = None) -> None:
|
||||||
|
"""Execute a local directory agent run.
|
||||||
|
|
||||||
|
In the microservice world, trigger_data is a serialized dict from
|
||||||
|
the REST route (forwarded via Redis), containing the agent config
|
||||||
|
fields and run_context.
|
||||||
|
|
||||||
|
set_current_user() must be called BEFORE this function.
|
||||||
|
"""
|
||||||
|
run_context: dict = trigger_data.get("run_context", {})
|
||||||
|
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
|
||||||
|
run_id = run_context.get("run_id")
|
||||||
|
|
||||||
|
_running_agents.add(agent_id)
|
||||||
|
|
||||||
|
# Extract config from trigger payload
|
||||||
|
directory_paths: list[str] = trigger_data.get("directory_paths", [])
|
||||||
|
if not directory_paths:
|
||||||
|
directory = trigger_data.get("directory", "")
|
||||||
|
if directory:
|
||||||
|
directory_paths = [directory]
|
||||||
|
|
||||||
|
data_types: list[str] = trigger_data.get("data_types", [])
|
||||||
|
file_extensions: list[str] = trigger_data.get("file_extensions", [])
|
||||||
|
prompt_template: str = trigger_data.get("prompt_template", "")
|
||||||
|
last_run_at_raw = trigger_data.get("last_run_at")
|
||||||
|
last_run_at: datetime | None = None
|
||||||
|
if last_run_at_raw:
|
||||||
|
if isinstance(last_run_at_raw, str):
|
||||||
|
last_run_at = datetime.fromisoformat(last_run_at_raw)
|
||||||
|
elif isinstance(last_run_at_raw, (int, float)):
|
||||||
|
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
custom_section = (
|
||||||
|
f"User instructions:\n{prompt_template}"
|
||||||
|
if prompt_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create or load run log
|
||||||
|
run_log_id = run_id
|
||||||
|
if not run_log_id:
|
||||||
|
async with async_session() as db:
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
try:
|
||||||
|
# ── Scan directories ─────────────────────────────────────────
|
||||||
|
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
|
||||||
|
file_paths = await _scan_directories(
|
||||||
|
paths=directory_paths,
|
||||||
|
extensions=file_extensions,
|
||||||
|
last_run_at=last_run_at,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
|
||||||
|
)
|
||||||
|
|
||||||
|
if not file_paths:
|
||||||
|
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Fetch all projects once ──────────────────────────────────
|
||||||
|
projects = await _fetch_projects()
|
||||||
|
|
||||||
|
for file_path in file_paths:
|
||||||
|
try:
|
||||||
|
file_result = await execute_on_client(
|
||||||
|
action="read_file_content", data={"path": file_path}
|
||||||
|
)
|
||||||
|
file_content: str = file_result.get("content", "")
|
||||||
|
if not file_content:
|
||||||
|
continue
|
||||||
|
|
||||||
|
items_processed += 1
|
||||||
|
|
||||||
|
# Step 1 — classify file
|
||||||
|
project_id, domains, new_project_name = await _classify_file(
|
||||||
|
file_path=file_path,
|
||||||
|
file_content=file_content,
|
||||||
|
projects=projects,
|
||||||
|
config_data_types=data_types,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Step 2 — resolve project_id, fetch entities, process
|
||||||
|
if project_id == "new":
|
||||||
|
proj_name = new_project_name or "Untitled Project"
|
||||||
|
try:
|
||||||
|
proj_result = await execute_on_client(
|
||||||
|
action="insert",
|
||||||
|
table="projects",
|
||||||
|
data={"name": proj_name, "clientId": None},
|
||||||
|
)
|
||||||
|
created = proj_result.get("row", {})
|
||||||
|
effective_project_id = created.get("id", "standalone")
|
||||||
|
if "id" in created:
|
||||||
|
projects.append(created)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
|
||||||
|
effective_project_id = "standalone"
|
||||||
|
proj_name = "unknown"
|
||||||
|
project_context = (
|
||||||
|
f"Project: {proj_name} (id: {effective_project_id}). "
|
||||||
|
"Always set projectId to this id on every record you create."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
effective_project_id = project_id
|
||||||
|
proj = next((p for p in projects if p["id"] == project_id), None)
|
||||||
|
proj_name = proj.get("name", project_id) if proj else project_id
|
||||||
|
project_context = (
|
||||||
|
f"Project: {proj_name} (id: {project_id}). "
|
||||||
|
"Always set projectId to this id on every record you create."
|
||||||
|
)
|
||||||
|
|
||||||
|
domains = [d for d in domains if d != "projects"]
|
||||||
|
|
||||||
|
existing_blocks: list[str] = []
|
||||||
|
for domain in domains:
|
||||||
|
rows = await _fetch_domain_entities(domain, effective_project_id)
|
||||||
|
existing_blocks.append(_format_entities_for_context(domain, rows))
|
||||||
|
|
||||||
|
existing_context = "\n\n".join(existing_blocks)
|
||||||
|
|
||||||
|
system_prompt = tracing.compile_prompt(
|
||||||
|
"batch_processing",
|
||||||
|
fallback=_PROCESSING_SYSTEM_PROMPT,
|
||||||
|
variables={
|
||||||
|
"existing_context": existing_context,
|
||||||
|
"project_context": project_context,
|
||||||
|
"data_types": ", ".join(domains),
|
||||||
|
"custom_prompt_section": custom_section,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
processing_tools = _build_processing_tools(domains)
|
||||||
|
|
||||||
|
result_text = await _run_agent_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_message=(
|
||||||
|
f"Process this file and extract relevant information.\n\n"
|
||||||
|
f"File: {file_path}\n\nContent:\n{file_content}"
|
||||||
|
),
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s file=%r result=%s",
|
||||||
|
run_log_id, file_path, result_text[:200],
|
||||||
|
)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Error processing '{file_path}': {exc}")
|
||||||
|
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
|
||||||
|
finally:
|
||||||
|
_running_agents.discard(agent_id)
|
||||||
|
|
||||||
|
# ── Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=items_created,
|
||||||
|
errors=errors,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Notify Electron that the run is complete via Redis
|
||||||
|
if run_context:
|
||||||
|
try:
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps({
|
||||||
|
"type": "run_complete",
|
||||||
|
"run_context": run_context,
|
||||||
|
"status": final_status,
|
||||||
|
}))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||||||
|
|
||||||
|
|
||||||
|
async def run_cloud_agent(user_id: str, config_id: str, *, langfuse_handler: Any | None = None) -> None:
|
||||||
|
"""Execute a cloud connector agent run.
|
||||||
|
|
||||||
|
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
|
||||||
|
messages from the provider, and runs LLM extraction.
|
||||||
|
|
||||||
|
set_current_user() must be called BEFORE this function.
|
||||||
|
"""
|
||||||
|
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||||||
|
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
config = result.scalar_one_or_none()
|
||||||
|
if config is None:
|
||||||
|
logger.error("agent_runner: cloud config %s not found", config_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
# Create run log
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=config.id,
|
||||||
|
agent_type="cloud",
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
# ── Decrypt OAuth token ────────────────────────────────────────
|
||||||
|
if not config.oauth_token_encrypted:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Instantiate provider ──────────────────────────────────────
|
||||||
|
try:
|
||||||
|
provider = get_provider(config.provider, credentials_info)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── Fetch messages ────────────────────────────────────────────
|
||||||
|
since: datetime | None = config.last_run_at
|
||||||
|
if since is None:
|
||||||
|
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||||||
|
if since.tzinfo is None:
|
||||||
|
since = since.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.provider == "gmail":
|
||||||
|
raw_messages = await provider.fetch_messages(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "outlook":
|
||||||
|
raw_messages = await provider.fetch_emails(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "teams":
|
||||||
|
raw_messages = await provider.fetch_messages(
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_messages = []
|
||||||
|
except RuntimeError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Provider fetch failed: {exc}"],
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud agent %s fetched %d item(s) from %s",
|
||||||
|
config.id, len(raw_messages), config.provider,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Extract + insert via LLM ─────────────────────────────────
|
||||||
|
try:
|
||||||
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
|
custom_section = (
|
||||||
|
f"User instructions:\n{config.prompt_template}"
|
||||||
|
if config.prompt_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
|
||||||
|
for msg in raw_messages:
|
||||||
|
content_text = msg.as_text
|
||||||
|
if not content_text:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
|
||||||
|
processing_prompt = tracing.compile_prompt(
|
||||||
|
"batch_cloud_processing",
|
||||||
|
fallback=_CLOUD_PROCESSING_PROMPT,
|
||||||
|
variables={
|
||||||
|
"data_types": ", ".join(config.data_types),
|
||||||
|
"project_context": "Determine the appropriate project from the message context.",
|
||||||
|
"file_list": f"Message from {config.provider} (id: {msg.id})",
|
||||||
|
"custom_prompt_section": custom_section,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
await _run_agent_with_tools(
|
||||||
|
system_prompt=processing_prompt,
|
||||||
|
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
||||||
|
tools=processing_tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"Agent run failed: {exc}")
|
||||||
|
|
||||||
|
# ── Persist refreshed token ───────────────────────────────────
|
||||||
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
|
if refreshed:
|
||||||
|
try:
|
||||||
|
new_encrypted = encrypt_token(refreshed)
|
||||||
|
async with async_session() as db:
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||||||
|
)
|
||||||
|
cfg_row = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg_row:
|
||||||
|
cfg_row.oauth_token_encrypted = new_encrypted
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
|
||||||
|
|
||||||
|
# ── Finalise ──────────────────────────────────────────────────
|
||||||
|
if errors and items_processed == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log_id,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=0,
|
||||||
|
errors=errors,
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _finalize_run(
|
||||||
|
run_log_id: int | str,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
items_processed: int = 0,
|
||||||
|
items_created: int = 0,
|
||||||
|
errors: list[str] | None = None,
|
||||||
|
update_config_last_run: bool = False,
|
||||||
|
config_id: str | None = None,
|
||||||
|
config_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Persist the run outcome and optionally update last_run_at on the config."""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
|
||||||
|
)
|
||||||
|
managed = result.scalar_one_or_none()
|
||||||
|
if managed is None:
|
||||||
|
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
managed.status = status
|
||||||
|
managed.items_processed = items_processed
|
||||||
|
managed.items_created = items_created
|
||||||
|
managed.errors = errors or []
|
||||||
|
managed.completed_at = now
|
||||||
|
|
||||||
|
if update_config_last_run and config_id:
|
||||||
|
if config_type == "local":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
elif config_type == "cloud":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)
|
||||||
1
services/batch-agent/app/agents/__init__.py
Normal file
1
services/batch-agent/app/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Batch Agent Service domain agents and filesystem tools."""
|
||||||
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.ws_context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
|
from shared.ws_context import execute_on_client
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str:
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{path}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str:
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{path}' is empty or could not be read."
|
||||||
|
return content
|
||||||
|
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str:
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": path},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", path)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FILESYSTEM_TOOLS: list[Any] = [
|
||||||
|
list_directory,
|
||||||
|
read_file_content,
|
||||||
|
get_file_metadata,
|
||||||
|
]
|
||||||
108
services/batch-agent/app/integrations/__init__.py
Normal file
108
services/batch-agent/app/integrations/__init__.py
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
"""Cloud provider integration utilities.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from shared.config instead of app.config.
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
* Shared message dataclasses (EmailMessage, ChatMessage)
|
||||||
|
* get_provider() — factory for Gmail/MS Graph clients
|
||||||
|
* encrypt_token() / decrypt_token() — Fernet-based OAuth token encryption
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EmailMessage:
|
||||||
|
id: str
|
||||||
|
subject: str
|
||||||
|
sender: str
|
||||||
|
body_text: str
|
||||||
|
date: datetime
|
||||||
|
labels: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{labels_str}\n"
|
||||||
|
f"Subject: {self.subject}\n\n"
|
||||||
|
f"{self.body_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ChatMessage:
|
||||||
|
id: str
|
||||||
|
content: str
|
||||||
|
sender: str
|
||||||
|
channel: str | None
|
||||||
|
date: datetime
|
||||||
|
|
||||||
|
@property
|
||||||
|
def as_text(self) -> str:
|
||||||
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
|
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||||
|
return (
|
||||||
|
f"From: {self.sender}\n"
|
||||||
|
f"Date: {date_str}{channel_str}\n\n"
|
||||||
|
f"{self.content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_fernet() -> Fernet:
|
||||||
|
key = settings.OAUTH_ENCRYPTION_KEY
|
||||||
|
if not key:
|
||||||
|
raise RuntimeError(
|
||||||
|
"OAUTH_ENCRYPTION_KEY is not set. "
|
||||||
|
"Generate one with: python -c \"from cryptography.fernet import Fernet; print(Fernet.generate_key().decode())\""
|
||||||
|
)
|
||||||
|
return Fernet(key.encode() if isinstance(key, str) else key)
|
||||||
|
|
||||||
|
|
||||||
|
def encrypt_token(token_info: dict) -> str:
|
||||||
|
if not isinstance(token_info, dict) or not token_info:
|
||||||
|
raise ValueError("token_info must be a non-empty dict")
|
||||||
|
plaintext = json.dumps(token_info).encode("utf-8")
|
||||||
|
return _get_fernet().encrypt(plaintext).decode("utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def decrypt_token(encrypted: str) -> dict:
|
||||||
|
try:
|
||||||
|
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||||
|
return json.loads(plaintext)
|
||||||
|
except (InvalidToken, json.JSONDecodeError) as exc:
|
||||||
|
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider(
|
||||||
|
provider: str,
|
||||||
|
credentials_info: dict,
|
||||||
|
) -> "GmailClient | MSGraphClient":
|
||||||
|
if provider == "gmail":
|
||||||
|
from app.integrations.gmail import GmailClient
|
||||||
|
return GmailClient(credentials_info)
|
||||||
|
if provider in {"outlook", "teams"}:
|
||||||
|
from app.integrations.ms_graph import MSGraphClient
|
||||||
|
return MSGraphClient(credentials_info)
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown cloud provider {provider!r}. "
|
||||||
|
"Supported: 'gmail', 'outlook', 'teams'."
|
||||||
|
)
|
||||||
252
services/batch-agent/app/integrations/gmail.py
Normal file
252
services/batch-agent/app/integrations/gmail.py
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
"""Gmail API client for cloud agent integration.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import from app.integrations instead of
|
||||||
|
app.integrations (same relative path within the service).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import base64
|
||||||
|
import email
|
||||||
|
import html
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.integrations import EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gmail_query(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
parts: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
labels: list[str] = cfg.get("labels", [])
|
||||||
|
if labels:
|
||||||
|
if len(labels) == 1:
|
||||||
|
parts.append(f"label:{labels[0]}")
|
||||||
|
else:
|
||||||
|
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||||
|
parts.append(f"({label_expr})")
|
||||||
|
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
for sender in senders:
|
||||||
|
parts.append(f"from:{sender}")
|
||||||
|
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
parts.append(f"after:{effective_since.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
parts.append(f"before:{to_dt.strftime(_GMAIL_DATE_FMT)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("gmail: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " ".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw_html: str) -> str:
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||||
|
decoded = html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_body(payload: dict[str, Any]) -> str:
|
||||||
|
mime_type: str = payload.get("mimeType", "")
|
||||||
|
body: dict = payload.get("body", {})
|
||||||
|
parts: list[dict] = payload.get("parts", [])
|
||||||
|
|
||||||
|
if mime_type == "text/plain":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
if mime_type == "text/html":
|
||||||
|
data = body.get("data", "")
|
||||||
|
if data:
|
||||||
|
raw = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||||
|
return _strip_html(raw)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
plain_fallback = ""
|
||||||
|
for part in parts:
|
||||||
|
part_mime = part.get("mimeType", "")
|
||||||
|
if part_mime == "text/plain":
|
||||||
|
return _parse_body(part)
|
||||||
|
if part_mime == "text/html" and not plain_fallback:
|
||||||
|
plain_fallback = _parse_body(part)
|
||||||
|
if part_mime.startswith("multipart/"):
|
||||||
|
nested = _parse_body(part)
|
||||||
|
if nested:
|
||||||
|
return nested
|
||||||
|
return plain_fallback
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_date(raw: str) -> datetime:
|
||||||
|
try:
|
||||||
|
parsed = email.utils.parsedate_to_datetime(raw)
|
||||||
|
if parsed.tzinfo is None:
|
||||||
|
parsed = parsed.replace(tzinfo=timezone.utc)
|
||||||
|
return parsed.astimezone(timezone.utc)
|
||||||
|
except Exception:
|
||||||
|
return datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
|
||||||
|
class GmailClient:
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
expiry_str: str | None = credentials_info.get("expiry")
|
||||||
|
expiry: datetime | None = None
|
||||||
|
if expiry_str:
|
||||||
|
try:
|
||||||
|
expiry = datetime.fromisoformat(
|
||||||
|
expiry_str.replace("Z", "+00:00")
|
||||||
|
).replace(tzinfo=timezone.utc)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
self._credentials = Credentials(
|
||||||
|
token=credentials_info.get("token"),
|
||||||
|
refresh_token=credentials_info.get("refresh_token"),
|
||||||
|
token_uri=credentials_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||||
|
client_id=credentials_info.get("client_id"),
|
||||||
|
client_secret=credentials_info.get("client_secret"),
|
||||||
|
scopes=credentials_info.get("scopes"),
|
||||||
|
expiry=expiry,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
query = _build_gmail_query(filter_config, since)
|
||||||
|
logger.debug("gmail: executing search query %r", query)
|
||||||
|
return await asyncio.to_thread(self._fetch_sync, query)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
creds = self._credentials
|
||||||
|
if not creds.valid and creds.expired:
|
||||||
|
return None
|
||||||
|
if creds.token != self._credentials_info.get("token"):
|
||||||
|
result = {
|
||||||
|
"token": creds.token,
|
||||||
|
"refresh_token": creds.refresh_token,
|
||||||
|
"token_uri": creds.token_uri,
|
||||||
|
"client_id": creds.client_id,
|
||||||
|
"client_secret": creds.client_secret,
|
||||||
|
"scopes": list(creds.scopes or []),
|
||||||
|
}
|
||||||
|
if creds.expiry:
|
||||||
|
result["expiry"] = creds.expiry.isoformat()
|
||||||
|
return result
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||||
|
import googleapiclient.discovery
|
||||||
|
import googleapiclient.errors
|
||||||
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
if self._credentials.expired and self._credentials.refresh_token:
|
||||||
|
try:
|
||||||
|
self._credentials.refresh(Request())
|
||||||
|
except Exception as exc:
|
||||||
|
raise RuntimeError(f"Gmail token refresh failed: {exc}") from exc
|
||||||
|
|
||||||
|
service = googleapiclient.discovery.build(
|
||||||
|
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||||
|
)
|
||||||
|
user_api = service.users()
|
||||||
|
|
||||||
|
ids: list[str] = []
|
||||||
|
page_token: str | None = None
|
||||||
|
while len(ids) < _MAX_MESSAGES:
|
||||||
|
batch_size = min(100, _MAX_MESSAGES - len(ids))
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"userId": "me",
|
||||||
|
"maxResults": batch_size,
|
||||||
|
}
|
||||||
|
if query:
|
||||||
|
kwargs["q"] = query
|
||||||
|
if page_token:
|
||||||
|
kwargs["pageToken"] = page_token
|
||||||
|
|
||||||
|
try:
|
||||||
|
resp = user_api.messages().list(**kwargs).execute()
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
raise RuntimeError(f"Gmail messages.list failed: {exc}") from exc
|
||||||
|
|
||||||
|
for msg in resp.get("messages", []):
|
||||||
|
ids.append(msg["id"])
|
||||||
|
|
||||||
|
page_token = resp.get("nextPageToken")
|
||||||
|
if not page_token:
|
||||||
|
break
|
||||||
|
|
||||||
|
if not ids:
|
||||||
|
return []
|
||||||
|
|
||||||
|
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||||
|
|
||||||
|
messages: list[EmailMessage] = []
|
||||||
|
for msg_id in ids:
|
||||||
|
try:
|
||||||
|
msg = user_api.messages().get(
|
||||||
|
userId="me", id=msg_id, format="full"
|
||||||
|
).execute()
|
||||||
|
|
||||||
|
headers: dict[str, str] = {
|
||||||
|
h["name"].lower(): h["value"]
|
||||||
|
for h in msg.get("payload", {}).get("headers", [])
|
||||||
|
}
|
||||||
|
subject = headers.get("subject", "(no subject)")
|
||||||
|
sender = headers.get("from", "unknown")
|
||||||
|
date_raw = headers.get("date", "")
|
||||||
|
date = _parse_date(date_raw) if date_raw else datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_text = _parse_body(msg.get("payload", {}))[:_BODY_TRUNCATE]
|
||||||
|
labels = msg.get("labelIds", [])
|
||||||
|
|
||||||
|
messages.append(EmailMessage(
|
||||||
|
id=msg_id,
|
||||||
|
subject=subject,
|
||||||
|
sender=sender,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
labels=labels,
|
||||||
|
))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("gmail: skipping message %s: %s", msg_id, exc)
|
||||||
|
|
||||||
|
logger.info("gmail: returned %d message(s)", len(messages))
|
||||||
|
return messages
|
||||||
266
services/batch-agent/app/integrations/ms_graph.py
Normal file
266
services/batch-agent/app/integrations/ms_graph.py
Normal file
@@ -0,0 +1,266 @@
|
|||||||
|
"""Microsoft Graph API client for Outlook and Teams.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: import settings from shared.config.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from app.integrations import ChatMessage, EmailMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
_MAX_EMAILS = 200
|
||||||
|
_MAX_MESSAGES = 200
|
||||||
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
|
||||||
|
def _strip_html(raw: str) -> str:
|
||||||
|
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||||
|
import html as _html
|
||||||
|
decoded = _html.unescape(no_tags)
|
||||||
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def _odata_datetime(dt: datetime) -> str:
|
||||||
|
utc = dt.astimezone(timezone.utc)
|
||||||
|
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
|
|
||||||
|
def _build_email_filter(
|
||||||
|
filter_config: dict[str, Any] | None,
|
||||||
|
since: datetime | None,
|
||||||
|
) -> str:
|
||||||
|
clauses: list[str] = []
|
||||||
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
senders: list[str] = cfg.get("senders", [])
|
||||||
|
if senders:
|
||||||
|
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||||
|
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||||
|
|
||||||
|
date_range: dict = cfg.get("date_range", {})
|
||||||
|
from_str: str | None = date_range.get("from")
|
||||||
|
|
||||||
|
effective_since: datetime | None = since
|
||||||
|
if from_str:
|
||||||
|
try:
|
||||||
|
cfg_since = datetime.fromisoformat(from_str.replace("Z", "+00:00"))
|
||||||
|
if cfg_since.tzinfo is None:
|
||||||
|
cfg_since = cfg_since.replace(tzinfo=timezone.utc)
|
||||||
|
if effective_since is None or cfg_since > effective_since:
|
||||||
|
effective_since = cfg_since
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.from %r — ignoring", from_str)
|
||||||
|
|
||||||
|
if effective_since:
|
||||||
|
clauses.append(f"receivedDateTime ge {_odata_datetime(effective_since)}")
|
||||||
|
|
||||||
|
to_str: str | None = date_range.get("to")
|
||||||
|
if to_str:
|
||||||
|
try:
|
||||||
|
to_dt = datetime.fromisoformat(to_str.replace("Z", "+00:00"))
|
||||||
|
if to_dt.tzinfo is None:
|
||||||
|
to_dt = to_dt.replace(tzinfo=timezone.utc)
|
||||||
|
clauses.append(f"receivedDateTime le {_odata_datetime(to_dt)}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("ms_graph: invalid date_range.to %r — ignoring", to_str)
|
||||||
|
|
||||||
|
return " and ".join(clauses)
|
||||||
|
|
||||||
|
|
||||||
|
class MSGraphClient:
|
||||||
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
|
self._credentials_info = credentials_info
|
||||||
|
self._access_token: str = credentials_info.get("access_token", "")
|
||||||
|
self._original_access_token: str = self._access_token
|
||||||
|
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||||
|
|
||||||
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
|
return {"Authorization": f"Bearer {self._access_token}"}
|
||||||
|
|
||||||
|
async def _refresh_access_token(self) -> None:
|
||||||
|
import msal
|
||||||
|
|
||||||
|
app = msal.ConfidentialClientApplication(
|
||||||
|
client_id=settings.MS_CLIENT_ID,
|
||||||
|
client_credential=settings.MS_CLIENT_SECRET,
|
||||||
|
authority=f"https://login.microsoftonline.com/{settings.MS_TENANT_ID}",
|
||||||
|
)
|
||||||
|
scopes: list[str] = self._credentials_info.get("scope", "").split()
|
||||||
|
if not scopes:
|
||||||
|
scopes = ["https://graph.microsoft.com/.default"]
|
||||||
|
|
||||||
|
result = app.acquire_token_by_refresh_token(
|
||||||
|
self._refresh_token,
|
||||||
|
scopes=scopes,
|
||||||
|
)
|
||||||
|
if "access_token" not in result:
|
||||||
|
error = result.get("error_description", result.get("error", "unknown"))
|
||||||
|
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||||
|
|
||||||
|
self._access_token = result["access_token"]
|
||||||
|
if "refresh_token" in result:
|
||||||
|
self._refresh_token = result["refresh_token"]
|
||||||
|
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||||
|
self._credentials_info["access_token"] = self._access_token
|
||||||
|
|
||||||
|
@property
|
||||||
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
if self._access_token != self._original_access_token:
|
||||||
|
return {**self._credentials_info, "access_token": self._access_token}
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _get(
|
||||||
|
self,
|
||||||
|
client: httpx.AsyncClient,
|
||||||
|
url: str,
|
||||||
|
params: dict[str, Any] | None = None,
|
||||||
|
*,
|
||||||
|
retry_on_401: bool = True,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||||
|
await self._refresh_access_token()
|
||||||
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
|
if resp.status_code == 429:
|
||||||
|
raise RuntimeError("MS Graph rate limit hit (429). Try again later.")
|
||||||
|
resp.raise_for_status()
|
||||||
|
return resp.json()
|
||||||
|
|
||||||
|
async def fetch_emails(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[EmailMessage]:
|
||||||
|
odata_filter = _build_email_filter(filter_config, since)
|
||||||
|
params: dict[str, Any] = {
|
||||||
|
"$top": 50,
|
||||||
|
"$select": "id,subject,from,receivedDateTime,body,bodyPreview",
|
||||||
|
"$orderby": "receivedDateTime desc",
|
||||||
|
}
|
||||||
|
if odata_filter:
|
||||||
|
params["$filter"] = odata_filter
|
||||||
|
|
||||||
|
emails: list[EmailMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/messages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(emails) < _MAX_EMAILS:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
for item in data.get("value", []):
|
||||||
|
emails.append(self._parse_email(item))
|
||||||
|
if len(emails) >= _MAX_EMAILS:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||||
|
return emails
|
||||||
|
|
||||||
|
async def fetch_messages(
|
||||||
|
self,
|
||||||
|
filter_config: dict[str, Any] | None = None,
|
||||||
|
since: datetime | None = None,
|
||||||
|
) -> list[ChatMessage]:
|
||||||
|
cfg = filter_config or {}
|
||||||
|
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||||
|
params: dict[str, Any] = {"$top": 50}
|
||||||
|
if since:
|
||||||
|
params["$filter"] = f"createdDateTime ge {_odata_datetime(since)}"
|
||||||
|
|
||||||
|
messages: list[ChatMessage] = []
|
||||||
|
url = f"{_GRAPH_BASE}/me/chats/getAllMessages"
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||||
|
while url and len(messages) < _MAX_MESSAGES:
|
||||||
|
try:
|
||||||
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
if exc.response.status_code in (403, 404):
|
||||||
|
logger.warning(
|
||||||
|
"ms_graph: /me/chats/getAllMessages not available (%d)",
|
||||||
|
exc.response.status_code,
|
||||||
|
)
|
||||||
|
break
|
||||||
|
raise
|
||||||
|
|
||||||
|
for item in data.get("value", []):
|
||||||
|
msg = self._parse_teams_message(item)
|
||||||
|
if channel_filter and msg.channel:
|
||||||
|
if not any(c in msg.channel.lower() for c in channel_filter):
|
||||||
|
continue
|
||||||
|
messages.append(msg)
|
||||||
|
if len(messages) >= _MAX_MESSAGES:
|
||||||
|
break
|
||||||
|
url = data.get("@odata.nextLink", "")
|
||||||
|
params = {}
|
||||||
|
|
||||||
|
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||||
|
return messages
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||||
|
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||||
|
sender_block = item.get("from", {}) or {}
|
||||||
|
sender_addr = (
|
||||||
|
(sender_block.get("emailAddress") or {}).get("address", "unknown")
|
||||||
|
)
|
||||||
|
date_str: str = item.get("receivedDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_body: str = body_block.get("content", "")
|
||||||
|
if content_type == "html":
|
||||||
|
body_text = _strip_html(raw_body)
|
||||||
|
else:
|
||||||
|
body_text = raw_body or item.get("bodyPreview", "")
|
||||||
|
body_text = body_text[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return EmailMessage(
|
||||||
|
id=item.get("id", ""),
|
||||||
|
subject=subject,
|
||||||
|
sender=sender_addr,
|
||||||
|
body_text=body_text,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_teams_message(item: dict[str, Any]) -> ChatMessage:
|
||||||
|
msg_id: str = item.get("id", "")
|
||||||
|
sender_block = (item.get("from") or {}).get("user") or {}
|
||||||
|
sender: str = sender_block.get("displayName", "unknown")
|
||||||
|
channel: str | None = (item.get("channelIdentity") or {}).get("channelId")
|
||||||
|
|
||||||
|
date_str: str = item.get("createdDateTime", "")
|
||||||
|
try:
|
||||||
|
date = datetime.fromisoformat(date_str.replace("Z", "+00:00"))
|
||||||
|
except Exception:
|
||||||
|
date = datetime.now(timezone.utc)
|
||||||
|
|
||||||
|
body_block = item.get("body", {}) or {}
|
||||||
|
content_type: str = body_block.get("contentType", "text")
|
||||||
|
raw_content: str = body_block.get("content", "")
|
||||||
|
content = _strip_html(raw_content) if content_type == "html" else raw_content
|
||||||
|
content = content[:_BODY_TRUNCATE]
|
||||||
|
|
||||||
|
return ChatMessage(
|
||||||
|
id=msg_id,
|
||||||
|
content=content,
|
||||||
|
sender=sender,
|
||||||
|
channel=channel,
|
||||||
|
date=date,
|
||||||
|
)
|
||||||
395
services/batch-agent/app/journey.py
Normal file
395
services/batch-agent/app/journey.py
Normal file
@@ -0,0 +1,395 @@
|
|||||||
|
"""Chatbot Journey — guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
|
||||||
|
and app.llm instead of monolith paths. Session state is in-memory (could
|
||||||
|
be moved to Redis for horizontal scaling in the future).
|
||||||
|
|
||||||
|
Journey flow:
|
||||||
|
1. Redis consumer dispatches ``journey_start`` with basic agent config.
|
||||||
|
2. Server creates an in-memory session, runs the setup LLM with
|
||||||
|
file-system tools to explore the directory, returns first question.
|
||||||
|
3. ``journey_message`` frames drive the conversation.
|
||||||
|
4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
|
||||||
|
5. Server parses the block and returns ``journey_reply`` with ``done=True``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
|
||||||
|
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||||
|
from shared.llm import get_llm
|
||||||
|
import app.tracing as tracing
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
|
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||||
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
|
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||||
|
_MAX_TURNS: int = 15
|
||||||
|
_MAX_TOOL_STEPS: int = 6
|
||||||
|
|
||||||
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JourneySession:
|
||||||
|
session_id: str
|
||||||
|
user_id: str
|
||||||
|
agent_type: str # "local" | "cloud"
|
||||||
|
directory: str
|
||||||
|
data_types: list[str]
|
||||||
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
system_prompt: str = ""
|
||||||
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||||
|
|
||||||
|
|
||||||
|
# session_id → session
|
||||||
|
_sessions: dict[str, JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||||
|
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||||
|
s = _sessions.get(session_id)
|
||||||
|
if s is None or s.is_expired():
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
return None
|
||||||
|
if s.user_id != user_id:
|
||||||
|
return None
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
|
Your job is to understand exactly what data the user wants to extract from their
|
||||||
|
local directory and produce a concise prompt_template that a separate AI will use
|
||||||
|
as its instruction set.
|
||||||
|
|
||||||
|
You have access to file-system tools to explore the user's directory:
|
||||||
|
- list_directory: to see folder structure
|
||||||
|
- read_file_content: to peek at file contents
|
||||||
|
- get_file_metadata: to check file info
|
||||||
|
|
||||||
|
The user's configured directory is: {directory}
|
||||||
|
Target data types: {data_types}
|
||||||
|
|
||||||
|
IMPORTANT — project assignment is handled automatically. You MUST NOT ask the user
|
||||||
|
about projects, projectId, or how to link records to projects. Never include
|
||||||
|
projectId logic or project creation instructions in the generated prompt_template.
|
||||||
|
|
||||||
|
Start by exploring the directory to understand its structure. Then ask concise,
|
||||||
|
focused questions one at a time. Cover only the topics relevant to the target
|
||||||
|
data types listed above:
|
||||||
|
|
||||||
|
1. Content type and format — confirmed by your exploration.
|
||||||
|
2. For TASKS (if in scope): field mapping for title, status, priority, content,
|
||||||
|
dueDate (where is the date found? what's the fallback when absent?),
|
||||||
|
and assignee (is there a person name to assign?).
|
||||||
|
3. For NOTES when TASKS are also in scope: note vs task distinction —
|
||||||
|
what makes something a note rather than a task?
|
||||||
|
4. For TIMELINES (if in scope): the date source — what marks a milestone or event?
|
||||||
|
5. Exclusions and special handling applicable to the target data types.
|
||||||
|
|
||||||
|
Keep asking focused questions until you are at least 90% confident. Then stop and
|
||||||
|
output the final prompt_template immediately, wrapped between these exact markers
|
||||||
|
on their own lines:
|
||||||
|
|
||||||
|
{template_start}
|
||||||
|
<the complete extraction prompt here>
|
||||||
|
{template_end}
|
||||||
|
|
||||||
|
The prompt_template must be concise (bullet points, ~15–25 lines maximum).
|
||||||
|
Specify only:
|
||||||
|
- Scope: what files/content qualify and what entity types to create.
|
||||||
|
- Field mapping rules per entity type (camelCase fields: title, status, priority,
|
||||||
|
dueDate, content, assignee, etc.).
|
||||||
|
- dueDate rule (if tasks in scope): source and fallback behaviour.
|
||||||
|
- Note vs task rule (if both in scope): the criterion that separates them.
|
||||||
|
- Timeline date rule (if timelines in scope): what constitutes a timeline event.
|
||||||
|
- Exclusion/filtering rules.
|
||||||
|
- 2–3 concrete mapping examples based on what you discovered.
|
||||||
|
|
||||||
|
{existing_section}Begin by exploring the directory, then ask your first question.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_system_prompt(
|
||||||
|
directory: str,
|
||||||
|
data_types: list[str],
|
||||||
|
existing_template: str | None = None,
|
||||||
|
) -> str:
|
||||||
|
existing_section = (
|
||||||
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
|
f"---\n{existing_template}\n---\n"
|
||||||
|
if existing_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
# Use Langfuse compile_prompt ({{variable}} syntax) with Python .format() fallback
|
||||||
|
return tracing.compile_prompt(
|
||||||
|
"journey_system",
|
||||||
|
fallback=_SYSTEM_PROMPT_TEMPLATE,
|
||||||
|
variables={
|
||||||
|
"directory": directory,
|
||||||
|
"data_types": ", ".join(data_types),
|
||||||
|
"existing_section": existing_section,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_template(text: str) -> str | None:
|
||||||
|
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||||
|
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||||
|
return None
|
||||||
|
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||||
|
end_idx = text.index(_TEMPLATE_END)
|
||||||
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _as_text(content: Any) -> str:
|
||||||
|
if content is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(content, str):
|
||||||
|
return content
|
||||||
|
if isinstance(content, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for item in content:
|
||||||
|
if isinstance(item, str):
|
||||||
|
parts.append(item)
|
||||||
|
elif isinstance(item, dict):
|
||||||
|
text = item.get("text")
|
||||||
|
if isinstance(text, str):
|
||||||
|
parts.append(text)
|
||||||
|
return "".join(parts)
|
||||||
|
return str(content)
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm_with_tools(
|
||||||
|
system_prompt: str,
|
||||||
|
history: list[dict[str, Any]],
|
||||||
|
tools: list[Any],
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
|
continue until a final text response is produced.
|
||||||
|
"""
|
||||||
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
|
for turn in history:
|
||||||
|
if turn["role"] == "user":
|
||||||
|
messages.append(HumanMessage(content=turn["content"]))
|
||||||
|
else:
|
||||||
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
|
callbacks = [langfuse_handler] if langfuse_handler else None
|
||||||
|
llm = get_llm(model=None, temperature=0.4, callbacks=callbacks)
|
||||||
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
for _ in range(_MAX_TOOL_STEPS):
|
||||||
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
messages.append(response)
|
||||||
|
|
||||||
|
if not response.tool_calls:
|
||||||
|
return _as_text(response.content)
|
||||||
|
|
||||||
|
for call in response.tool_calls:
|
||||||
|
call_name = str(call.get("name", ""))
|
||||||
|
call_args = call.get("args", {})
|
||||||
|
logger.info(
|
||||||
|
"journey: tool_call name=%s args=%s",
|
||||||
|
call_name,
|
||||||
|
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_fn = tool_map.get(call_name)
|
||||||
|
if tool_fn is None:
|
||||||
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
|
else:
|
||||||
|
tool_output = await tool_fn.ainvoke(call_args)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey: tool_result name=%s output=%s",
|
||||||
|
call_name,
|
||||||
|
str(tool_output)[:800],
|
||||||
|
)
|
||||||
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
|
# Fallback: exceeded max tool steps.
|
||||||
|
final = await llm.ainvoke(messages)
|
||||||
|
return _as_text(final.content)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Journey handlers (called from redis_consumer) ────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_start(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_start`` request.
|
||||||
|
|
||||||
|
Creates a session, runs the setup LLM with directory exploration,
|
||||||
|
and returns the ``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
agent_type = frame.get("agent_type", "local")
|
||||||
|
directory = frame.get("directory", "")
|
||||||
|
data_types = frame.get("data_types", [])
|
||||||
|
existing_template = frame.get("existing_template")
|
||||||
|
|
||||||
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
|
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||||
|
|
||||||
|
session = JourneySession(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=user_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
directory=directory,
|
||||||
|
data_types=data_types,
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
seed_history: list[dict[str, Any]] = [
|
||||||
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
|
]
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
history=seed_history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.extend(seed_history)
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
_sessions[session_id] = session
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey: session %s started for user %s (directory=%s)",
|
||||||
|
session_id,
|
||||||
|
user_id,
|
||||||
|
directory,
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def handle_journey_message(
|
||||||
|
user_id: str,
|
||||||
|
frame: dict[str, Any],
|
||||||
|
*,
|
||||||
|
langfuse_handler: Any | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Handle a ``journey_message`` request.
|
||||||
|
|
||||||
|
Appends the user message, calls the LLM, and returns the
|
||||||
|
``journey_reply`` payload.
|
||||||
|
"""
|
||||||
|
session_id = frame.get("session_id", "")
|
||||||
|
message = frame.get("message", "")
|
||||||
|
|
||||||
|
session = get_journey_session(session_id, user_id)
|
||||||
|
if session is None:
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
|
ai_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
if not done:
|
||||||
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
|
if turns >= _MAX_TURNS:
|
||||||
|
nudge_content = (
|
||||||
|
"[System: You have enough information. Please generate the final "
|
||||||
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
|
)
|
||||||
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
|
nudge_reply = await _call_llm_with_tools(
|
||||||
|
system_prompt=session.system_prompt,
|
||||||
|
history=session.history,
|
||||||
|
tools=list(FILESYSTEM_TOOLS),
|
||||||
|
langfuse_handler=langfuse_handler,
|
||||||
|
)
|
||||||
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
|
prompt_template = _extract_template(nudge_reply)
|
||||||
|
if prompt_template is not None:
|
||||||
|
done = True
|
||||||
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
if _TEMPLATE_START in ai_reply
|
||||||
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
logger.info("journey: session %s completed for user %s", session_id, user_id)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": display_message,
|
||||||
|
"done": done,
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
}
|
||||||
76
services/batch-agent/app/llm.py
Normal file
76
services/batch-agent/app/llm.py
Normal file
@@ -0,0 +1,76 @@
|
|||||||
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
|
Identical to services/chat/app/llm.py. Uses shared.config.settings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import warnings
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
warnings.filterwarnings(
|
||||||
|
"ignore",
|
||||||
|
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||||
|
category=UserWarning,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("cerebras/"):
|
||||||
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github/"):
|
||||||
|
return settings.GITHUB_TOKEN or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
return None
|
||||||
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
*,
|
||||||
|
model: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
callbacks: list | None = None,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
if settings.GITHUB_TOKEN:
|
||||||
|
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
|
||||||
|
|
||||||
|
if "/" in model:
|
||||||
|
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=_api_key_for_model(model),
|
||||||
|
callbacks=callbacks,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed(text: str) -> list[float]:
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
|
return response.data[0].embedding
|
||||||
79
services/batch-agent/app/main.py
Normal file
79
services/batch-agent/app/main.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Batch Agent Service — FastAPI application.
|
||||||
|
|
||||||
|
Owns: agent_runner (local directory + cloud connectors), journey builder,
|
||||||
|
filesystem_agent, integrations (Gmail, MS Graph).
|
||||||
|
|
||||||
|
Communicates with WS Gateway via Redis:
|
||||||
|
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
|
||||||
|
- Publishes to ws:out:{user_id} (journey replies + tool calls)
|
||||||
|
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
|
||||||
|
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so ``shared`` is importable when
|
||||||
|
# running locally (in Docker the COPY already places it at /app/shared/).
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.redis_consumer import start_consumer
|
||||||
|
from app.routes import router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
# Initialise Langfuse tracing (no-op if keys are missing)
|
||||||
|
from app.tracing import init_langfuse
|
||||||
|
init_langfuse()
|
||||||
|
|
||||||
|
logger.info("batch-agent: starting Redis consumer")
|
||||||
|
task = asyncio.create_task(start_consumer())
|
||||||
|
yield
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
from app.tracing import shutdown as shutdown_langfuse
|
||||||
|
shutdown_langfuse()
|
||||||
|
|
||||||
|
from shared.db import engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
from shared.redis import redis_client
|
||||||
|
await redis_client.aclose()
|
||||||
|
|
||||||
|
logger.info("batch-agent: Redis consumer stopped")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> dict[str, str]:
|
||||||
|
return {"status": "ok", "service": "batch-agent"}
|
||||||
183
services/batch-agent/app/redis_consumer.py
Normal file
183
services/batch-agent/app/redis_consumer.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Redis consumer for the Batch Agent Service.
|
||||||
|
|
||||||
|
Subscribes to batch:request:* (pattern) and dispatches:
|
||||||
|
- journey_start → handle_journey_start
|
||||||
|
- journey_message → handle_journey_message
|
||||||
|
- agent_trigger → run_local_agent / run_cloud_agent
|
||||||
|
|
||||||
|
Results are published back to ws:out:{user_id} via Redis.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.redis import redis_client, batch_request_channel, ws_out_channel
|
||||||
|
|
||||||
|
import app.tracing as tracing
|
||||||
|
from shared.ws_context import set_current_user, clear_current_user
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
|
||||||
|
"""Publish a frame to the user's WS outbound channel."""
|
||||||
|
channel = ws_out_channel(user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(payload))
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle a journey_start request from WS Gateway."""
|
||||||
|
from app.journey import handle_journey_start
|
||||||
|
|
||||||
|
session_id = data.get("session_id", "")
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="journey_start",
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
input=data.get("directory", ""),
|
||||||
|
metadata={"data_types": data.get("data_types", [])},
|
||||||
|
tags=["journey"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
reply = await handle_journey_start(user_id, data, langfuse_handler=langfuse_handler)
|
||||||
|
tracing.link_prompt_to_trace(span, "journey_system")
|
||||||
|
span.update(output=reply.get("message", "")[:500])
|
||||||
|
await _publish_to_user(user_id, reply)
|
||||||
|
tracing.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Journey setup failed: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle a journey_message from WS Gateway."""
|
||||||
|
from app.journey import handle_journey_message
|
||||||
|
|
||||||
|
session_id = data.get("session_id", "")
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="journey_message",
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
input=data.get("message", "")[:200],
|
||||||
|
tags=["journey"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
reply = await handle_journey_message(user_id, data, langfuse_handler=langfuse_handler)
|
||||||
|
tracing.link_prompt_to_trace(span, "journey_system")
|
||||||
|
span.update(output=reply.get("message", "")[:500])
|
||||||
|
await _publish_to_user(user_id, reply)
|
||||||
|
tracing.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "journey_reply",
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": f"Journey processing failed: {exc}",
|
||||||
|
"done": True,
|
||||||
|
"prompt_template": None,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
|
||||||
|
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
|
||||||
|
from app.agent_runner import run_local_agent
|
||||||
|
|
||||||
|
run_context = data.get("run_context", {})
|
||||||
|
agent_id = run_context.get("agent_id", "")
|
||||||
|
set_current_user(user_id)
|
||||||
|
try:
|
||||||
|
with tracing.trace_span(
|
||||||
|
name="agent_trigger",
|
||||||
|
user_id=user_id,
|
||||||
|
trace_id=run_context.get("run_id"),
|
||||||
|
input={"agent_id": agent_id, "directory": data.get("directory", "")},
|
||||||
|
metadata={"data_types": data.get("data_types", [])},
|
||||||
|
tags=["batch", "agent_run"],
|
||||||
|
) as span:
|
||||||
|
langfuse_handler = tracing.get_langfuse_callback()
|
||||||
|
await run_local_agent(user_id, data, langfuse_handler=langfuse_handler)
|
||||||
|
tracing.link_prompt_to_trace(span, "batch_processing")
|
||||||
|
span.update(output={"status": "completed"})
|
||||||
|
tracing.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
|
||||||
|
await _publish_to_user(user_id, {
|
||||||
|
"type": "run_complete",
|
||||||
|
"status": "error",
|
||||||
|
"run_context": run_context,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
|
||||||
|
"""Route a batch request to the correct handler."""
|
||||||
|
msg_type = message_data.get("type", "")
|
||||||
|
|
||||||
|
if msg_type == "journey_start":
|
||||||
|
await _handle_journey_start(user_id, message_data)
|
||||||
|
elif msg_type == "journey_message":
|
||||||
|
await _handle_journey_message(user_id, message_data)
|
||||||
|
elif msg_type == "agent_trigger":
|
||||||
|
await _handle_agent_trigger(user_id, message_data)
|
||||||
|
elif msg_type == "device_online":
|
||||||
|
logger.info("batch-agent: device_online user=%s device=%s", user_id, message_data.get("device_id", "?"))
|
||||||
|
else:
|
||||||
|
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
|
||||||
|
|
||||||
|
|
||||||
|
async def start_consumer() -> None:
|
||||||
|
"""Subscribe to batch:request:* and dispatch incoming frames."""
|
||||||
|
pubsub = redis_client.pubsub()
|
||||||
|
await pubsub.psubscribe("batch:request:*")
|
||||||
|
logger.info("batch-agent: subscribed to batch:request:*")
|
||||||
|
|
||||||
|
try:
|
||||||
|
async for message in pubsub.listen():
|
||||||
|
if message["type"] != "pmessage":
|
||||||
|
continue
|
||||||
|
|
||||||
|
channel: str = message["channel"]
|
||||||
|
if isinstance(channel, bytes):
|
||||||
|
channel = channel.decode()
|
||||||
|
|
||||||
|
# Extract user_id from channel: batch:request:{user_id}
|
||||||
|
parts = channel.split(":", 2)
|
||||||
|
if len(parts) < 3:
|
||||||
|
continue
|
||||||
|
user_id = parts[2]
|
||||||
|
|
||||||
|
raw = message["data"]
|
||||||
|
if isinstance(raw, bytes):
|
||||||
|
raw = raw.decode()
|
||||||
|
|
||||||
|
try:
|
||||||
|
data = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("batch-agent: invalid JSON on channel %s", channel)
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Dispatch in a separate task to avoid blocking the consumer
|
||||||
|
asyncio.create_task(_dispatch(user_id, data))
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
logger.info("batch-agent: consumer shutting down")
|
||||||
|
finally:
|
||||||
|
await pubsub.punsubscribe("batch:request:*")
|
||||||
208
services/batch-agent/app/routes.py
Normal file
208
services/batch-agent/app/routes.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
"""Agent REST routes — catalog, billing checks, trigger.
|
||||||
|
|
||||||
|
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
|
||||||
|
Agent trigger dispatches via Redis to the consumer instead of spawning
|
||||||
|
an in-process background task.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Header, HTTPException, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.models import AgentRunLog
|
||||||
|
from shared.redis import redis_client, batch_request_channel
|
||||||
|
|
||||||
|
from app.agent_runner import is_agent_running
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||||
|
|
||||||
|
# ── Tier feature limits ───────────────────────────────────────────────
|
||||||
|
# Mirrors app/billing/tier_manager.py FEATURES dict.
|
||||||
|
FEATURES: dict[str, dict] = {
|
||||||
|
"free": {"batch_active": 1, "batch_runs_per_day": 3},
|
||||||
|
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
|
||||||
|
"power": {"batch_active": 20, "batch_runs_per_day": 100},
|
||||||
|
"team": {"batch_active": -1, "batch_runs_per_day": -1},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms(dt: datetime) -> int:
|
||||||
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
def _to_data_types(values: list[str]) -> list[str]:
|
||||||
|
normalize = {
|
||||||
|
"task": "tasks", "tasks": "tasks",
|
||||||
|
"note": "notes", "notes": "notes",
|
||||||
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||||
|
"project": "projects", "projects": "projects",
|
||||||
|
}
|
||||||
|
seen: set[str] = set()
|
||||||
|
result: list[str] = []
|
||||||
|
for v in values:
|
||||||
|
mapped = normalize.get(v)
|
||||||
|
if mapped and mapped not in seen:
|
||||||
|
seen.add(mapped)
|
||||||
|
result.append(mapped)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
|
if limit != -1 and current_count >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
|
)
|
||||||
|
return limit
|
||||||
|
|
||||||
|
|
||||||
|
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||||
|
if limit == -1:
|
||||||
|
return
|
||||||
|
today_start = datetime.now(timezone.utc).replace(
|
||||||
|
hour=0, minute=0, second=0, microsecond=0
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.count(AgentRunLog.id)).where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.started_at >= today_start,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
runs_today: int = result.scalar_one()
|
||||||
|
|
||||||
|
if runs_today >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Daily batch run limit ({limit}) reached for your tier.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/catalog")
|
||||||
|
async def get_agent_catalog(
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
) -> list[dict]:
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"type": "local_directory",
|
||||||
|
"name": "Local Directory Monitor",
|
||||||
|
"description": "Watches local directories, extracts data from files using AI",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "gmail",
|
||||||
|
"name": "Gmail Connector",
|
||||||
|
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "teams",
|
||||||
|
"name": "Microsoft Teams Connector",
|
||||||
|
"description": "Monitors Teams messages, extracts action items",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "outlook",
|
||||||
|
"name": "Outlook Connector",
|
||||||
|
"description": "Scans Outlook inbox, extracts tasks/notes",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Can-create check ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/can-create")
|
||||||
|
async def can_create_agent(
|
||||||
|
body: dict,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict:
|
||||||
|
active_agents = body.get("active_agents", 0)
|
||||||
|
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
|
||||||
|
allowed = limit == -1 or active_agents < limit
|
||||||
|
return {
|
||||||
|
"allowed": allowed,
|
||||||
|
"tier": x_user_tier,
|
||||||
|
"active_agents": active_agents,
|
||||||
|
"limit": limit,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trigger ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
async def trigger_agent_run(
|
||||||
|
body: dict,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict:
|
||||||
|
"""Trigger a local agent run — creates run log and dispatches via Redis."""
|
||||||
|
active_agents = body.get("active_agents", 0)
|
||||||
|
_enforce_agent_limit(x_user_tier, active_agents)
|
||||||
|
await _enforce_run_frequency(x_user_tier, x_user_id)
|
||||||
|
|
||||||
|
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
|
||||||
|
|
||||||
|
if is_agent_running(stable_agent_id):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_409_CONFLICT,
|
||||||
|
detail="Agent is already running.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create run log in DB
|
||||||
|
async with async_session() as db:
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=stable_agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=x_user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
run_log_id = run_log.id
|
||||||
|
|
||||||
|
run_context = {
|
||||||
|
"type": "agent_batch",
|
||||||
|
"run_id": run_log_id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Dispatch to the Redis consumer for processing
|
||||||
|
trigger_data = {
|
||||||
|
"type": "agent_trigger",
|
||||||
|
"directory": body.get("directory", ""),
|
||||||
|
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
|
||||||
|
"data_types": _to_data_types(body.get("what_to_extract", [])),
|
||||||
|
"file_extensions": body.get("file_extensions", []),
|
||||||
|
"prompt_template": body.get("custom_agent_prompt", ""),
|
||||||
|
"device_id": body.get("device_id", ""),
|
||||||
|
"run_context": run_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
channel = batch_request_channel(x_user_id)
|
||||||
|
await redis_client.publish(channel, json.dumps(trigger_data))
|
||||||
|
|
||||||
|
return {
|
||||||
|
"id": run_log_id,
|
||||||
|
"agent_id": stable_agent_id,
|
||||||
|
"agent_type": "local",
|
||||||
|
"status": "running",
|
||||||
|
"items_processed": 0,
|
||||||
|
"items_created": 0,
|
||||||
|
"errors": [],
|
||||||
|
"started_at": _dt_ms(run_log.started_at),
|
||||||
|
"completed_at": None,
|
||||||
|
}
|
||||||
336
services/batch-agent/app/tracing.py
Normal file
336
services/batch-agent/app/tracing.py
Normal file
@@ -0,0 +1,336 @@
|
|||||||
|
"""Langfuse tracing & prompt management for the Batch Agent Service (v4 SDK).
|
||||||
|
|
||||||
|
Provides:
|
||||||
|
- ``init_langfuse()`` — initialise the singleton client at startup
|
||||||
|
- ``trace_span()`` — context manager that creates a trace + span
|
||||||
|
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
|
||||||
|
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
|
||||||
|
- ``flush()`` / ``shutdown()`` — lifecycle management
|
||||||
|
|
||||||
|
All functions gracefully degrade to no-ops when Langfuse is not configured,
|
||||||
|
so the service works identically with or without observability keys.
|
||||||
|
|
||||||
|
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── State ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_initialised: bool = False
|
||||||
|
_disabled: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def _is_configured() -> bool:
|
||||||
|
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
|
||||||
|
|
||||||
|
|
||||||
|
def init_langfuse() -> None:
|
||||||
|
"""Initialise the Langfuse singleton. Call once at startup."""
|
||||||
|
global _initialised, _disabled
|
||||||
|
|
||||||
|
if _initialised or _disabled:
|
||||||
|
return
|
||||||
|
|
||||||
|
if not _is_configured():
|
||||||
|
_disabled = True
|
||||||
|
logger.info("tracing: Langfuse keys not set — tracing disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
Langfuse(
|
||||||
|
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||||
|
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||||
|
host=settings.LANGFUSE_HOST,
|
||||||
|
)
|
||||||
|
_initialised = True
|
||||||
|
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
|
||||||
|
except Exception as exc:
|
||||||
|
_disabled = True
|
||||||
|
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client() -> Any | None:
|
||||||
|
"""Return the singleton Langfuse client, or *None* if disabled."""
|
||||||
|
if _disabled:
|
||||||
|
return None
|
||||||
|
if not _initialised:
|
||||||
|
init_langfuse()
|
||||||
|
if _disabled:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
from langfuse import get_client
|
||||||
|
return get_client()
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _NullSpan:
|
||||||
|
"""Drop-in replacement when Langfuse is disabled."""
|
||||||
|
|
||||||
|
def update(self, **_: Any) -> None: ...
|
||||||
|
def set_trace_io(self, **_: Any) -> None: ...
|
||||||
|
def score_trace(self, **_: Any) -> None: ...
|
||||||
|
|
||||||
|
|
||||||
|
# ── Trace context manager ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def trace_span(
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str | None = None,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
input: Any = None,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
):
|
||||||
|
"""Context manager that creates a Langfuse trace/span.
|
||||||
|
|
||||||
|
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
|
||||||
|
A ``CallbackHandler`` created inside this block auto-inherits the trace
|
||||||
|
context, so there is no need to pass trace IDs manually.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
yield _NullSpan()
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse, propagate_attributes
|
||||||
|
|
||||||
|
trace_ctx: dict[str, str] = {}
|
||||||
|
if trace_id is not None:
|
||||||
|
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
|
||||||
|
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name=name,
|
||||||
|
input=input,
|
||||||
|
metadata=metadata or {},
|
||||||
|
**({"trace_context": trace_ctx} if trace_ctx else {}),
|
||||||
|
) as span:
|
||||||
|
with propagate_attributes(
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
tags=tags or [],
|
||||||
|
):
|
||||||
|
yield span
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
|
||||||
|
yield _NullSpan()
|
||||||
|
|
||||||
|
|
||||||
|
# ── LangChain callback handler ──────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_langfuse_callback() -> Any | None:
|
||||||
|
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
|
||||||
|
|
||||||
|
Must be called inside a ``trace_span()`` block for proper linking.
|
||||||
|
Returns *None* when Langfuse is disabled.
|
||||||
|
"""
|
||||||
|
if _disabled and not _initialised:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse.langchain import CallbackHandler
|
||||||
|
return CallbackHandler()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Prompt management ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt(
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
fallback: str | None = None,
|
||||||
|
cache_ttl_seconds: int = 300,
|
||||||
|
) -> str | None:
|
||||||
|
"""Fetch a managed prompt from Langfuse by name (without variable compilation).
|
||||||
|
|
||||||
|
Returns the raw prompt string, or *fallback* if the prompt is not
|
||||||
|
found or Langfuse is disabled.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"cache_ttl_seconds": cache_ttl_seconds,
|
||||||
|
}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
prompt = lf.get_prompt(**kwargs)
|
||||||
|
return prompt.prompt
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
def compile_prompt(
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
fallback: str,
|
||||||
|
variables: dict[str, str],
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
cache_ttl_seconds: int = 300,
|
||||||
|
) -> str:
|
||||||
|
"""Fetch a managed prompt from Langfuse and compile it with ``{{variables}}``.
|
||||||
|
|
||||||
|
If the prompt exists in Langfuse, uses the SDK's ``.compile(**variables)``
|
||||||
|
which replaces ``{{key}}`` placeholders. If Langfuse is disabled or the
|
||||||
|
prompt is not found, falls back to ``fallback.format(**variables)`` (Python
|
||||||
|
``{key}`` placeholders).
|
||||||
|
|
||||||
|
This means:
|
||||||
|
- Langfuse prompts use ``{{variable}}`` syntax.
|
||||||
|
- Hardcoded fallback strings use Python ``{variable}`` syntax.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return fallback.format(**variables)
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"cache_ttl_seconds": cache_ttl_seconds,
|
||||||
|
}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
prompt = lf.get_prompt(**kwargs)
|
||||||
|
return prompt.compile(**variables)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: compile_prompt(%s) failed, using fallback: %s", name, exc)
|
||||||
|
return fallback.format(**variables)
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_object(
|
||||||
|
name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
cache_ttl_seconds: int = 300,
|
||||||
|
) -> Any | None:
|
||||||
|
"""Fetch the raw Langfuse prompt *object* (not the compiled string).
|
||||||
|
|
||||||
|
Returns ``None`` when Langfuse is disabled or the prompt is not found.
|
||||||
|
Use this when you need to pass the prompt to ``start_observation(prompt=...)``
|
||||||
|
for linking the prompt to a trace in the Langfuse UI.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"name": name,
|
||||||
|
"cache_ttl_seconds": cache_ttl_seconds,
|
||||||
|
}
|
||||||
|
if version is not None:
|
||||||
|
kwargs["version"] = version
|
||||||
|
if label is not None:
|
||||||
|
kwargs["label"] = label
|
||||||
|
return lf.get_prompt(**kwargs)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: get_prompt_object(%s) failed: %s", name, exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def link_prompt_to_trace(
|
||||||
|
span: Any,
|
||||||
|
prompt_name: str,
|
||||||
|
*,
|
||||||
|
version: int | None = None,
|
||||||
|
label: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Link a Langfuse managed prompt to a span/observation.
|
||||||
|
|
||||||
|
Uses the SDK v4 ``prompt=`` parameter so that the prompt version
|
||||||
|
appears linked in the Langfuse UI with metrics tracking.
|
||||||
|
"""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None or isinstance(span, _NullSpan):
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt = get_prompt_object(prompt_name, version=version, label=label)
|
||||||
|
if prompt is not None:
|
||||||
|
span.update(prompt=prompt)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scoring helper ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def score_trace(
|
||||||
|
trace_id: str,
|
||||||
|
name: str,
|
||||||
|
value: float,
|
||||||
|
*,
|
||||||
|
comment: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: score_trace failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Shutdown ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def flush() -> None:
|
||||||
|
"""Flush pending Langfuse events."""
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is not None:
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: flush failed: %s", exc)
|
||||||
|
|
||||||
|
|
||||||
|
def shutdown() -> None:
|
||||||
|
"""Flush and close the Langfuse client."""
|
||||||
|
global _initialised, _disabled
|
||||||
|
lf = _get_client()
|
||||||
|
if lf is not None:
|
||||||
|
try:
|
||||||
|
lf.flush()
|
||||||
|
lf.shutdown()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("tracing: shutdown failed: %s", exc)
|
||||||
|
_initialised = False
|
||||||
|
_disabled = False
|
||||||
1
services/batch-agent/eval/__init__.py
Normal file
1
services/batch-agent/eval/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Batch Agent E2E evaluation harness."""
|
||||||
5
services/batch-agent/eval/__main__.py
Normal file
5
services/batch-agent/eval/__main__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Allow running the eval package as ``python -m eval``."""
|
||||||
|
|
||||||
|
from eval.cli import main
|
||||||
|
|
||||||
|
main()
|
||||||
285
services/batch-agent/eval/cli.py
Normal file
285
services/batch-agent/eval/cli.py
Normal file
@@ -0,0 +1,285 @@
|
|||||||
|
"""CLI entry point for the batch agent evaluation harness.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
# From services/batch-agent/:
|
||||||
|
python -m eval run # all agent fixtures, default model
|
||||||
|
python -m eval run --fixture=classify-invoices # single fixture
|
||||||
|
python -m eval run --models=gpt-4o,gpt-5.3-codex # multiple models
|
||||||
|
python -m eval run --mode=step1 # only step1 fixtures
|
||||||
|
python -m eval run --no-judge # skip LLM judge scoring
|
||||||
|
|
||||||
|
python -m eval interactive # interactive journey session
|
||||||
|
python -m eval interactive --fixture=journey-invoice-setup
|
||||||
|
python -m eval interactive --model=gpt-4o
|
||||||
|
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
|
||||||
|
|
||||||
|
python -m eval list # list all fixtures
|
||||||
|
python -m eval sync # sync fixtures to Langfuse datasets
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import asyncio
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
# Ensure the service root and repo root are in sys.path.
|
||||||
|
# Service root must come BEFORE repo root so its ``app/`` package
|
||||||
|
# shadows the monolith ``app/`` in the repo root.
|
||||||
|
_SERVICE_ROOT = Path(__file__).resolve().parent.parent
|
||||||
|
_REPO_ROOT = _SERVICE_ROOT.parent.parent
|
||||||
|
_sr = str(_SERVICE_ROOT)
|
||||||
|
_rr = str(_REPO_ROOT)
|
||||||
|
if _rr not in sys.path:
|
||||||
|
sys.path.insert(0, _rr)
|
||||||
|
# Always force service root to position 0 (python -m may have already
|
||||||
|
# added CWD further down the list, which loses to repo root).
|
||||||
|
if _sr in sys.path:
|
||||||
|
sys.path.remove(_sr)
|
||||||
|
sys.path.insert(0, _sr)
|
||||||
|
|
||||||
|
from eval.config import discover_fixtures, discover_journey_fixtures
|
||||||
|
from eval.runner import run_fixture_eval, print_results
|
||||||
|
from eval.interactive import run_interactive
|
||||||
|
from eval import langfuse_eval
|
||||||
|
|
||||||
|
|
||||||
|
def _setup_logging(verbose: bool) -> None:
|
||||||
|
level = logging.DEBUG if verbose else logging.INFO
|
||||||
|
logging.basicConfig(
|
||||||
|
level=level,
|
||||||
|
format="%(asctime)s %(name)-20s %(levelname)-5s %(message)s",
|
||||||
|
datefmt="%H:%M:%S",
|
||||||
|
)
|
||||||
|
# Quiet noisy libraries
|
||||||
|
for name in ("httpx", "httpcore", "openai", "litellm", "urllib3"):
|
||||||
|
logging.getLogger(name).setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
|
||||||
|
def _parse_args() -> argparse.Namespace:
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
description="Batch Agent E2E evaluation harness",
|
||||||
|
prog="python -m eval",
|
||||||
|
)
|
||||||
|
sub = parser.add_subparsers(dest="command", required=True)
|
||||||
|
|
||||||
|
# ── run ───────────────────────────────────────────────────────
|
||||||
|
run_cmd = sub.add_parser("run", help="Run evaluations")
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--fixture", "-f",
|
||||||
|
help="Run only the named fixture (default: all)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--models", "-m",
|
||||||
|
default="github_copilot/gpt-5.3-codex",
|
||||||
|
help="Comma-separated list of models to test (default: github_copilot/gpt-5.3-codex)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--mode",
|
||||||
|
default=None,
|
||||||
|
choices=["step1", "step2", "full"],
|
||||||
|
help="Only run fixtures with this mode (default: all)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--no-judge",
|
||||||
|
action="store_true",
|
||||||
|
help="Skip LLM-as-judge scoring",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--judge-model",
|
||||||
|
default="gpt-4o",
|
||||||
|
help="Model for LLM judge (default: gpt-4o)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument(
|
||||||
|
"--fixtures-dir",
|
||||||
|
default=None,
|
||||||
|
help="Path to fixtures directory (default: eval/fixtures/)",
|
||||||
|
)
|
||||||
|
run_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
|
||||||
|
# ── list ──────────────────────────────────────────────────────
|
||||||
|
list_cmd = sub.add_parser("list", help="List available fixtures")
|
||||||
|
list_cmd.add_argument("--fixtures-dir", default=None)
|
||||||
|
list_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
|
||||||
|
# ── sync ──────────────────────────────────────────────────────
|
||||||
|
sync_cmd = sub.add_parser("sync", help="Sync fixtures to Langfuse datasets")
|
||||||
|
sync_cmd.add_argument("--fixture", "-f", default=None, help="Sync only the named fixture")
|
||||||
|
sync_cmd.add_argument("--fixtures-dir", default=None)
|
||||||
|
sync_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
|
||||||
|
# ── interactive ───────────────────────────────────────────────
|
||||||
|
inter_cmd = sub.add_parser("interactive", help="Interactive journey session (human-in-the-loop)")
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--fixture", "-f",
|
||||||
|
help="Journey fixture to use (default: pick interactively)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--model", "-m",
|
||||||
|
default="github_copilot/gpt-5.3-codex",
|
||||||
|
help="Model for the journey AI (default: github_copilot/gpt-5.3-codex)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--judge-model",
|
||||||
|
default="gpt-4o",
|
||||||
|
help="Model for LLM judge (default: gpt-4o)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--fixtures-dir",
|
||||||
|
default=None,
|
||||||
|
help="Path to fixtures directory (default: eval/fixtures/)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument(
|
||||||
|
"--data-dir",
|
||||||
|
default=None,
|
||||||
|
help="Override sample data directory (e.g. path to private test files not in git)",
|
||||||
|
)
|
||||||
|
inter_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def _fixtures_dir(arg: str | None) -> Path | None:
|
||||||
|
if arg:
|
||||||
|
return Path(arg)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
async def _cmd_run(args: argparse.Namespace) -> None:
|
||||||
|
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
if not fixtures:
|
||||||
|
print("No fixtures found. Create YAML files in eval/fixtures/.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.fixture:
|
||||||
|
fixtures = [f for f in fixtures if f.name == args.fixture]
|
||||||
|
if not fixtures:
|
||||||
|
print(f"Fixture '{args.fixture}' not found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
models = [m.strip() for m in args.models.split(",")]
|
||||||
|
|
||||||
|
all_results = []
|
||||||
|
for fixture in fixtures:
|
||||||
|
if args.mode and fixture.mode != args.mode:
|
||||||
|
continue
|
||||||
|
results = await run_fixture_eval(
|
||||||
|
fixture,
|
||||||
|
models=models,
|
||||||
|
use_llm_judge=not args.no_judge,
|
||||||
|
judge_model=args.judge_model,
|
||||||
|
)
|
||||||
|
all_results.extend(results)
|
||||||
|
|
||||||
|
print_results(all_results)
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_list(args: argparse.Namespace) -> None:
|
||||||
|
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
|
||||||
|
if not fixtures and not journey_fixtures:
|
||||||
|
print("No fixtures found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if fixtures:
|
||||||
|
print(f"\n{'[Agent Fixtures]'}")
|
||||||
|
print(f"{'Name':<30} {'Mode':<6} {'Types':<25} {'Expected'}")
|
||||||
|
print("-" * 90)
|
||||||
|
for f in fixtures:
|
||||||
|
types = ", ".join(f.data_types)
|
||||||
|
n_expected = len(f.expected) + len(f.expected_classification)
|
||||||
|
print(f"{f.name:<30} {f.mode:<6} {types:<25} {n_expected}")
|
||||||
|
|
||||||
|
if journey_fixtures:
|
||||||
|
print(f"\n{'[Journey Fixtures]'}")
|
||||||
|
print(f"{'Name':<30} {'Types':<25} {'Messages':<10} {'Criteria'}")
|
||||||
|
print("-" * 90)
|
||||||
|
for f in journey_fixtures:
|
||||||
|
types = ", ".join(f.data_types)
|
||||||
|
print(f"{f.name:<30} {types:<25} {len(f.user_messages):<10} {len(f.expected_template_criteria)}")
|
||||||
|
|
||||||
|
print()
|
||||||
|
|
||||||
|
|
||||||
|
def _cmd_sync(args: argparse.Namespace) -> None:
|
||||||
|
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
|
||||||
|
if args.fixture:
|
||||||
|
fixtures = [f for f in fixtures if f.name == args.fixture]
|
||||||
|
journey_fixtures = [f for f in journey_fixtures if f.name == args.fixture]
|
||||||
|
|
||||||
|
if not fixtures and not journey_fixtures:
|
||||||
|
print("No fixtures to sync.")
|
||||||
|
return
|
||||||
|
|
||||||
|
for fixture in fixtures:
|
||||||
|
name = langfuse_eval.sync_fixture_to_dataset(fixture)
|
||||||
|
if name:
|
||||||
|
print(f"Synced: {fixture.name} → {name}")
|
||||||
|
else:
|
||||||
|
print(f"Skipped: {fixture.name} (Langfuse not configured)")
|
||||||
|
|
||||||
|
for fixture in journey_fixtures:
|
||||||
|
name = langfuse_eval.sync_journey_fixture_to_dataset(fixture)
|
||||||
|
if name:
|
||||||
|
print(f"Synced: {fixture.name} → {name}")
|
||||||
|
else:
|
||||||
|
print(f"Skipped: {fixture.name} (Langfuse not configured)")
|
||||||
|
|
||||||
|
|
||||||
|
async def _cmd_interactive(args: argparse.Namespace) -> None:
|
||||||
|
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||||
|
if not journey_fixtures:
|
||||||
|
print("No journey fixtures found. Create YAML files with type: journey in eval/fixtures/.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if args.fixture:
|
||||||
|
fixtures = [f for f in journey_fixtures if f.name == args.fixture]
|
||||||
|
if not fixtures:
|
||||||
|
print(f"Journey fixture '{args.fixture}' not found.")
|
||||||
|
return
|
||||||
|
fixture = fixtures[0]
|
||||||
|
elif len(journey_fixtures) == 1:
|
||||||
|
fixture = journey_fixtures[0]
|
||||||
|
else:
|
||||||
|
# Let user pick
|
||||||
|
print("\nAvailable journey fixtures:")
|
||||||
|
for i, f in enumerate(journey_fixtures, 1):
|
||||||
|
print(f" {i}. {f.name} — {f.description[:60]}")
|
||||||
|
print()
|
||||||
|
try:
|
||||||
|
choice = int(input("Pick a fixture number: ").strip()) - 1
|
||||||
|
fixture = journey_fixtures[choice]
|
||||||
|
except (ValueError, IndexError, EOFError, KeyboardInterrupt):
|
||||||
|
print("Invalid choice.")
|
||||||
|
return
|
||||||
|
|
||||||
|
await run_interactive(
|
||||||
|
fixture,
|
||||||
|
model=args.model,
|
||||||
|
judge_model=args.judge_model,
|
||||||
|
data_dir=Path(args.data_dir).resolve() if args.data_dir else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> None:
|
||||||
|
args = _parse_args()
|
||||||
|
_setup_logging(args.verbose)
|
||||||
|
|
||||||
|
if args.command == "run":
|
||||||
|
asyncio.run(_cmd_run(args))
|
||||||
|
elif args.command == "interactive":
|
||||||
|
asyncio.run(_cmd_interactive(args))
|
||||||
|
elif args.command == "list":
|
||||||
|
_cmd_list(args)
|
||||||
|
elif args.command == "sync":
|
||||||
|
_cmd_sync(args)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
220
services/batch-agent/eval/config.py
Normal file
220
services/batch-agent/eval/config.py
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
"""Eval configuration — YAML fixture loader and dataclasses.
|
||||||
|
|
||||||
|
Fixtures come in two families:
|
||||||
|
|
||||||
|
1. **Agent fixtures** — test the batch agent pipeline.
|
||||||
|
Three modes controlled by ``mode``:
|
||||||
|
|
||||||
|
``step1`` — classification prompt only.
|
||||||
|
``step2`` — processing prompt only.
|
||||||
|
``full`` — both steps in sequence.
|
||||||
|
|
||||||
|
2. **Journey fixtures** — test the prompt-template builder conversation
|
||||||
|
(unchanged).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
EvalMode = Literal["step1", "step2", "full"]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpectedRecord:
|
||||||
|
"""A single expected extraction result.
|
||||||
|
|
||||||
|
Only the fields specified are checked — unspecified fields are ignored.
|
||||||
|
"""
|
||||||
|
|
||||||
|
table: str # tasks | notes | timelines | projects
|
||||||
|
fields: dict[str, Any] # field_name → expected_value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExpectedClassification:
|
||||||
|
"""Expected output of step-1 classification for one file."""
|
||||||
|
|
||||||
|
file: str # relative path to the sample file
|
||||||
|
project_id: str # expected matched project id, or "new"
|
||||||
|
domains: list[str] # expected domain list
|
||||||
|
new_project_name: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalFixture:
|
||||||
|
"""A complete test scenario loaded from YAML.
|
||||||
|
|
||||||
|
``mode`` determines which pipeline steps are exercised:
|
||||||
|
|
||||||
|
- **step1**: only ``_classify_file``
|
||||||
|
- **step2**: only the processing LLM + tool loop
|
||||||
|
- **full**: both steps in sequence (``run_local_agent``)
|
||||||
|
"""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
mode: EvalMode
|
||||||
|
directory: str # relative path to sample files
|
||||||
|
data_types: list[str]
|
||||||
|
file_extensions: list[str]
|
||||||
|
models: list[str] # if empty, use CLI default
|
||||||
|
fixture_path: Path = field(default_factory=lambda: Path("."))
|
||||||
|
|
||||||
|
# ── Step-1 inputs (classification) ───────────────────────────
|
||||||
|
domain_definitions: str = ""
|
||||||
|
projects_list: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
custom_step1_prompt: str = ""
|
||||||
|
|
||||||
|
# ── Step-2 inputs (processing) ───────────────────────────────
|
||||||
|
existing_context: str = ""
|
||||||
|
project_context: str = ""
|
||||||
|
custom_prompt_section: str = ""
|
||||||
|
|
||||||
|
# ── Seed records for mock executor ───────────────────────────
|
||||||
|
seed_records: dict[str, list[dict]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
# ── Expected outputs ─────────────────────────────────────────
|
||||||
|
expected_classification: list[ExpectedClassification] = field(default_factory=list)
|
||||||
|
expected: list[ExpectedRecord] = field(default_factory=list)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fixture_dir(self) -> Path:
|
||||||
|
"""Absolute path to the sample files directory."""
|
||||||
|
return self.fixture_path.parent / self.directory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: Path) -> "EvalFixture":
|
||||||
|
"""Load a fixture from a YAML file."""
|
||||||
|
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
mode: EvalMode = raw.get("mode", "full")
|
||||||
|
|
||||||
|
# Parse expected records (step2/full)
|
||||||
|
expected: list[ExpectedRecord] = []
|
||||||
|
for table, records in (raw.get("expected") or {}).items():
|
||||||
|
for rec in records:
|
||||||
|
expected.append(ExpectedRecord(table=table, fields=rec))
|
||||||
|
|
||||||
|
# Parse expected classification (step1/full)
|
||||||
|
expected_classification: list[ExpectedClassification] = []
|
||||||
|
for item in raw.get("expected_classification") or []:
|
||||||
|
expected_classification.append(ExpectedClassification(
|
||||||
|
file=item["file"],
|
||||||
|
project_id=item["project_id"],
|
||||||
|
domains=item.get("domains", []),
|
||||||
|
new_project_name=item.get("new_project_name"),
|
||||||
|
))
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=raw["name"],
|
||||||
|
description=raw.get("description", ""),
|
||||||
|
mode=mode,
|
||||||
|
directory=raw.get("directory", "sample_files"),
|
||||||
|
data_types=raw.get("data_types", ["tasks"]),
|
||||||
|
file_extensions=raw.get("file_extensions", []),
|
||||||
|
models=raw.get("models", []),
|
||||||
|
fixture_path=path,
|
||||||
|
# Step-1 inputs
|
||||||
|
domain_definitions=raw.get("domain_definitions", ""),
|
||||||
|
projects_list=raw.get("projects_list", []),
|
||||||
|
# Step-2 inputs
|
||||||
|
existing_context=raw.get("existing_context", ""),
|
||||||
|
project_context=raw.get("project_context", ""),
|
||||||
|
custom_prompt_section=raw.get("custom_prompt_section", ""),
|
||||||
|
# Shared
|
||||||
|
seed_records=raw.get("seed_records", {}),
|
||||||
|
expected_classification=expected_classification,
|
||||||
|
expected=expected,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_fixtures(fixtures_dir: Path | None = None) -> list[EvalFixture]:
|
||||||
|
"""Find and load all YAML fixtures in the fixtures directory."""
|
||||||
|
if fixtures_dir is None:
|
||||||
|
fixtures_dir = Path(__file__).parent / "fixtures"
|
||||||
|
|
||||||
|
fixtures: list[EvalFixture] = []
|
||||||
|
if not fixtures_dir.is_dir():
|
||||||
|
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
|
||||||
|
return fixtures
|
||||||
|
|
||||||
|
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
||||||
|
if raw.get("type") == "journey":
|
||||||
|
continue # Skip journey fixtures
|
||||||
|
fixtures.append(EvalFixture.from_yaml(yaml_path))
|
||||||
|
logger.info("eval: loaded fixture %s from %s", fixtures[-1].name, yaml_path.name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("eval: failed to load fixture %s: %s", yaml_path.name, exc)
|
||||||
|
|
||||||
|
return fixtures
|
||||||
|
|
||||||
|
|
||||||
|
# ── Journey fixtures ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JourneyFixture:
|
||||||
|
"""A journey test scenario — tests the prompt_template builder conversation."""
|
||||||
|
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
directory: str # relative path to sample files
|
||||||
|
data_types: list[str]
|
||||||
|
expected_template_criteria: list[str] # what the template should contain/satisfy
|
||||||
|
user_messages: list[str] = field(default_factory=list) # for automated journey runs (unused in interactive mode)
|
||||||
|
models: list[str] = field(default_factory=list)
|
||||||
|
fixture_path: Path = field(default_factory=lambda: Path("."))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def fixture_dir(self) -> Path:
|
||||||
|
"""Absolute path to the sample files directory."""
|
||||||
|
return self.fixture_path.parent / self.directory
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_yaml(cls, path: Path) -> "JourneyFixture":
|
||||||
|
"""Load a journey fixture from a YAML file."""
|
||||||
|
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
return cls(
|
||||||
|
name=raw["name"],
|
||||||
|
description=raw.get("description", ""),
|
||||||
|
directory=raw.get("directory", "sample_files"),
|
||||||
|
data_types=raw.get("data_types", ["tasks"]),
|
||||||
|
user_messages=raw.get("user_messages", []),
|
||||||
|
expected_template_criteria=raw.get("expected_template_criteria", []),
|
||||||
|
models=raw.get("models", []),
|
||||||
|
fixture_path=path,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def discover_journey_fixtures(fixtures_dir: Path | None = None) -> list[JourneyFixture]:
|
||||||
|
"""Find and load all journey YAML fixtures in the fixtures directory."""
|
||||||
|
if fixtures_dir is None:
|
||||||
|
fixtures_dir = Path(__file__).parent / "fixtures"
|
||||||
|
|
||||||
|
fixtures: list[JourneyFixture] = []
|
||||||
|
if not fixtures_dir.is_dir():
|
||||||
|
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
|
||||||
|
return fixtures
|
||||||
|
|
||||||
|
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
|
||||||
|
try:
|
||||||
|
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
||||||
|
if raw.get("type") != "journey":
|
||||||
|
continue
|
||||||
|
fixtures.append(JourneyFixture.from_yaml(yaml_path))
|
||||||
|
logger.info("eval: loaded journey fixture %s from %s", fixtures[-1].name, yaml_path.name)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("eval: failed to load journey fixture %s: %s", yaml_path.name, exc)
|
||||||
|
|
||||||
|
return fixtures
|
||||||
40
services/batch-agent/eval/fixtures/classify_invoices.yaml
Normal file
40
services/batch-agent/eval/fixtures/classify_invoices.yaml
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
# Fixture: classify-invoices (step1)
|
||||||
|
# Tests _STEP1_SYSTEM_PROMPT — file classification and project matching.
|
||||||
|
# Verifies that the LLM correctly matches files to existing projects
|
||||||
|
# and identifies the right data domains.
|
||||||
|
|
||||||
|
name: classify-invoices
|
||||||
|
mode: step1
|
||||||
|
description: >
|
||||||
|
Test file classification on Italian freelance invoices and meeting notes.
|
||||||
|
Verifies project matching and domain identification.
|
||||||
|
|
||||||
|
directory: sample_files/invoices
|
||||||
|
data_types: [tasks, notes, timelines]
|
||||||
|
file_extensions: [txt, md]
|
||||||
|
|
||||||
|
# ── Step-1 prompt variables ──────────────────────────────────────
|
||||||
|
domain_definitions: |
|
||||||
|
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
|
||||||
|
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
|
||||||
|
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
|
||||||
|
|
||||||
|
projects_list:
|
||||||
|
- id: "proj-web-redesign"
|
||||||
|
name: "Redesign Sito Web Corporate"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
||||||
|
- id: "proj-ecommerce"
|
||||||
|
name: "E-Commerce FashionStore"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
||||||
|
|
||||||
|
# ── Expected classification results ─────────────────────────────
|
||||||
|
expected_classification:
|
||||||
|
- file: "sample_files/invoices/fattura_042.txt"
|
||||||
|
project_id: "proj-web-redesign"
|
||||||
|
domains: [tasks, notes, timelines]
|
||||||
|
|
||||||
|
- file: "sample_files/invoices/meeting_ecommerce.md"
|
||||||
|
project_id: "proj-ecommerce"
|
||||||
|
domains: [tasks, notes, timelines]
|
||||||
108
services/batch-agent/eval/fixtures/full_invoices.yaml
Normal file
108
services/batch-agent/eval/fixtures/full_invoices.yaml
Normal file
@@ -0,0 +1,108 @@
|
|||||||
|
# Fixture: full-invoices (full)
|
||||||
|
# Tests both _STEP1_SYSTEM_PROMPT and _PROCESSING_SYSTEM_PROMPT in sequence
|
||||||
|
# via run_local_agent(). Verifies end-to-end classification + extraction.
|
||||||
|
|
||||||
|
name: full-invoices
|
||||||
|
mode: full
|
||||||
|
description: >
|
||||||
|
End-to-end test: classify Italian invoices/meeting notes into the
|
||||||
|
correct project, then extract tasks, notes, and timeline events.
|
||||||
|
|
||||||
|
directory: sample_files/invoices
|
||||||
|
data_types: [tasks, notes, timelines]
|
||||||
|
file_extensions: [txt, md]
|
||||||
|
|
||||||
|
# ── Step-1 prompt variables ──────────────────────────────────────
|
||||||
|
domain_definitions: |
|
||||||
|
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
|
||||||
|
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
|
||||||
|
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
|
||||||
|
|
||||||
|
projects_list:
|
||||||
|
- id: "proj-web-redesign"
|
||||||
|
name: "Redesign Sito Web Corporate"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
||||||
|
- id: "proj-ecommerce"
|
||||||
|
name: "E-Commerce FashionStore"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
||||||
|
|
||||||
|
# ── Step-2 prompt variables ──────────────────────────────────────
|
||||||
|
existing_context: |
|
||||||
|
Existing tasks:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
Existing notes:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
Existing timelines:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
project_context: ""
|
||||||
|
|
||||||
|
custom_prompt_section: |
|
||||||
|
User instructions:
|
||||||
|
Estrai i dati dai file come segue:
|
||||||
|
- TASK: ogni azione da fare, deliverable, o item con scadenza.
|
||||||
|
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
|
||||||
|
Mappa "media priorità" → priority: medium.
|
||||||
|
Mappa "bassa priorità" → priority: low.
|
||||||
|
Se un item è marcato come "completato" o [x], impostalo status: done.
|
||||||
|
Altrimenti status: todo.
|
||||||
|
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
|
||||||
|
- TIMELINE: date di scadenza, milestone, meeting futuri.
|
||||||
|
Imposta sempre isAiSuggested=1.
|
||||||
|
|
||||||
|
# ── Seed records (pre-existing DB state) ─────────────────────────
|
||||||
|
seed_records:
|
||||||
|
projects:
|
||||||
|
- id: "proj-web-redesign"
|
||||||
|
name: "Redesign Sito Web Corporate"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
||||||
|
- id: "proj-ecommerce"
|
||||||
|
name: "E-Commerce FashionStore"
|
||||||
|
status: "active"
|
||||||
|
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
||||||
|
tasks: []
|
||||||
|
notes: []
|
||||||
|
timelines: []
|
||||||
|
|
||||||
|
# ── Expected classification (step 1) ─────────────────────────────
|
||||||
|
expected_classification:
|
||||||
|
- file: "sample_files/invoices/fattura_042.txt"
|
||||||
|
project_id: "proj-web-redesign"
|
||||||
|
domains: [tasks, notes, timelines]
|
||||||
|
|
||||||
|
- file: "sample_files/invoices/meeting_ecommerce.md"
|
||||||
|
project_id: "proj-ecommerce"
|
||||||
|
domains: [tasks, notes, timelines]
|
||||||
|
|
||||||
|
# ── Expected extractions (step 2) ────────────────────────────────
|
||||||
|
expected:
|
||||||
|
tasks:
|
||||||
|
- title: "Sviluppo frontend React"
|
||||||
|
priority: "high"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Integrazione API backend"
|
||||||
|
priority: "medium"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Testing cross-browser e fix bug responsive"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Preparare wireframe homepage"
|
||||||
|
priority: "high"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Setup progetto Next.js e configurare CI/CD"
|
||||||
|
priority: "medium"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Ricerca plugin Stripe per gestione abbonamenti"
|
||||||
|
priority: "low"
|
||||||
|
status: "todo"
|
||||||
|
|
||||||
|
notes:
|
||||||
|
- title: "Meeting Kickoff Progetto E-Commerce"
|
||||||
|
|
||||||
|
timelines:
|
||||||
|
- title: "MVP E-Commerce pronto"
|
||||||
|
- title: "Meeting di revisione"
|
||||||
@@ -0,0 +1,28 @@
|
|||||||
|
# Journey Fixture: journey-invoice-setup
|
||||||
|
# Used by `python -m eval interactive` for human-in-the-loop testing
|
||||||
|
# of the journey chatbot's prompt-building conversation.
|
||||||
|
|
||||||
|
type: journey
|
||||||
|
name: journey-invoice-setup
|
||||||
|
description: >
|
||||||
|
Interactive test for the journey chatbot — explore a directory of
|
||||||
|
Italian invoices and meeting notes, answer the chatbot's questions,
|
||||||
|
and verify it produces a well-structured prompt_template for data
|
||||||
|
extraction.
|
||||||
|
|
||||||
|
directory: sample_files/invoices
|
||||||
|
data_types: [tasks, notes, timelines, projects]
|
||||||
|
|
||||||
|
# Criteria the generated prompt_template must satisfy
|
||||||
|
# Each is scored 0-1 by an LLM judge
|
||||||
|
expected_template_criteria:
|
||||||
|
- "Mentions creating tasks from action items and work descriptions"
|
||||||
|
- "Mentions creating notes from meeting summaries"
|
||||||
|
- "Mentions extracting timeline events from deadlines and meeting dates"
|
||||||
|
- "Mentions creating projects from relevant information"
|
||||||
|
- "Sets isAiSuggested=1 on all created records"
|
||||||
|
- "Does NOT include projectId assignment logic"
|
||||||
|
- "Uses camelCase field names (title, status, priority, dueDate, content)"
|
||||||
|
|
||||||
|
# Models to test (empty = use CLI --models default)
|
||||||
|
models: []
|
||||||
81
services/batch-agent/eval/fixtures/process_invoices.yaml
Normal file
81
services/batch-agent/eval/fixtures/process_invoices.yaml
Normal file
@@ -0,0 +1,81 @@
|
|||||||
|
# Fixture: process-invoices (step2)
|
||||||
|
# Tests _PROCESSING_SYSTEM_PROMPT — data extraction & tool calling.
|
||||||
|
# The classification step is skipped; prompt variables are injected directly.
|
||||||
|
|
||||||
|
name: process-invoices
|
||||||
|
mode: step2
|
||||||
|
description: >
|
||||||
|
Test data extraction from Italian freelance invoices.
|
||||||
|
Verifies correct record creation via tool calls with the right
|
||||||
|
fields, priorities, and status values.
|
||||||
|
|
||||||
|
directory: sample_files/invoices
|
||||||
|
data_types: [tasks, notes, timelines]
|
||||||
|
file_extensions: [txt, md]
|
||||||
|
|
||||||
|
# ── Step-2 prompt variables ──────────────────────────────────────
|
||||||
|
existing_context: |
|
||||||
|
Existing tasks:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
Existing notes:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
Existing timelines:
|
||||||
|
(none)
|
||||||
|
|
||||||
|
project_context: >
|
||||||
|
Project: Redesign Sito Web Corporate (id: proj-web-redesign).
|
||||||
|
Always set projectId to this id on every record you create.
|
||||||
|
|
||||||
|
custom_prompt_section: |
|
||||||
|
User instructions:
|
||||||
|
Estrai i dati dai file come segue:
|
||||||
|
- TASK: ogni azione da fare, deliverable, o item con scadenza.
|
||||||
|
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
|
||||||
|
Mappa "media priorità" → priority: medium.
|
||||||
|
Mappa "bassa priorità" → priority: low.
|
||||||
|
Se un item è marcato come "completato" o [x], impostalo status: done.
|
||||||
|
Altrimenti status: todo.
|
||||||
|
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
|
||||||
|
Il titolo deve essere descrittivo. Il content deve includere tutti i dettagli.
|
||||||
|
- TIMELINE: date di scadenza, milestone, meeting futuri.
|
||||||
|
Imposta sempre isAiSuggested=1.
|
||||||
|
|
||||||
|
# ── Seed records (pre-existing DB state) ─────────────────────────
|
||||||
|
seed_records:
|
||||||
|
projects:
|
||||||
|
- id: "proj-web-redesign"
|
||||||
|
name: "Redesign Sito Web Corporate"
|
||||||
|
status: "active"
|
||||||
|
tasks: []
|
||||||
|
notes: []
|
||||||
|
timelines: []
|
||||||
|
|
||||||
|
# ── Expected extractions ─────────────────────────────────────────
|
||||||
|
expected:
|
||||||
|
tasks:
|
||||||
|
- title: "Sviluppo frontend React"
|
||||||
|
priority: "high"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Integrazione API backend"
|
||||||
|
priority: "medium"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Testing cross-browser e fix bug responsive"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Preparare wireframe homepage"
|
||||||
|
priority: "high"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Setup progetto Next.js e configurare CI/CD"
|
||||||
|
priority: "medium"
|
||||||
|
status: "todo"
|
||||||
|
- title: "Ricerca plugin Stripe per gestione abbonamenti"
|
||||||
|
priority: "low"
|
||||||
|
status: "todo"
|
||||||
|
|
||||||
|
notes:
|
||||||
|
- title: "Meeting Kickoff Progetto E-Commerce"
|
||||||
|
|
||||||
|
timelines:
|
||||||
|
- title: "MVP E-Commerce pronto"
|
||||||
|
- title: "Meeting di revisione"
|
||||||
@@ -0,0 +1,18 @@
|
|||||||
|
FATTURA N. 2026-0042
|
||||||
|
Data: 15 Marzo 2026
|
||||||
|
Cliente: Studio Architettura Bianchi
|
||||||
|
|
||||||
|
Progetto: Redesign Sito Web Corporate
|
||||||
|
|
||||||
|
Descrizione lavori:
|
||||||
|
- Sviluppo frontend React (40 ore) — URGENTE, completare entro 20 marzo
|
||||||
|
- Integrazione API backend (20 ore) — priorità media
|
||||||
|
- Design UI/UX mockup homepage (8 ore) — completato
|
||||||
|
- Testing cross-browser e fix bug responsive (12 ore) — da iniziare
|
||||||
|
|
||||||
|
Totale: €4.800,00 + IVA
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Meeting di revisione previsto per il 18 marzo alle 10:00.
|
||||||
|
Il cliente ha richiesto modifiche al layout mobile della sezione contatti.
|
||||||
|
Attendere conferma budget aggiuntivo per sezione blog.
|
||||||
@@ -0,0 +1,25 @@
|
|||||||
|
# Meeting Notes - Kickoff Progetto E-Commerce
|
||||||
|
|
||||||
|
**Data:** 10 Marzo 2026
|
||||||
|
**Partecipanti:** Marco R., Giulia T., Cliente (FashionStore srl)
|
||||||
|
|
||||||
|
## Decisioni prese
|
||||||
|
|
||||||
|
1. **Piattaforma**: Next.js + Stripe per i pagamenti
|
||||||
|
2. **Timeline**: MVP pronto entro 30 aprile 2026
|
||||||
|
3. **Budget**: €12.000 totale, €4.000 anticipo già ricevuto
|
||||||
|
|
||||||
|
## Action items
|
||||||
|
|
||||||
|
- [ ] Marco: preparare wireframe homepage entro 14 marzo — ALTA PRIORITÀ
|
||||||
|
- [ ] Giulia: setup progetto Next.js e configurare CI/CD — media priorità
|
||||||
|
- [ ] Marco: ricerca plugin Stripe per gestione abbonamenti — bassa priorità
|
||||||
|
- [x] Giulia: inviare contratto firmato al cliente — COMPLETATO
|
||||||
|
|
||||||
|
## Note aggiuntive
|
||||||
|
|
||||||
|
Il cliente vuole un design minimalista, ispirato a Zara.com.
|
||||||
|
Colori primari: nero, bianco, oro.
|
||||||
|
Font: Inter per body, Playfair Display per headings.
|
||||||
|
|
||||||
|
Prossimo meeting: 24 marzo 2026 ore 15:00.
|
||||||
471
services/batch-agent/eval/interactive.py
Normal file
471
services/batch-agent/eval/interactive.py
Normal file
@@ -0,0 +1,471 @@
|
|||||||
|
"""Interactive journey session — human-in-the-loop CLI conversation.
|
||||||
|
|
||||||
|
Flow:
|
||||||
|
1. Show the system prompt used by the journey AI.
|
||||||
|
2. Start the journey (AI explores files, asks first question).
|
||||||
|
3. User types responses in the terminal — AI replies.
|
||||||
|
4. User types `/done` to end the conversation.
|
||||||
|
5. User writes a comment about the interaction quality.
|
||||||
|
6. LLM judge scores the conversation + generated template.
|
||||||
|
7. Results are reported to Langfuse.
|
||||||
|
|
||||||
|
Usage::
|
||||||
|
|
||||||
|
python -m eval interactive # pick a fixture interactively
|
||||||
|
python -m eval interactive --fixture=journey-invoice-setup
|
||||||
|
python -m eval interactive --model=gpt-4o
|
||||||
|
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from eval.config import JourneyFixture, discover_journey_fixtures
|
||||||
|
from eval.mock_executor import MockExecutor
|
||||||
|
from eval import langfuse_eval
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Special commands ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_CMD_DONE = "/done"
|
||||||
|
_CMD_QUIT = "/quit"
|
||||||
|
_CMD_TEMPLATE = "/template"
|
||||||
|
_CMD_HELP = "/help"
|
||||||
|
|
||||||
|
_HELP_TEXT = f"""\
|
||||||
|
{_CMD_DONE} — End the conversation and proceed to evaluation
|
||||||
|
{_CMD_QUIT} — Abort without evaluation
|
||||||
|
{_CMD_TEMPLATE} — Show the generated template (if any)
|
||||||
|
{_CMD_HELP} — Show this help"""
|
||||||
|
|
||||||
|
# ── Terminal colours (ANSI) ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
_C_RESET = "\033[0m"
|
||||||
|
_C_BOLD = "\033[1m"
|
||||||
|
_C_DIM = "\033[2m"
|
||||||
|
_C_CYAN = "\033[36m"
|
||||||
|
_C_GREEN = "\033[32m"
|
||||||
|
_C_YELLOW = "\033[33m"
|
||||||
|
_C_MAGENTA = "\033[35m"
|
||||||
|
_C_RED = "\033[31m"
|
||||||
|
_C_BLUE = "\033[34m"
|
||||||
|
|
||||||
|
|
||||||
|
def _print_header(text: str) -> None:
|
||||||
|
print(f"\n{_C_BOLD}{_C_CYAN}{'═' * 80}")
|
||||||
|
print(f" {text}")
|
||||||
|
print(f"{'═' * 80}{_C_RESET}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_ai(text: str) -> None:
|
||||||
|
print(f"\n{_C_GREEN}{_C_BOLD}AI:{_C_RESET} {text}\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_system(text: str) -> None:
|
||||||
|
print(f"{_C_DIM}{text}{_C_RESET}")
|
||||||
|
|
||||||
|
|
||||||
|
def _print_score(label: str, score: float) -> None:
|
||||||
|
if score >= 0.7:
|
||||||
|
color = _C_GREEN
|
||||||
|
tag = "PASS"
|
||||||
|
elif score >= 0.4:
|
||||||
|
color = _C_YELLOW
|
||||||
|
tag = "PARTIAL"
|
||||||
|
else:
|
||||||
|
color = _C_RED
|
||||||
|
tag = "FAIL"
|
||||||
|
print(f" {color}{tag:>7}{_C_RESET} ({score:.1f}) {label}")
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class InteractiveResult:
|
||||||
|
fixture_name: str
|
||||||
|
model: str
|
||||||
|
judge_model: str
|
||||||
|
prompt_template: str | None
|
||||||
|
conversation: list[dict[str, str]]
|
||||||
|
user_comment: str
|
||||||
|
done: bool
|
||||||
|
criteria_scores: dict[str, float]
|
||||||
|
overall_score: float
|
||||||
|
judge_reasoning: str
|
||||||
|
elapsed_seconds: float
|
||||||
|
|
||||||
|
def summary(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"fixture": self.fixture_name,
|
||||||
|
"model": self.model,
|
||||||
|
"judge_model": self.judge_model,
|
||||||
|
"done": self.done,
|
||||||
|
"turns": len([c for c in self.conversation if c["role"] == "user"]),
|
||||||
|
"overall_score": round(self.overall_score, 3),
|
||||||
|
"user_comment": self.user_comment,
|
||||||
|
"criteria_scores": {k: round(v, 3) for k, v in self.criteria_scores.items()},
|
||||||
|
"elapsed_s": round(self.elapsed_seconds, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM judge ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_INTERACTIVE_JUDGE_SYSTEM = """\
|
||||||
|
You are an evaluation judge for AI-generated prompt templates produced during
|
||||||
|
an interactive conversation between a human and a journey chatbot.
|
||||||
|
|
||||||
|
The chatbot explored a directory and through multi-turn conversation with the
|
||||||
|
user produced a prompt_template — an instruction set for a data-extraction agent.
|
||||||
|
|
||||||
|
You have access to:
|
||||||
|
- The full conversation transcript
|
||||||
|
- The generated prompt_template (if any)
|
||||||
|
- The user's own comment about the interaction
|
||||||
|
- A list of quality criteria
|
||||||
|
|
||||||
|
Score each criterion from 0 to 1:
|
||||||
|
- 1.0: Fully satisfied
|
||||||
|
- 0.5: Partially satisfied
|
||||||
|
- 0.0: Not satisfied
|
||||||
|
|
||||||
|
Also provide an overall_quality score (0-1) evaluating the conversation flow,
|
||||||
|
how well the AI understood the user, and the template quality.
|
||||||
|
|
||||||
|
Respond with ONLY a JSON object:
|
||||||
|
{
|
||||||
|
"criteria_scores": {"criterion_1": 0.8, ...},
|
||||||
|
"overall_quality": 0.85,
|
||||||
|
"reasoning": "Brief explanation covering both conversation quality and template accuracy"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def _judge_interactive(
|
||||||
|
conversation: list[dict[str, str]],
|
||||||
|
prompt_template: str | None,
|
||||||
|
user_comment: str,
|
||||||
|
criteria: list[str],
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> tuple[dict[str, float], float, str]:
|
||||||
|
"""Score an interactive session. Returns (criteria_scores, overall_quality, reasoning)."""
|
||||||
|
from shared.llm import get_llm
|
||||||
|
|
||||||
|
llm = get_llm(model=judge_model, temperature=0)
|
||||||
|
|
||||||
|
conv_text = "\n".join(
|
||||||
|
f"{'USER' if t['role'] == 'user' else 'AI'}: {t['content']}"
|
||||||
|
for t in conversation
|
||||||
|
)
|
||||||
|
criteria_text = "\n".join(f" {i+1}. {c}" for i, c in enumerate(criteria))
|
||||||
|
|
||||||
|
user_content = (
|
||||||
|
f"## Conversation transcript\n```\n{conv_text}\n```\n\n"
|
||||||
|
f"## Generated prompt_template\n```\n{prompt_template or '(none — conversation did not complete)'}\n```\n\n"
|
||||||
|
f"## User's comment\n{user_comment}\n\n"
|
||||||
|
f"## Criteria to evaluate\n{criteria_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=_INTERACTIVE_JUDGE_SYSTEM),
|
||||||
|
HumanMessage(content=user_content),
|
||||||
|
])
|
||||||
|
raw = response.content.strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
|
||||||
|
scores_raw = parsed.get("criteria_scores", parsed.get("scores", {}))
|
||||||
|
criteria_scores: dict[str, float] = {}
|
||||||
|
for i, criterion in enumerate(criteria):
|
||||||
|
key_candidates = [f"criterion_{i+1}", criterion, criterion[:50], str(i + 1)]
|
||||||
|
score = 0.0
|
||||||
|
for key in key_candidates:
|
||||||
|
if key in scores_raw:
|
||||||
|
score = float(scores_raw[key])
|
||||||
|
break
|
||||||
|
if score == 0.0 and i < len(scores_raw):
|
||||||
|
score = float(list(scores_raw.values())[i])
|
||||||
|
criteria_scores[criterion] = score
|
||||||
|
|
||||||
|
overall = float(parsed.get("overall_quality", 0.0))
|
||||||
|
reasoning = str(parsed.get("reasoning", ""))
|
||||||
|
return criteria_scores, overall, reasoning
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("interactive judge failed: %s", exc)
|
||||||
|
return {c: 0.0 for c in criteria}, 0.0, f"Judge error: {exc}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Interactive session ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_interactive(
|
||||||
|
fixture: JourneyFixture,
|
||||||
|
*,
|
||||||
|
model: str = "gpt-4o",
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
data_dir: Path | None = None,
|
||||||
|
) -> InteractiveResult:
|
||||||
|
"""Run an interactive journey session in the terminal.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
data_dir :
|
||||||
|
If set, overrides the fixture's sample-file directory. The LLM
|
||||||
|
will explore this folder instead of the default
|
||||||
|
``fixtures/sample_files/…``. Useful for private test data that
|
||||||
|
shouldn't be committed to git.
|
||||||
|
"""
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.ws_context import set_current_user, clear_current_user
|
||||||
|
from app.journey import (
|
||||||
|
handle_journey_start,
|
||||||
|
handle_journey_message,
|
||||||
|
_build_system_prompt,
|
||||||
|
)
|
||||||
|
|
||||||
|
# When --data-dir is given, the MockExecutor's root becomes
|
||||||
|
# data_dir's parent and the journey directory is data_dir's name.
|
||||||
|
# This way the LLM sees a meaningful directory name (not ".") and
|
||||||
|
# MockExecutor resolves paths correctly.
|
||||||
|
# Otherwise, use the fixture's YAML parent and its relative path.
|
||||||
|
if data_dir:
|
||||||
|
mock_root = data_dir.parent
|
||||||
|
journey_directory = data_dir.name
|
||||||
|
else:
|
||||||
|
mock_root = fixture.fixture_path.parent
|
||||||
|
journey_directory = fixture.directory
|
||||||
|
|
||||||
|
mock = MockExecutor(
|
||||||
|
fixture_dir=mock_root,
|
||||||
|
seed_records={},
|
||||||
|
)
|
||||||
|
|
||||||
|
original_model = settings.LLM_MODEL
|
||||||
|
settings.LLM_MODEL = model
|
||||||
|
eval_user_id = f"interactive-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
# ── Show system prompt ───────────────────────────────────────
|
||||||
|
system_prompt = _build_system_prompt(journey_directory, fixture.data_types)
|
||||||
|
|
||||||
|
_print_header("SYSTEM PROMPT")
|
||||||
|
print(f"{_C_DIM}{system_prompt}{_C_RESET}")
|
||||||
|
|
||||||
|
_print_header(f"INTERACTIVE JOURNEY | fixture: {fixture.name} | model: {model}")
|
||||||
|
print(f" Data dir: {mock_root}")
|
||||||
|
print(f" Type your responses. Commands: {_CMD_DONE}, {_CMD_QUIT}, {_CMD_TEMPLATE}, {_CMD_HELP}")
|
||||||
|
print(f" Judge model: {judge_model}")
|
||||||
|
print(f" Criteria: {len(fixture.expected_template_criteria)}")
|
||||||
|
print()
|
||||||
|
|
||||||
|
conversation: list[dict[str, str]] = []
|
||||||
|
prompt_template: str | None = None
|
||||||
|
done = False
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
try:
|
||||||
|
set_current_user(eval_user_id)
|
||||||
|
|
||||||
|
with mock.patch():
|
||||||
|
# ── Start ────────────────────────────────────────────
|
||||||
|
_print_system("Starting journey... (AI is exploring your files)")
|
||||||
|
|
||||||
|
start_frame: dict[str, Any] = {
|
||||||
|
"agent_type": "local",
|
||||||
|
"directory": journey_directory,
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"session_id": f"interactive-{uuid.uuid4().hex[:8]}",
|
||||||
|
}
|
||||||
|
|
||||||
|
reply = await handle_journey_start(eval_user_id, start_frame)
|
||||||
|
session_id = reply["session_id"]
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
_print_ai(reply["message"])
|
||||||
|
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
_print_system("Journey completed on first reply (template generated).")
|
||||||
|
|
||||||
|
# ── Conversation loop ────────────────────────────────
|
||||||
|
while not done:
|
||||||
|
try:
|
||||||
|
user_input = input(f"{_C_BOLD}{_C_BLUE}YOU:{_C_RESET} ").strip()
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
print()
|
||||||
|
user_input = _CMD_QUIT
|
||||||
|
|
||||||
|
if not user_input:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Handle commands
|
||||||
|
if user_input.lower() == _CMD_QUIT:
|
||||||
|
_print_system("Aborted — no evaluation will be performed.")
|
||||||
|
settings.LLM_MODEL = original_model
|
||||||
|
clear_current_user()
|
||||||
|
return InteractiveResult(
|
||||||
|
fixture_name=fixture.name, model=model, judge_model=judge_model,
|
||||||
|
prompt_template=None, conversation=conversation,
|
||||||
|
user_comment="(aborted)", done=False,
|
||||||
|
criteria_scores={}, overall_score=0.0,
|
||||||
|
judge_reasoning="Session aborted by user.",
|
||||||
|
elapsed_seconds=time.time() - start_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
if user_input.lower() == _CMD_HELP:
|
||||||
|
print(_HELP_TEXT)
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_input.lower() == _CMD_TEMPLATE:
|
||||||
|
if prompt_template:
|
||||||
|
print(f"\n{_C_MAGENTA}{prompt_template}{_C_RESET}\n")
|
||||||
|
else:
|
||||||
|
_print_system("No template generated yet.")
|
||||||
|
continue
|
||||||
|
|
||||||
|
if user_input.lower() == _CMD_DONE:
|
||||||
|
_print_system("Ending conversation...")
|
||||||
|
break
|
||||||
|
|
||||||
|
# ── Send message to AI ───────────────────────────
|
||||||
|
conversation.append({"role": "user", "content": user_input})
|
||||||
|
_print_system("AI is thinking...")
|
||||||
|
|
||||||
|
msg_frame: dict[str, Any] = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": user_input,
|
||||||
|
}
|
||||||
|
reply = await handle_journey_message(eval_user_id, msg_frame)
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
_print_ai(reply["message"])
|
||||||
|
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
_print_system("Journey completed — template generated!")
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("interactive journey failed: %s", exc)
|
||||||
|
_print_system(f"Error: {exc}")
|
||||||
|
finally:
|
||||||
|
settings.LLM_MODEL = original_model
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
turns = len([c for c in conversation if c["role"] == "user"])
|
||||||
|
|
||||||
|
# ── Show template if generated ───────────────────────────────
|
||||||
|
if prompt_template:
|
||||||
|
_print_header("GENERATED TEMPLATE")
|
||||||
|
print(f"{_C_MAGENTA}{prompt_template}{_C_RESET}\n")
|
||||||
|
else:
|
||||||
|
_print_system("No template was generated during this session.")
|
||||||
|
|
||||||
|
# ── User comment ─────────────────────────────────────────────
|
||||||
|
_print_header("YOUR EVALUATION")
|
||||||
|
print(" Write your comment about this interaction (press Enter twice to finish):")
|
||||||
|
print()
|
||||||
|
comment_lines: list[str] = []
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
line = input()
|
||||||
|
if line == "" and comment_lines and comment_lines[-1] == "":
|
||||||
|
comment_lines.pop() # remove trailing empty
|
||||||
|
break
|
||||||
|
comment_lines.append(line)
|
||||||
|
except (EOFError, KeyboardInterrupt):
|
||||||
|
pass
|
||||||
|
user_comment = "\n".join(comment_lines).strip() or "(no comment)"
|
||||||
|
|
||||||
|
# ── Judge ────────────────────────────────────────────────────
|
||||||
|
_print_header("LLM JUDGE EVALUATION")
|
||||||
|
_print_system(f"Scoring with {judge_model}...")
|
||||||
|
|
||||||
|
criteria_scores, overall_quality, judge_reasoning = await _judge_interactive(
|
||||||
|
conversation=conversation,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
user_comment=user_comment,
|
||||||
|
criteria=fixture.expected_template_criteria,
|
||||||
|
judge_model=judge_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Display scores ───────────────────────────────────────────
|
||||||
|
print()
|
||||||
|
for criterion, score in criteria_scores.items():
|
||||||
|
_print_score(criterion, score)
|
||||||
|
|
||||||
|
overall = (
|
||||||
|
sum(criteria_scores.values()) / len(criteria_scores)
|
||||||
|
if criteria_scores
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n {_C_BOLD}Criteria avg: {overall:.2f}{_C_RESET}")
|
||||||
|
print(f" {_C_BOLD}Overall quality: {overall_quality:.2f}{_C_RESET}")
|
||||||
|
print(f" {_C_BOLD}Turns: {turns}{_C_RESET}")
|
||||||
|
print(f" {_C_BOLD}Time: {elapsed:.1f}s{_C_RESET}")
|
||||||
|
print(f"\n {_C_DIM}Judge: {judge_reasoning}{_C_RESET}")
|
||||||
|
print(f" {_C_DIM}Your comment: {user_comment}{_C_RESET}\n")
|
||||||
|
|
||||||
|
result = InteractiveResult(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
judge_model=judge_model,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
conversation=conversation,
|
||||||
|
user_comment=user_comment,
|
||||||
|
done=done,
|
||||||
|
criteria_scores=criteria_scores,
|
||||||
|
overall_score=overall_quality,
|
||||||
|
judge_reasoning=judge_reasoning,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Report to Langfuse ───────────────────────────────────────
|
||||||
|
trace_id = langfuse_eval.log_eval_trace(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant="interactive",
|
||||||
|
prompt_template=prompt_template or "(not generated)",
|
||||||
|
actual_mutations=[{
|
||||||
|
"conversation": conversation[:30],
|
||||||
|
"user_comment": user_comment,
|
||||||
|
}],
|
||||||
|
scores_summary=result.summary(),
|
||||||
|
langfuse_prompt_names=["journey_system"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if trace_id:
|
||||||
|
from eval.scorer import EvalScores
|
||||||
|
scores_obj = EvalScores(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant="interactive",
|
||||||
|
precision=overall,
|
||||||
|
recall=float(done),
|
||||||
|
f1=overall,
|
||||||
|
llm_judge_score=overall_quality,
|
||||||
|
llm_judge_reasoning=judge_reasoning,
|
||||||
|
)
|
||||||
|
langfuse_eval.post_eval_scores(scores_obj, trace_id=trace_id)
|
||||||
|
_print_system(f"Results reported to Langfuse (trace: {trace_id})")
|
||||||
|
else:
|
||||||
|
_print_system("Langfuse not configured — results not reported.")
|
||||||
|
|
||||||
|
return result
|
||||||
385
services/batch-agent/eval/journey_runner.py
Normal file
385
services/batch-agent/eval/journey_runner.py
Normal file
@@ -0,0 +1,385 @@
|
|||||||
|
"""Journey eval runner — tests the prompt_template builder conversation.
|
||||||
|
|
||||||
|
For each (journey_fixture × model) combination:
|
||||||
|
1. Build a MockExecutor (for filesystem tools used during journey)
|
||||||
|
2. Patch execute_on_client
|
||||||
|
3. Override LLM_MODEL
|
||||||
|
4. Call handle_journey_start to kick off the conversation
|
||||||
|
5. Feed simulated user_messages via handle_journey_message
|
||||||
|
6. Collect the generated prompt_template
|
||||||
|
7. Score it against expected_template_criteria (via LLM judge)
|
||||||
|
8. Report to Langfuse
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
from eval.config import JourneyFixture
|
||||||
|
from eval.mock_executor import MockExecutor
|
||||||
|
from eval import langfuse_eval
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result type ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class JourneyEvalResult:
|
||||||
|
"""Result of one journey eval run."""
|
||||||
|
|
||||||
|
fixture_name: str
|
||||||
|
model: str
|
||||||
|
prompt_template: str | None # the generated template (None if journey failed)
|
||||||
|
conversation_turns: int
|
||||||
|
done: bool # whether journey reached completion
|
||||||
|
criteria_scores: dict[str, float] # criterion → 0-1 score
|
||||||
|
overall_score: float # average of criteria scores
|
||||||
|
judge_reasoning: str
|
||||||
|
elapsed_seconds: float
|
||||||
|
|
||||||
|
def summary(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"fixture": self.fixture_name,
|
||||||
|
"model": self.model,
|
||||||
|
"done": self.done,
|
||||||
|
"turns": self.conversation_turns,
|
||||||
|
"overall_score": round(self.overall_score, 3),
|
||||||
|
"criteria_scores": {k: round(v, 3) for k, v in self.criteria_scores.items()},
|
||||||
|
"elapsed_s": round(self.elapsed_seconds, 1),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM judge for template quality ──────────────────────────────────────
|
||||||
|
|
||||||
|
_JOURNEY_JUDGE_SYSTEM = """\
|
||||||
|
You are an evaluation judge for AI-generated prompt templates.
|
||||||
|
|
||||||
|
A journey chatbot explored a user's directory structure and through
|
||||||
|
conversation produced a prompt_template — an instruction set for a
|
||||||
|
data-extraction agent.
|
||||||
|
|
||||||
|
Your task: evaluate the generated template against a list of criteria.
|
||||||
|
Score each criterion from 0 to 1:
|
||||||
|
- 1.0: Fully satisfied, clearly present in the template
|
||||||
|
- 0.5: Partially satisfied or ambiguously addressed
|
||||||
|
- 0.0: Not satisfied, missing from the template
|
||||||
|
|
||||||
|
Respond with ONLY a JSON object:
|
||||||
|
{
|
||||||
|
"scores": {"criterion_1": 0.8, "criterion_2": 1.0, ...},
|
||||||
|
"reasoning": "Brief explanation"
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def _judge_template(
|
||||||
|
prompt_template: str,
|
||||||
|
criteria: list[str],
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> tuple[dict[str, float], str]:
|
||||||
|
"""Use an LLM to evaluate a generated prompt_template against criteria.
|
||||||
|
|
||||||
|
Returns (criteria_scores, reasoning).
|
||||||
|
"""
|
||||||
|
from shared.llm import get_llm
|
||||||
|
|
||||||
|
llm = get_llm(model=judge_model, temperature=0)
|
||||||
|
|
||||||
|
criteria_text = "\n".join(f" {i+1}. {c}" for i, c in enumerate(criteria))
|
||||||
|
user_content = (
|
||||||
|
f"## Generated prompt_template\n```\n{prompt_template}\n```\n\n"
|
||||||
|
f"## Criteria to evaluate\n{criteria_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=_JOURNEY_JUDGE_SYSTEM),
|
||||||
|
HumanMessage(content=user_content),
|
||||||
|
])
|
||||||
|
raw = response.content.strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
|
||||||
|
scores_raw = parsed.get("scores", {})
|
||||||
|
# Map criterion keys back to the original criteria text
|
||||||
|
criteria_scores: dict[str, float] = {}
|
||||||
|
for i, criterion in enumerate(criteria):
|
||||||
|
# Try matching by index key or exact criterion text
|
||||||
|
key_candidates = [
|
||||||
|
f"criterion_{i+1}",
|
||||||
|
criterion,
|
||||||
|
criterion[:50],
|
||||||
|
str(i + 1),
|
||||||
|
]
|
||||||
|
score = 0.0
|
||||||
|
for key in key_candidates:
|
||||||
|
if key in scores_raw:
|
||||||
|
score = float(scores_raw[key])
|
||||||
|
break
|
||||||
|
# If no match found, try values in order
|
||||||
|
if score == 0.0 and i < len(scores_raw):
|
||||||
|
score = float(list(scores_raw.values())[i])
|
||||||
|
criteria_scores[criterion] = score
|
||||||
|
|
||||||
|
reasoning = str(parsed.get("reasoning", ""))
|
||||||
|
return criteria_scores, reasoning
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("journey_eval: LLM judge failed: %s", exc)
|
||||||
|
return {c: 0.0 for c in criteria}, f"Judge error: {exc}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Journey runner ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_single_journey_eval(
|
||||||
|
fixture: JourneyFixture,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
data_dir: Path | None = None,
|
||||||
|
) -> JourneyEvalResult:
|
||||||
|
"""Execute one journey eval: start \u2192 messages \u2192 score template."""
|
||||||
|
from shared.config import settings
|
||||||
|
|
||||||
|
# When data_dir is given, use its parent as MockExecutor root
|
||||||
|
# and its name as the journey directory so the LLM sees a
|
||||||
|
# meaningful path (not ".").
|
||||||
|
if data_dir:
|
||||||
|
mock_root = data_dir.parent
|
||||||
|
journey_directory = data_dir.name
|
||||||
|
else:
|
||||||
|
mock_root = fixture.fixture_path.parent
|
||||||
|
journey_directory = fixture.directory
|
||||||
|
|
||||||
|
mock = MockExecutor(
|
||||||
|
fixture_dir=mock_root,
|
||||||
|
seed_records={},
|
||||||
|
)
|
||||||
|
|
||||||
|
original_model = settings.LLM_MODEL
|
||||||
|
settings.LLM_MODEL = model
|
||||||
|
|
||||||
|
eval_user_id = f"eval-journey-{uuid.uuid4().hex[:8]}"
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey_eval: starting %s | model=%s",
|
||||||
|
fixture.name, model,
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
prompt_template: str | None = None
|
||||||
|
conversation: list[dict[str, str]] = []
|
||||||
|
done = False
|
||||||
|
|
||||||
|
try:
|
||||||
|
from shared.ws_context import set_current_user, clear_current_user
|
||||||
|
from app.journey import handle_journey_start, handle_journey_message, _sessions
|
||||||
|
|
||||||
|
set_current_user(eval_user_id)
|
||||||
|
with mock.patch():
|
||||||
|
# ── Start the journey ────────────────────────────────
|
||||||
|
start_frame: dict[str, Any] = {
|
||||||
|
"agent_type": "local",
|
||||||
|
"directory": journey_directory,
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"session_id": f"eval-{uuid.uuid4().hex[:8]}",
|
||||||
|
}
|
||||||
|
|
||||||
|
reply = await handle_journey_start(eval_user_id, start_frame)
|
||||||
|
session_id = reply["session_id"]
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey_eval: start reply (%d chars), done=%s",
|
||||||
|
len(reply["message"]), reply["done"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
else:
|
||||||
|
# ── Send user messages ───────────────────────────
|
||||||
|
for i, user_msg in enumerate(fixture.user_messages):
|
||||||
|
if done:
|
||||||
|
break
|
||||||
|
|
||||||
|
conversation.append({"role": "user", "content": user_msg})
|
||||||
|
|
||||||
|
msg_frame: dict[str, Any] = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": user_msg,
|
||||||
|
}
|
||||||
|
reply = await handle_journey_message(eval_user_id, msg_frame)
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey_eval: turn %d reply (%d chars), done=%s",
|
||||||
|
i + 1, len(reply["message"]), reply["done"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
|
||||||
|
# If not done after all user messages, send a final nudge
|
||||||
|
if not done:
|
||||||
|
nudge = "Please generate the final prompt_template now. I'm satisfied with the configuration."
|
||||||
|
conversation.append({"role": "user", "content": nudge})
|
||||||
|
|
||||||
|
nudge_frame: dict[str, Any] = {
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": nudge,
|
||||||
|
}
|
||||||
|
reply = await handle_journey_message(eval_user_id, nudge_frame)
|
||||||
|
conversation.append({"role": "assistant", "content": reply["message"]})
|
||||||
|
if reply["done"]:
|
||||||
|
prompt_template = reply.get("prompt_template")
|
||||||
|
done = True
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("journey_eval: pipeline failed for %s/%s: %s", fixture.name, model, exc)
|
||||||
|
finally:
|
||||||
|
settings.LLM_MODEL = original_model
|
||||||
|
from shared.ws_context import clear_current_user
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
turns = len([c for c in conversation if c["role"] == "user"])
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"journey_eval: completed in %.1fs — %d turns, done=%s, template=%s",
|
||||||
|
elapsed, turns, done, "yes" if prompt_template else "no",
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Score the template ───────────────────────────────────────
|
||||||
|
criteria_scores: dict[str, float] = {}
|
||||||
|
judge_reasoning = ""
|
||||||
|
|
||||||
|
if prompt_template and fixture.expected_template_criteria:
|
||||||
|
criteria_scores, judge_reasoning = await _judge_template(
|
||||||
|
prompt_template,
|
||||||
|
fixture.expected_template_criteria,
|
||||||
|
judge_model=judge_model,
|
||||||
|
)
|
||||||
|
elif not prompt_template:
|
||||||
|
criteria_scores = {c: 0.0 for c in fixture.expected_template_criteria}
|
||||||
|
judge_reasoning = "No prompt_template was generated — journey did not complete."
|
||||||
|
|
||||||
|
overall = (
|
||||||
|
sum(criteria_scores.values()) / len(criteria_scores)
|
||||||
|
if criteria_scores
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
|
||||||
|
result = JourneyEvalResult(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
conversation_turns=turns,
|
||||||
|
done=done,
|
||||||
|
criteria_scores=criteria_scores,
|
||||||
|
overall_score=overall,
|
||||||
|
judge_reasoning=judge_reasoning,
|
||||||
|
elapsed_seconds=elapsed,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Report to Langfuse ───────────────────────────────────────
|
||||||
|
trace_id = langfuse_eval.log_eval_trace(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant="journey",
|
||||||
|
prompt_template=prompt_template or "(not generated)",
|
||||||
|
actual_mutations=[{"conversation": conversation[:20]}],
|
||||||
|
scores_summary=result.summary(),
|
||||||
|
langfuse_prompt_names=["journey_system"],
|
||||||
|
)
|
||||||
|
|
||||||
|
if trace_id:
|
||||||
|
from eval.scorer import EvalScores
|
||||||
|
scores_obj = EvalScores(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant="journey",
|
||||||
|
precision=overall,
|
||||||
|
recall=float(done),
|
||||||
|
f1=overall,
|
||||||
|
llm_judge_score=overall,
|
||||||
|
llm_judge_reasoning=judge_reasoning,
|
||||||
|
)
|
||||||
|
langfuse_eval.post_eval_scores(scores_obj, trace_id=trace_id)
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
async def run_journey_fixture_eval(
|
||||||
|
fixture: JourneyFixture,
|
||||||
|
models: list[str],
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
data_dir: Path | None = None,
|
||||||
|
) -> list[JourneyEvalResult]:
|
||||||
|
"""Run all models for a journey fixture."""
|
||||||
|
langfuse_eval.sync_journey_fixture_to_dataset(fixture)
|
||||||
|
|
||||||
|
results: list[JourneyEvalResult] = []
|
||||||
|
for model in models:
|
||||||
|
result = await run_single_journey_eval(
|
||||||
|
fixture, model, judge_model=judge_model,
|
||||||
|
data_dir=data_dir,
|
||||||
|
)
|
||||||
|
results.append(result)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def print_journey_results(results: list[JourneyEvalResult]) -> None:
|
||||||
|
"""Print a formatted summary of journey eval results."""
|
||||||
|
if not results:
|
||||||
|
print("\nNo journey eval results.")
|
||||||
|
return
|
||||||
|
|
||||||
|
print("\n" + "=" * 95)
|
||||||
|
print(f"{'Fixture':<25} {'Model':<25} {'Done':>5} {'Turns':>6} {'Score':>7} {'Time':>7}")
|
||||||
|
print("-" * 95)
|
||||||
|
|
||||||
|
for r in results:
|
||||||
|
done_str = "yes" if r.done else "NO"
|
||||||
|
print(
|
||||||
|
f"{r.fixture_name:<25} {r.model:<25} {done_str:>5} "
|
||||||
|
f"{r.conversation_turns:>6} {r.overall_score:>7.2f} {r.elapsed_seconds:>6.1f}s"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * 95)
|
||||||
|
|
||||||
|
# Criteria breakdown
|
||||||
|
for r in results:
|
||||||
|
if r.criteria_scores:
|
||||||
|
print(f"\n[{r.model}] Criteria scores:")
|
||||||
|
for criterion, score in r.criteria_scores.items():
|
||||||
|
indicator = "PASS" if score >= 0.7 else "PARTIAL" if score >= 0.4 else "FAIL"
|
||||||
|
print(f" {indicator:>7} ({score:.1f}) {criterion}")
|
||||||
|
|
||||||
|
if r.judge_reasoning:
|
||||||
|
print(f" Judge: {r.judge_reasoning}")
|
||||||
|
|
||||||
|
if r.prompt_template:
|
||||||
|
preview = r.prompt_template[:200].replace("\n", " ")
|
||||||
|
print(f" Template preview: {preview}...")
|
||||||
|
|
||||||
|
print()
|
||||||
327
services/batch-agent/eval/langfuse_eval.py
Normal file
327
services/batch-agent/eval/langfuse_eval.py
Normal file
@@ -0,0 +1,327 @@
|
|||||||
|
"""Langfuse evaluation integration — datasets, runs, and scoring.
|
||||||
|
|
||||||
|
Uses the Langfuse Python SDK v4 (OpenTelemetry-based) to:
|
||||||
|
|
||||||
|
1. **Sync fixtures → Langfuse datasets**: Each YAML fixture becomes a dataset,
|
||||||
|
each prompt variant + expected pair becomes a dataset item.
|
||||||
|
|
||||||
|
2. **Track eval runs**: Each (fixture × model × prompt_variant) execution
|
||||||
|
is recorded as a trace with linked scores.
|
||||||
|
|
||||||
|
3. **Post scores**: precision, recall, F1, field_accuracy, llm_judge are
|
||||||
|
posted as numeric scores on the trace.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from eval.config import EvalFixture
|
||||||
|
from eval.scorer import EvalScores
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_langfuse():
|
||||||
|
"""Get or create a Langfuse client instance (SDK v4)."""
|
||||||
|
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
os.environ.setdefault("LANGFUSE_SECRET_KEY", settings.LANGFUSE_SECRET_KEY)
|
||||||
|
os.environ.setdefault("LANGFUSE_PUBLIC_KEY", settings.LANGFUSE_PUBLIC_KEY)
|
||||||
|
if settings.LANGFUSE_HOST:
|
||||||
|
os.environ.setdefault("LANGFUSE_HOST", settings.LANGFUSE_HOST)
|
||||||
|
from langfuse import get_client
|
||||||
|
return get_client()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to create client: %s", exc)
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def sync_fixture_to_dataset(fixture: EvalFixture) -> str | None:
|
||||||
|
"""Create or update a Langfuse dataset from a fixture.
|
||||||
|
|
||||||
|
Each prompt variant becomes a separate dataset item with:
|
||||||
|
- input: {directory, data_types, prompt_template, seed_records}
|
||||||
|
- expected_output: {expected records}
|
||||||
|
|
||||||
|
Returns the dataset name, or None if Langfuse is unavailable.
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
logger.info("langfuse_eval: Langfuse not configured — skipping dataset sync")
|
||||||
|
return None
|
||||||
|
|
||||||
|
dataset_name = f"batch-eval-{fixture.name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
lf.create_dataset(
|
||||||
|
name=dataset_name,
|
||||||
|
description=fixture.description,
|
||||||
|
metadata={
|
||||||
|
"data_types": ",".join(fixture.data_types),
|
||||||
|
"file_extensions": ",".join(fixture.file_extensions) if fixture.file_extensions else "",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
# Dataset may already exist — that's fine
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Build expected_output appropriate to the fixture's mode
|
||||||
|
expected_output: dict[str, Any] = {}
|
||||||
|
if fixture.mode in ("step1", "full") and fixture.expected_classification:
|
||||||
|
expected_output["classifications"] = [
|
||||||
|
{"file": ec.file, "project_id": ec.project_id, "domains": ec.domains}
|
||||||
|
for ec in fixture.expected_classification
|
||||||
|
]
|
||||||
|
if fixture.mode in ("step2", "full") and fixture.expected:
|
||||||
|
for rec in fixture.expected:
|
||||||
|
expected_output.setdefault(rec.table, []).append(rec.fields)
|
||||||
|
|
||||||
|
item_id = f"{fixture.name}--{fixture.mode}"
|
||||||
|
try:
|
||||||
|
lf.create_dataset_item(
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
id=item_id,
|
||||||
|
input={
|
||||||
|
"directory": fixture.directory,
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"mode": fixture.mode,
|
||||||
|
"seed_records": fixture.seed_records,
|
||||||
|
},
|
||||||
|
expected_output=expected_output,
|
||||||
|
metadata={"mode": fixture.mode},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"langfuse_eval: failed to upsert dataset item %s: %s", item_id, exc
|
||||||
|
)
|
||||||
|
|
||||||
|
lf.flush()
|
||||||
|
logger.info("langfuse_eval: synced fixture '%s' → dataset '%s'", fixture.name, dataset_name)
|
||||||
|
return dataset_name
|
||||||
|
|
||||||
|
|
||||||
|
def sync_journey_fixture_to_dataset(fixture) -> str | None:
|
||||||
|
"""Create or update a Langfuse dataset from a journey fixture.
|
||||||
|
|
||||||
|
Each journey fixture becomes a single dataset item with:
|
||||||
|
- input: {directory, data_types, user_messages}
|
||||||
|
- expected_output: {criteria}
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
logger.info("langfuse_eval: Langfuse not configured — skipping journey dataset sync")
|
||||||
|
return None
|
||||||
|
|
||||||
|
dataset_name = f"journey-eval-{fixture.name}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
lf.create_dataset(
|
||||||
|
name=dataset_name,
|
||||||
|
description=fixture.description,
|
||||||
|
metadata={"type": "journey", "data_types": ",".join(fixture.data_types)},
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Dataset may already exist
|
||||||
|
|
||||||
|
item_id = f"{fixture.name}--journey"
|
||||||
|
try:
|
||||||
|
lf.create_dataset_item(
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
id=item_id,
|
||||||
|
input={
|
||||||
|
"directory": fixture.directory,
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"user_messages": fixture.user_messages,
|
||||||
|
},
|
||||||
|
expected_output={
|
||||||
|
"criteria": fixture.expected_template_criteria,
|
||||||
|
},
|
||||||
|
metadata={"type": "journey"},
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to upsert journey dataset item %s: %s", item_id, exc)
|
||||||
|
|
||||||
|
lf.flush()
|
||||||
|
logger.info("langfuse_eval: synced journey fixture '%s' → dataset '%s'", fixture.name, dataset_name)
|
||||||
|
return dataset_name
|
||||||
|
|
||||||
|
|
||||||
|
def create_eval_run(
|
||||||
|
dataset_name: str,
|
||||||
|
run_name: str,
|
||||||
|
*,
|
||||||
|
metadata: dict[str, Any] | None = None,
|
||||||
|
) -> str:
|
||||||
|
"""Create a dataset run in Langfuse. Returns the run name.
|
||||||
|
|
||||||
|
Note: In SDK v4, dataset runs are created implicitly via
|
||||||
|
dataset.run_experiment(). This function is kept for backwards
|
||||||
|
compatibility but may not create a run.
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
return run_name
|
||||||
|
|
||||||
|
try:
|
||||||
|
if hasattr(lf, "create_dataset_run"):
|
||||||
|
lf.create_dataset_run(
|
||||||
|
dataset_name=dataset_name,
|
||||||
|
run_name=run_name,
|
||||||
|
metadata=metadata or {},
|
||||||
|
)
|
||||||
|
lf.flush()
|
||||||
|
else:
|
||||||
|
logger.debug("langfuse_eval: create_dataset_run not available in SDK v4")
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to create run %s: %s", run_name, exc)
|
||||||
|
|
||||||
|
return run_name
|
||||||
|
|
||||||
|
|
||||||
|
def post_eval_scores(
|
||||||
|
scores: EvalScores,
|
||||||
|
*,
|
||||||
|
trace_id: str | None = None,
|
||||||
|
dataset_name: str | None = None,
|
||||||
|
run_name: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Post evaluation scores to Langfuse.
|
||||||
|
|
||||||
|
If trace_id is provided, scores are attached to that trace.
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
score_data = [
|
||||||
|
("precision", scores.precision),
|
||||||
|
("recall", scores.recall),
|
||||||
|
("f1", scores.f1),
|
||||||
|
]
|
||||||
|
# Only post field_accuracy when there are field-level scores (step2/full)
|
||||||
|
if scores.field_scores:
|
||||||
|
score_data.append(("field_accuracy", scores.field_accuracy))
|
||||||
|
if scores.llm_judge_score is not None:
|
||||||
|
score_data.append(("llm_judge", scores.llm_judge_score))
|
||||||
|
|
||||||
|
for name, value in score_data:
|
||||||
|
try:
|
||||||
|
lf.create_score(
|
||||||
|
name=name,
|
||||||
|
value=value,
|
||||||
|
trace_id=trace_id,
|
||||||
|
data_type="NUMERIC",
|
||||||
|
comment=f"{scores.fixture_name} | {scores.model} | {scores.prompt_variant}",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to post score %s: %s", name, exc)
|
||||||
|
|
||||||
|
lf.flush()
|
||||||
|
logger.info(
|
||||||
|
"langfuse_eval: posted %d scores for %s/%s/%s",
|
||||||
|
len(score_data), scores.fixture_name, scores.model, scores.prompt_variant,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def log_eval_trace(
|
||||||
|
*,
|
||||||
|
fixture_name: str,
|
||||||
|
model: str,
|
||||||
|
prompt_variant: str,
|
||||||
|
prompt_template: str,
|
||||||
|
actual_mutations: list[dict],
|
||||||
|
scores_summary: dict[str, Any],
|
||||||
|
step1_results: list[dict] | None = None,
|
||||||
|
dataset_name: str | None = None,
|
||||||
|
run_name: str | None = None,
|
||||||
|
dataset_item_id: str | None = None,
|
||||||
|
langfuse_prompt_names: list[str] | None = None,
|
||||||
|
) -> str | None:
|
||||||
|
"""Create a Langfuse trace for one eval execution and link it to a dataset run.
|
||||||
|
|
||||||
|
Uses SDK v4 observation API (traces are created implicitly by root spans).
|
||||||
|
``langfuse_prompt_names`` can contain one or two prompt names to link
|
||||||
|
(e.g. ``["batch_file_classifier", "batch_processing"]`` for full mode).
|
||||||
|
Each prompt gets its own generation-type observation for per-version
|
||||||
|
metrics tracking.
|
||||||
|
|
||||||
|
Returns the trace_id, or None if Langfuse is unavailable.
|
||||||
|
"""
|
||||||
|
lf = _get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import propagate_attributes
|
||||||
|
|
||||||
|
# Fetch prompt objects for linking
|
||||||
|
prompt_objs: list[tuple[str, Any]] = []
|
||||||
|
for pname in (langfuse_prompt_names or []):
|
||||||
|
try:
|
||||||
|
obj = lf.get_prompt(name=pname, cache_ttl_seconds=300)
|
||||||
|
prompt_objs.append((pname, obj))
|
||||||
|
logger.info("langfuse_eval: linked prompt '%s' (type=%s)", pname, type(obj).__name__)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: prompt '%s' not found — %s", pname, exc)
|
||||||
|
|
||||||
|
# Build trace output dict
|
||||||
|
trace_output: dict[str, Any] = {"scores": scores_summary}
|
||||||
|
if step1_results:
|
||||||
|
trace_output["classifications"] = step1_results
|
||||||
|
if actual_mutations:
|
||||||
|
trace_output["mutations"] = actual_mutations[:50]
|
||||||
|
|
||||||
|
with propagate_attributes(
|
||||||
|
trace_name=f"eval-{fixture_name}",
|
||||||
|
metadata={
|
||||||
|
"eval": "true",
|
||||||
|
"fixture": fixture_name,
|
||||||
|
"model": model,
|
||||||
|
"prompt_variant": prompt_variant,
|
||||||
|
},
|
||||||
|
tags=["eval", f"model:{model}", f"variant:{prompt_variant}"],
|
||||||
|
):
|
||||||
|
# Root span for the eval run
|
||||||
|
span = lf.start_observation(name=f"eval-{fixture_name}")
|
||||||
|
span.update(
|
||||||
|
input={
|
||||||
|
"prompt_template": prompt_template,
|
||||||
|
"model": model,
|
||||||
|
"prompt_variant": prompt_variant,
|
||||||
|
},
|
||||||
|
output=trace_output,
|
||||||
|
)
|
||||||
|
trace_id = span.trace_id
|
||||||
|
|
||||||
|
# Create a generation-type observation per linked prompt
|
||||||
|
for pname, pobj in prompt_objs:
|
||||||
|
gen = lf.start_observation(
|
||||||
|
name=f"prompt-{pname}",
|
||||||
|
prompt=pobj,
|
||||||
|
as_type="generation",
|
||||||
|
)
|
||||||
|
gen.end()
|
||||||
|
|
||||||
|
# Link to dataset run if available
|
||||||
|
if dataset_name and run_name and dataset_item_id:
|
||||||
|
try:
|
||||||
|
dataset = lf.get_dataset(dataset_name)
|
||||||
|
for item in dataset.items:
|
||||||
|
if item.id == dataset_item_id:
|
||||||
|
item.link(span, run_name)
|
||||||
|
break
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to link trace to dataset run: %s", exc)
|
||||||
|
|
||||||
|
span.end()
|
||||||
|
|
||||||
|
lf.flush()
|
||||||
|
return trace_id
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse_eval: failed to create eval trace: %s", exc)
|
||||||
|
return None
|
||||||
258
services/batch-agent/eval/mock_executor.py
Normal file
258
services/batch-agent/eval/mock_executor.py
Normal file
@@ -0,0 +1,258 @@
|
|||||||
|
"""Mock executor — intercepts execute_on_client for offline E2E testing.
|
||||||
|
|
||||||
|
Patches ``execute_on_client`` at all usage sites so agent pipeline runs don't
|
||||||
|
require a live Electron client or Redis. Instead:
|
||||||
|
|
||||||
|
- **Filesystem actions** (list_directory, read_file_content, get_file_metadata)
|
||||||
|
are served from local fixture files on disk.
|
||||||
|
- **Read actions** (select, get) return preseeded records from an in-memory
|
||||||
|
store provided by the test fixture.
|
||||||
|
- **Write actions** (insert, update, delete) are captured as *mutations* and
|
||||||
|
stored for later comparison against expected results.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from contextlib import contextmanager, asynccontextmanager
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Mutation:
|
||||||
|
"""A single recorded write operation."""
|
||||||
|
|
||||||
|
action: str # insert | update | delete
|
||||||
|
table: str
|
||||||
|
data: dict[str, Any]
|
||||||
|
timestamp: float = field(default_factory=time.time)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fake DB helpers (used to bypass async_session in full mode) ───────
|
||||||
|
|
||||||
|
class _FakeRow:
|
||||||
|
"""Mimics an AgentRunLog row returned by SQLAlchemy."""
|
||||||
|
id = 0
|
||||||
|
status = "running"
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
errors: list[str] = []
|
||||||
|
completed_at = None
|
||||||
|
|
||||||
|
def __setattr__(self, name: str, value: Any) -> None:
|
||||||
|
object.__setattr__(self, name, value)
|
||||||
|
|
||||||
|
|
||||||
|
class _FakeResult:
|
||||||
|
"""Mimics a SQLAlchemy ``Result`` with ``scalar_one_or_none``."""
|
||||||
|
def __init__(self, row: _FakeRow) -> None:
|
||||||
|
self._row = row
|
||||||
|
|
||||||
|
def scalar_one_or_none(self) -> _FakeRow:
|
||||||
|
return self._row
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MockExecutor:
|
||||||
|
"""In-memory executor that replaces Redis-based tool round-trip.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
fixture_dir : Path
|
||||||
|
Directory containing sample files for filesystem tool calls.
|
||||||
|
seed_records : dict[str, list[dict]]
|
||||||
|
Pre-existing records per table, e.g. ``{"tasks": [...], "projects": [...]}``.
|
||||||
|
The executor returns these for ``select`` / ``get`` actions and auto-updates
|
||||||
|
them on ``insert`` / ``update`` / ``delete`` so subsequent selects reflect changes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
fixture_dir: Path
|
||||||
|
seed_records: dict[str, list[dict]] = field(default_factory=dict)
|
||||||
|
mutations: list[Mutation] = field(default_factory=list)
|
||||||
|
_id_counter: int = field(default=1000, repr=False)
|
||||||
|
|
||||||
|
# ── Public API ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Clear recorded mutations (keep seed_records intact)."""
|
||||||
|
self.mutations.clear()
|
||||||
|
|
||||||
|
def get_mutations(self, *, table: str | None = None, action: str | None = None) -> list[Mutation]:
|
||||||
|
"""Filter mutations by table and/or action."""
|
||||||
|
result = self.mutations
|
||||||
|
if table:
|
||||||
|
result = [m for m in result if m.table == table]
|
||||||
|
if action:
|
||||||
|
result = [m for m in result if m.action == action]
|
||||||
|
return result
|
||||||
|
|
||||||
|
def created_records(self, table: str) -> list[dict]:
|
||||||
|
"""Return data dicts of all inserts into *table*."""
|
||||||
|
return [m.data for m in self.mutations if m.table == table and m.action == "insert"]
|
||||||
|
|
||||||
|
def updated_records(self, table: str) -> list[dict]:
|
||||||
|
"""Return data dicts of all updates to *table*."""
|
||||||
|
return [m.data for m in self.mutations if m.table == table and m.action == "update"]
|
||||||
|
|
||||||
|
# ── Context manager for patching ──────────────────────────────
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def patch(self):
|
||||||
|
"""Patch execute_on_client and DB session at all usage sites."""
|
||||||
|
mock_fn = AsyncMock(side_effect=self._handle)
|
||||||
|
targets = [
|
||||||
|
"shared.ws_context.execute_on_client",
|
||||||
|
"app.agent_runner.execute_on_client",
|
||||||
|
"app.agents.filesystem_agent.execute_on_client",
|
||||||
|
]
|
||||||
|
|
||||||
|
# Mock async_session so run_local_agent / _finalize_run skip real DB
|
||||||
|
fake_row = _FakeRow()
|
||||||
|
fake_db = AsyncMock()
|
||||||
|
fake_db.commit = AsyncMock()
|
||||||
|
fake_db.refresh = AsyncMock()
|
||||||
|
fake_db.execute = AsyncMock(return_value=_FakeResult(fake_row))
|
||||||
|
fake_db.add = lambda obj: None # noqa: ARG005
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def _fake_session():
|
||||||
|
yield fake_db
|
||||||
|
|
||||||
|
patches = [patch(t, new=mock_fn) for t in targets]
|
||||||
|
patches.append(patch("app.agent_runner.async_session", _fake_session))
|
||||||
|
for p in patches:
|
||||||
|
p.start()
|
||||||
|
try:
|
||||||
|
yield mock_fn
|
||||||
|
finally:
|
||||||
|
for p in patches:
|
||||||
|
p.stop()
|
||||||
|
|
||||||
|
# ── Internal dispatch ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _handle(
|
||||||
|
self,
|
||||||
|
action: str,
|
||||||
|
table: str | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
vector: list[float] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
# Filesystem
|
||||||
|
if action == "list_directory":
|
||||||
|
return self._list_directory(data or {})
|
||||||
|
if action == "read_file_content":
|
||||||
|
return self._read_file(data or {})
|
||||||
|
if action == "get_file_metadata":
|
||||||
|
return self._get_file_metadata(data or {})
|
||||||
|
|
||||||
|
# CRUD
|
||||||
|
if action == "select":
|
||||||
|
return self._select(table or "", filters)
|
||||||
|
if action == "get":
|
||||||
|
return self._get(table or "", data or {})
|
||||||
|
if action == "insert":
|
||||||
|
return self._insert(table or "", data or {})
|
||||||
|
if action == "update":
|
||||||
|
return self._update(table or "", data or {})
|
||||||
|
if action == "delete":
|
||||||
|
return self._delete(table or "", data or {})
|
||||||
|
|
||||||
|
# Vector (no-op for eval)
|
||||||
|
if action in ("vector_upsert", "vector_search"):
|
||||||
|
return {"rows": []}
|
||||||
|
|
||||||
|
return {"error": f"Unknown action: {action}"}
|
||||||
|
|
||||||
|
# ── Filesystem handlers ───────────────────────────────────────
|
||||||
|
|
||||||
|
def _list_directory(self, data: dict) -> dict:
|
||||||
|
rel_path = data.get("path", "")
|
||||||
|
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
|
||||||
|
if not abs_path.is_dir():
|
||||||
|
return {"entries": []}
|
||||||
|
entries: list[dict] = []
|
||||||
|
for child in sorted(abs_path.iterdir()):
|
||||||
|
entry_type = "directory" if child.is_dir() else "file"
|
||||||
|
# Return paths relative to fixture_dir but with the original prefix
|
||||||
|
entry_path = rel_path.rstrip("/\\") + "/" + child.name
|
||||||
|
entries.append({
|
||||||
|
"name": child.name,
|
||||||
|
"path": entry_path,
|
||||||
|
"type": entry_type,
|
||||||
|
})
|
||||||
|
return {"entries": entries}
|
||||||
|
|
||||||
|
def _read_file(self, data: dict) -> dict:
|
||||||
|
rel_path = data.get("path", "")
|
||||||
|
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
|
||||||
|
if not abs_path.is_file():
|
||||||
|
return {"content": "", "error": f"File not found: {rel_path}"}
|
||||||
|
return {"content": abs_path.read_text(encoding="utf-8", errors="replace")}
|
||||||
|
|
||||||
|
def _get_file_metadata(self, data: dict) -> dict:
|
||||||
|
rel_path = data.get("path", "")
|
||||||
|
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
|
||||||
|
if not abs_path.exists():
|
||||||
|
return {"error": f"Not found: {rel_path}"}
|
||||||
|
stat = abs_path.stat()
|
||||||
|
return {
|
||||||
|
"path": rel_path,
|
||||||
|
"size": stat.st_size,
|
||||||
|
"modifiedAt": int(stat.st_mtime * 1000),
|
||||||
|
"createdAt": int(stat.st_ctime * 1000),
|
||||||
|
"isDirectory": abs_path.is_dir(),
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── CRUD handlers ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _select(self, table: str, filters: dict | None) -> dict:
|
||||||
|
rows = list(self.seed_records.get(table, []))
|
||||||
|
if filters:
|
||||||
|
rows = [
|
||||||
|
r for r in rows
|
||||||
|
if all(r.get(k) == v for k, v in filters.items() if v is not None)
|
||||||
|
]
|
||||||
|
return {"rows": rows}
|
||||||
|
|
||||||
|
def _get(self, table: str, data: dict) -> dict:
|
||||||
|
record_id = data.get("id", "")
|
||||||
|
rows = self.seed_records.get(table, [])
|
||||||
|
for r in rows:
|
||||||
|
if r.get("id") == record_id:
|
||||||
|
return {"row": r}
|
||||||
|
return {"row": None}
|
||||||
|
|
||||||
|
def _insert(self, table: str, data: dict) -> dict:
|
||||||
|
self._id_counter += 1
|
||||||
|
record = {**data, "id": str(self._id_counter)}
|
||||||
|
# Add to seed so subsequent selects can find it
|
||||||
|
self.seed_records.setdefault(table, []).append(record)
|
||||||
|
self.mutations.append(Mutation(action="insert", table=table, data=record))
|
||||||
|
return {"row": record}
|
||||||
|
|
||||||
|
def _update(self, table: str, data: dict) -> dict:
|
||||||
|
record_id = data.get("id", "")
|
||||||
|
rows = self.seed_records.get(table, [])
|
||||||
|
for r in rows:
|
||||||
|
if r.get("id") == record_id:
|
||||||
|
r.update({k: v for k, v in data.items() if v is not None and v != ""})
|
||||||
|
self.mutations.append(Mutation(action="update", table=table, data=dict(r)))
|
||||||
|
return {"row": r}
|
||||||
|
# Record not found — still log the mutation
|
||||||
|
self.mutations.append(Mutation(action="update", table=table, data=data))
|
||||||
|
return {"row": data}
|
||||||
|
|
||||||
|
def _delete(self, table: str, data: dict) -> dict:
|
||||||
|
record_id = data.get("id", "")
|
||||||
|
rows = self.seed_records.get(table, [])
|
||||||
|
self.seed_records[table] = [r for r in rows if r.get("id") != record_id]
|
||||||
|
self.mutations.append(Mutation(action="delete", table=table, data={"id": record_id}))
|
||||||
|
return {"deleted": True}
|
||||||
2
services/batch-agent/eval/requirements.txt
Normal file
2
services/batch-agent/eval/requirements.txt
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
# Extra dependencies for the eval harness (on top of the service requirements.txt)
|
||||||
|
pyyaml>=6.0.0
|
||||||
545
services/batch-agent/eval/runner.py
Normal file
545
services/batch-agent/eval/runner.py
Normal file
@@ -0,0 +1,545 @@
|
|||||||
|
"""Eval runner — orchestrates fixture → mock → agent pipeline → scoring.
|
||||||
|
|
||||||
|
Supports three eval modes:
|
||||||
|
|
||||||
|
- **step1**: Test classification prompt only (``_STEP1_SYSTEM_PROMPT``).
|
||||||
|
Calls the LLM with fixture-provided ``domain_definitions`` and
|
||||||
|
``projects_list`` and compares output against ``expected_classification``.
|
||||||
|
|
||||||
|
- **step2**: Test processing prompt only (``_PROCESSING_SYSTEM_PROMPT``).
|
||||||
|
Compiles the prompt with fixture-provided ``existing_context``,
|
||||||
|
``project_context``, ``data_types``, and ``custom_prompt_section``,
|
||||||
|
then runs the tool-calling loop. Mutations are scored against
|
||||||
|
``expected`` records.
|
||||||
|
|
||||||
|
- **full**: Run ``run_local_agent()`` end-to-end (both steps).
|
||||||
|
Scored on both classification and extraction.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import copy
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from eval.config import EvalFixture, ExpectedClassification
|
||||||
|
from eval.mock_executor import MockExecutor
|
||||||
|
from eval.scorer import (
|
||||||
|
EvalScores,
|
||||||
|
FieldScore,
|
||||||
|
compute_precision_recall,
|
||||||
|
llm_judge_score,
|
||||||
|
score_field_match,
|
||||||
|
)
|
||||||
|
from eval import langfuse_eval
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 1 runner ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_step1(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
model: str,
|
||||||
|
mock: MockExecutor,
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Run step-1 classification for every file in the fixture directory.
|
||||||
|
|
||||||
|
Scans the directory recursively, classifies each file, and returns
|
||||||
|
a list of result dicts:
|
||||||
|
``[{file, project_id, domains, new_project_name}, ...]``
|
||||||
|
"""
|
||||||
|
from app.agent_runner import _classify_file
|
||||||
|
|
||||||
|
# Build project name lookup for display
|
||||||
|
proj_names: dict[str, str] = {
|
||||||
|
p.get("id", ""): p.get("name", "") for p in fixture.projects_list
|
||||||
|
}
|
||||||
|
|
||||||
|
# Discover all files in the fixture directory
|
||||||
|
all_files = await _scan_fixture_files(mock, fixture.directory)
|
||||||
|
print(f"\n Scanning {len(all_files)} files in {fixture.directory}\n")
|
||||||
|
|
||||||
|
results: list[dict[str, Any]] = []
|
||||||
|
for i, file_path in enumerate(all_files, 1):
|
||||||
|
file_result = await mock._handle(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": file_path},
|
||||||
|
)
|
||||||
|
file_content: str = file_result.get("content", "")
|
||||||
|
if not file_content.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
project_id, domains, new_name = await _classify_file(
|
||||||
|
file_path=file_path,
|
||||||
|
file_content=file_content,
|
||||||
|
projects=fixture.projects_list,
|
||||||
|
config_data_types=fixture.data_types,
|
||||||
|
custom_system_prompt=fixture.custom_step1_prompt or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
short_name = file_path.rsplit("/", 1)[-1] if "/" in file_path else file_path
|
||||||
|
proj_label = proj_names.get(project_id, new_name or "?")
|
||||||
|
print(f" [{i}/{len(all_files)}] {short_name} → {project_id} ({proj_label}) {domains}")
|
||||||
|
|
||||||
|
results.append({
|
||||||
|
"file": file_path,
|
||||||
|
"project_id": project_id,
|
||||||
|
"domains": domains,
|
||||||
|
"new_project_name": new_name,
|
||||||
|
})
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
async def _scan_fixture_files(mock: MockExecutor, directory: str) -> list[str]:
|
||||||
|
"""Recursively list all files under *directory* via the mock executor."""
|
||||||
|
files: list[str] = []
|
||||||
|
|
||||||
|
async def _walk(path: str) -> None:
|
||||||
|
result = await mock._handle(action="list_directory", data={"path": path})
|
||||||
|
for entry in result.get("entries", []):
|
||||||
|
if entry.get("type") == "directory":
|
||||||
|
await _walk(entry["path"])
|
||||||
|
elif entry.get("type") == "file":
|
||||||
|
files.append(entry["path"])
|
||||||
|
|
||||||
|
await _walk(directory)
|
||||||
|
return sorted(files)
|
||||||
|
|
||||||
|
|
||||||
|
def _score_step1(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
results: list[dict[str, Any]],
|
||||||
|
) -> tuple[float, float, float, str]:
|
||||||
|
"""Score step-1 results. Returns (precision, recall, f1, reasoning).
|
||||||
|
|
||||||
|
Files with expected classifications are scored (OK/FAIL).
|
||||||
|
Files without expectations are shown as informational (INFO).
|
||||||
|
"""
|
||||||
|
if not fixture.expected_classification:
|
||||||
|
return 0.0, 0.0, 0.0, "No expected classifications"
|
||||||
|
|
||||||
|
# Build project name lookup
|
||||||
|
proj_names: dict[str, str] = {
|
||||||
|
p.get("id", ""): p.get("name", "") for p in fixture.projects_list
|
||||||
|
}
|
||||||
|
proj_names["new"] = "(new project)"
|
||||||
|
|
||||||
|
def _proj_label(pid: str, new_name: str | None = None) -> str:
|
||||||
|
name = proj_names.get(pid, "?")
|
||||||
|
if pid == "new" and new_name:
|
||||||
|
return f"new → \"{new_name}\""
|
||||||
|
return f"{pid} ({name})" if name and name != "?" else pid
|
||||||
|
|
||||||
|
def _short_file(path: str) -> str:
|
||||||
|
"""Use just the filename for cleaner display."""
|
||||||
|
return path.rsplit("/", 1)[-1] if "/" in path else path
|
||||||
|
|
||||||
|
expected_files = {ec.file for ec in fixture.expected_classification}
|
||||||
|
total = len(fixture.expected_classification)
|
||||||
|
matched = 0
|
||||||
|
|
||||||
|
scored_lines: list[str] = []
|
||||||
|
info_lines: list[str] = []
|
||||||
|
|
||||||
|
# Score expected files
|
||||||
|
for ec in fixture.expected_classification:
|
||||||
|
actual = next((r for r in results if r["file"] == ec.file), None)
|
||||||
|
fname = _short_file(ec.file)
|
||||||
|
if actual is None:
|
||||||
|
scored_lines.append(f" MISS {fname}")
|
||||||
|
scored_lines.append(f" expected: {_proj_label(ec.project_id)}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
pid_ok = actual["project_id"] == ec.project_id
|
||||||
|
domains_ok = set(actual["domains"]) == set(ec.domains) if ec.domains else True
|
||||||
|
|
||||||
|
if pid_ok and domains_ok:
|
||||||
|
matched += 1
|
||||||
|
scored_lines.append(f" OK {fname}")
|
||||||
|
scored_lines.append(f" project: {_proj_label(actual['project_id'])}")
|
||||||
|
scored_lines.append(f" domains: {actual['domains']}")
|
||||||
|
else:
|
||||||
|
scored_lines.append(f" FAIL {fname}")
|
||||||
|
if not pid_ok:
|
||||||
|
scored_lines.append(f" project: {_proj_label(actual['project_id'])} (expected: {_proj_label(ec.project_id)})")
|
||||||
|
else:
|
||||||
|
scored_lines.append(f" project: {_proj_label(actual['project_id'])}")
|
||||||
|
if not domains_ok:
|
||||||
|
scored_lines.append(f" domains: {actual['domains']} (expected: {ec.domains})")
|
||||||
|
else:
|
||||||
|
scored_lines.append(f" domains: {actual['domains']}")
|
||||||
|
|
||||||
|
# Show unscored files
|
||||||
|
for r in results:
|
||||||
|
if r["file"] not in expected_files:
|
||||||
|
fname = _short_file(r["file"])
|
||||||
|
proj = _proj_label(r["project_id"], r.get("new_project_name"))
|
||||||
|
info_lines.append(f" · {fname}")
|
||||||
|
info_lines.append(f" project: {proj} | domains: {r['domains']}")
|
||||||
|
|
||||||
|
precision = matched / total if total > 0 else 0.0
|
||||||
|
recall = precision
|
||||||
|
f1 = precision
|
||||||
|
|
||||||
|
parts: list[str] = []
|
||||||
|
if scored_lines:
|
||||||
|
parts.append(f"Scored ({matched}/{total}):")
|
||||||
|
parts.extend(scored_lines)
|
||||||
|
if info_lines:
|
||||||
|
parts.append(f"\nOther files ({len(info_lines) // 2}):")
|
||||||
|
parts.extend(info_lines)
|
||||||
|
|
||||||
|
return precision, recall, f1, "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Step 2 runner ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_step2(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
model: str,
|
||||||
|
mock: MockExecutor,
|
||||||
|
) -> None:
|
||||||
|
"""Run step-2 processing for each file in the fixture directory.
|
||||||
|
|
||||||
|
Compiles ``_PROCESSING_SYSTEM_PROMPT`` with fixture-provided variables
|
||||||
|
and runs the tool-calling loop. Mutations are captured by the mock.
|
||||||
|
"""
|
||||||
|
from app.agent_runner import (
|
||||||
|
_PROCESSING_SYSTEM_PROMPT,
|
||||||
|
_build_processing_tools,
|
||||||
|
_run_agent_with_tools,
|
||||||
|
_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
from app import tracing
|
||||||
|
|
||||||
|
# Compile the processing prompt with fixture variables
|
||||||
|
system_prompt = tracing.compile_prompt(
|
||||||
|
"batch_processing",
|
||||||
|
fallback=_PROCESSING_SYSTEM_PROMPT,
|
||||||
|
variables={
|
||||||
|
"existing_context": fixture.existing_context,
|
||||||
|
"project_context": fixture.project_context,
|
||||||
|
"data_types": ", ".join(fixture.data_types),
|
||||||
|
"custom_prompt_section": fixture.custom_prompt_section,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tools = _build_processing_tools(fixture.data_types)
|
||||||
|
|
||||||
|
# Scan files in the fixture directory
|
||||||
|
file_entries = await mock._handle(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": fixture.directory},
|
||||||
|
)
|
||||||
|
for entry in file_entries.get("entries", []):
|
||||||
|
if entry.get("type") != "file":
|
||||||
|
continue
|
||||||
|
# Filter by extension if specified
|
||||||
|
if fixture.file_extensions:
|
||||||
|
ext = entry["name"].rsplit(".", 1)[-1] if "." in entry["name"] else ""
|
||||||
|
if ext not in fixture.file_extensions:
|
||||||
|
continue
|
||||||
|
|
||||||
|
file_result = await mock._handle(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": entry["path"]},
|
||||||
|
)
|
||||||
|
file_content: str = file_result.get("content", "")
|
||||||
|
if not file_content.strip():
|
||||||
|
continue
|
||||||
|
|
||||||
|
await _run_agent_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_message=(
|
||||||
|
f"Process this file and extract relevant information.\n\n"
|
||||||
|
f"File: {entry['path']}\n\nContent:\n{file_content}"
|
||||||
|
),
|
||||||
|
tools=tools,
|
||||||
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Full runner ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_full(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
model: str,
|
||||||
|
mock: MockExecutor,
|
||||||
|
user_id: str,
|
||||||
|
) -> None:
|
||||||
|
"""Run the full two-step pipeline via ``run_local_agent``."""
|
||||||
|
from app.agent_runner import run_local_agent
|
||||||
|
|
||||||
|
trigger_data: dict[str, Any] = {
|
||||||
|
"type": "agent_trigger",
|
||||||
|
"directory": fixture.directory,
|
||||||
|
"directory_paths": [fixture.directory],
|
||||||
|
"data_types": fixture.data_types,
|
||||||
|
"file_extensions": fixture.file_extensions,
|
||||||
|
"prompt_template": fixture.custom_prompt_section,
|
||||||
|
"device_id": "eval-harness",
|
||||||
|
"run_context": {
|
||||||
|
"agent_id": f"eval-{fixture.name}",
|
||||||
|
"run_id": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
with mock.patch():
|
||||||
|
await run_local_agent(user_id, trigger_data)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scoring helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _score_mutations(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
mock: MockExecutor,
|
||||||
|
) -> tuple[list[FieldScore], float, float, float, int, int]:
|
||||||
|
"""Score mutations against expected records.
|
||||||
|
|
||||||
|
Returns (field_scores, precision, recall, f1, extra, missing).
|
||||||
|
"""
|
||||||
|
all_field_scores: list[FieldScore] = []
|
||||||
|
total_expected = 0
|
||||||
|
total_actual = 0
|
||||||
|
total_matched = 0
|
||||||
|
total_extra = 0
|
||||||
|
total_missing = 0
|
||||||
|
|
||||||
|
expected_by_table: dict[str, list[dict]] = {}
|
||||||
|
for rec in fixture.expected:
|
||||||
|
expected_by_table.setdefault(rec.table, []).append(rec.fields)
|
||||||
|
|
||||||
|
tables = set(expected_by_table.keys()) | {m.table for m in mock.mutations}
|
||||||
|
for table in tables:
|
||||||
|
expected_records = expected_by_table.get(table, [])
|
||||||
|
actual_records = mock.created_records(table) + mock.updated_records(table)
|
||||||
|
|
||||||
|
field_scores, extra, missing = score_field_match(expected_records, actual_records, table)
|
||||||
|
all_field_scores.extend(field_scores)
|
||||||
|
|
||||||
|
matched = sum(1 for s in field_scores if s.best_match is not None)
|
||||||
|
total_expected += len(expected_records)
|
||||||
|
total_actual += len(actual_records)
|
||||||
|
total_matched += matched
|
||||||
|
total_extra += extra
|
||||||
|
total_missing += missing
|
||||||
|
|
||||||
|
precision, recall, f1 = compute_precision_recall(total_expected, total_actual, total_matched)
|
||||||
|
return all_field_scores, precision, recall, f1, total_extra, total_missing
|
||||||
|
|
||||||
|
|
||||||
|
# ── Main entry point ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_single_eval(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
model: str,
|
||||||
|
*,
|
||||||
|
use_llm_judge: bool = True,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> EvalScores:
|
||||||
|
"""Execute one eval run for a fixture + model. Mode is read from the fixture."""
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.ws_context import set_current_user, clear_current_user
|
||||||
|
|
||||||
|
seed = copy.deepcopy(fixture.seed_records)
|
||||||
|
mock = MockExecutor(
|
||||||
|
fixture_dir=fixture.fixture_path.parent,
|
||||||
|
seed_records=seed,
|
||||||
|
)
|
||||||
|
|
||||||
|
original_model = settings.LLM_MODEL
|
||||||
|
settings.LLM_MODEL = model
|
||||||
|
eval_user_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"eval: starting %s | mode=%s | model=%s",
|
||||||
|
fixture.name, fixture.mode, model,
|
||||||
|
)
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
step1_results: list[dict[str, Any]] = []
|
||||||
|
step1_reasoning = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
set_current_user(eval_user_id)
|
||||||
|
|
||||||
|
if fixture.mode == "step1":
|
||||||
|
with mock.patch():
|
||||||
|
step1_results = await _run_step1(fixture, model, mock)
|
||||||
|
|
||||||
|
elif fixture.mode == "step2":
|
||||||
|
with mock.patch():
|
||||||
|
await _run_step2(fixture, model, mock)
|
||||||
|
|
||||||
|
elif fixture.mode == "full":
|
||||||
|
with mock.patch():
|
||||||
|
# Step 1 — classification (independent from run_local_agent)
|
||||||
|
if fixture.expected_classification:
|
||||||
|
step1_results = await _run_step1(fixture, model, mock)
|
||||||
|
|
||||||
|
# Step 2 — full pipeline (run_local_agent handles both steps)
|
||||||
|
await _run_full(fixture, model, mock, eval_user_id)
|
||||||
|
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("eval: pipeline failed for %s/%s: %s", fixture.name, model, exc)
|
||||||
|
finally:
|
||||||
|
settings.LLM_MODEL = original_model
|
||||||
|
clear_current_user()
|
||||||
|
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
logger.info("eval: completed in %.1fs — %d mutations", elapsed, len(mock.mutations))
|
||||||
|
|
||||||
|
# ── Score ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
if fixture.mode == "step1":
|
||||||
|
s1_precision, s1_recall, s1_f1, step1_reasoning = _score_step1(fixture, step1_results)
|
||||||
|
scores = EvalScores(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant=fixture.mode,
|
||||||
|
precision=s1_precision,
|
||||||
|
recall=s1_recall,
|
||||||
|
f1=s1_f1,
|
||||||
|
llm_judge_reasoning=step1_reasoning,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# step2 or full — score mutations
|
||||||
|
field_scores, precision, recall, f1, extra, missing = _score_mutations(fixture, mock)
|
||||||
|
scores = EvalScores(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant=fixture.mode,
|
||||||
|
field_scores=field_scores,
|
||||||
|
precision=precision,
|
||||||
|
recall=recall,
|
||||||
|
f1=f1,
|
||||||
|
extra_records=extra,
|
||||||
|
missing_records=missing,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Add step1 classification scores for full mode
|
||||||
|
if fixture.mode == "full" and fixture.expected_classification:
|
||||||
|
s1_p, s1_r, s1_f1, step1_reasoning = _score_step1(fixture, step1_results)
|
||||||
|
scores.llm_judge_reasoning = f"Step1 classification:\n{step1_reasoning}"
|
||||||
|
|
||||||
|
# Optional LLM judge for extraction quality
|
||||||
|
if use_llm_judge and fixture.expected:
|
||||||
|
all_expected = [r.fields for r in fixture.expected]
|
||||||
|
all_actual = [m.data for m in mock.mutations if m.action in ("insert", "update")]
|
||||||
|
judge_score, reasoning = await llm_judge_score(
|
||||||
|
all_expected, all_actual, judge_model=judge_model,
|
||||||
|
)
|
||||||
|
scores.llm_judge_score = judge_score
|
||||||
|
if step1_reasoning:
|
||||||
|
scores.llm_judge_reasoning += f"\n\nLLM judge:\n{reasoning}"
|
||||||
|
else:
|
||||||
|
scores.llm_judge_reasoning = reasoning
|
||||||
|
|
||||||
|
# ── Report to Langfuse ────────────────────────────────────────
|
||||||
|
prompt_names = {
|
||||||
|
"step1": ["batch_file_classifier"],
|
||||||
|
"step2": ["batch_processing"],
|
||||||
|
"full": ["batch_file_classifier", "batch_processing"],
|
||||||
|
}.get(fixture.mode, ["batch_processing"])
|
||||||
|
|
||||||
|
trace_id = langfuse_eval.log_eval_trace(
|
||||||
|
fixture_name=fixture.name,
|
||||||
|
model=model,
|
||||||
|
prompt_variant=fixture.mode,
|
||||||
|
prompt_template=fixture.custom_prompt_section or "(default)",
|
||||||
|
actual_mutations=[{"action": m.action, "table": m.table, "data": m.data} for m in mock.mutations],
|
||||||
|
scores_summary=scores.summary(),
|
||||||
|
step1_results=step1_results or None,
|
||||||
|
langfuse_prompt_names=prompt_names,
|
||||||
|
)
|
||||||
|
|
||||||
|
if trace_id:
|
||||||
|
langfuse_eval.post_eval_scores(scores, trace_id=trace_id)
|
||||||
|
|
||||||
|
# For full mode, post classification scores separately
|
||||||
|
if fixture.mode == "full" and fixture.expected_classification:
|
||||||
|
s1_p, s1_r, s1_f1, _ = _score_step1(fixture, step1_results)
|
||||||
|
for name, value in [
|
||||||
|
("classification_precision", s1_p),
|
||||||
|
("classification_recall", s1_r),
|
||||||
|
("classification_f1", s1_f1),
|
||||||
|
]:
|
||||||
|
try:
|
||||||
|
from langfuse import get_client
|
||||||
|
lf = get_client()
|
||||||
|
if lf:
|
||||||
|
lf.create_score(
|
||||||
|
name=name,
|
||||||
|
value=value,
|
||||||
|
trace_id=trace_id,
|
||||||
|
data_type="NUMERIC",
|
||||||
|
comment=f"{fixture.name} | {model} | full",
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return scores
|
||||||
|
|
||||||
|
|
||||||
|
async def run_fixture_eval(
|
||||||
|
fixture: EvalFixture,
|
||||||
|
models: list[str],
|
||||||
|
*,
|
||||||
|
use_llm_judge: bool = True,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> list[EvalScores]:
|
||||||
|
"""Run all models for a fixture."""
|
||||||
|
langfuse_eval.sync_fixture_to_dataset(fixture)
|
||||||
|
|
||||||
|
results: list[EvalScores] = []
|
||||||
|
for model in models:
|
||||||
|
scores = await run_single_eval(
|
||||||
|
fixture, model,
|
||||||
|
use_llm_judge=use_llm_judge,
|
||||||
|
judge_model=judge_model,
|
||||||
|
)
|
||||||
|
results.append(scores)
|
||||||
|
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def print_results(results: list[EvalScores]) -> None:
|
||||||
|
"""Print a formatted summary table of eval results."""
|
||||||
|
if not results:
|
||||||
|
print("\nNo eval results.")
|
||||||
|
return
|
||||||
|
|
||||||
|
W = 90
|
||||||
|
|
||||||
|
print("\n" + "=" * W)
|
||||||
|
print(f"{'Fixture':<25} {'Mode':<6} {'Model':<25} {'P':>6} {'R':>6} {'F1':>6} {'FA':>6} {'LLM':>6}")
|
||||||
|
print("-" * W)
|
||||||
|
|
||||||
|
for s in results:
|
||||||
|
llm_str = f"{s.llm_judge_score:.2f}" if s.llm_judge_score is not None else " --"
|
||||||
|
fa_str = f"{s.field_accuracy:.2f}" if s.field_scores else " --"
|
||||||
|
print(
|
||||||
|
f"{s.fixture_name:<25} {s.prompt_variant:<6} {s.model:<25} "
|
||||||
|
f"{s.precision:>6.2f} {s.recall:>6.2f} {s.f1:>6.2f} "
|
||||||
|
f"{fa_str:>6} {llm_str:>6}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print("=" * W)
|
||||||
|
|
||||||
|
for s in results:
|
||||||
|
if s.llm_judge_reasoning:
|
||||||
|
print(f"\n{'─' * W}")
|
||||||
|
print(f" {s.fixture_name} | {s.model} | {s.prompt_variant}")
|
||||||
|
print(f"{'─' * W}")
|
||||||
|
print(s.llm_judge_reasoning)
|
||||||
|
|
||||||
|
print()
|
||||||
268
services/batch-agent/eval/scorer.py
Normal file
268
services/batch-agent/eval/scorer.py
Normal file
@@ -0,0 +1,268 @@
|
|||||||
|
"""Scoring functions for batch agent evaluation.
|
||||||
|
|
||||||
|
Two scoring strategies:
|
||||||
|
|
||||||
|
1. **FieldMatchScorer** — deterministic check: for each expected record,
|
||||||
|
find the best-matching actual record and compare specified fields.
|
||||||
|
Returns precision, recall, and per-field accuracy.
|
||||||
|
|
||||||
|
2. **LLMJudgeScorer** — uses a secondary LLM to semantically evaluate
|
||||||
|
whether the actual extractions satisfy the expected intent, even if
|
||||||
|
wording differs. Returns a 0-1 score + reasoning.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from difflib import SequenceMatcher
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Result types ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class FieldScore:
|
||||||
|
"""Score for a single expected record against its best match."""
|
||||||
|
|
||||||
|
expected: dict[str, Any]
|
||||||
|
best_match: dict[str, Any] | None
|
||||||
|
matched_fields: dict[str, bool]
|
||||||
|
similarity: float # 0-1 overall similarity
|
||||||
|
|
||||||
|
@property
|
||||||
|
def field_accuracy(self) -> float:
|
||||||
|
if not self.matched_fields:
|
||||||
|
return 0.0
|
||||||
|
return sum(self.matched_fields.values()) / len(self.matched_fields)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class EvalScores:
|
||||||
|
"""Aggregated scores for one eval run."""
|
||||||
|
|
||||||
|
fixture_name: str
|
||||||
|
model: str
|
||||||
|
prompt_variant: str
|
||||||
|
field_scores: list[FieldScore] = field(default_factory=list)
|
||||||
|
precision: float = 0.0
|
||||||
|
recall: float = 0.0
|
||||||
|
f1: float = 0.0
|
||||||
|
llm_judge_score: float | None = None
|
||||||
|
llm_judge_reasoning: str = ""
|
||||||
|
extra_records: int = 0 # records created but not expected
|
||||||
|
missing_records: int = 0 # expected but not found
|
||||||
|
|
||||||
|
@property
|
||||||
|
def field_accuracy(self) -> float:
|
||||||
|
if not self.field_scores:
|
||||||
|
return 0.0
|
||||||
|
return sum(s.field_accuracy for s in self.field_scores) / len(self.field_scores)
|
||||||
|
|
||||||
|
def summary(self) -> dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"fixture": self.fixture_name,
|
||||||
|
"model": self.model,
|
||||||
|
"prompt_variant": self.prompt_variant,
|
||||||
|
"precision": round(self.precision, 3),
|
||||||
|
"recall": round(self.recall, 3),
|
||||||
|
"f1": round(self.f1, 3),
|
||||||
|
"field_accuracy": round(self.field_accuracy, 3),
|
||||||
|
"llm_judge_score": round(self.llm_judge_score, 3) if self.llm_judge_score is not None else None,
|
||||||
|
"extra_records": self.extra_records,
|
||||||
|
"missing_records": self.missing_records,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Field Match Scorer ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize(value: Any) -> str:
|
||||||
|
"""Normalize a value for comparison."""
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
return str(value).strip().lower()
|
||||||
|
|
||||||
|
|
||||||
|
def _text_similarity(a: str, b: str) -> float:
|
||||||
|
"""Fuzzy text similarity using SequenceMatcher."""
|
||||||
|
if not a and not b:
|
||||||
|
return 1.0
|
||||||
|
if not a or not b:
|
||||||
|
return 0.0
|
||||||
|
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
|
||||||
|
|
||||||
|
|
||||||
|
def _find_best_match(
|
||||||
|
expected: dict[str, Any],
|
||||||
|
actuals: list[dict[str, Any]],
|
||||||
|
) -> tuple[dict[str, Any] | None, float]:
|
||||||
|
"""Find the actual record most similar to expected, return (match, similarity)."""
|
||||||
|
if not actuals:
|
||||||
|
return None, 0.0
|
||||||
|
|
||||||
|
best_match = None
|
||||||
|
best_score = 0.0
|
||||||
|
|
||||||
|
# Primary matching key: title or name
|
||||||
|
expected_title = _normalize(expected.get("title", expected.get("name", "")))
|
||||||
|
|
||||||
|
for actual in actuals:
|
||||||
|
actual_title = _normalize(actual.get("title", actual.get("name", "")))
|
||||||
|
sim = _text_similarity(expected_title, actual_title)
|
||||||
|
if sim > best_score:
|
||||||
|
best_score = sim
|
||||||
|
best_match = actual
|
||||||
|
|
||||||
|
return best_match, best_score
|
||||||
|
|
||||||
|
|
||||||
|
def _compare_fields(
|
||||||
|
expected: dict[str, Any],
|
||||||
|
actual: dict[str, Any],
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Compare each expected field against the actual record."""
|
||||||
|
results: dict[str, bool] = {}
|
||||||
|
for key, expected_val in expected.items():
|
||||||
|
actual_val = actual.get(key)
|
||||||
|
# Exact match for non-string types
|
||||||
|
if not isinstance(expected_val, str):
|
||||||
|
results[key] = actual_val == expected_val
|
||||||
|
else:
|
||||||
|
# Fuzzy match for strings (threshold: 0.7)
|
||||||
|
results[key] = _text_similarity(
|
||||||
|
_normalize(expected_val), _normalize(actual_val)
|
||||||
|
) >= 0.7
|
||||||
|
return results
|
||||||
|
|
||||||
|
|
||||||
|
def score_field_match(
|
||||||
|
expected_records: list[dict[str, Any]],
|
||||||
|
actual_records: list[dict[str, Any]],
|
||||||
|
table: str,
|
||||||
|
) -> tuple[list[FieldScore], int, int]:
|
||||||
|
"""Score actual extractions against expected records for one table.
|
||||||
|
|
||||||
|
Returns (field_scores, extra_count, missing_count).
|
||||||
|
"""
|
||||||
|
field_scores: list[FieldScore] = []
|
||||||
|
matched_actuals: set[int] = set()
|
||||||
|
|
||||||
|
for exp in expected_records:
|
||||||
|
# Find best match among unmatched actuals
|
||||||
|
candidates = [
|
||||||
|
(i, a) for i, a in enumerate(actual_records) if i not in matched_actuals
|
||||||
|
]
|
||||||
|
if not candidates:
|
||||||
|
field_scores.append(FieldScore(
|
||||||
|
expected=exp, best_match=None, matched_fields={}, similarity=0.0,
|
||||||
|
))
|
||||||
|
continue
|
||||||
|
|
||||||
|
best_idx, best_match = None, None
|
||||||
|
best_sim = 0.0
|
||||||
|
for idx, actual in candidates:
|
||||||
|
_, sim = _find_best_match(exp, [actual])
|
||||||
|
if sim > best_sim:
|
||||||
|
best_sim = sim
|
||||||
|
best_idx = idx
|
||||||
|
best_match = actual
|
||||||
|
|
||||||
|
if best_sim >= 0.5 and best_match is not None:
|
||||||
|
matched_actuals.add(best_idx)
|
||||||
|
matched_fields = _compare_fields(exp, best_match)
|
||||||
|
field_scores.append(FieldScore(
|
||||||
|
expected=exp, best_match=best_match,
|
||||||
|
matched_fields=matched_fields, similarity=best_sim,
|
||||||
|
))
|
||||||
|
else:
|
||||||
|
field_scores.append(FieldScore(
|
||||||
|
expected=exp, best_match=None, matched_fields={}, similarity=0.0,
|
||||||
|
))
|
||||||
|
|
||||||
|
extra_count = len(actual_records) - len(matched_actuals)
|
||||||
|
missing_count = sum(1 for s in field_scores if s.best_match is None)
|
||||||
|
|
||||||
|
return field_scores, extra_count, missing_count
|
||||||
|
|
||||||
|
|
||||||
|
def compute_precision_recall(
|
||||||
|
expected_count: int,
|
||||||
|
actual_count: int,
|
||||||
|
matched_count: int,
|
||||||
|
) -> tuple[float, float, float]:
|
||||||
|
"""Compute precision, recall, F1."""
|
||||||
|
precision = matched_count / actual_count if actual_count > 0 else 0.0
|
||||||
|
recall = matched_count / expected_count if expected_count > 0 else 0.0
|
||||||
|
f1 = (
|
||||||
|
2 * precision * recall / (precision + recall)
|
||||||
|
if (precision + recall) > 0
|
||||||
|
else 0.0
|
||||||
|
)
|
||||||
|
return precision, recall, f1
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM Judge Scorer ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_JUDGE_SYSTEM_PROMPT = """\
|
||||||
|
You are an evaluation judge for a data extraction system.
|
||||||
|
|
||||||
|
Your task is to compare the EXPECTED extractions against the ACTUAL extractions
|
||||||
|
produced by an AI agent, and assess quality on a 0-1 scale.
|
||||||
|
|
||||||
|
Scoring criteria:
|
||||||
|
- 1.0: All expected records found with correct fields, no significant extras
|
||||||
|
- 0.8: Most expected records found, minor field differences or extras
|
||||||
|
- 0.6: Core extractions present but some missing or incorrect
|
||||||
|
- 0.4: Partial match — several expected records missing or wrong
|
||||||
|
- 0.2: Poor quality — most expected records missing or incorrect
|
||||||
|
- 0.0: Complete failure — no meaningful overlap
|
||||||
|
|
||||||
|
Consider semantic equivalence: "Send invoice" and "Email the invoice" are matches.
|
||||||
|
Ignore field ordering and formatting differences.
|
||||||
|
|
||||||
|
Respond with ONLY a JSON object:
|
||||||
|
{"score": 0.85, "reasoning": "Brief explanation of the score"}
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
async def llm_judge_score(
|
||||||
|
expected: list[dict[str, Any]],
|
||||||
|
actual: list[dict[str, Any]],
|
||||||
|
*,
|
||||||
|
judge_model: str = "gpt-4o-mini",
|
||||||
|
) -> tuple[float, str]:
|
||||||
|
"""Use an LLM to semantically evaluate extraction quality.
|
||||||
|
|
||||||
|
Returns (score, reasoning).
|
||||||
|
"""
|
||||||
|
from shared.llm import get_llm
|
||||||
|
|
||||||
|
llm = get_llm(model=judge_model, temperature=0)
|
||||||
|
|
||||||
|
user_content = (
|
||||||
|
f"## Expected extractions\n```json\n{json.dumps(expected, indent=2, default=str)}\n```\n\n"
|
||||||
|
f"## Actual extractions\n```json\n{json.dumps(actual, indent=2, default=str)}\n```"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke([
|
||||||
|
SystemMessage(content=_JUDGE_SYSTEM_PROMPT),
|
||||||
|
HumanMessage(content=user_content),
|
||||||
|
])
|
||||||
|
raw = response.content.strip()
|
||||||
|
if raw.startswith("```"):
|
||||||
|
raw = raw.split("```")[1]
|
||||||
|
if raw.startswith("json"):
|
||||||
|
raw = raw[4:]
|
||||||
|
parsed = json.loads(raw.strip())
|
||||||
|
return float(parsed.get("score", 0.0)), str(parsed.get("reasoning", ""))
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("eval: LLM judge failed: %s", exc)
|
||||||
|
return 0.0, f"Judge error: {exc}"
|
||||||
21
services/batch-agent/requirements.txt
Normal file
21
services/batch-agent/requirements.txt
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
redis>=5.0.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
langchain-core>=0.3.0
|
||||||
|
langchain-openai>=0.3.0
|
||||||
|
langchain-litellm>=0.3.0
|
||||||
|
litellm>=1.50.0
|
||||||
|
openai>=1.50.0
|
||||||
|
httpx>=0.27.0
|
||||||
|
langfuse>=3.0.0
|
||||||
|
croniter>=2.0.0
|
||||||
|
google-api-python-client>=2.130.0
|
||||||
|
google-auth>=2.30.0
|
||||||
|
msal>=1.28.0
|
||||||
@@ -3,29 +3,34 @@ FROM python:3.12-slim AS builder
|
|||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
||||||
COPY requirements.txt .
|
COPY services/billing/requirements.txt ./requirements.txt
|
||||||
RUN pip install --upgrade pip && \
|
RUN pip install --upgrade pip && \
|
||||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
FROM python:3.12-slim AS runtime
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
# Non-root user
|
|
||||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
# Copy installed packages from builder
|
|
||||||
COPY --from=builder /install /usr/local
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
# Copy application source
|
# Shared module
|
||||||
COPY app/ app/
|
COPY shared/ shared/
|
||||||
|
|
||||||
|
# Service source
|
||||||
|
COPY services/billing/app/ app/
|
||||||
|
|
||||||
# Ensure appuser owns the working directory
|
|
||||||
RUN chown -R appuser:appgroup /app
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
USER appuser
|
USER appuser
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "2"]
|
# Billing is lightweight — single worker is fine
|
||||||
|
CMD ["gunicorn", "app.main:app", \
|
||||||
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
|
"--bind", "0.0.0.0:8000", \
|
||||||
|
"--workers", "1", \
|
||||||
|
"--timeout", "30"]
|
||||||
15
services/billing/README.md
Normal file
15
services/billing/README.md
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
# Billing Service
|
||||||
|
|
||||||
|
Owns: Stripe integration, tier management, subscription CRUD.
|
||||||
|
|
||||||
|
## Tables owned (write)
|
||||||
|
- `subscriptions`
|
||||||
|
|
||||||
|
## Endpoints
|
||||||
|
- `POST /billing/checkout`
|
||||||
|
- `POST /billing/webhook` (Stripe, no JWT auth)
|
||||||
|
- `GET /billing/subscription`
|
||||||
|
- `DELETE /billing/subscription`
|
||||||
|
|
||||||
|
## Redis channels
|
||||||
|
- Publish: `tier:changed:{user_id}` on tier change
|
||||||
53
services/billing/app/main.py
Normal file
53
services/billing/app/main.py
Normal file
@@ -0,0 +1,53 @@
|
|||||||
|
"""Billing Service — FastAPI application.
|
||||||
|
|
||||||
|
Owns: Stripe checkout/webhook, subscription management, tier feature matrix,
|
||||||
|
quota enforcement.
|
||||||
|
|
||||||
|
Downstream services query this service (or read the user's tier from
|
||||||
|
the X-User-Tier header injected by Traefik) for billing decisions.
|
||||||
|
The webhook endpoint is exposed WITHOUT ForwardAuth so Stripe can reach it.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
from contextlib import asynccontextmanager
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
|
||||||
|
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||||
|
if _repo_root not in sys.path:
|
||||||
|
sys.path.insert(0, _repo_root)
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
from app.routes import router
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||||
|
logger.info("billing: service started")
|
||||||
|
yield
|
||||||
|
logger.info("billing: service stopped")
|
||||||
|
|
||||||
|
|
||||||
|
app = FastAPI(title="Adiuva Billing Service", lifespan=lifespan)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["GET", "POST", "DELETE"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
app.include_router(router)
|
||||||
|
|
||||||
|
|
||||||
|
@app.get("/health")
|
||||||
|
async def health() -> dict[str, str]:
|
||||||
|
return {"status": "ok", "service": "billing"}
|
||||||
134
services/billing/app/routes.py
Normal file
134
services/billing/app/routes.py
Normal file
@@ -0,0 +1,134 @@
|
|||||||
|
"""Billing routes: Stripe checkout, webhook, subscription, tier query.
|
||||||
|
|
||||||
|
Adapted for the Billing microservice:
|
||||||
|
- Authenticated routes use Traefik-injected headers (X-User-Id, X-User-Tier)
|
||||||
|
- Webhook route has NO auth (Stripe signature verification only)
|
||||||
|
- Added /tier/{user_id} for internal service-to-service tier lookups
|
||||||
|
- Added /features/{tier} for feature matrix queries
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Header, HTTPException, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from shared.db import async_session
|
||||||
|
from shared.schemas import BillingTier
|
||||||
|
|
||||||
|
from app.stripe_service import stripe_service
|
||||||
|
from app.tier_manager import tier_manager, FEATURES, RATE_LIMITS
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CheckoutRequest(BaseModel):
|
||||||
|
tier: BillingTier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Checkout ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/checkout")
|
||||||
|
async def create_checkout(
|
||||||
|
body: _CheckoutRequest,
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Create a Stripe checkout session for a tier upgrade."""
|
||||||
|
url = stripe_service.create_checkout_session(x_user_id, body.tier)
|
||||||
|
return {"checkout_url": url}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Webhook (NO auth — Stripe signature only) ─────────────────────────
|
||||||
|
|
||||||
|
@router.post("/webhook")
|
||||||
|
async def stripe_webhook(
|
||||||
|
request: Request,
|
||||||
|
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Handle Stripe webhook events.
|
||||||
|
|
||||||
|
This endpoint is exposed without ForwardAuth in Traefik config
|
||||||
|
so Stripe can reach it directly.
|
||||||
|
"""
|
||||||
|
payload = await request.body()
|
||||||
|
async with async_session() as db:
|
||||||
|
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Subscription CRUD ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/subscription")
|
||||||
|
async def get_subscription(
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return the current subscription info for the authenticated user."""
|
||||||
|
async with async_session() as db:
|
||||||
|
sub = await stripe_service.get_subscription(x_user_id, db)
|
||||||
|
if sub is None:
|
||||||
|
return {
|
||||||
|
"tier": x_user_tier,
|
||||||
|
"status": "free",
|
||||||
|
"stripe_subscription_id": None,
|
||||||
|
"current_period_end": None,
|
||||||
|
}
|
||||||
|
return sub
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/subscription")
|
||||||
|
async def cancel_subscription(
|
||||||
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Cancel the active subscription."""
|
||||||
|
async with async_session() as db:
|
||||||
|
await stripe_service.cancel_subscription(x_user_id, db)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier query (internal, service-to-service) ─────────────────────────
|
||||||
|
|
||||||
|
@router.get("/tier/{user_id}")
|
||||||
|
async def get_user_tier(user_id: str) -> dict[str, str]:
|
||||||
|
"""Return the billing tier for a given user_id.
|
||||||
|
|
||||||
|
Used by other services for tier lookups. Protected by Traefik
|
||||||
|
ForwardAuth — only internal services should call this.
|
||||||
|
"""
|
||||||
|
async with async_session() as db:
|
||||||
|
tier = await tier_manager.get_tier(user_id, db)
|
||||||
|
return {"user_id": user_id, "tier": tier}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Feature matrix (public, cacheable) ────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/features/{tier}")
|
||||||
|
async def get_tier_features(tier: str) -> dict[str, Any]:
|
||||||
|
"""Return the feature matrix for a tier."""
|
||||||
|
if tier not in FEATURES:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail=f"Unknown tier: {tier}",
|
||||||
|
)
|
||||||
|
return {
|
||||||
|
"tier": tier,
|
||||||
|
"features": FEATURES[tier],
|
||||||
|
"rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/features")
|
||||||
|
async def get_all_features() -> dict[str, Any]:
|
||||||
|
"""Return the full feature matrix for all tiers."""
|
||||||
|
return {
|
||||||
|
"tiers": {
|
||||||
|
tier: {
|
||||||
|
"features": features,
|
||||||
|
"rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]),
|
||||||
|
}
|
||||||
|
for tier, features in FEATURES.items()
|
||||||
|
},
|
||||||
|
}
|
||||||
240
services/billing/app/stripe_service.py
Normal file
240
services/billing/app/stripe_service.py
Normal file
@@ -0,0 +1,240 @@
|
|||||||
|
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||||
|
|
||||||
|
Adapted for the Billing microservice — uses shared.models and shared.db.
|
||||||
|
All Stripe calls are gracefully stubbed when STRIPE_SECRET_KEY is not
|
||||||
|
configured, enabling local development without live credentials.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.models import Subscription
|
||||||
|
|
||||||
|
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||||
|
TIER_PRICE_IDS: dict[str, str] = {
|
||||||
|
"pro": "price_pro_monthly",
|
||||||
|
"power": "price_power_monthly",
|
||||||
|
"team": "price_team_monthly",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class StripeService:
|
||||||
|
"""Wraps all Stripe interactions and owns subscription persistence."""
|
||||||
|
|
||||||
|
# ── Internal helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _configured(self) -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Public API ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_checkout_session(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
tier: str,
|
||||||
|
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
|
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||||
|
) -> str:
|
||||||
|
"""Create a Stripe checkout session and return the URL."""
|
||||||
|
if tier == "free":
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Cannot create a checkout session for the free tier",
|
||||||
|
)
|
||||||
|
|
||||||
|
price_id = TIER_PRICE_IDS.get(tier)
|
||||||
|
if not price_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail=f"Unknown tier: {tier}",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self._configured():
|
||||||
|
return "https://stripe.com/stub-checkout"
|
||||||
|
|
||||||
|
s = self._client()
|
||||||
|
session = s.checkout.Session.create(
|
||||||
|
payment_method_types=["card"],
|
||||||
|
mode="subscription",
|
||||||
|
line_items=[{"price": price_id, "quantity": 1}],
|
||||||
|
success_url=success_url,
|
||||||
|
cancel_url=cancel_url,
|
||||||
|
metadata={"user_id": user_id, "tier": tier},
|
||||||
|
)
|
||||||
|
return session.url
|
||||||
|
|
||||||
|
async def handle_webhook(
|
||||||
|
self,
|
||||||
|
payload: bytes,
|
||||||
|
sig_header: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> None:
|
||||||
|
"""Process a Stripe webhook event.
|
||||||
|
|
||||||
|
Verifies the signature, then dispatches on event type.
|
||||||
|
"""
|
||||||
|
if not self._configured():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = self._client()
|
||||||
|
event = s.Webhook.construct_event(
|
||||||
|
payload, sig_header, settings.STRIPE_WEBHOOK_SECRET
|
||||||
|
)
|
||||||
|
except stripe_lib.error.SignatureVerificationError:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
|
detail="Invalid Stripe signature",
|
||||||
|
)
|
||||||
|
|
||||||
|
event_type: str = event["type"]
|
||||||
|
data: dict[str, Any] = event["data"]["object"]
|
||||||
|
|
||||||
|
if event_type == "checkout.session.completed":
|
||||||
|
user_id = data.get("metadata", {}).get("user_id")
|
||||||
|
tier = data.get("metadata", {}).get("tier", "free")
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if user_id:
|
||||||
|
await self._upsert_subscription(
|
||||||
|
db, user_id, sub_id, tier, "active", period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.updated":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
new_status = data.get("status", "active")
|
||||||
|
period_end_ts = data.get("current_period_end")
|
||||||
|
period_end = (
|
||||||
|
datetime.fromtimestamp(period_end_ts, tz=timezone.utc)
|
||||||
|
if period_end_ts
|
||||||
|
else None
|
||||||
|
)
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status=new_status, current_period_end=period_end
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "customer.subscription.deleted":
|
||||||
|
sub_id = data.get("id")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, tier="free", status="canceled"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif event_type == "invoice.payment_failed":
|
||||||
|
sub_id = data.get("subscription")
|
||||||
|
if sub_id:
|
||||||
|
await self._update_subscription_by_stripe_id(
|
||||||
|
db, sub_id, status="past_due"
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def get_subscription(
|
||||||
|
self, user_id: str, db: AsyncSession
|
||||||
|
) -> dict[str, Any] | None:
|
||||||
|
"""Return the subscription record for user_id, or None."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"tier": sub.tier,
|
||||||
|
"stripe_subscription_id": sub.stripe_subscription_id,
|
||||||
|
"status": sub.status,
|
||||||
|
"current_period_end": (
|
||||||
|
int(sub.current_period_end.timestamp() * 1000)
|
||||||
|
if sub.current_period_end
|
||||||
|
else None
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||||
|
"""Cancel the user's Stripe subscription and downgrade to free."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None or not sub.stripe_subscription_id:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_404_NOT_FOUND,
|
||||||
|
detail="No active subscription found",
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._configured():
|
||||||
|
s = self._client()
|
||||||
|
s.Subscription.cancel(sub.stripe_subscription_id)
|
||||||
|
|
||||||
|
sub.tier = "free"
|
||||||
|
sub.status = "canceled"
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _upsert_subscription(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
user_id: str,
|
||||||
|
stripe_subscription_id: str | None,
|
||||||
|
tier: str,
|
||||||
|
sub_status: str,
|
||||||
|
current_period_end: datetime | None,
|
||||||
|
) -> None:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
sub = Subscription(user_id=user_id)
|
||||||
|
db.add(sub)
|
||||||
|
sub.stripe_subscription_id = stripe_subscription_id
|
||||||
|
sub.tier = tier
|
||||||
|
sub.status = sub_status
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
async def _update_subscription_by_stripe_id(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
stripe_subscription_id: str,
|
||||||
|
*,
|
||||||
|
tier: str | None = None,
|
||||||
|
status: str | None = None,
|
||||||
|
current_period_end: datetime | None = None,
|
||||||
|
) -> None:
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription).where(
|
||||||
|
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
sub = result.scalar_one_or_none()
|
||||||
|
if sub is None:
|
||||||
|
return
|
||||||
|
if tier is not None:
|
||||||
|
sub.tier = tier
|
||||||
|
if status is not None:
|
||||||
|
sub.status = status
|
||||||
|
if current_period_end is not None:
|
||||||
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
stripe_service = StripeService()
|
||||||
178
services/billing/app/tier_manager.py
Normal file
178
services/billing/app/tier_manager.py
Normal file
@@ -0,0 +1,178 @@
|
|||||||
|
"""Tier manager: feature matrix and quota enforcement.
|
||||||
|
|
||||||
|
Single source of truth for what each billing tier allows.
|
||||||
|
Other services can query the /tier/{user_id} endpoint or rely on the
|
||||||
|
X-User-Tier header injected by Traefik.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import HTTPException, status
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from shared.config import settings
|
||||||
|
from shared.models import Subscription
|
||||||
|
from shared.schemas import BillingTier
|
||||||
|
|
||||||
|
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||||
|
FEATURES: dict[str, dict[str, Any]] = {
|
||||||
|
"free": {
|
||||||
|
"agents": 3,
|
||||||
|
"batch_active": 2,
|
||||||
|
"batch_runs_per_day": 5,
|
||||||
|
"cloud_storage_gb": 0,
|
||||||
|
"backup_gb": 0,
|
||||||
|
"providers": 1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"pro": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": 10,
|
||||||
|
"batch_runs_per_day": 50,
|
||||||
|
"cloud_storage_gb": 5,
|
||||||
|
"backup_gb": 5,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"power": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1,
|
||||||
|
"batch_runs_per_day": -1,
|
||||||
|
"cloud_storage_gb": 25,
|
||||||
|
"backup_gb": 25,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": False,
|
||||||
|
},
|
||||||
|
"team": {
|
||||||
|
"agents": -1,
|
||||||
|
"batch_active": -1,
|
||||||
|
"batch_runs_per_day": -1,
|
||||||
|
"cloud_storage_gb": -1,
|
||||||
|
"backup_gb": -1,
|
||||||
|
"providers": -1,
|
||||||
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
|
"sso": True,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Requests-per-minute limit per tier.
|
||||||
|
RATE_LIMITS: dict[str, int] = {
|
||||||
|
"free": 20,
|
||||||
|
"pro": 60,
|
||||||
|
"power": 120,
|
||||||
|
"team": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TierManager:
|
||||||
|
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||||
|
|
||||||
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
|
"""Return the current billing tier for user_id from the DB."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
tier: str | None = result.scalar_one_or_none()
|
||||||
|
if tier is None or tier not in FEATURES:
|
||||||
|
return "power" if settings.ENV == "dev" else "free"
|
||||||
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
|
def get_features(self, tier: BillingTier) -> dict[str, Any]:
|
||||||
|
"""Return the full feature dict for a tier."""
|
||||||
|
return FEATURES.get(tier, FEATURES["free"])
|
||||||
|
|
||||||
|
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||||
|
"""Return True if tier has feature enabled."""
|
||||||
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||||
|
if value is None:
|
||||||
|
return False
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return value
|
||||||
|
return value != 0
|
||||||
|
|
||||||
|
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||||
|
"""Raise HTTP 403 if tier does not have feature."""
|
||||||
|
if not self.check_feature(tier, feature):
|
||||||
|
detail = (
|
||||||
|
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||||
|
if tier_name
|
||||||
|
else f"Feature '{feature}' is not available on your current tier."
|
||||||
|
)
|
||||||
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||||
|
"""Return the requests-per-minute limit for tier."""
|
||||||
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
|
def enforce_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user would exceed their cloud storage quota."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def enforce_backup_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the user would exceed their backup quota."""
|
||||||
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if the user can store additional_bytes more data."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
return False
|
||||||
|
if limit_gb == -1:
|
||||||
|
return True
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
return current_bytes + additional_bytes <= limit_bytes
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
tier_manager = TierManager()
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user