Compare commits
24 Commits
70c19d3064
...
feature/mi
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2b7d302ef2 | ||
|
|
7f6ea29525 | ||
|
|
48036397f1 | ||
|
|
57b5648915 | ||
|
|
7e4374c69b | ||
|
|
fe0dd038ee | ||
|
|
d3f7099d93 | ||
|
|
63fa119543 | ||
|
|
d856dfd28c | ||
|
|
ccba54ac24 | ||
|
|
55500cc818 | ||
|
|
75a826c9d8 | ||
|
|
971f1dd84f | ||
|
|
333bba6fdd | ||
|
|
229e20d073 | ||
|
|
0b491b3643 | ||
|
|
0d5fa3e569 | ||
|
|
aff68a9051 | ||
|
|
5e9ef2809e | ||
|
|
90018af311 | ||
|
|
1e2e395676 | ||
|
|
59d3a53980 | ||
|
|
9feeaa79c8 | ||
|
|
aa219a4d08 |
114
.env.example
114
.env.example
@@ -2,94 +2,66 @@
|
||||
ENV=dev
|
||||
|
||||
# ── Database ──────────────────────────────────────────────────────────────────
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
||||
|
||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||
JWT_SECRET=replace-with-a-long-random-secret
|
||||
JWT_ALGORITHM=HS256
|
||||
# ── Redis ─────────────────────────────────────────────────────────────────────
|
||||
REDIS_URL=redis://localhost:6379/0
|
||||
|
||||
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
|
||||
# Generate keypair:
|
||||
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||
# openssl rsa -in private.pem -pubout -out public.pem
|
||||
# Paste PEM content with literal \n for newlines.
|
||||
#
|
||||
# Private key — ONLY used by the Auth Service (JWT signing).
|
||||
JWT_PRIVATE_KEY=
|
||||
# Public key — used by all services / Traefik ForwardAuth (JWT verification).
|
||||
JWT_PUBLIC_KEY=
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||
|
||||
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||
#
|
||||
# API keys — only the key(s) matching your chosen provider(s) are required.
|
||||
# The correct key is picked automatically from the model prefix (e.g.
|
||||
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
|
||||
OPENAI_API_KEY=
|
||||
ANTHROPIC_API_KEY=
|
||||
GOOGLE_API_KEY=
|
||||
CEREBRAS_API_KEY=
|
||||
GROQ_API_KEY=
|
||||
DEEPSEEK_API_KEY=
|
||||
|
||||
# Default model used by any agent that does not have a specific override below.
|
||||
LLM_MODEL=gpt-5-mini
|
||||
LLM_EMBED_MODEL=text-embedding-3-small
|
||||
|
||||
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
|
||||
# In Docker, point this to a named-volume path so tokens survive restarts.
|
||||
# GITHUB_COPILOT_TOKEN_DIR=
|
||||
|
||||
# ── Per-agent model overrides ─────────────────────────────────────────────────
|
||||
# Leave a value empty to fall back to LLM_MODEL.
|
||||
# Each agent resolves its API key from the model prefix automatically.
|
||||
#
|
||||
# Intent classifier — routes user messages to the right domain agent.
|
||||
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
|
||||
LLM_MODEL_CLASSIFIER=
|
||||
|
||||
# Home-agent — handles chat from the home screen (all tools available).
|
||||
LLM_MODEL_HOME_AGENT=
|
||||
|
||||
# Floating-agent — handles contextual chat triggered from a task/project/note.
|
||||
LLM_MODEL_FLOATING_AGENT=
|
||||
|
||||
# Unified-processor — processes local directory files (local agent runner).
|
||||
LLM_MODEL_UNIFIED_PROCESSOR=
|
||||
|
||||
# Cloud-processor — fetches and processes data from cloud connectors.
|
||||
LLM_MODEL_CLOUD_PROCESSOR=
|
||||
|
||||
# Brief-agent — produces home and project text briefs.
|
||||
# A small model (e.g. gpt-4o-mini) is sufficient.
|
||||
# LLM_MODEL_BRIEF_AGENT=
|
||||
|
||||
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
|
||||
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
|
||||
# LLM_MODEL_TASK_BRIEF_AGENT=
|
||||
|
||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||
LLM_MODEL_SETUP_AGENT=
|
||||
|
||||
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
|
||||
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
|
||||
LLM_MODEL_MEMORY_EXTRACTOR=
|
||||
|
||||
# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only).
|
||||
# Defaults to gpt-4o-mini when empty.
|
||||
LLM_MODEL_MEMORY_MINER=
|
||||
|
||||
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
|
||||
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
|
||||
LLM_MODEL_MEMORY_AUDITOR=
|
||||
|
||||
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
|
||||
SCHEDULER_ENABLED=true
|
||||
LLM_MODEL=gpt-4o
|
||||
|
||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||
STRIPE_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
||||
S3_BUCKET=adiuva
|
||||
S3_REGION=us-east-1
|
||||
S3_ENDPOINT_URL=
|
||||
AWS_ACCESS_KEY_ID=
|
||||
AWS_SECRET_ACCESS_KEY=
|
||||
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
|
||||
|
||||
# ── Langfuse (leave empty to disable observability) ───────────────────────────
|
||||
LANGFUSE_SECRET_KEY=
|
||||
LANGFUSE_PUBLIC_KEY=
|
||||
# LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default)
|
||||
# LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US
|
||||
# LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted
|
||||
# ── Vector Store ──────────────────────────────────────────────────────────────
|
||||
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
|
||||
PINECONE_API_KEY=
|
||||
PINECONE_INDEX=adiuva
|
||||
QDRANT_URL=
|
||||
QDRANT_API_KEY=
|
||||
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
||||
|
||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||
# Comma-separated list parsed by Settings (override default if needed)
|
||||
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||
|
||||
# ── Langfuse (observability) ─────────────────────────────────────────────────
|
||||
LANGFUSE_SECRET_KEY=sk-lf-...
|
||||
LANGFUSE_PUBLIC_KEY=pk-lf-...
|
||||
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL
|
||||
|
||||
# ── Cloudflare (Traefik ACME DNS-01 challenge) ───────────────────────────────
|
||||
CF_DNS_API_TOKEN=
|
||||
ACME_EMAIL=
|
||||
|
||||
# ── PostgreSQL (used by docker-compose) ──────────────────────────────────────
|
||||
POSTGRES_USER=postgres
|
||||
POSTGRES_PASSWORD=postgres
|
||||
POSTGRES_DB=adiuva
|
||||
@@ -48,23 +48,23 @@ jobs:
|
||||
key: ${{ secrets.SSH_KEY }}
|
||||
script: |
|
||||
set -e
|
||||
DEPLOY_DIR="/opt/adiuvai-api"
|
||||
DEPLOY_DIR="/opt/adiuva-api"
|
||||
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||
TAG="${{ gitea.ref_name }}"
|
||||
|
||||
# ── Pull latest code ──
|
||||
cd /tmp && rm -rf adiuvai-api-deploy
|
||||
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
|
||||
cd /tmp && rm -rf adiuva-api-deploy
|
||||
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
|
||||
|
||||
# ── Sync source (preserve .env) ──
|
||||
cp -rf /tmp/adiuvai-api-deploy/app/ \
|
||||
/tmp/adiuvai-api-deploy/alembic/ \
|
||||
/tmp/adiuvai-api-deploy/alembic.ini \
|
||||
/tmp/adiuvai-api-deploy/Dockerfile \
|
||||
/tmp/adiuvai-api-deploy/docker-compose.yml \
|
||||
/tmp/adiuvai-api-deploy/requirements.txt \
|
||||
cp -rf /tmp/adiuva-api-deploy/app/ \
|
||||
/tmp/adiuva-api-deploy/alembic/ \
|
||||
/tmp/adiuva-api-deploy/alembic.ini \
|
||||
/tmp/adiuva-api-deploy/Dockerfile \
|
||||
/tmp/adiuva-api-deploy/docker-compose.yml \
|
||||
/tmp/adiuva-api-deploy/requirements.txt \
|
||||
"$DEPLOY_DIR/"
|
||||
rm -rf /tmp/adiuvai-api-deploy
|
||||
rm -rf /tmp/adiuva-api-deploy
|
||||
|
||||
# ── Verify .env ──
|
||||
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||
|
||||
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
||||
- uses: actions/checkout@v4
|
||||
|
||||
- name: Build image
|
||||
run: docker build -t adiuvai-api:ci .
|
||||
run: docker build -t adiuva-api:ci .
|
||||
|
||||
- name: Verify gunicorn installed
|
||||
run: docker run --rm adiuvai-api:ci gunicorn --version
|
||||
run: docker run --rm adiuva-api:ci gunicorn --version
|
||||
|
||||
10
.gitignore
vendored
10
.gitignore
vendored
@@ -13,6 +13,9 @@ env/
|
||||
# Environment variables
|
||||
.env
|
||||
|
||||
# Cryptographic keys
|
||||
*.pem
|
||||
|
||||
# IDE
|
||||
.vscode/
|
||||
.idea/
|
||||
@@ -21,18 +24,17 @@ env/
|
||||
.pytest_cache/
|
||||
htmlcov/
|
||||
.coverage
|
||||
tests/fixtures/private*/
|
||||
|
||||
# Docker
|
||||
*.log
|
||||
|
||||
# OS
|
||||
.DS_Store
|
||||
|
||||
# Smoke scripts (dev-only, not for CI)
|
||||
scripts/smoke_*.py
|
||||
Thumbs.db
|
||||
|
||||
# Claude Code
|
||||
.claude/
|
||||
logs/
|
||||
|
||||
# Eval private test data
|
||||
services/batch-agent/eval/fixtures/private_data/
|
||||
|
||||
796
README.md
796
README.md
@@ -1,5 +1,793 @@
|
||||
## DEV
|
||||
Run in DEV with command:
|
||||
# Adiuva Cloud API
|
||||
|
||||
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
|
||||
|
||||
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
|
||||
|
||||
---
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Overview](#overview)
|
||||
- [Architecture](#architecture)
|
||||
- [Key Features](#key-features)
|
||||
- [Tech Stack](#tech-stack)
|
||||
- [Getting Started](#getting-started)
|
||||
- [Docker Deployment](#docker-deployment)
|
||||
- [Environment Variables](#environment-variables)
|
||||
- [API Reference](#api-reference)
|
||||
- [Data Model](#data-model)
|
||||
- [AI Agent System](#ai-agent-system)
|
||||
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
||||
- [Middleware](#middleware)
|
||||
- [Storage Layer](#storage-layer)
|
||||
- [Billing & Tiers](#billing--tiers)
|
||||
- [Plugin Marketplace](#plugin-marketplace)
|
||||
- [Testing](#testing)
|
||||
- [Project Structure](#project-structure)
|
||||
- [License](#license)
|
||||
|
||||
---
|
||||
|
||||
## Overview
|
||||
|
||||
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
|
||||
|
||||
### Design Principles
|
||||
|
||||
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
|
||||
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
||||
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
|
||||
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
||||
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
||||
|
||||
---
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-config logging.conf
|
||||
```
|
||||
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
|
||||
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
|
||||
│ Desktop App │────▶│ │
|
||||
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
|
||||
└──────────────┘ │ │
|
||||
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
||||
│ │ Auth Routes │ │ Chat Routes │ │
|
||||
│ │ Billing Routes │ │ ↓ │ │
|
||||
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
||||
│ │ Backup Routes │ │ ↓ classify intent │ │
|
||||
│ │ Plugin Routes │ │ Agent Registry │ │
|
||||
│ │ Vector Routes │ │ ↓ │ │
|
||||
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
|
||||
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
|
||||
│ │ (GPT-4o + LangChain) │ │
|
||||
│ └────────────────────────────┘ │
|
||||
└────────────────────────────────────────────────────────┘
|
||||
│ │ │
|
||||
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
|
||||
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
|
||||
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
|
||||
│ Billing, │ │ backups) │ │ (Vectors) │
|
||||
│ Metadata) │ └───────────────┘ └────────────────┘
|
||||
└────────────┘
|
||||
│
|
||||
┌────────▼───┐
|
||||
│ Stripe │
|
||||
│ (Billing, │
|
||||
│ Connect) │
|
||||
└────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Key Features
|
||||
|
||||
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
||||
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
||||
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
||||
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
||||
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
||||
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
|
||||
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
|
||||
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
||||
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
||||
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
||||
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
||||
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
|
||||
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
||||
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
|
||||
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
|
||||
|
||||
---
|
||||
|
||||
## Tech Stack
|
||||
|
||||
| Package | Version | Purpose |
|
||||
|---|---|---|
|
||||
| `fastapi` | ≥ 0.115.0 | Web framework |
|
||||
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
|
||||
| `gunicorn` | ≥ 22.0.0 | Production process manager |
|
||||
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
|
||||
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
|
||||
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
|
||||
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
|
||||
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
||||
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
||||
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
||||
| `boto3` | ≥ 1.35.0 | AWS S3 client |
|
||||
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
||||
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
||||
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
||||
| `alembic` | ≥ 1.14.0 | Database migration management |
|
||||
| `bcrypt` | ≥ 4.2.0 | Password hashing |
|
||||
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
|
||||
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
||||
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
||||
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
||||
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
|
||||
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
|
||||
| `pytest` | ≥ 8.0.0 | Test framework |
|
||||
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
||||
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
||||
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
|
||||
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
||||
|
||||
---
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.12+
|
||||
- PostgreSQL 16+
|
||||
- An OpenAI API key (for LLM features)
|
||||
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
||||
- AWS credentials (optional — needed for S3 storage in production)
|
||||
|
||||
### Installation
|
||||
|
||||
```bash
|
||||
# Clone the repository
|
||||
git clone <repo-url> && cd adiuva-api
|
||||
|
||||
# Create a virtual environment
|
||||
python -m venv .venv && source .venv/bin/activate
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Configure environment
|
||||
cp .env.example .env
|
||||
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
|
||||
```
|
||||
|
||||
### Database Setup
|
||||
|
||||
```bash
|
||||
# Start PostgreSQL (or use the Docker Compose database)
|
||||
docker compose up db -d
|
||||
|
||||
# Run migrations
|
||||
alembic upgrade head
|
||||
```
|
||||
|
||||
### Run the Development Server
|
||||
|
||||
```bash
|
||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||
```
|
||||
|
||||
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
|
||||
|
||||
---
|
||||
|
||||
## Docker Deployment
|
||||
|
||||
### Quick Start
|
||||
|
||||
```bash
|
||||
docker compose up --build
|
||||
```
|
||||
|
||||
This starts two services:
|
||||
|
||||
- **app** — FastAPI server on port `8000`
|
||||
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
||||
|
||||
The compose file also includes optional services for fully local deployments:
|
||||
|
||||
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
|
||||
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
|
||||
|
||||
### Dockerfile Details
|
||||
|
||||
The Dockerfile uses a multi-stage build:
|
||||
|
||||
1. **Builder stage** — Installs Python dependencies into a virtual environment.
|
||||
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
|
||||
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
|
||||
|
||||
```bash
|
||||
# Production command (run by the container)
|
||||
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## Homelab / Self-Hosted Deployment
|
||||
|
||||
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
|
||||
|
||||
### 1. Start all services
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
|
||||
|
||||
### 2. Create the MinIO bucket
|
||||
|
||||
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
|
||||
|
||||
```bash
|
||||
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
|
||||
docker compose exec minio mc mb local/adiuva
|
||||
```
|
||||
|
||||
### 3. Configure your `.env`
|
||||
|
||||
```bash
|
||||
# Database (uses the compose PostgreSQL)
|
||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||
|
||||
# S3 → MinIO
|
||||
S3_BUCKET=adiuva
|
||||
S3_REGION=us-east-1
|
||||
S3_ENDPOINT_URL=http://minio:9000
|
||||
AWS_ACCESS_KEY_ID=minioadmin
|
||||
AWS_SECRET_ACCESS_KEY=minioadmin
|
||||
|
||||
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
|
||||
QDRANT_URL=http://qdrant:6333
|
||||
QDRANT_API_KEY=
|
||||
PINECONE_API_KEY=
|
||||
|
||||
# Billing — leave empty to stub (no Stripe needed)
|
||||
STRIPE_SECRET_KEY=
|
||||
STRIPE_WEBHOOK_SECRET=
|
||||
|
||||
# LLM — the only external service
|
||||
OPENAI_API_KEY=sk-...
|
||||
LLM_MODEL=gpt-4o
|
||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||
|
||||
# Auth
|
||||
JWT_SECRET=your-secret-here
|
||||
ENV=dev
|
||||
```
|
||||
|
||||
### 4. Run migrations
|
||||
|
||||
```bash
|
||||
docker compose exec app alembic upgrade head
|
||||
```
|
||||
|
||||
### What runs where
|
||||
|
||||
| Service | Runs on | Port | Notes |
|
||||
|---|---|---|---|
|
||||
| FastAPI app | Docker | 8000 | API server |
|
||||
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
|
||||
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
|
||||
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
|
||||
| Stripe | — | — | Stubbed when keys are empty |
|
||||
| OpenAI / LLM | Cloud | — | Only external dependency |
|
||||
|
||||
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
|
||||
|
||||
---
|
||||
|
||||
## Environment Variables
|
||||
|
||||
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
|
||||
|
||||
| Variable | Type | Default | Description |
|
||||
|---|---|---|---|
|
||||
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
|
||||
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
|
||||
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
|
||||
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
||||
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
||||
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
||||
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
|
||||
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
|
||||
| `S3_REGION` | `str` | `us-east-1` | AWS region |
|
||||
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
|
||||
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
|
||||
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
|
||||
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
|
||||
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
|
||||
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
|
||||
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
|
||||
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
|
||||
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
||||
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
||||
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
||||
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
|
||||
|
||||
---
|
||||
|
||||
## API Reference
|
||||
|
||||
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
|
||||
|
||||
### Health
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|---|---|---|---|
|
||||
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
|
||||
|
||||
### Auth
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|---|---|---|---|
|
||||
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
|
||||
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
|
||||
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
|
||||
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
|
||||
|
||||
### Chat
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|---|---|---|---|
|
||||
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
||||
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
||||
|
||||
### Plans
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|---|---|---|---|
|
||||
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
||||
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
||||
|
||||
### Storage (Cloud Records)
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|---|---|---|---|
|
||||
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
|
||||
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
|
||||
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
|
||||
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
|
||||
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
|
||||
|
||||
### Vectors (Cloud Vector Store)
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|---|---|---|---|
|
||||
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
|
||||
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
|
||||
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
|
||||
|
||||
### Backup
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|---|---|---|---|
|
||||
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
|
||||
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
|
||||
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
|
||||
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
|
||||
|
||||
### Plugins (Marketplace)
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|---|---|---|---|
|
||||
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
|
||||
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
|
||||
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
|
||||
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
|
||||
|
||||
### Billing
|
||||
|
||||
| Method | Path | Auth | Description |
|
||||
|---|---|---|---|
|
||||
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
|
||||
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
|
||||
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
|
||||
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
|
||||
|
||||
---
|
||||
|
||||
## Data Model
|
||||
|
||||
9 tables managed by Alembic migrations. Source: `app/models.py`
|
||||
|
||||
### Tables
|
||||
|
||||
| Table | Primary Key | Key Columns | Purpose |
|
||||
|---|---|---|---|
|
||||
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
||||
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
||||
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
||||
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
|
||||
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
|
||||
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
|
||||
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
|
||||
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
|
||||
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
|
||||
|
||||
### Enum Types
|
||||
|
||||
| Enum | Values |
|
||||
|---|---|
|
||||
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
||||
| `plugin_status` | `pending_review`, `approved`, `rejected` |
|
||||
| `review_decision` | `approved`, `rejected` |
|
||||
|
||||
### Migrations
|
||||
|
||||
| Version | Description |
|
||||
|---|---|
|
||||
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
|
||||
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
|
||||
|
||||
---
|
||||
|
||||
## AI Agent System
|
||||
|
||||
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
|
||||
|
||||
### Architecture
|
||||
|
||||
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
|
||||
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
||||
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
||||
|
||||
### Registered Agents
|
||||
|
||||
| Agent | Registry Name | Tools | Description |
|
||||
|---|---|---|---|
|
||||
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
||||
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
||||
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
|
||||
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
||||
|
||||
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
||||
|
||||
### Switching LLM Providers
|
||||
|
||||
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
|
||||
|
||||
```bash
|
||||
# OpenAI (default)
|
||||
LLM_MODEL=gpt-4o
|
||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||
|
||||
# Anthropic
|
||||
LLM_MODEL=anthropic/claude-3.5-sonnet
|
||||
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
|
||||
|
||||
# Google Gemini
|
||||
LLM_MODEL=gemini/gemini-pro
|
||||
LLM_ROUTER_MODEL=gemini/gemini-flash
|
||||
|
||||
# Local Ollama
|
||||
LLM_MODEL=ollama/llama3
|
||||
LLM_ROUTER_MODEL=ollama/llama3
|
||||
|
||||
# AWS Bedrock
|
||||
LLM_MODEL=bedrock/anthropic.claude-v2
|
||||
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
|
||||
```
|
||||
|
||||
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
|
||||
|
||||
---
|
||||
|
||||
## Orchestration & Execution Plans
|
||||
|
||||
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
|
||||
|
||||
### Orchestrator
|
||||
|
||||
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
|
||||
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
|
||||
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
|
||||
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
|
||||
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
|
||||
|
||||
### Execution Plans
|
||||
|
||||
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
|
||||
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
|
||||
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
|
||||
|
||||
### Built-in Templates (6)
|
||||
|
||||
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
||||
|
||||
### Built-in Playbooks (2)
|
||||
|
||||
| Playbook | Description |
|
||||
|---|---|
|
||||
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
|
||||
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
|
||||
|
||||
---
|
||||
|
||||
## Middleware
|
||||
|
||||
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
|
||||
|
||||
### JWT Authentication
|
||||
|
||||
Source: `app/api/middleware/auth.py`
|
||||
|
||||
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
|
||||
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
|
||||
- Falls back to `free` when no subscription row exists.
|
||||
- Raises `401 Unauthorized` on invalid or expired tokens.
|
||||
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||
|
||||
### Tier-Based Rate Limiter
|
||||
|
||||
Source: `app/api/middleware/rate_limit.py`
|
||||
|
||||
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
|
||||
- Per-user 60-second window sized by subscription tier:
|
||||
|
||||
| Tier | Requests / Minute |
|
||||
|---|---|
|
||||
| Free | 20 |
|
||||
| Pro | 60 |
|
||||
| Power | 120 |
|
||||
| Team | 200 |
|
||||
|
||||
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
|
||||
- **Exempt paths:** register, login, webhook, health
|
||||
|
||||
### Response Sanitizer
|
||||
|
||||
Source: `app/api/middleware/sanitizer.py`
|
||||
|
||||
- Runs only on `/api/v1/chat` endpoints.
|
||||
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
||||
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
||||
- Logs sanitization events as `WARNING`.
|
||||
- Binary responses (storage, backup) are never touched.
|
||||
|
||||
---
|
||||
|
||||
## Storage Layer
|
||||
|
||||
### Blob Store
|
||||
|
||||
Source: `app/storage/blob_store.py`
|
||||
|
||||
- S3-backed storage for E2E encrypted blobs.
|
||||
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
|
||||
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
|
||||
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
|
||||
- The backend **never inspects or decrypts blob content**.
|
||||
|
||||
### Vector Store
|
||||
|
||||
Source: `app/storage/vector_store.py`
|
||||
|
||||
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
|
||||
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
|
||||
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
|
||||
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
|
||||
- Methods: `upsert()`, `search()`, `delete()`
|
||||
|
||||
### Encryption Utilities
|
||||
|
||||
Source: `app/storage/encryption.py`
|
||||
|
||||
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
|
||||
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
|
||||
- **No decryption key ever reaches the backend.**
|
||||
|
||||
---
|
||||
|
||||
## Billing & Tiers
|
||||
|
||||
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
||||
|
||||
### Feature Matrix
|
||||
|
||||
| Feature | Free | Pro | Power | Team |
|
||||
|---|---|---|---|---|
|
||||
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
||||
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
||||
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
||||
| Batch Builder | — | — | ✓ | ✓ |
|
||||
| Plugin Marketplace | — | — | ✓ | ✓ |
|
||||
| SSO | — | — | — | ✓ |
|
||||
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
||||
|
||||
### Stripe Integration
|
||||
|
||||
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
|
||||
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
|
||||
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
|
||||
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
|
||||
|
||||
### Tier Manager
|
||||
|
||||
- `get_tier(user_id)` — Returns the user's current billing tier.
|
||||
- `check_feature(tier, feature)` — Boolean feature gate check.
|
||||
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
||||
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
|
||||
|
||||
---
|
||||
|
||||
## Plugin Marketplace
|
||||
|
||||
Source: `app/marketplace/`
|
||||
|
||||
### Plugin Registry
|
||||
|
||||
- PostgreSQL-backed catalog of submitted and approved plugins.
|
||||
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
|
||||
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
|
||||
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
|
||||
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
|
||||
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
|
||||
|
||||
### Review Queue
|
||||
|
||||
- Automated security checklist before human review:
|
||||
- Plugin ID must match `^[a-z0-9-]+$`
|
||||
- Permissions must be from the allowed set only
|
||||
- No binary blobs in the manifest
|
||||
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
|
||||
- `get_pending(db)` — Lists plugins awaiting review.
|
||||
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
||||
|
||||
### Revenue Sharing
|
||||
|
||||
- **70% developer / 30% platform** split on all paid plugin sales.
|
||||
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
|
||||
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
|
||||
- Gracefully stubs transfers when Stripe is not configured.
|
||||
|
||||
### Seed Plugins
|
||||
|
||||
| Plugin | Category | Price |
|
||||
|---|---|---|
|
||||
| GitHub Sync | Productivity | Free |
|
||||
| Slack Notifier | Communication | €4.99 |
|
||||
| Time Tracker | Productivity | €9.99 |
|
||||
|
||||
---
|
||||
|
||||
## Testing
|
||||
|
||||
### Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
pytest
|
||||
|
||||
# Run a specific test file
|
||||
pytest tests/test_auth.py
|
||||
|
||||
# Run with verbose output
|
||||
pytest -v
|
||||
```
|
||||
|
||||
### Test Infrastructure
|
||||
|
||||
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
||||
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
|
||||
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
||||
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
||||
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
|
||||
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
||||
- **No external dependencies** — all tests run fully offline.
|
||||
|
||||
### Test Coverage
|
||||
|
||||
| File | Coverage |
|
||||
|---|---|
|
||||
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
||||
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
|
||||
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
|
||||
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
|
||||
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
|
||||
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
|
||||
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
|
||||
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
|
||||
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
||||
|
||||
---
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
adiuva-api/
|
||||
├── alembic.ini # Alembic configuration
|
||||
├── BACKEND_PLAN.md # Architecture & design decisions
|
||||
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
||||
├── Dockerfile # Multi-stage production build
|
||||
├── requirements.txt # Python dependencies
|
||||
│
|
||||
├── alembic/ # Database migrations
|
||||
│ ├── env.py # Alembic environment config
|
||||
│ ├── script.py.mako # Migration template
|
||||
│ └── versions/
|
||||
│ ├── 001_initial_schema.py # Tables, indexes, FKs
|
||||
│ └── 002_seed_plugins.py # Seed marketplace plugins
|
||||
│
|
||||
├── app/ # Application source
|
||||
│ ├── main.py # FastAPI app factory, middleware, routes
|
||||
│ ├── db.py # Async SQLAlchemy engine & session
|
||||
│ ├── models.py # SQLAlchemy ORM models (9 tables)
|
||||
│ ├── schemas.py # Pydantic request/response schemas
|
||||
│ │
|
||||
│ ├── config/
|
||||
│ │ └── settings.py # Pydantic Settings (env vars)
|
||||
│ │
|
||||
│ ├── agents/ # LLM-powered domain agents
|
||||
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
||||
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
||||
│ │ ├── timeline_agent.py # Milestones (4 tools)
|
||||
│ │ └── note_agent.py # Markdown notes (5 tools)
|
||||
│ │
|
||||
│ ├── core/ # Orchestration engine
|
||||
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
||||
│ │ ├── llm.py # LiteLLM factory (get_llm)
|
||||
│ │ ├── orchestrator.py # Intent classification & routing
|
||||
│ │ └── execution_plan.py # Plan builder, templates, cache
|
||||
│ │
|
||||
│ ├── api/ # HTTP layer
|
||||
│ │ ├── deps.py # Shared FastAPI dependencies
|
||||
│ │ ├── middleware/
|
||||
│ │ │ ├── auth.py # JWT validation, live tier lookup
|
||||
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
||||
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
||||
│ │ └── routes/
|
||||
│ │ ├── auth.py # Register, login, refresh, me
|
||||
│ │ ├── chat.py # Chat + WebSocket streaming
|
||||
│ │ ├── plans.py # Execution plan playbooks
|
||||
│ │ ├── storage.py # E2E encrypted record CRUD
|
||||
│ │ ├── vectors.py # Vector upsert, search, delete
|
||||
│ │ ├── backup.py # Encrypted backup management
|
||||
│ │ ├── plugins.py # Marketplace browse & install
|
||||
│ │ └── billing.py # Stripe checkout & webhooks
|
||||
│ │
|
||||
│ ├── storage/ # Storage backends
|
||||
│ │ ├── blob_store.py # S3 blob storage
|
||||
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
|
||||
│ │ └── encryption.py # Checksum verification utilities
|
||||
│ │
|
||||
│ ├── billing/ # Subscription management
|
||||
│ │ ├── stripe_service.py # Stripe API integration
|
||||
│ │ └── tier_manager.py # Feature matrix & quota enforcement
|
||||
│ │
|
||||
│ └── marketplace/ # Plugin ecosystem
|
||||
│ ├── plugin_registry.py # Catalog CRUD & search
|
||||
│ ├── plugin_review.py # Security checklist & review queue
|
||||
│ └── revenue_share.py # 70/30 split & Stripe Connect
|
||||
│
|
||||
└── tests/ # Test suite
|
||||
├── conftest.py # Fixtures: DB, S3, auth, seeds
|
||||
├── test_auth.py
|
||||
├── test_orchestrator.py
|
||||
├── test_agents.py
|
||||
├── test_storage.py
|
||||
├── test_backup.py
|
||||
├── test_plugins.py
|
||||
├── test_agent_registry.py
|
||||
├── test_execution_plan.py
|
||||
└── test_middleware.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
*To be determined.*
|
||||
|
||||
@@ -16,7 +16,7 @@ import re
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
|
||||
# Alembic Config object (gives access to alembic.ini values).
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
"""Initial schema: users, refresh_tokens, subscriptions.
|
||||
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||
|
||||
Revision ID: 001
|
||||
Revises:
|
||||
@@ -27,6 +28,18 @@ def upgrade() -> None:
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
|
||||
# ── users ─────────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
@@ -75,10 +88,122 @@ def upgrade() -> None:
|
||||
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||
|
||||
# ── storage_records ───────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"storage_records",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("table_name", sa.String(100), nullable=False),
|
||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||
sa.Column("checksum", sa.String(64), nullable=False),
|
||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||
|
||||
# ── backup_metadata ───────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"backup_metadata",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||
sa.Column("version", sa.Integer, nullable=False),
|
||||
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||
sa.Column("checksum", sa.String(64), nullable=False),
|
||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||
|
||||
# ── plugins ───────────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"plugins",
|
||||
sa.Column("id", sa.String(255), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
||||
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||
)
|
||||
|
||||
# ── plugin_installations ──────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"plugin_installations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||
)
|
||||
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||
|
||||
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"plugin_reviews",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
||||
sa.Column("notes", sa.Text, nullable=True),
|
||||
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||
)
|
||||
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||
|
||||
# ── revenue_events ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"revenue_events",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("revenue_events")
|
||||
op.drop_table("plugin_reviews")
|
||||
op.drop_table("plugin_installations")
|
||||
op.drop_table("plugins")
|
||||
op.drop_table("backup_metadata")
|
||||
op.drop_table("storage_records")
|
||||
op.drop_table("subscriptions")
|
||||
op.drop_table("refresh_tokens")
|
||||
op.drop_table("users")
|
||||
|
||||
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||
|
||||
92
alembic/versions/002_seed_plugins.py
Normal file
92
alembic/versions/002_seed_plugins.py
Normal file
@@ -0,0 +1,92 @@
|
||||
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
|
||||
|
||||
Revision ID: 002
|
||||
Revises: 001
|
||||
Create Date: 2026-03-03
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from datetime import datetime, timezone
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision: str = "002"
|
||||
down_revision: Union[str, None] = "001"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
_SEED_PLUGINS = [
|
||||
{
|
||||
"id": "plugin-github-sync",
|
||||
"name": "GitHub Sync",
|
||||
"description": "Sync tasks with GitHub Issues and pull requests.",
|
||||
"version": "1.0.0",
|
||||
"author_name": "Adiuva",
|
||||
"category": "productivity",
|
||||
"price_cents": 0,
|
||||
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
},
|
||||
{
|
||||
"id": "plugin-slack-notify",
|
||||
"name": "Slack Notifier",
|
||||
"description": "Post task and timeline updates to Slack channels.",
|
||||
"version": "1.2.0",
|
||||
"author_name": "Adiuva",
|
||||
"category": "communication",
|
||||
"price_cents": 499,
|
||||
"permissions": json.dumps(["read:tasks", "read:timelines"]),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
},
|
||||
{
|
||||
"id": "plugin-time-tracker",
|
||||
"name": "Time Tracker",
|
||||
"description": "Track time spent on tasks with automatic reporting.",
|
||||
"version": "0.9.1",
|
||||
"author_name": "Third Party",
|
||||
"category": "productivity",
|
||||
"price_cents": 999,
|
||||
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||
"status": "approved",
|
||||
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||
"install_count": 0,
|
||||
"avg_rating": 0.0,
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
plugins = sa.table(
|
||||
"plugins",
|
||||
sa.column("id", sa.String),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("description", sa.Text),
|
||||
sa.column("version", sa.String),
|
||||
sa.column("author_name", sa.String),
|
||||
sa.column("category", sa.String),
|
||||
sa.column("price_cents", sa.Integer),
|
||||
sa.column("permissions", sa.Text),
|
||||
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
|
||||
sa.column("s3_package_key", sa.String),
|
||||
sa.column("install_count", sa.Integer),
|
||||
sa.column("avg_rating", sa.Float),
|
||||
)
|
||||
op.bulk_insert(plugins, _SEED_PLUGINS)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
"DELETE FROM plugins WHERE id IN ("
|
||||
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
|
||||
")"
|
||||
)
|
||||
@@ -14,7 +14,7 @@ from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "003"
|
||||
down_revision: Union[str, None] = "001"
|
||||
down_revision: Union[str, None] = "002"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Phase 1 — confirm pgvector activation on memory_associative.
|
||||
|
||||
Migration 004 created the embedding column as vector(1536) and added the
|
||||
IVFFlat index. This migration is the Phase-1 checkpoint:
|
||||
1. Ensures the pgvector extension is enabled (idempotent).
|
||||
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
|
||||
memory_associative_embedding_idx (creates it only if absent).
|
||||
|
||||
Revision ID: 005
|
||||
Revises: 9a1f2d0b6c7e
|
||||
Create Date: 2026-04-15
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "005"
|
||||
down_revision: Union[str, None] = "e04100e88ace"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Ensure pgvector extension is enabled (also done in 004, idempotent).
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||
|
||||
# Ensure the canonical Phase-1 IVFFlat index exists.
|
||||
# 004 may have created ix_memory_associative_embedding; this adds the
|
||||
# Phase-1 name memory_associative_embedding_idx if it is missing.
|
||||
op.execute(
|
||||
"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF NOT EXISTS (
|
||||
SELECT 1
|
||||
FROM pg_indexes
|
||||
WHERE tablename = 'memory_associative'
|
||||
AND indexname = 'memory_associative_embedding_idx'
|
||||
) THEN
|
||||
CREATE INDEX memory_associative_embedding_idx
|
||||
ON memory_associative
|
||||
USING ivfflat (embedding vector_cosine_ops)
|
||||
WITH (lists = 100);
|
||||
END IF;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")
|
||||
@@ -1,74 +0,0 @@
|
||||
"""Add memory_relations table (Phase 3 — relational tier).
|
||||
|
||||
Revision ID: 006
|
||||
Revises: 1f5975a4f3f4
|
||||
Create Date: 2026-04-16
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "006"
|
||||
down_revision: Union[str, None] = "1f5975a4f3f4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"memory_relations",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("subject_label", sa.String(128), nullable=False),
|
||||
sa.Column("subject_type", sa.String(32), nullable=False),
|
||||
sa.Column("predicate", sa.String(64), nullable=False),
|
||||
sa.Column("object_label", sa.String(128), nullable=False),
|
||||
sa.Column("object_type", sa.String(32), nullable=False),
|
||||
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
|
||||
sa.Column(
|
||||
"source_episode_id",
|
||||
postgresql.UUID(as_uuid=False),
|
||||
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column(
|
||||
"updated_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
op.create_index(
|
||||
"memory_relations_user_subject_idx",
|
||||
"memory_relations",
|
||||
["user_id", "subject_label"],
|
||||
)
|
||||
op.create_index(
|
||||
"memory_relations_user_predicate_idx",
|
||||
"memory_relations",
|
||||
["user_id", "predicate"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
|
||||
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
|
||||
op.drop_table("memory_relations")
|
||||
@@ -1,38 +0,0 @@
|
||||
"""add extraction_queue
|
||||
|
||||
Revision ID: 1f5975a4f3f4
|
||||
Revises: 005
|
||||
Create Date: 2026-04-16 17:26:25.790870
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = '1f5975a4f3f4'
|
||||
down_revision: Union[str, None] = '005'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
'extraction_queue',
|
||||
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
|
||||
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
|
||||
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
)
|
||||
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
|
||||
op.drop_table('extraction_queue')
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Restore agent config tables and add agent_config column.
|
||||
|
||||
9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both
|
||||
ORM models are still active. This migration recreates them with agent_config
|
||||
added to local_agent_configs.
|
||||
|
||||
Revision ID: a3b9c0d1e2f3
|
||||
Revises: 9a1f2d0b6c7e
|
||||
Create Date: 2026-04-07 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "a3b9c0d1e2f3"
|
||||
down_revision: Union[str, None] = "9a1f2d0b6c7e"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Recreate enum types (idempotent — they may already exist from migration 003)
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
op.execute("""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
""")
|
||||
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
existing = set(inspector.get_table_names())
|
||||
|
||||
# ── local_agent_configs (with agent_config column) ────────────────────
|
||||
if "local_agent_configs" not in existing:
|
||||
op.create_table(
|
||||
"local_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("device_id", sa.String(255), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("agent_config", sa.JSON, nullable=True),
|
||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||
|
||||
# ── cloud_agent_configs ───────────────────────────────────────────────
|
||||
if "cloud_agent_configs" not in existing:
|
||||
op.create_table(
|
||||
"cloud_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column(
|
||||
"provider",
|
||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||
op.drop_table("cloud_agent_configs")
|
||||
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||
op.drop_table("local_agent_configs")
|
||||
@@ -1,56 +0,0 @@
|
||||
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
|
||||
|
||||
Revision ID: b4c0d1e2f3a4
|
||||
Revises: a3b9c0d1e2f3
|
||||
Create Date: 2026-04-10 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "b4c0d1e2f3a4"
|
||||
down_revision: Union[str, None] = "a3b9c0d1e2f3"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ── users: make password_hash nullable (social users have no password) ──
|
||||
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
|
||||
|
||||
# ── users: add avatar_url ─────────────────────────────────────────────
|
||||
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
|
||||
|
||||
# ── oauth_accounts ────────────────────────────────────────────────────
|
||||
op.create_table(
|
||||
"oauth_accounts",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("provider", sa.String(50), nullable=False),
|
||||
sa.Column("provider_user_id", sa.String(255), nullable=False),
|
||||
sa.Column("provider_email", sa.String(255), nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
server_default=sa.text("now()"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||
)
|
||||
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
|
||||
op.drop_table("oauth_accounts")
|
||||
op.drop_column("users", "avatar_url")
|
||||
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Add onboarding_completed_at column to users table.
|
||||
|
||||
Revision ID: c5d1e2f3a4b5
|
||||
Revises: b4c0d1e2f3a4
|
||||
Create Date: 2026-04-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c5d1e2f3a4b5"
|
||||
down_revision: Union[str, None] = "b4c0d1e2f3a4"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"users",
|
||||
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("users", "onboarding_completed_at")
|
||||
@@ -1,46 +0,0 @@
|
||||
"""Add token tracking columns for folder integration.
|
||||
|
||||
Revision ID: d6e3f4a5b6c7
|
||||
Revises: 006
|
||||
Create Date: 2026-05-11 00:00:00.000000
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "d6e3f4a5b6c7"
|
||||
down_revision: Union[str, None] = "006"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"agent_run_logs",
|
||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
op.create_table(
|
||||
"monthly_token_usage",
|
||||
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("year_month", sa.String(7), nullable=False),
|
||||
sa.Column("feature", sa.String(64), nullable=False),
|
||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_monthly_token_usage_user_month",
|
||||
"monthly_token_usage",
|
||||
["user_id", "year_month"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
|
||||
op.drop_table("monthly_token_usage")
|
||||
op.drop_column("agent_run_logs", "tokens_used")
|
||||
@@ -1,34 +0,0 @@
|
||||
"""avatar_url_varchar_to_text
|
||||
|
||||
Revision ID: e04100e88ace
|
||||
Revises: c5d1e2f3a4b5
|
||||
Create Date: 2026-04-13 09:13:06.733674
|
||||
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = 'e04100e88ace'
|
||||
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column('users', 'avatar_url',
|
||||
existing_type=sa.VARCHAR(length=2048),
|
||||
type_=sa.Text(),
|
||||
existing_nullable=True)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column('users', 'avatar_url',
|
||||
existing_type=sa.Text(),
|
||||
type_=sa.VARCHAR(length=2048),
|
||||
existing_nullable=True)
|
||||
@@ -1,5 +0,0 @@
|
||||
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
||||
|
||||
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
||||
|
||||
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Client agent — read-only tools for the clients table."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
|
||||
@tool
|
||||
async def list_clients(search: str = "", limit: int = 20) -> str:
|
||||
"""List clients, optionally filtered by a name/email substring search.
|
||||
|
||||
search: optional substring to match against client name or email.
|
||||
limit: max rows to return (default 20).
|
||||
"""
|
||||
filters: dict[str, Any] = {"limit": limit}
|
||||
if search:
|
||||
filters["search"] = search
|
||||
|
||||
result = await execute_on_client(action="select", table="clients", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No clients found."
|
||||
lines = [
|
||||
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
|
||||
f"company: {r.get('company', '')})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_client(id: str) -> str:
|
||||
"""Get full details for one client by UUID.
|
||||
|
||||
id: the client's UUID.
|
||||
"""
|
||||
if not id:
|
||||
return "Client id is required."
|
||||
|
||||
result = await execute_on_client(action="get", table="clients", data={"id": id})
|
||||
row = result.get("row") or result.get("rows", [None])[0] if result else None
|
||||
if not row:
|
||||
return f"Client '{id}' not found."
|
||||
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
|
||||
|
||||
|
||||
CLIENT_TOOLS: list[Any] = [list_clients, get_client]
|
||||
@@ -1,194 +0,0 @@
|
||||
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||
|
||||
These tools delegate to the Electron client via ``execute_on_client()`` using
|
||||
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
||||
handles actual disk I/O and responds with ``tool_result`` frames.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
# Max characters returned by read_file_content in journey (exploration) tools.
|
||||
# The journey only needs to understand file structure, not full content.
|
||||
_JOURNEY_READ_MAX_CHARS: int = 4000
|
||||
|
||||
|
||||
def _resolve_path(path: str, base: str) -> str:
|
||||
"""Resolve *path* against *base* when *path* is relative.
|
||||
|
||||
The LLM often passes ``"."`` meaning "the configured directory".
|
||||
Without this, Electron resolves ``"."`` relative to its own CWD instead
|
||||
of the user's chosen directory.
|
||||
"""
|
||||
if os.path.isabs(path):
|
||||
return path
|
||||
return str(Path(base) / path)
|
||||
|
||||
|
||||
@tool
|
||||
async def list_directory(path: str) -> str:
|
||||
"""List files and folders in a local directory on the user's device.
|
||||
|
||||
Returns a formatted listing of entries with name, type (file/directory),
|
||||
and full path.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="list_directory",
|
||||
data={"path": path},
|
||||
)
|
||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||
if not entries:
|
||||
return f"Directory '{path}' is empty or does not exist."
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
entry_type = entry.get("type", "unknown")
|
||||
entry_name = entry.get("name", "")
|
||||
entry_path = entry.get("path", "")
|
||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def read_file_content(path: str) -> str:
|
||||
"""Read the text content of a local file on the user's device.
|
||||
|
||||
Returns the file content as a string. Large files may be truncated
|
||||
by the Electron client.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="read_file_content",
|
||||
data={"path": path},
|
||||
)
|
||||
content: str = result.get("content", "")
|
||||
if not content:
|
||||
return f"File '{path}' is empty or could not be read."
|
||||
return content
|
||||
|
||||
|
||||
@tool
|
||||
async def get_file_metadata(path: str) -> str:
|
||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||
|
||||
Returns a formatted summary of the file's metadata.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="get_file_metadata",
|
||||
data={"path": path},
|
||||
)
|
||||
size = result.get("size", "unknown")
|
||||
created = result.get("createdAt", "unknown")
|
||||
modified = result.get("modifiedAt", "unknown")
|
||||
extension = result.get("extension", "unknown")
|
||||
name = result.get("name", path)
|
||||
return (
|
||||
f"File: {name}\n"
|
||||
f" Extension: {extension}\n"
|
||||
f" Size: {size} bytes\n"
|
||||
f" Created: {created}\n"
|
||||
f" Modified: {modified}"
|
||||
)
|
||||
|
||||
|
||||
FILESYSTEM_TOOLS: list[Any] = [
|
||||
list_directory,
|
||||
read_file_content,
|
||||
get_file_metadata,
|
||||
]
|
||||
|
||||
|
||||
def make_directory_tools(base_directory: str) -> list[Any]:
|
||||
"""Return filesystem tools that resolve relative paths against *base_directory*.
|
||||
|
||||
Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target
|
||||
directory upfront (e.g., journey setup sessions). Relative paths like ``"."``
|
||||
from the LLM are resolved to the correct absolute path before being sent to
|
||||
the Electron client, preventing it from falling back to its own CWD.
|
||||
"""
|
||||
|
||||
def _compact_for_journey(raw: str) -> str:
|
||||
"""Strip HTML noise and truncate for journey exploration.
|
||||
|
||||
The journey LLM only needs to understand file structure (headers,
|
||||
first paragraphs). Full CSS/style blocks are pure noise that eat
|
||||
up context window budget.
|
||||
"""
|
||||
text = re.sub(r"<style[^>]*>.*?</style>", "", raw, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r"<script[^>]*>.*?</script>", "", text, flags=re.DOTALL | re.IGNORECASE)
|
||||
text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)
|
||||
if len(text) > _JOURNEY_READ_MAX_CHARS:
|
||||
text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]"
|
||||
return text
|
||||
|
||||
@tool
|
||||
async def list_directory(path: str) -> str: # noqa: F811
|
||||
"""List files and folders in a local directory on the user's device.
|
||||
|
||||
Returns a formatted listing of entries with name, type (file/directory),
|
||||
and full path.
|
||||
"""
|
||||
resolved = _resolve_path(path, base_directory)
|
||||
result = await execute_on_client(
|
||||
action="list_directory",
|
||||
data={"path": resolved},
|
||||
)
|
||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||
if not entries:
|
||||
return f"Directory '{resolved}' is empty or does not exist."
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
entry_type = entry.get("type", "unknown")
|
||||
entry_name = entry.get("name", "")
|
||||
entry_path = entry.get("path", "")
|
||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||
return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||
|
||||
@tool
|
||||
async def read_file_content(path: str) -> str: # noqa: F811
|
||||
"""Read the text content of a local file on the user's device.
|
||||
|
||||
Returns the file content as a string. Large files may be truncated
|
||||
by the Electron client.
|
||||
"""
|
||||
resolved = _resolve_path(path, base_directory)
|
||||
result = await execute_on_client(
|
||||
action="read_file_content",
|
||||
data={"path": resolved},
|
||||
)
|
||||
content: str = result.get("content", "")
|
||||
if not content:
|
||||
return f"File '{resolved}' is empty or could not be read."
|
||||
return _compact_for_journey(content)
|
||||
|
||||
@tool
|
||||
async def get_file_metadata(path: str) -> str: # noqa: F811
|
||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||
|
||||
Returns a formatted summary of the file's metadata.
|
||||
"""
|
||||
resolved = _resolve_path(path, base_directory)
|
||||
result = await execute_on_client(
|
||||
action="get_file_metadata",
|
||||
data={"path": resolved},
|
||||
)
|
||||
size = result.get("size", "unknown")
|
||||
created = result.get("createdAt", "unknown")
|
||||
modified = result.get("modifiedAt", "unknown")
|
||||
extension = result.get("extension", "unknown")
|
||||
name = result.get("name", resolved)
|
||||
return (
|
||||
f"File: {name}\n"
|
||||
f" Extension: {extension}\n"
|
||||
f" Size: {size} bytes\n"
|
||||
f" Created: {created}\n"
|
||||
f" Modified: {modified}"
|
||||
)
|
||||
|
||||
return [list_directory, read_file_content, get_file_metadata]
|
||||
@@ -1,168 +0,0 @@
|
||||
"""Scoped file-read and search tools for the project folder feature."""
|
||||
from __future__ import annotations
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
# Cap returned slice size to keep tool output under control.
|
||||
_MAX_RETURN_CHARS = 50_000
|
||||
_MAX_SEARCH_MATCHES = 20
|
||||
|
||||
|
||||
def _is_unsafe_path(rel: str) -> bool:
|
||||
if not rel:
|
||||
return True
|
||||
norm = rel.replace("\\", "/")
|
||||
if norm.startswith("/"):
|
||||
return True
|
||||
# Windows drive letter
|
||||
if len(rel) >= 2 and rel[1] == ":":
|
||||
return True
|
||||
parts = norm.split("/")
|
||||
return ".." in parts
|
||||
|
||||
|
||||
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
|
||||
"""Return the raw Electron tool_result dict for a file read."""
|
||||
return await execute_on_client(
|
||||
action="read_project_folder_file",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"relativePath": relative_path,
|
||||
"offset": offset,
|
||||
"length": length,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _decode(result: dict) -> tuple[str, str, int]:
|
||||
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
|
||||
extracts text from base64. For images, returns a placeholder string.
|
||||
For text, content is already a sliced utf-8 string.
|
||||
"""
|
||||
kind = result.get("kind", "text")
|
||||
content = result.get("content", "") or ""
|
||||
total = int(result.get("totalSize", 0) or 0)
|
||||
if kind == "image":
|
||||
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
|
||||
if kind == "pdf":
|
||||
return (_extract_pdf_text(content), kind, total)
|
||||
if kind == "docx":
|
||||
return (_extract_docx_text(content), kind, total)
|
||||
return (content, kind, total)
|
||||
|
||||
|
||||
@tool
|
||||
async def read_project_folder_file(
|
||||
project_id: str,
|
||||
relative_path: str,
|
||||
offset: int = 0,
|
||||
length: int = _MAX_RETURN_CHARS,
|
||||
) -> str:
|
||||
"""Read a slice of a file inside the project's linked folder.
|
||||
|
||||
Args:
|
||||
project_id: project ID.
|
||||
relative_path: path relative to the linked folder root.
|
||||
offset: char offset to start reading from (0 = beginning).
|
||||
length: max chars to return. Default 50000. Use smaller values to save tokens.
|
||||
|
||||
Returns text content slice with a header showing position. Header tells you
|
||||
when more content is available; call again with the suggested next offset.
|
||||
|
||||
For PDF / DOCX files the backend extracts text first, then applies offset/length
|
||||
on the extracted text. For images returns a placeholder; navigate with the
|
||||
manifest summary instead.
|
||||
"""
|
||||
if _is_unsafe_path(relative_path):
|
||||
return "Access denied"
|
||||
|
||||
result = await _fetch_file(project_id, relative_path, offset, length)
|
||||
text, kind, total_size = _decode(result)
|
||||
|
||||
if not text and kind in ("missing", "error"):
|
||||
return f"File not found or unreadable: {relative_path}"
|
||||
|
||||
if kind in ("pdf", "docx"):
|
||||
# Backend extracted full text — apply offset/length on chars.
|
||||
sliced = text[offset:offset + length]
|
||||
slice_end = min(offset + length, len(text))
|
||||
header = (
|
||||
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
|
||||
f"totalChars={len(text)}]"
|
||||
)
|
||||
if slice_end < len(text):
|
||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||
return header + "\n" + sliced
|
||||
|
||||
if kind == "text":
|
||||
slice_end = offset + len(text)
|
||||
header = (
|
||||
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
|
||||
f"totalBytes={total_size}]"
|
||||
)
|
||||
if slice_end < total_size:
|
||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
||||
return header + "\n" + text
|
||||
|
||||
# image or unknown
|
||||
return text
|
||||
|
||||
|
||||
@tool
|
||||
async def search_project_folder_file(
|
||||
project_id: str,
|
||||
relative_path: str,
|
||||
query: str,
|
||||
context_lines: int = 3,
|
||||
) -> str:
|
||||
"""Search a project folder file for a query string (case-insensitive substring).
|
||||
|
||||
Args:
|
||||
project_id: project ID.
|
||||
relative_path: path relative to the linked folder root.
|
||||
query: text to search for.
|
||||
context_lines: number of lines of context around each match (default 3).
|
||||
|
||||
Returns matching line ranges with surrounding context and 1-based line numbers.
|
||||
Capped at 20 matches; if more exist the header shows the total.
|
||||
|
||||
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
|
||||
Images and binary files are not searchable.
|
||||
"""
|
||||
if _is_unsafe_path(relative_path):
|
||||
return "Access denied"
|
||||
if not query:
|
||||
return "Empty query."
|
||||
|
||||
# For text we still need full file; pass length=very large.
|
||||
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
|
||||
text, kind, _ = _decode(result)
|
||||
|
||||
if not text and kind in ("missing", "error"):
|
||||
return f"File not found or unreadable: {relative_path}"
|
||||
if kind == "image":
|
||||
return "Cannot search inside images."
|
||||
|
||||
lines = text.splitlines()
|
||||
q = query.lower()
|
||||
matches = [i for i, line in enumerate(lines) if q in line.lower()]
|
||||
if not matches:
|
||||
return f"No matches for '{query}' in {relative_path}."
|
||||
|
||||
shown = matches[:_MAX_SEARCH_MATCHES]
|
||||
snippets: list[str] = []
|
||||
for i in shown:
|
||||
start = max(0, i - context_lines)
|
||||
end = min(len(lines), i + context_lines + 1)
|
||||
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
|
||||
snippets.append(block)
|
||||
|
||||
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
|
||||
body = "\n---\n".join(snippets)
|
||||
return header + "\n" + body
|
||||
|
||||
|
||||
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]
|
||||
@@ -1,206 +0,0 @@
|
||||
"""Note agent — Markdown note management (list, get, create, update, propose edit)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.note_summarizer import generate_note_summary
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
|
||||
def _fmt_summary(row: dict) -> str:
|
||||
summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
|
||||
if summary:
|
||||
return f" — {summary}"
|
||||
snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
|
||||
return f" — {snippet}" if snippet else ""
|
||||
|
||||
|
||||
@tool
|
||||
async def list_notes(project_id: str = "") -> str:
|
||||
"""List notes with AI summaries, optionally scoped to a project by project_id.
|
||||
|
||||
Returns id, title, and ai_summary for each note so you can decide which
|
||||
note to read in full with get_note before creating or updating.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="notes",
|
||||
filters={"projectId": normalized_project_id or None},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No notes found."
|
||||
lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
|
||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def get_note(note_id: str) -> str:
|
||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
||||
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
||||
row = result.get("row")
|
||||
if not row:
|
||||
return f"Note {note_id} not found."
|
||||
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_note(
|
||||
title: str,
|
||||
content: str,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Create a new note.
|
||||
title: note heading (required)
|
||||
content: Markdown body text (required)
|
||||
project_id: optional UUID linking this note to a project
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="notes",
|
||||
data={
|
||||
"title": title,
|
||||
"content": content,
|
||||
"projectId": project_id or None,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
note_id: str = row["id"]
|
||||
# Generate summary asynchronously — fire-and-forget.
|
||||
asyncio.create_task(_refresh_summary(note_id, title, content))
|
||||
return f"Note created: '{row['title']}' (id: {note_id})."
|
||||
|
||||
|
||||
@tool
|
||||
async def update_note(
|
||||
note_id: str,
|
||||
title: str = "",
|
||||
content: str = "",
|
||||
) -> str:
|
||||
"""Update an existing note directly (no approval required).
|
||||
Use propose_note_edit instead when human review is needed.
|
||||
note_id: UUID of the note (required)
|
||||
If you need to preserve existing content, call get_note first.
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if content:
|
||||
updates["content"] = content
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="notes",
|
||||
data={"id": note_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
if content:
|
||||
new_title = title or row.get("title", "")
|
||||
asyncio.create_task(_refresh_summary(note_id, new_title, content))
|
||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||
|
||||
|
||||
@tool
|
||||
async def propose_note_edit(
|
||||
note_id: str,
|
||||
edit_type: str,
|
||||
proposed_content: str,
|
||||
reasoning: str = "",
|
||||
anchor_before: str = "",
|
||||
anchor_text: str = "",
|
||||
agent_id: str = "",
|
||||
run_id: str = "",
|
||||
) -> str:
|
||||
"""Propose an AI edit to an existing note, pending human approval.
|
||||
|
||||
Use this instead of update_note when review_required is true.
|
||||
The user will see the proposal highlighted before it is merged.
|
||||
|
||||
note_id: UUID of the target note (required)
|
||||
edit_type: 'append' | 'insert' | 'replace'
|
||||
- append: adds proposed_content at the end of the note
|
||||
- insert: inserts proposed_content immediately after anchor_before text
|
||||
- replace: replaces the first occurrence of anchor_text with proposed_content
|
||||
proposed_content: the new Markdown text to add or substitute (required)
|
||||
reasoning: brief explanation shown to the user (recommended)
|
||||
anchor_before: for 'insert' — the text snippet that precedes the insertion point
|
||||
anchor_text: for 'replace' — the exact text to be replaced
|
||||
agent_id: agent identifier (for traceability)
|
||||
run_id: run identifier (for traceability)
|
||||
"""
|
||||
if edit_type not in ("append", "insert", "replace"):
|
||||
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
|
||||
|
||||
result = await execute_on_client(
|
||||
action="propose_note_edit",
|
||||
data={
|
||||
"noteId": note_id,
|
||||
"type": edit_type,
|
||||
"proposedContent": proposed_content,
|
||||
"reasoning": reasoning or None,
|
||||
"anchorBefore": anchor_before or None,
|
||||
"anchorText": anchor_text or None,
|
||||
"agentId": agent_id or None,
|
||||
"runId": run_id or None,
|
||||
},
|
||||
)
|
||||
edit_id = result.get("id", "?")
|
||||
return (
|
||||
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
|
||||
f"Status: pending user approval."
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_note(note_id: str) -> str:
|
||||
"""Delete a note permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
||||
return f"Note {note_id} deleted."
|
||||
|
||||
|
||||
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
|
||||
"""Generate and persist the AI summary for a note. Fire-and-forget."""
|
||||
try:
|
||||
summary = await generate_note_summary(title, content)
|
||||
if summary:
|
||||
await execute_on_client(
|
||||
action="update",
|
||||
table="notes",
|
||||
data={
|
||||
"id": note_id,
|
||||
"updates": {
|
||||
"aiSummary": summary,
|
||||
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
|
||||
},
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
pass # fire-and-forget; errors logged by generate_note_summary
|
||||
|
||||
|
||||
NOTE_TOOLS: list[Any] = [
|
||||
list_notes,
|
||||
get_note,
|
||||
create_note,
|
||||
update_note,
|
||||
propose_note_edit,
|
||||
delete_note,
|
||||
]
|
||||
|
||||
NOTE_READ_TOOLS: list[Any] = [
|
||||
list_notes,
|
||||
get_note,
|
||||
]
|
||||
@@ -1,63 +0,0 @@
|
||||
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import async_session
|
||||
|
||||
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
|
||||
# Each tool closure captures the user_id bound at factory time.
|
||||
|
||||
|
||||
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
|
||||
"""Return a query_relations tool bound to *user_id*."""
|
||||
|
||||
@tool
|
||||
async def query_relations(
|
||||
subject_label: str = "",
|
||||
predicate: str = "",
|
||||
object_label: str = "",
|
||||
limit: int = 10,
|
||||
) -> str:
|
||||
"""Query the relational memory graph for entity relationships.
|
||||
|
||||
Returns rows where subject ↔ predicate ↔ object match the given filters.
|
||||
All parameters are optional — omit to retrieve all relations up to limit.
|
||||
|
||||
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
|
||||
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
|
||||
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
|
||||
limit: max rows to return (default 10).
|
||||
"""
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
|
||||
trace_id or "-", user_id, subject_label, predicate, object_label,
|
||||
)
|
||||
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
rows = await memory.query_relations(
|
||||
user_id=user_id,
|
||||
subject=subject_label or None,
|
||||
predicate=predicate or None,
|
||||
object_=object_label or None,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return "No relational memory entries found for the given filters."
|
||||
|
||||
lines = [
|
||||
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
|
||||
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
|
||||
|
||||
return query_relations
|
||||
@@ -1,358 +0,0 @@
|
||||
"""Task agent — full CRUD for tasks and task comments."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
|
||||
# ── Task tools ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignee: str = "",
|
||||
search: str = "",
|
||||
order_by: str = "",
|
||||
order_dir: str = "",
|
||||
due_date_from: int = -1,
|
||||
due_date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> str:
|
||||
"""List tasks with optional filters. Returns up to `limit` results (default 50).
|
||||
|
||||
project_id: UUID of the project to scope results to.
|
||||
status: filter by status — todo | in_progress | done.
|
||||
priority: filter by priority — high | medium | low.
|
||||
assignee: substring to match against assignee names. OMIT unless the user explicitly
|
||||
names a person or refers to themselves ("my tasks", "assigned to me", "mine").
|
||||
Do NOT default to the current user.
|
||||
search: substring search across title and description.
|
||||
order_by: sort field — dueDate | priority | createdAt | completedAt.
|
||||
order_dir: asc (default) | desc.
|
||||
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
||||
limit: max rows to return (default 50). Use with offset to paginate.
|
||||
offset: skip first N rows (default 0).
|
||||
|
||||
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
||||
Tip — prefer count_tasks for "how many" questions to avoid listing rows.
|
||||
Tip — for natural-language windows ("today", "tomorrow", "this week", "last month", etc.)
|
||||
take due_date_from / due_date_to verbatim from the DATE CONTEXT block in the system prompt;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"priority": priority or None,
|
||||
"search": search or None,
|
||||
"orderBy": order_by or None,
|
||||
"orderDir": order_dir or None,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if assignee:
|
||||
filters["assignee"] = assignee
|
||||
if due_date_from != -1:
|
||||
filters["dueDateFrom"] = due_date_from
|
||||
if due_date_to != -1:
|
||||
filters["dueDateTo"] = due_date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
|
||||
result = await execute_on_client(action="select", table="tasks", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks found matching the given filters."
|
||||
lines = [
|
||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, "
|
||||
f"dueDate: {r.get('dueDate')}, completedAt: {r.get('completedAt')}, "
|
||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def count_tasks(
|
||||
project_id: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignee: str = "",
|
||||
search: str = "",
|
||||
due_date_from: int = -1,
|
||||
due_date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
) -> str:
|
||||
"""Count tasks matching the given filters without returning rows.
|
||||
|
||||
Use this instead of list_tasks for "how many" questions — it is much cheaper.
|
||||
Same filter parameters as list_tasks (no limit/offset/order_by needed).
|
||||
assignee: OMIT unless the user explicitly names a person or refers to themselves
|
||||
("my tasks"). Do NOT default to the current user.
|
||||
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
Tip — for natural-language windows take due_date_from / due_date_to from the DATE CONTEXT block;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {
|
||||
"projectId": normalized_project_id or None,
|
||||
"status": status or None,
|
||||
"priority": priority or None,
|
||||
"search": search or None,
|
||||
}
|
||||
if assignee:
|
||||
filters["assignee"] = assignee
|
||||
if due_date_from != -1:
|
||||
filters["dueDateFrom"] = due_date_from
|
||||
if due_date_to != -1:
|
||||
filters["dueDateTo"] = due_date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
|
||||
result = await execute_on_client(action="count", table="tasks", filters=filters)
|
||||
return f"Task count: {result.get('count', 0)}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_task(
|
||||
title: str,
|
||||
description: str = "",
|
||||
status: str = "todo",
|
||||
priority: str = "medium",
|
||||
assignees: str = "[]",
|
||||
due_date: int = 0,
|
||||
project_id: str = "",
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a new task.
|
||||
title: task title (required)
|
||||
description: optional details
|
||||
status: todo | in_progress | done (default: todo)
|
||||
priority: high | medium | low (default: medium)
|
||||
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||
project_id: optional UUID of the parent project
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
|
||||
completedAt is set automatically when status is 'done'.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="tasks",
|
||||
data={
|
||||
"title": title,
|
||||
"description": description or None,
|
||||
"status": status,
|
||||
"priority": priority,
|
||||
"assignee": assignees,
|
||||
"dueDate": due_date or None,
|
||||
"projectId": project_id or None,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return (
|
||||
f"Task created: '{row['title']}' "
|
||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']}, projectId: {row.get('projectId')})"
|
||||
)
|
||||
|
||||
|
||||
@tool
|
||||
async def update_task(
|
||||
task_id: str,
|
||||
title: str = "",
|
||||
description: str = "",
|
||||
status: str = "",
|
||||
priority: str = "",
|
||||
assignees: str = "",
|
||||
due_date: int = -1,
|
||||
project_id: str = "",
|
||||
) -> str:
|
||||
"""Update fields on an existing task. Only pass fields you want to change.
|
||||
task_id: the task's UUID (required)
|
||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||
|
||||
completedAt is managed automatically:
|
||||
- setting status to 'done' records the current timestamp
|
||||
- changing status away from 'done' clears completedAt
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if description:
|
||||
updates["description"] = description
|
||||
if status:
|
||||
updates["status"] = status
|
||||
if priority:
|
||||
updates["priority"] = priority
|
||||
if assignees:
|
||||
updates["assignee"] = assignees
|
||||
if due_date != -1:
|
||||
updates["dueDate"] = due_date or None
|
||||
if project_id:
|
||||
updates["projectId"] = project_id
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="tasks",
|
||||
data={"id": task_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']}, projectId: {row.get('projectId')})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task(task_id: str) -> str:
|
||||
"""Delete a task permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
||||
return f"Task {task_id} deleted."
|
||||
|
||||
|
||||
@tool
|
||||
async def list_tasks_due_today(user_timezone: str = "UTC", include_done: bool = False) -> str:
|
||||
"""List all tasks whose due date falls on today's date.
|
||||
|
||||
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
||||
Always pass the user's timezone so 'today' is computed in their local time.
|
||||
include_done: set True to also include already-completed tasks due today (default False).
|
||||
"""
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
tz = ZoneInfo(user_timezone or "UTC")
|
||||
except Exception:
|
||||
tz = timezone.utc
|
||||
now_local = datetime.now(tz=tz)
|
||||
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
||||
start_ms = int(start_dt.timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1
|
||||
filters: dict[str, Any] = {"dueDateFrom": start_ms, "dueDateTo": end_ms}
|
||||
if not include_done:
|
||||
filters["status"] = "todo"
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="tasks",
|
||||
filters=filters,
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No tasks are due today."
|
||||
lines = [
|
||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, "
|
||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
# ── Task comment tools ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@tool
|
||||
async def list_task_comments(task_id: str) -> str:
|
||||
"""List all comments on a task by its UUID."""
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="taskComments",
|
||||
filters={"taskId": task_id},
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return f"No comments found for task {task_id}."
|
||||
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
||||
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
||||
"""Add a comment to a task.
|
||||
task_id: UUID of the task to comment on
|
||||
author: name or ID of the comment author
|
||||
content: comment text
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="taskComments",
|
||||
data={"taskId": task_id, "author": author, "content": content},
|
||||
)
|
||||
row = result.get("row", {})
|
||||
row_author = row.get("author", author)
|
||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
||||
row_comment_id = row.get("id", "unknown")
|
||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_task_comment(comment_id: str) -> str:
|
||||
"""Delete a task comment by its UUID."""
|
||||
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
||||
return f"Comment {comment_id} deleted."
|
||||
|
||||
|
||||
# ── Agent ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
TASK_TOOLS: list[Any] = [
|
||||
list_tasks,
|
||||
count_tasks,
|
||||
create_task,
|
||||
update_task,
|
||||
delete_task,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
add_task_comment,
|
||||
delete_task_comment,
|
||||
]
|
||||
|
||||
TASK_READ_TOOLS: list[Any] = [
|
||||
list_tasks,
|
||||
count_tasks,
|
||||
list_tasks_due_today,
|
||||
list_task_comments,
|
||||
]
|
||||
@@ -1,270 +0,0 @@
|
||||
"""Timeline agent — project milestone management (list, create, update, delete)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from app.core.ws_context import execute_on_client
|
||||
|
||||
_UUID_RE = re.compile(
|
||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
||||
)
|
||||
|
||||
|
||||
def _is_uuid(value: str) -> bool:
|
||||
return bool(_UUID_RE.match(value))
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines(
|
||||
project_id: str = "",
|
||||
type: str = "",
|
||||
is_completed: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
order_by: str = "",
|
||||
order_dir: str = "",
|
||||
date_from: int = -1,
|
||||
date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
limit: int = 50,
|
||||
offset: int = 0,
|
||||
) -> str:
|
||||
"""List timeline events (milestones, checkpoints, activities) with optional filters.
|
||||
|
||||
project_id: UUID to scope results to a specific project.
|
||||
type: filter by event type — milestone | checkpoint | activity.
|
||||
is_completed: 0 = incomplete only, 1 = completed only, -1 = any (default).
|
||||
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
||||
order_by: sort field — date (default) | createdAt | completedAt.
|
||||
order_dir: asc (default) | desc.
|
||||
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
limit: max rows to return (default 50). Use with offset to paginate.
|
||||
offset: skip first N rows (default 0).
|
||||
|
||||
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
||||
Tip — prefer count_timelines for "how many" questions to avoid listing rows.
|
||||
Tip — for natural-language windows ("today", "this week", "last month", etc.)
|
||||
take date_from / date_to verbatim from the DATE CONTEXT block in the system prompt;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {
|
||||
"projectId": normalized_project_id or None,
|
||||
"orderBy": order_by or None,
|
||||
"orderDir": order_dir or None,
|
||||
"limit": limit,
|
||||
"offset": offset,
|
||||
}
|
||||
if type:
|
||||
filters["type"] = type
|
||||
if is_completed != -1:
|
||||
filters["isCompleted"] = is_completed
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
if date_from != -1:
|
||||
filters["dateFrom"] = date_from
|
||||
if date_to != -1:
|
||||
filters["dateTo"] = date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
|
||||
result = await execute_on_client(action="select", table="timelines", filters=filters)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No timeline events found."
|
||||
lines = [
|
||||
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
||||
f"completed: {bool(r.get('isCompleted'))}, completedAt: {r.get('completedAt')}, "
|
||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Found {len(rows)} timeline event(s):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def count_timelines(
|
||||
project_id: str = "",
|
||||
type: str = "",
|
||||
is_completed: int = -1,
|
||||
is_ai_suggested: int = -1,
|
||||
date_from: int = -1,
|
||||
date_to: int = -1,
|
||||
created_at_from: int = -1,
|
||||
created_at_to: int = -1,
|
||||
completed_at_from: int = -1,
|
||||
completed_at_to: int = -1,
|
||||
) -> str:
|
||||
"""Count timeline events matching the given filters without returning rows.
|
||||
|
||||
Use this instead of list_timelines for "how many" questions — it is much cheaper.
|
||||
Same filter parameters as list_timelines (no limit/offset/order_by needed).
|
||||
|
||||
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
||||
Tip — for natural-language windows take date_from / date_to from the DATE CONTEXT block;
|
||||
do not compute boundaries from the current UTC instant.
|
||||
"""
|
||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
||||
filters: dict[str, Any] = {"projectId": normalized_project_id or None}
|
||||
if type:
|
||||
filters["type"] = type
|
||||
if is_completed != -1:
|
||||
filters["isCompleted"] = is_completed
|
||||
if is_ai_suggested != -1:
|
||||
filters["isAiSuggested"] = is_ai_suggested
|
||||
if date_from != -1:
|
||||
filters["dateFrom"] = date_from
|
||||
if date_to != -1:
|
||||
filters["dateTo"] = date_to
|
||||
if created_at_from != -1:
|
||||
filters["createdAtFrom"] = created_at_from
|
||||
if created_at_to != -1:
|
||||
filters["createdAtTo"] = created_at_to
|
||||
if completed_at_from != -1:
|
||||
filters["completedAtFrom"] = completed_at_from
|
||||
if completed_at_to != -1:
|
||||
filters["completedAtTo"] = completed_at_to
|
||||
|
||||
result = await execute_on_client(action="count", table="timelines", filters=filters)
|
||||
return f"Timeline event count: {result.get('count', 0)}"
|
||||
|
||||
|
||||
@tool
|
||||
async def create_timeline(
|
||||
project_id: str,
|
||||
title: str,
|
||||
date: int,
|
||||
type: str = "milestone",
|
||||
is_completed: int = 0,
|
||||
is_ai_suggested: int = 0,
|
||||
) -> str:
|
||||
"""Create a project timeline event.
|
||||
project_id: REQUIRED UUID of the parent project
|
||||
title: descriptive name for the event
|
||||
date: Unix timestamp in milliseconds for the event date
|
||||
type: milestone (default) | checkpoint | activity
|
||||
is_completed: 1 if already completed, 0 if not (default 0)
|
||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||
|
||||
completedAt is set automatically when is_completed is 1.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="insert",
|
||||
table="timelines",
|
||||
data={
|
||||
"projectId": project_id,
|
||||
"title": title,
|
||||
"date": date,
|
||||
"type": type,
|
||||
"isCompleted": is_completed,
|
||||
"isAiSuggested": is_ai_suggested,
|
||||
},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline event created: '{row['title']}' (id: {row['id']}, date: {row['date']}, type: {row.get('type')})"
|
||||
|
||||
|
||||
@tool
|
||||
async def update_timeline(
|
||||
timeline_id: str,
|
||||
title: str = "",
|
||||
date: int = -1,
|
||||
is_completed: int = -1,
|
||||
) -> str:
|
||||
"""Update a timeline event. Only pass fields that should change.
|
||||
timeline_id: UUID of the event (required)
|
||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||
is_completed: 0 = mark incomplete, 1 = mark complete, -1 = unchanged
|
||||
|
||||
completedAt is managed automatically:
|
||||
- setting is_completed to 1 records the current timestamp
|
||||
- setting is_completed to 0 clears completedAt
|
||||
"""
|
||||
updates: dict[str, Any] = {}
|
||||
if title:
|
||||
updates["title"] = title
|
||||
if date != -1:
|
||||
updates["date"] = date
|
||||
if is_completed != -1:
|
||||
updates["isCompleted"] = is_completed
|
||||
result = await execute_on_client(
|
||||
action="update",
|
||||
table="timelines",
|
||||
data={"id": timeline_id, "updates": updates},
|
||||
)
|
||||
row = result["row"]
|
||||
return f"Timeline event updated: '{row['title']}' (id: {row['id']})"
|
||||
|
||||
|
||||
@tool
|
||||
async def delete_timeline(timeline_id: str) -> str:
|
||||
"""Delete a timeline event permanently by its UUID."""
|
||||
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||
return f"Timeline event {timeline_id} deleted."
|
||||
|
||||
|
||||
@tool
|
||||
async def list_timelines_today(user_timezone: str = "UTC", include_completed: bool = True) -> str:
|
||||
"""List all timeline events whose date falls on today.
|
||||
|
||||
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
||||
Always pass the user's timezone so 'today' is computed in their local time.
|
||||
include_completed: set False to exclude already-completed events (default True).
|
||||
"""
|
||||
try:
|
||||
from zoneinfo import ZoneInfo
|
||||
tz = ZoneInfo(user_timezone or "UTC")
|
||||
except Exception:
|
||||
tz = timezone.utc
|
||||
now_local = datetime.now(tz=tz)
|
||||
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
||||
start_ms = int(start_dt.timestamp() * 1000)
|
||||
end_ms = start_ms + 86_400_000 - 1
|
||||
filters: dict[str, Any] = {"dateFrom": start_ms, "dateTo": end_ms}
|
||||
if not include_completed:
|
||||
filters["isCompleted"] = 0
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table="timelines",
|
||||
filters=filters,
|
||||
)
|
||||
rows = result.get("rows", [])
|
||||
if not rows:
|
||||
return "No timeline events today."
|
||||
lines = [
|
||||
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
||||
f"completed: {bool(r.get('isCompleted'))}, projectId: {r.get('projectId')}, id: {r['id']})"
|
||||
for r in rows
|
||||
]
|
||||
return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
TIMELINE_TOOLS: list[Any] = [
|
||||
list_timelines,
|
||||
count_timelines,
|
||||
list_timelines_today,
|
||||
create_timeline,
|
||||
update_timeline,
|
||||
delete_timeline,
|
||||
]
|
||||
|
||||
TIMELINE_READ_TOOLS: list[Any] = [
|
||||
list_timelines,
|
||||
count_timelines,
|
||||
list_timelines_today,
|
||||
]
|
||||
@@ -1,14 +0,0 @@
|
||||
"""Shared FastAPI dependencies.
|
||||
|
||||
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
|
||||
(the canonical location per Step 9). This module re-exports them so that all
|
||||
existing route imports (``from app.api.deps import get_current_user``) continue
|
||||
to work without modification.
|
||||
|
||||
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
|
||||
instead of reading it from the JWT payload.
|
||||
"""
|
||||
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
|
||||
|
||||
__all__ = ["get_current_user", "oauth2_scheme"]
|
||||
@@ -1,19 +0,0 @@
|
||||
"""API middleware package.
|
||||
|
||||
Exports the three middleware components introduced in Step 9:
|
||||
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
|
||||
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
|
||||
- Sanitizer: ``SanitizerMiddleware``
|
||||
"""
|
||||
|
||||
from app.api.middleware.auth import get_current_user, oauth2_scheme
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
|
||||
__all__ = [
|
||||
"get_current_user",
|
||||
"oauth2_scheme",
|
||||
"TierRateLimitMiddleware",
|
||||
"limiter",
|
||||
"SanitizerMiddleware",
|
||||
]
|
||||
@@ -1,103 +0,0 @@
|
||||
"""Auth middleware — JWT validation dependency.
|
||||
|
||||
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||
without requiring token re-issue.
|
||||
|
||||
Exempt routes (no JWT required):
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import get_session
|
||||
from app.schemas import UserProfile
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
The JWT is used for identity and expiry only. The tier is fetched live
|
||||
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||
|
||||
Raises HTTP 401 on any invalid or expired token.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
# Live tier lookup — subscription row is the authoritative source.
|
||||
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
||||
# block local development when no Stripe subscription exists.
|
||||
from app.models import Subscription, User # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
|
||||
user_result = await db.execute(
|
||||
select(
|
||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||
User.password_hash,
|
||||
).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
# Convert onboarding_completed_at to epoch ms (int) or None.
|
||||
onboarding_ms: int | None = None
|
||||
if user_row and user_row.onboarding_completed_at is not None:
|
||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||
|
||||
# Load decrypted core memory.
|
||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||
|
||||
memory_dict: dict[str, str] = {}
|
||||
try:
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(user_id)
|
||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||
except Exception:
|
||||
pass # Non-critical — return empty memory on failure
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
avatar_url=user_row.avatar_url if user_row else None,
|
||||
has_password=bool(user_row.password_hash) if user_row else False,
|
||||
tier=tier,
|
||||
onboarding_completed_at=onboarding_ms,
|
||||
memory=memory_dict,
|
||||
) # type: ignore[arg-type]
|
||||
@@ -1,129 +0,0 @@
|
||||
"""Tier-aware rate limiting middleware.
|
||||
|
||||
Uses a per-user sliding-window counter (in-process, no Redis required).
|
||||
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
||||
|
||||
Limits (requests per minute):
|
||||
- free: 20
|
||||
- pro: 60
|
||||
- power: 120
|
||||
- team: 200
|
||||
|
||||
Exempt paths bypass the limiter entirely:
|
||||
- POST /api/v1/auth/register
|
||||
- POST /api/v1/auth/login
|
||||
- POST /api/v1/billing/webhook
|
||||
- GET /api/v1/health
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
|
||||
from fastapi import Request, Response
|
||||
from jose import JWTError, jwt
|
||||
from slowapi import Limiter
|
||||
from slowapi.util import get_remote_address
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
_TIER_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||
{
|
||||
"/api/v1/auth/register",
|
||||
"/api/v1/auth/login",
|
||||
"/api/v1/billing/webhook",
|
||||
"/api/v1/health",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _get_user_id_from_jwt(request: Request) -> str:
|
||||
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return get_remote_address(request)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
return payload.get("sub") or get_remote_address(request)
|
||||
except JWTError:
|
||||
return get_remote_address(request)
|
||||
|
||||
|
||||
# Exported Limiter instance — available for optional route-level decoration.
|
||||
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
||||
|
||||
|
||||
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
||||
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
||||
|
||||
Each authenticated user gets their own 60-second window sized by tier.
|
||||
Unauthenticated requests pass through (the auth dependency will reject them
|
||||
with 401 before the route handler runs).
|
||||
"""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
# user_id → list of request timestamps (float, seconds since epoch)
|
||||
self._window: dict[str, list[float]] = defaultdict(list)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
if request.url.path in _EXEMPT_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
||||
auth = request.headers.get("Authorization", "")
|
||||
token = auth.removeprefix("Bearer ").strip()
|
||||
if not token:
|
||||
return await call_next(request)
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str = payload.get("sub") or get_remote_address(request)
|
||||
tier: str = payload.get("tier", "free")
|
||||
except JWTError:
|
||||
return await call_next(request)
|
||||
|
||||
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
||||
now = time.monotonic()
|
||||
window_start = now - 60.0
|
||||
|
||||
# Slide the window: discard timestamps older than 60 seconds.
|
||||
timestamps = [t for t in self._window[user_id] if t > window_start]
|
||||
|
||||
if len(timestamps) >= limit:
|
||||
retry_after = max(1, int(60 - (now - min(timestamps))))
|
||||
return Response(
|
||||
content=json.dumps(
|
||||
{
|
||||
"detail": (
|
||||
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
||||
f"Retry in {retry_after}s."
|
||||
)
|
||||
}
|
||||
),
|
||||
status_code=429,
|
||||
headers={
|
||||
"Retry-After": str(retry_after),
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
|
||||
timestamps.append(now)
|
||||
self._window[user_id] = timestamps
|
||||
return await call_next(request)
|
||||
@@ -1,138 +0,0 @@
|
||||
"""Response sanitizer middleware.
|
||||
|
||||
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
||||
that could reveal server-side prompt IP:
|
||||
- System prompt openers ("You are a/an/the …")
|
||||
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
||||
- LangChain tool schema fragments (``"type": "function"``)
|
||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||
- Exact-match known prompt fingerprints
|
||||
|
||||
The middleware only activates for paths under /api/v1/chat.
|
||||
|
||||
Any sanitisation event is logged as a WARNING with the request path and the
|
||||
names of the fields that were modified.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
|
||||
from fastapi import Request, Response
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.types import ASGIApp
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Detection patterns — order matters: fingerprints checked first (exact),
|
||||
# then compiled regexes.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FINGERPRINTS: tuple[str, ...] = (
|
||||
"You are an intent classifier",
|
||||
"Respond with just the agent name",
|
||||
"Summarize these agent results",
|
||||
"Available agents:",
|
||||
"route to:",
|
||||
)
|
||||
|
||||
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||||
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
||||
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
||||
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
||||
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
||||
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
||||
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
||||
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_text(text: str) -> tuple[str, bool]:
|
||||
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
||||
|
||||
Returns ``(cleaned_text, was_changed)``.
|
||||
"""
|
||||
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
||||
for fp in _FINGERPRINTS:
|
||||
if fp in text:
|
||||
return "[REDACTED]", True
|
||||
|
||||
changed = False
|
||||
for pattern in _PATTERNS:
|
||||
new_text, n = pattern.subn("[REDACTED]", text)
|
||||
if n:
|
||||
text = new_text
|
||||
changed = True
|
||||
|
||||
return text, changed
|
||||
|
||||
|
||||
class SanitizerMiddleware(BaseHTTPMiddleware):
|
||||
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
||||
|
||||
def __init__(self, app: ASGIApp) -> None:
|
||||
super().__init__(app)
|
||||
|
||||
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||
response: Response = await call_next(request)
|
||||
|
||||
# Only process chat endpoint responses.
|
||||
if not request.url.path.startswith("/api/v1/chat"):
|
||||
return response
|
||||
|
||||
# Read body — collect streaming chunks.
|
||||
body_bytes = b""
|
||||
async for chunk in response.body_iterator:
|
||||
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||
|
||||
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
||||
try:
|
||||
body = json.loads(body_bytes.decode("utf-8"))
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
if not isinstance(body, dict):
|
||||
return Response(
|
||||
content=body_bytes,
|
||||
status_code=response.status_code,
|
||||
headers=dict(response.headers),
|
||||
media_type=response.media_type,
|
||||
)
|
||||
|
||||
# Walk top-level string fields and sanitise.
|
||||
sanitised_fields: list[str] = []
|
||||
for key, value in body.items():
|
||||
if isinstance(value, str):
|
||||
cleaned, changed = _sanitize_text(value)
|
||||
if changed:
|
||||
body[key] = cleaned
|
||||
sanitised_fields.append(key)
|
||||
|
||||
if sanitised_fields:
|
||||
logger.warning(
|
||||
"Sanitizer redacted prompt fragments",
|
||||
extra={
|
||||
"path": request.url.path,
|
||||
"fields": sanitised_fields,
|
||||
},
|
||||
)
|
||||
|
||||
new_body = json.dumps(body).encode("utf-8")
|
||||
headers = dict(response.headers)
|
||||
headers["content-length"] = str(len(new_body))
|
||||
|
||||
return Response(
|
||||
content=new_body,
|
||||
status_code=response.status_code,
|
||||
headers=headers,
|
||||
media_type="application/json",
|
||||
)
|
||||
@@ -1,513 +0,0 @@
|
||||
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
|
||||
|
||||
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||
frames to the functions exported here.
|
||||
|
||||
Journey flow:
|
||||
1. FE sends ``journey_start`` frame with basic agent info (directory,
|
||||
data_types, schedule).
|
||||
2. Server creates an in-memory session, sets up a WS executor so the
|
||||
setup LLM can use file-system tools, does a first directory scrape,
|
||||
and sends back a ``journey_reply`` with the first question.
|
||||
3. FE sends ``journey_message`` frames for each user reply.
|
||||
4. Server appends the user message, calls the LLM (which may read files
|
||||
via tools), and sends back a ``journey_reply``.
|
||||
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
|
||||
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
||||
6. Server parses and validates the JSON with Pydantic, sends
|
||||
``journey_reply`` with ``done=True`` and the serialised config.
|
||||
FE stores it locally.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from app.agents.filesystem_agent import make_directory_tools
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
from app.schemas import AgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||
|
||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||
|
||||
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
|
||||
_CONFIG_START = "AGENT_CONFIG_START"
|
||||
_CONFIG_END = "AGENT_CONFIG_END"
|
||||
|
||||
# Minimum turns before we consider nudging the LLM to wrap up.
|
||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
|
||||
_MAX_TURNS: int = 15
|
||||
# Max tool-calling steps per LLM invocation.
|
||||
_MAX_TOOL_STEPS: int = 6
|
||||
|
||||
# ── In-memory session store ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class JourneySession:
|
||||
session_id: str
|
||||
user_id: str
|
||||
agent_type: str # "local" | "cloud"
|
||||
directory: str
|
||||
data_types: list[str]
|
||||
history: list[dict[str, Any]] = field(default_factory=list)
|
||||
system_prompt: str = ""
|
||||
langfuse_prompt: Any = None
|
||||
created_at: float = field(default_factory=time.monotonic)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||
|
||||
|
||||
# session_id → session
|
||||
_sessions: dict[str, JourneySession] = {}
|
||||
|
||||
|
||||
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||
s = _sessions.get(session_id)
|
||||
if s is None or s.is_expired():
|
||||
_sessions.pop(session_id, None)
|
||||
return None
|
||||
if s.user_id != user_id:
|
||||
return None
|
||||
return s
|
||||
|
||||
|
||||
# ── System prompt ─────────────────────────────────────────────────────────
|
||||
|
||||
_JOURNEY_SYSTEM_PROMPT = """\
|
||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||
Your job is to understand what files the user has in their directory and produce a
|
||||
structured AgentConfig JSON that the extraction agent will use as its instruction set.
|
||||
|
||||
You have access to file-system tools to explore the user's directory:
|
||||
- list_directory: see folder structure and file names
|
||||
- read_file_content: peek at a file's content
|
||||
- get_file_metadata: check file size, extension, dates
|
||||
|
||||
The user's configured directory is: {directory}
|
||||
Target data types: {data_types}
|
||||
|
||||
## Your process
|
||||
|
||||
### Step 1 — Explore the directory
|
||||
Use list_directory and read_file_content to understand what types of files are present
|
||||
(HTML emails, plain-text documents, CSVs, etc.).
|
||||
|
||||
### Step 2 — Identify content types
|
||||
For each distinct file type found, decide:
|
||||
- A short id (e.g. "email_html", "plain_text", "csv")
|
||||
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
|
||||
- A human-readable label and optional detection_hint
|
||||
|
||||
### Step 3 — Ask focused questions (one at a time)
|
||||
Cover these topics based on what you discovered:
|
||||
1. How to map content to entity types (task / note / timeline entry)
|
||||
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
|
||||
3. Priority or status rules (e.g. "urgent" in subject → high priority)
|
||||
4. Date extraction (e.g. "by Friday" → dueDate)
|
||||
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
||||
|
||||
### Step 4 — Produce the AgentConfig JSON
|
||||
Once you are ≥ 90% confident, output the final config between these exact markers
|
||||
(each on its own line):
|
||||
|
||||
{config_start}
|
||||
{{
|
||||
"content_types": [
|
||||
{{
|
||||
"id": "email_html",
|
||||
"label": "Email HTML",
|
||||
"detection_hint": "HTML file with From/To/Subject headers",
|
||||
"preprocessing": "email_html",
|
||||
"extraction_prompt": "Detailed extraction instructions for this content type..."
|
||||
}}
|
||||
],
|
||||
"global_rules": [
|
||||
"If the file cannot be matched to any project, do not create any entity."
|
||||
],
|
||||
"data_types": {data_types_json}
|
||||
}}
|
||||
{config_end}
|
||||
|
||||
## Rules for the extraction_prompt field
|
||||
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
|
||||
- Include field mapping rules based on what you found in the directory
|
||||
- Include priority/status/date rules if applicable
|
||||
- Do NOT include projectId logic — the runner handles project assignment automatically
|
||||
- Do NOT mention isAiSuggested — the runner always sets it to 1
|
||||
|
||||
## Constraints
|
||||
- Never ask about projects, projectId, or how to link records to projects
|
||||
- Never include projectId or project creation logic in the generated config
|
||||
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
|
||||
|
||||
{existing_section}\
|
||||
Begin by exploring the directory, then ask your first question.\
|
||||
"""
|
||||
|
||||
|
||||
def _build_system_prompt(
|
||||
directory: str,
|
||||
data_types: list[str],
|
||||
existing_config: str | None = None,
|
||||
) -> tuple[str, Any]:
|
||||
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
||||
existing_section = (
|
||||
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
|
||||
f"```json\n{existing_config}\n```\n"
|
||||
if existing_config
|
||||
else ""
|
||||
)
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"journey_system", _JOURNEY_SYSTEM_PROMPT
|
||||
)
|
||||
compiled = compile_prompt(
|
||||
template,
|
||||
prompt_obj,
|
||||
directory=directory,
|
||||
data_types=", ".join(data_types),
|
||||
data_types_json=json.dumps(data_types),
|
||||
config_start=_CONFIG_START,
|
||||
config_end=_CONFIG_END,
|
||||
existing_section=existing_section,
|
||||
)
|
||||
return compiled, prompt_obj
|
||||
|
||||
|
||||
# ── AgentConfig extraction ────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _extract_agent_config(text: str) -> str | None:
|
||||
"""Return validated AgentConfig JSON string from between markers, or None.
|
||||
|
||||
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
||||
returning. Returns None if markers are absent or JSON is invalid.
|
||||
"""
|
||||
if _CONFIG_START not in text or _CONFIG_END not in text:
|
||||
return None
|
||||
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
|
||||
end_idx = text.index(_CONFIG_END)
|
||||
raw = text[start_idx:end_idx].strip()
|
||||
if not raw:
|
||||
return None
|
||||
try:
|
||||
parsed = AgentConfig.model_validate_json(raw)
|
||||
return parsed.model_dump_json()
|
||||
except Exception as exc:
|
||||
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
# ── LLM call with tool support ───────────────────────────────────────────
|
||||
|
||||
|
||||
def _as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
async def _call_llm_with_tools(
|
||||
system_prompt: str,
|
||||
history: list[dict[str, Any]],
|
||||
tools: list[Any],
|
||||
*,
|
||||
user_id: str = "",
|
||||
session_id: str = "",
|
||||
langfuse_prompt: Any = None,
|
||||
) -> str:
|
||||
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||
|
||||
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||
continue until a final text response is produced.
|
||||
"""
|
||||
lf = get_langfuse()
|
||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||
for turn in history:
|
||||
if turn["role"] == "user":
|
||||
messages.append(HumanMessage(content=turn["content"]))
|
||||
else:
|
||||
messages.append(AIMessage(content=turn["content"]))
|
||||
|
||||
llm = get_agent_llm("setup", temperature=0.4)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
|
||||
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
||||
_lf_ctx.__enter__()
|
||||
|
||||
_span_ctx = (
|
||||
lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="journey-setup",
|
||||
input=history[-1]["content"] if history else "",
|
||||
)
|
||||
if lf else None
|
||||
)
|
||||
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||
|
||||
try:
|
||||
for step in range(_MAX_TOOL_STEPS):
|
||||
_gen_ctx = (
|
||||
lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="journey-setup-llm",
|
||||
model=model_for_agent("setup"),
|
||||
prompt=langfuse_prompt,
|
||||
input=messages,
|
||||
)
|
||||
if lf else None
|
||||
)
|
||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
if _gen_ctx:
|
||||
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||
_gen_ctx.__exit__(None, None, None)
|
||||
|
||||
resp_text = _as_text(response.content)
|
||||
|
||||
# Guard against empty responses (e.g. model returned finish_reason
|
||||
# 'error' which LiteLLM maps to 'stop' with empty content).
|
||||
if not response.tool_calls and not resp_text.strip():
|
||||
logger.warning(
|
||||
"agent_setup: journey LLM returned empty response at step %d — retrying",
|
||||
step,
|
||||
)
|
||||
# Drop the empty AIMessage so we don't pollute history, and retry.
|
||||
continue
|
||||
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
if _span:
|
||||
_span.update(output=resp_text)
|
||||
return resp_text
|
||||
|
||||
for call in response.tool_calls:
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
"agent_setup: journey tool_call name=%s args=%s",
|
||||
call_name,
|
||||
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||
)
|
||||
|
||||
tool_fn = tool_map.get(call_name)
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call_name}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call_args)
|
||||
|
||||
logger.info(
|
||||
"agent_setup: journey tool_result name=%s output=%s",
|
||||
call_name,
|
||||
str(tool_output)[:800],
|
||||
)
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
# Fallback: exceeded max steps.
|
||||
final = await llm.ainvoke(messages)
|
||||
final_text = _as_text(final.content)
|
||||
if _span:
|
||||
_span.update(output=final_text)
|
||||
return final_text or (
|
||||
"Sorry, I had trouble processing the files. "
|
||||
"Could you try again? If the issue persists, the files might be too large for me to analyse."
|
||||
)
|
||||
finally:
|
||||
if _span_ctx:
|
||||
_span_ctx.__exit__(None, None, None)
|
||||
_lf_ctx.__exit__(None, None, None)
|
||||
if lf:
|
||||
lf.flush()
|
||||
|
||||
|
||||
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
||||
|
||||
|
||||
async def handle_journey_start(
|
||||
user_id: str,
|
||||
frame: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Handle a ``journey_start`` WS frame.
|
||||
|
||||
Creates a session, runs the setup LLM with directory exploration,
|
||||
and returns the ``journey_reply`` payload.
|
||||
"""
|
||||
agent_type = frame.get("agent_type", "local")
|
||||
directory = frame.get("directory", "")
|
||||
data_types = frame.get("data_types", [])
|
||||
existing_config = frame.get("existing_config")
|
||||
|
||||
# Use the session_id provided by the FE so the reply matches the
|
||||
# listener key; fall back to a generated one if absent.
|
||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
|
||||
|
||||
session = JourneySession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
agent_type=agent_type,
|
||||
directory=directory,
|
||||
data_types=data_types,
|
||||
system_prompt=system_prompt,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
)
|
||||
|
||||
# Seed with an initial user message — some providers require at least one
|
||||
# user/input message to be present.
|
||||
seed_history: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||
]
|
||||
ai_reply = await _call_llm_with_tools(
|
||||
system_prompt=system_prompt,
|
||||
history=seed_history,
|
||||
tools=make_directory_tools(directory),
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
)
|
||||
|
||||
session.history.extend(seed_history)
|
||||
session.history.append({"role": "assistant", "content": ai_reply})
|
||||
_sessions[session_id] = session
|
||||
|
||||
logger.info(
|
||||
"agent_setup: journey session %s started for user %s (directory=%s)",
|
||||
session_id,
|
||||
user_id,
|
||||
directory,
|
||||
)
|
||||
|
||||
# Check if the LLM produced the config on the first turn (unlikely but possible).
|
||||
agent_config = _extract_agent_config(ai_reply)
|
||||
done = agent_config is not None
|
||||
|
||||
display_message = ai_reply
|
||||
if done:
|
||||
display_message = (
|
||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||
or "Here is your agent configuration. You can save it or continue refining."
|
||||
)
|
||||
_sessions.pop(session_id, None)
|
||||
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": display_message,
|
||||
"done": done,
|
||||
"agent_config": agent_config,
|
||||
}
|
||||
|
||||
|
||||
async def handle_journey_message(
|
||||
user_id: str,
|
||||
frame: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Handle a ``journey_message`` WS frame.
|
||||
|
||||
Appends the user message, calls the LLM, and returns the
|
||||
``journey_reply`` payload.
|
||||
"""
|
||||
session_id = frame.get("session_id", "")
|
||||
message = frame.get("message", "")
|
||||
|
||||
session = get_journey_session(session_id, user_id)
|
||||
if session is None:
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": "Journey session not found or expired. Please start a new setup.",
|
||||
"done": True,
|
||||
"agent_config": None,
|
||||
}
|
||||
|
||||
# Append user turn.
|
||||
session.history.append({"role": "user", "content": message})
|
||||
|
||||
# Call the LLM with tools.
|
||||
session_tools = make_directory_tools(session.directory)
|
||||
ai_reply = await _call_llm_with_tools(
|
||||
system_prompt=session.system_prompt,
|
||||
history=session.history,
|
||||
tools=session_tools,
|
||||
user_id=session.user_id,
|
||||
session_id=session_id,
|
||||
langfuse_prompt=session.langfuse_prompt,
|
||||
)
|
||||
|
||||
session.history.append({"role": "assistant", "content": ai_reply})
|
||||
|
||||
# Check if the LLM produced the final config.
|
||||
agent_config = _extract_agent_config(ai_reply)
|
||||
done = agent_config is not None
|
||||
|
||||
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
|
||||
if not done:
|
||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||
if turns >= _MAX_TURNS:
|
||||
nudge_content = (
|
||||
"[System: You have enough information. Please generate the final "
|
||||
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
||||
)
|
||||
session.history.append({"role": "user", "content": nudge_content})
|
||||
|
||||
nudge_reply = await _call_llm_with_tools(
|
||||
system_prompt=session.system_prompt,
|
||||
history=session.history,
|
||||
tools=session_tools,
|
||||
user_id=session.user_id,
|
||||
session_id=session_id,
|
||||
langfuse_prompt=session.langfuse_prompt,
|
||||
)
|
||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||
|
||||
agent_config = _extract_agent_config(nudge_reply)
|
||||
if agent_config is not None:
|
||||
done = True
|
||||
ai_reply = nudge_reply
|
||||
|
||||
display_message = ai_reply
|
||||
if done:
|
||||
display_message = (
|
||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||
if _CONFIG_START in ai_reply
|
||||
else "Here is your agent configuration. You can save it or continue refining."
|
||||
)
|
||||
_sessions.pop(session_id, None)
|
||||
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
||||
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": display_message,
|
||||
"done": done,
|
||||
"agent_config": agent_config,
|
||||
}
|
||||
@@ -1,257 +0,0 @@
|
||||
"""Agent routes.
|
||||
|
||||
Backend responsibilities are intentionally minimal:
|
||||
GET /agents/catalog — static catalog for UI display
|
||||
POST /agents/can-create — billing eligibility check
|
||||
POST /agents/trigger — trigger a local agent run
|
||||
|
||||
Agent configuration is owned by the Electron app and is not persisted
|
||||
in backend agent-config tables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import FEATURES
|
||||
from app.core.agent_runner import is_agent_running, run_local_agent
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.note_summarizer import generate_note_summary
|
||||
from app.db import get_session
|
||||
from app.models import AgentRunLog, LocalAgentConfig
|
||||
from app.schemas import (
|
||||
AgentCatalogItem,
|
||||
AgentCreationCheckRequest,
|
||||
AgentCreationCheckResponse,
|
||||
AgentRunLogResponse,
|
||||
AgentTriggerRequest,
|
||||
UserProfile,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
|
||||
|
||||
# ── Datetime helpers ──────────────────────────────────────────────────
|
||||
|
||||
def _dt_ms(dt: datetime) -> int:
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
|
||||
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
def _to_data_types(values: list[str]) -> list[str]:
|
||||
normalize = {
|
||||
"task": "tasks", "tasks": "tasks",
|
||||
"note": "notes", "notes": "notes",
|
||||
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||
"project": "projects", "projects": "projects",
|
||||
}
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for v in values:
|
||||
mapped = normalize.get(v)
|
||||
if mapped and mapped not in seen:
|
||||
seen.add(mapped)
|
||||
result.append(mapped)
|
||||
return result
|
||||
|
||||
|
||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||
return AgentRunLogResponse(
|
||||
id=log.id,
|
||||
agent_id=log.agent_id,
|
||||
agent_type=log.agent_type, # type: ignore[arg-type]
|
||||
status=log.status, # type: ignore[arg-type]
|
||||
items_processed=log.items_processed,
|
||||
items_created=log.items_created,
|
||||
errors=log.errors or [],
|
||||
started_at=_dt_ms(log.started_at),
|
||||
completed_at=_dt_ms_opt(log.completed_at),
|
||||
)
|
||||
|
||||
|
||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||
if limit != -1 and current_count >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||
)
|
||||
return limit
|
||||
|
||||
|
||||
async def _enforce_run_frequency(
|
||||
tier: str,
|
||||
user_id: str,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||
if limit == -1:
|
||||
return # unlimited
|
||||
|
||||
today_start = datetime.now(timezone.utc).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
result = await db.execute(
|
||||
select(func.count(AgentRunLog.id)).where(
|
||||
AgentRunLog.user_id == user_id,
|
||||
AgentRunLog.started_at >= today_start,
|
||||
)
|
||||
)
|
||||
runs_today: int = result.scalar_one()
|
||||
|
||||
if runs_today >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
||||
)
|
||||
|
||||
|
||||
# ── Catalog ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog", response_model=list[AgentCatalogItem])
|
||||
async def get_agent_catalog(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> list[AgentCatalogItem]:
|
||||
"""Return the static list of available agent types and their descriptions."""
|
||||
return [
|
||||
AgentCatalogItem(
|
||||
type="local_directory",
|
||||
name="Local Directory Monitor",
|
||||
description="Watches local directories, extracts data from files using AI",
|
||||
),
|
||||
AgentCatalogItem(
|
||||
type="gmail",
|
||||
name="Gmail Connector",
|
||||
description="Scans Gmail inbox, extracts tasks/notes from emails",
|
||||
),
|
||||
AgentCatalogItem(
|
||||
type="teams",
|
||||
name="Microsoft Teams Connector",
|
||||
description="Monitors Teams messages, extracts action items",
|
||||
),
|
||||
AgentCatalogItem(
|
||||
type="outlook",
|
||||
name="Outlook Connector",
|
||||
description="Scans Outlook inbox, extracts tasks/notes",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
||||
async def can_create_agent(
|
||||
body: AgentCreationCheckRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> AgentCreationCheckResponse:
|
||||
"""Check if the user can create one more agent based on billing tier.
|
||||
|
||||
Since configuration is client-owned, the Electron app sends its current
|
||||
active agent count and the backend applies tier limits.
|
||||
"""
|
||||
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||
allowed = limit == -1 or body.active_agents < limit
|
||||
return AgentCreationCheckResponse(
|
||||
allowed=allowed,
|
||||
tier=current_user.tier,
|
||||
active_agents=body.active_agents,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||
async def trigger_agent_run(
|
||||
body: AgentTriggerRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AgentRunLogResponse:
|
||||
"""Trigger a local agent run using client-provided configuration."""
|
||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||
|
||||
last_run_dt = (
|
||||
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
|
||||
if body.last_run_at
|
||||
else None
|
||||
)
|
||||
config = LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=current_user.id,
|
||||
device_id=body.device_id,
|
||||
name="Local Directory Monitor",
|
||||
directory_paths=[body.directory],
|
||||
data_types=_to_data_types(body.what_to_extract),
|
||||
prompt_template=body.custom_agent_prompt or "",
|
||||
agent_config=body.agent_config,
|
||||
file_extensions=[],
|
||||
schedule_cron=body.batch_interval,
|
||||
enabled=True,
|
||||
last_run_at=last_run_dt,
|
||||
)
|
||||
|
||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||
stable_agent_id = body.agent_id or config.id
|
||||
|
||||
if is_agent_running(stable_agent_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
||||
)
|
||||
|
||||
run_log = AgentRunLog(
|
||||
agent_id=stable_agent_id,
|
||||
agent_type="local",
|
||||
user_id=current_user.id,
|
||||
status="running",
|
||||
)
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
|
||||
run_context = {
|
||||
"type": "agent_batch",
|
||||
"run_id": run_log.id,
|
||||
"agent_id": stable_agent_id,
|
||||
}
|
||||
|
||||
asyncio.create_task(
|
||||
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
||||
)
|
||||
|
||||
return _to_run_log_response(run_log)
|
||||
|
||||
|
||||
# ── Note summary endpoint ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class NoteSummarizeRequest(BaseModel):
|
||||
title: str
|
||||
content: str
|
||||
|
||||
|
||||
class NoteSummarizeResponse(BaseModel):
|
||||
summary: str
|
||||
|
||||
|
||||
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
|
||||
async def summarize_note(
|
||||
body: NoteSummarizeRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> NoteSummarizeResponse:
|
||||
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
|
||||
summary = await generate_note_summary(body.title, body.content)
|
||||
return NoteSummarizeResponse(summary=summary)
|
||||
@@ -1,795 +0,0 @@
|
||||
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
|
||||
|
||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||
SHA-256 hashes so plaintext never reaches the DB.
|
||||
|
||||
OAuth (Google):
|
||||
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
|
||||
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Literal
|
||||
|
||||
import bcrypt
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
|
||||
from app.config.settings import settings
|
||||
from app.core.llm import get_llm
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.models import OAuthAccount, RefreshToken, User
|
||||
from app.schemas import AuthTokens, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── OAuth provider registry ───────────────────────────────────────────
|
||||
|
||||
def _get_google_provider() -> GoogleOAuthProvider:
|
||||
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
"Google login is not configured on this server",
|
||||
)
|
||||
return GoogleOAuthProvider(
|
||||
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
|
||||
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||
)
|
||||
|
||||
|
||||
_PROVIDERS = {"google": _get_google_provider}
|
||||
|
||||
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
|
||||
# Production note: replace with Redis for multi-process deployments.
|
||||
_pending_states: dict[str, tuple[str, float]] = {}
|
||||
_STATE_TTL_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
|
||||
|
||||
def _hash_token(plain_token: str) -> str:
|
||||
"""SHA-256 of the plain refresh token string."""
|
||||
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||
|
||||
|
||||
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||
"""Return (signed JWT, expires_at_ms)."""
|
||||
now = int(time.time())
|
||||
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": exp,
|
||||
"iat": now,
|
||||
}
|
||||
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||
return token, exp * 1000 # ms for client
|
||||
|
||||
|
||||
# ── Request bodies ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
|
||||
|
||||
class _LoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class _RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
body: _RegisterRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Create a new account and return JWT tokens."""
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=body.email,
|
||||
name=body.name,
|
||||
surname=body.surname,
|
||||
password_hash=_hash_password(body.password),
|
||||
tier="free",
|
||||
encryption_key=Fernet.generate_key().decode(),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush() # get user.id without committing
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokens)
|
||||
async def login(
|
||||
body: _LoginRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate credentials and return JWT tokens."""
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not _verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokens)
|
||||
async def refresh(
|
||||
body: _RefreshRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Rotate a refresh token and return a new token pair."""
|
||||
token_hash = _hash_token(body.refresh_token)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
rt = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||
|
||||
# Rotate: delete old token, issue new one.
|
||||
await db.delete(rt)
|
||||
|
||||
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
new_rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=new_expires,
|
||||
)
|
||||
db.add(new_rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
class _UpdateProfileRequest(BaseModel):
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||
"""Return the profile for the authenticated user."""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.put("/me", response_model=UserProfile)
|
||||
async def update_profile(
|
||||
body: _UpdateProfileRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update the authenticated user's name and surname."""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
if body.name is not None:
|
||||
user.name = body.name
|
||||
if body.surname is not None:
|
||||
user.surname = body.surname
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return UserProfile(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
surname=user.surname,
|
||||
avatar_url=user.avatar_url,
|
||||
tier=current_user.tier,
|
||||
)
|
||||
|
||||
|
||||
# ── OAuth helpers ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
|
||||
"""Create a refresh token row and return (plain_token, AuthTokens)."""
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return plain_token, AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
# ── OAuth request/response schemas ───────────────────────────────────
|
||||
|
||||
|
||||
class _OAuthAuthorizeResponse(BaseModel):
|
||||
url: str
|
||||
state: str
|
||||
|
||||
|
||||
class _OAuthCallbackRequest(BaseModel):
|
||||
code: str
|
||||
state: str
|
||||
|
||||
|
||||
# ── OAuth routes ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get(
|
||||
"/oauth/{provider}/web-callback",
|
||||
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def oauth_web_callback(
|
||||
provider: Literal["google"],
|
||||
code: str,
|
||||
state: str,
|
||||
) -> RedirectResponse:
|
||||
"""Google redirects here after user consent.
|
||||
|
||||
This endpoint immediately redirects to the Electron deep-link URI so the
|
||||
desktop app receives the authorization code. It is intentionally simple —
|
||||
no state validation here (the Electron app + backend callback do that).
|
||||
|
||||
Registered in Google Cloud Console as:
|
||||
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
|
||||
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
|
||||
"""
|
||||
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
|
||||
deep_link = f"adiuvai://oauth/callback?{params}"
|
||||
return RedirectResponse(url=deep_link, status_code=302)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/oauth/{provider}/authorize",
|
||||
response_model=_OAuthAuthorizeResponse,
|
||||
summary="Start OAuth flow — returns the provider consent-screen URL",
|
||||
)
|
||||
async def oauth_authorize(
|
||||
provider: Literal["google"],
|
||||
) -> _OAuthAuthorizeResponse:
|
||||
"""Generate a PKCE state + code_challenge and return the authorization URL.
|
||||
|
||||
The client opens this URL in the system browser. After the user grants
|
||||
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
|
||||
with ``code`` and ``state`` query params. The client then calls
|
||||
``POST /auth/oauth/{provider}/callback`` with those values.
|
||||
"""
|
||||
provider_factory = _PROVIDERS.get(provider)
|
||||
if provider_factory is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||
|
||||
oauth_provider = provider_factory()
|
||||
state = str(uuid.uuid4())
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
|
||||
# Purge expired states to prevent unbounded growth.
|
||||
now = time.time()
|
||||
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
|
||||
for s in expired:
|
||||
del _pending_states[s]
|
||||
|
||||
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
|
||||
|
||||
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
|
||||
return _OAuthAuthorizeResponse(url=url, state=state)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/oauth/{provider}/callback",
|
||||
response_model=AuthTokens,
|
||||
summary="Complete OAuth flow — exchange code and issue JWT tokens",
|
||||
)
|
||||
async def oauth_callback(
|
||||
provider: Literal["google"],
|
||||
body: _OAuthCallbackRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate state, exchange the authorization code, and sign in (or register) the user.
|
||||
|
||||
Resolution order:
|
||||
1. ``oauth_accounts`` row match → existing user, log in.
|
||||
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
|
||||
3. No match → create new user (password_hash=None, avatar from provider).
|
||||
"""
|
||||
provider_factory = _PROVIDERS.get(provider)
|
||||
if provider_factory is None:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||
|
||||
# Validate state (CSRF protection).
|
||||
now = time.time()
|
||||
entry = _pending_states.pop(body.state, None)
|
||||
if entry is None or entry[1] < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
||||
|
||||
code_verifier, _ = entry
|
||||
|
||||
oauth_provider = provider_factory()
|
||||
|
||||
# Exchange code for tokens.
|
||||
try:
|
||||
token_data = await oauth_provider.exchange_code(
|
||||
code=body.code,
|
||||
code_verifier=code_verifier,
|
||||
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||
)
|
||||
except Exception:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
|
||||
)
|
||||
|
||||
access_token_google = token_data.get("access_token")
|
||||
if not access_token_google:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
|
||||
|
||||
# Fetch user identity.
|
||||
try:
|
||||
userinfo = await oauth_provider.get_userinfo(access_token_google)
|
||||
except Exception:
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
|
||||
|
||||
# ── Resolution order ──────────────────────────────────────────────
|
||||
|
||||
# 1. Existing OAuth link?
|
||||
oauth_result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
OAuthAccount.provider == provider,
|
||||
OAuthAccount.provider_user_id == userinfo.provider_user_id,
|
||||
)
|
||||
)
|
||||
oauth_account = oauth_result.scalar_one_or_none()
|
||||
|
||||
if oauth_account is not None:
|
||||
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
|
||||
user = user_result.scalar_one()
|
||||
# Backfill avatar if the user doesn't have one yet.
|
||||
if user.avatar_url is None and userinfo.avatar_url:
|
||||
user.avatar_url = userinfo.avatar_url
|
||||
await db.commit()
|
||||
plain_token, tokens = await _issue_refresh_token(user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
# 2. Email match with a verified Google email → link accounts.
|
||||
if userinfo.email_verified:
|
||||
email_result = await db.execute(select(User).where(User.email == userinfo.email))
|
||||
existing_user = email_result.scalar_one_or_none()
|
||||
|
||||
if existing_user is not None:
|
||||
new_link = OAuthAccount(
|
||||
user_id=existing_user.id,
|
||||
provider=provider,
|
||||
provider_user_id=userinfo.provider_user_id,
|
||||
provider_email=userinfo.email,
|
||||
)
|
||||
db.add(new_link)
|
||||
if existing_user.avatar_url is None and userinfo.avatar_url:
|
||||
existing_user.avatar_url = userinfo.avatar_url
|
||||
plain_token, tokens = await _issue_refresh_token(existing_user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
# Guard: if the email is already taken but we couldn't auto-link (e.g.
|
||||
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
|
||||
if not userinfo.email_verified:
|
||||
conflict = await db.execute(select(User).where(User.email == userinfo.email))
|
||||
if conflict.scalar_one_or_none() is not None:
|
||||
raise HTTPException(
|
||||
status.HTTP_409_CONFLICT,
|
||||
"An account with this email already exists. "
|
||||
"Please sign in with your password.",
|
||||
)
|
||||
|
||||
# 3. New user — social-only account (no password).
|
||||
new_user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=userinfo.email,
|
||||
name=userinfo.name,
|
||||
password_hash=None,
|
||||
avatar_url=userinfo.avatar_url,
|
||||
tier="free",
|
||||
encryption_key=Fernet.generate_key().decode(),
|
||||
)
|
||||
db.add(new_user)
|
||||
await db.flush() # populate new_user.id
|
||||
|
||||
new_oauth = OAuthAccount(
|
||||
user_id=new_user.id,
|
||||
provider=provider,
|
||||
provider_user_id=userinfo.provider_user_id,
|
||||
provider_email=userinfo.email,
|
||||
)
|
||||
db.add(new_oauth)
|
||||
|
||||
plain_token, tokens = await _issue_refresh_token(new_user, db)
|
||||
await db.commit()
|
||||
return tokens
|
||||
|
||||
|
||||
# ── Onboarding helpers ────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
|
||||
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
|
||||
|
||||
# We can't call the FastAPI dependency directly, but we can replicate
|
||||
# the core logic inline. Instead, we just re-query the same way.
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
user_result = await db.execute(
|
||||
select(
|
||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||
User.password_hash,
|
||||
).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
onboarding_ms: int | None = None
|
||||
if user_row and user_row.onboarding_completed_at is not None:
|
||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||
|
||||
memory_dict: dict[str, str] = {}
|
||||
try:
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(user_id)
|
||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
avatar_url=user_row.avatar_url if user_row else None,
|
||||
has_password=bool(user_row.password_hash) if user_row else False,
|
||||
tier=tier,
|
||||
onboarding_completed_at=onboarding_ms,
|
||||
memory=memory_dict,
|
||||
)
|
||||
|
||||
|
||||
# ── Onboarding routes ────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _UpdateMemoryRequest(BaseModel):
|
||||
memory: dict[str, str] = Field(default_factory=dict)
|
||||
mark_onboarded: bool = False
|
||||
|
||||
|
||||
@router.put("/me/memory", response_model=UserProfile)
|
||||
async def update_memory(
|
||||
body: _UpdateMemoryRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update core memory key/value pairs and optionally mark onboarding complete."""
|
||||
mw = MemoryMiddleware(db)
|
||||
for key, value in body.memory.items():
|
||||
await mw.update_core(current_user.id, key, value)
|
||||
if body.mark_onboarded:
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.onboarding_completed_at = datetime.now(timezone.utc)
|
||||
await db.commit()
|
||||
return await _build_profile(current_user.id, current_user.email, db)
|
||||
|
||||
|
||||
@router.post("/me/onboarding/reset")
|
||||
async def reset_onboarding(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
):
|
||||
"""Reset onboarding so the wizard runs again on next login."""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.onboarding_completed_at = None
|
||||
await db.commit()
|
||||
return {"status": "reset"}
|
||||
|
||||
|
||||
class _NormalizeRequest(BaseModel):
|
||||
inputs: dict[str, str]
|
||||
|
||||
|
||||
class _NormalizeResponse(BaseModel):
|
||||
normalized: dict[str, str]
|
||||
|
||||
|
||||
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
|
||||
async def normalize_onboarding(
|
||||
body: _NormalizeRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _NormalizeResponse:
|
||||
"""One-shot LLM normalization for free-text onboarding answers."""
|
||||
if not body.inputs:
|
||||
return _NormalizeResponse(normalized={})
|
||||
try:
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||
prompt = (
|
||||
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
|
||||
"Return a JSON object with the same keys and normalized values.\n"
|
||||
"Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n"
|
||||
f"Input: {json.dumps(body.inputs)}"
|
||||
)
|
||||
response = await llm.ainvoke(
|
||||
[
|
||||
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
)
|
||||
normalized = json.loads(response.content)
|
||||
return _NormalizeResponse(normalized=normalized)
|
||||
except Exception:
|
||||
# LLM failure must never block onboarding — return inputs unchanged
|
||||
return _NormalizeResponse(normalized=body.inputs)
|
||||
|
||||
|
||||
# ── Password management ───────────────────────────────────────────────
|
||||
|
||||
|
||||
class _ChangePasswordRequest(BaseModel):
|
||||
current_password: str = Field(min_length=1)
|
||||
new_password: str = Field(min_length=8)
|
||||
|
||||
|
||||
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
||||
async def change_password(
|
||||
body: _ChangePasswordRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Change the authenticated user's password.
|
||||
|
||||
Requires the current password for verification.
|
||||
Returns 400 for social-only users (no password set).
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
if user.password_hash is None:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
"This account uses social login and has no password to change",
|
||||
)
|
||||
|
||||
if not _verify_password(body.current_password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
||||
|
||||
user.password_hash = _hash_password(body.new_password)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── OAuth account management ─────────────────────────────────────────
|
||||
|
||||
|
||||
@router.get("/me/oauth-accounts", response_model=list[dict])
|
||||
async def list_oauth_accounts(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[dict]:
|
||||
"""List all OAuth providers linked to the authenticated user."""
|
||||
result = await db.execute(
|
||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||
)
|
||||
accounts = result.scalars().all()
|
||||
return [
|
||||
{
|
||||
"provider": a.provider,
|
||||
"provider_email": a.provider_email,
|
||||
"created_at": int(a.created_at.timestamp() * 1000),
|
||||
}
|
||||
for a in accounts
|
||||
]
|
||||
|
||||
|
||||
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
||||
async def unlink_oauth_account(
|
||||
provider: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Unlink an OAuth provider from the authenticated user.
|
||||
|
||||
Refuses if the user has no password and this is their only login method.
|
||||
"""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
oauth_result = await db.execute(
|
||||
select(OAuthAccount).where(
|
||||
OAuthAccount.user_id == current_user.id,
|
||||
OAuthAccount.provider == provider,
|
||||
)
|
||||
)
|
||||
account = oauth_result.scalar_one_or_none()
|
||||
if account is None:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
||||
|
||||
# Safety: don't let users lock themselves out.
|
||||
all_oauth = await db.execute(
|
||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||
)
|
||||
oauth_count = len(all_oauth.scalars().all())
|
||||
|
||||
if user.password_hash is None and oauth_count <= 1:
|
||||
raise HTTPException(
|
||||
status.HTTP_400_BAD_REQUEST,
|
||||
"Cannot unlink the only login method. Set a password first.",
|
||||
)
|
||||
|
||||
await db.delete(account)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Avatar update ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _UpdateAvatarRequest(BaseModel):
|
||||
avatar_url: str = Field(min_length=1)
|
||||
|
||||
|
||||
@router.put("/me/avatar", response_model=UserProfile)
|
||||
async def update_avatar(
|
||||
body: _UpdateAvatarRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update the authenticated user's avatar URL.
|
||||
|
||||
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
||||
to its own storage and passes the resulting URL here.
|
||||
"""
|
||||
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
||||
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
user.avatar_url = body.avatar_url
|
||||
await db.commit()
|
||||
|
||||
return await _build_profile(current_user.id, current_user.email, db)
|
||||
|
||||
|
||||
# ── Account deletion ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.delete("/me", status_code=status.HTTP_200_OK)
|
||||
async def delete_account(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Permanently delete the authenticated user's account.
|
||||
|
||||
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
||||
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
||||
is cancelled if active.
|
||||
"""
|
||||
# Cancel Stripe subscription if present.
|
||||
try:
|
||||
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
except HTTPException:
|
||||
pass # No subscription — that's fine
|
||||
|
||||
# Delete all memory rows (core, associative, episodic, proactive).
|
||||
try:
|
||||
from app.models import ( # noqa: PLC0415
|
||||
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
||||
)
|
||||
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
||||
await db.execute(
|
||||
model.__table__.delete().where(model.user_id == current_user.id)
|
||||
)
|
||||
except Exception:
|
||||
pass # Non-critical — cascade on User will handle most
|
||||
|
||||
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
await db.delete(user)
|
||||
await db.commit()
|
||||
|
||||
return {"ok": True}
|
||||
@@ -1,132 +0,0 @@
|
||||
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||
|
||||
Business logic lives in ``app.billing.stripe_service.StripeService``.
|
||||
The route layer handles HTTP concerns (request parsing, response shaping)
|
||||
and delegates everything else to the service singleton.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.db import get_session
|
||||
from app.schemas import BillingTier, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||
|
||||
|
||||
# ── Request bodies ─────────────────────────────────────────────────────
|
||||
|
||||
class _CheckoutRequest(BaseModel):
|
||||
tier: BillingTier
|
||||
|
||||
|
||||
# ── Routes ─────────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/checkout", response_model=dict)
|
||||
async def create_checkout(
|
||||
body: _CheckoutRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict[str, str]:
|
||||
"""Create a Stripe checkout session for a tier upgrade.
|
||||
|
||||
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||
"""
|
||||
url = stripe_service.create_checkout_session(current_user.id, body.tier)
|
||||
return {"checkout_url": url}
|
||||
|
||||
|
||||
@router.post("/webhook", response_model=dict)
|
||||
async def stripe_webhook(
|
||||
request: Request,
|
||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Handle Stripe webhook events.
|
||||
|
||||
No JWT auth — authenticated via Stripe signature verification instead.
|
||||
Returns 200 immediately when Stripe is not configured (local dev).
|
||||
"""
|
||||
payload = await request.body()
|
||||
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/subscription", response_model=dict)
|
||||
async def get_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, Any]:
|
||||
"""Return the current subscription info for the authenticated user."""
|
||||
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||
if sub is None:
|
||||
return {
|
||||
"tier": current_user.tier,
|
||||
"status": "free",
|
||||
"stripe_subscription_id": None,
|
||||
"current_period_end": None,
|
||||
}
|
||||
return sub
|
||||
|
||||
|
||||
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||
async def cancel_subscription(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Cancel the active subscription."""
|
||||
await stripe_service.cancel_subscription(current_user.id, db)
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
@router.get("/invoices", response_model=list[dict])
|
||||
async def list_invoices(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Return billing history (invoices) from Stripe.
|
||||
|
||||
Returns an empty list when Stripe is not configured.
|
||||
"""
|
||||
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||
return invoices
|
||||
|
||||
|
||||
# ── Quota check ────────────────────────────────────────────────────────
|
||||
|
||||
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
|
||||
|
||||
|
||||
class QuotaCheckRequest(BaseModel):
|
||||
feature: str
|
||||
estimated_files: int
|
||||
|
||||
|
||||
@router.post("/quota/check")
|
||||
async def quota_check(
|
||||
payload: QuotaCheckRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict:
|
||||
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
|
||||
if payload.feature != "folder_index":
|
||||
raise HTTPException(status_code=400, detail="Unknown feature")
|
||||
try:
|
||||
await check_folder_quota(
|
||||
user_id=current_user.id,
|
||||
tier=current_user.tier,
|
||||
estimated_files=payload.estimated_files,
|
||||
db=db,
|
||||
)
|
||||
except QuotaExceeded as exc:
|
||||
raise HTTPException(
|
||||
status_code=402,
|
||||
detail={"reason": exc.reason, "message": str(exc)},
|
||||
)
|
||||
return {"ok": True}
|
||||
@@ -1,116 +0,0 @@
|
||||
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
|
||||
|
||||
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
from app.core.deep_agent import run_home
|
||||
from app.core.llm import embed
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import async_session
|
||||
from app.schemas import ChatRequest, UserProfile
|
||||
|
||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||
|
||||
|
||||
# ── Embed helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _EmbedRequest(BaseModel):
|
||||
text: str
|
||||
|
||||
|
||||
class _EmbedResponse(BaseModel):
|
||||
vector: list[float]
|
||||
|
||||
|
||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def chat(
|
||||
body: ChatRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> JSONResponse:
|
||||
"""REST fallback for home chat when websocket streaming is unavailable."""
|
||||
response = await run_home(
|
||||
user_id=current_user.id,
|
||||
message=body.message,
|
||||
context=body.context.model_dump(),
|
||||
)
|
||||
return JSONResponse(content={"response": response})
|
||||
|
||||
|
||||
class _BriefRequest(BaseModel):
|
||||
mode: Literal["home", "project"]
|
||||
project_id: str | None = None
|
||||
|
||||
|
||||
class _BriefResponse(BaseModel):
|
||||
response: str
|
||||
|
||||
|
||||
@router.post("/brief", response_model=_BriefResponse)
|
||||
async def brief(
|
||||
body: _BriefRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _BriefResponse:
|
||||
"""REST fallback for brief when the device WebSocket is not ready."""
|
||||
if body.mode == "project":
|
||||
if not body.project_id:
|
||||
raise HTTPException(status_code=422, detail="project_id required for project mode")
|
||||
try:
|
||||
uuid.UUID(body.project_id)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="project_id must be a valid UUID")
|
||||
|
||||
request_id = str(uuid.uuid4())
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
current_user.id,
|
||||
"",
|
||||
trace_id=request_id,
|
||||
session_id=request_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "user_id": current_user.id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
chunks: list[str] = []
|
||||
if body.mode == "project":
|
||||
stream = run_project_brief(current_user.id, body.project_id, context) # type: ignore[arg-type]
|
||||
else:
|
||||
stream = run_home_brief(current_user.id, context)
|
||||
|
||||
async for event_type, data in stream:
|
||||
if event_type == "token" and data:
|
||||
chunks.append(str(data))
|
||||
|
||||
return _BriefResponse(response="".join(chunks))
|
||||
|
||||
|
||||
@router.post("/embed", response_model=_EmbedResponse)
|
||||
async def embed_text(
|
||||
body: _EmbedRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _EmbedResponse:
|
||||
"""Generate a 1536-dim embedding vector for the given text.
|
||||
|
||||
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||
Used by Electron (vectordb.ts) for local note search.
|
||||
"""
|
||||
vector = await embed(body.text)
|
||||
return _EmbedResponse(vector=vector)
|
||||
@@ -1,845 +0,0 @@
|
||||
"""Device WebSocket endpoint.
|
||||
|
||||
Persistent connection from Electron devices to the backend.
|
||||
|
||||
WS /api/v1/ws/device?token=<jwt>
|
||||
|
||||
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
||||
available during the WebSocket handshake).
|
||||
|
||||
Protocol:
|
||||
1. Client connects → JWT validated → connection accepted.
|
||||
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
||||
3. Backend registers the connection in ``DeviceConnectionManager``.
|
||||
4. Session enters message dispatch loop + heartbeat.
|
||||
|
||||
Incoming frame dispatch:
|
||||
- ``tool_result`` → resolves a pending tool-call Future.
|
||||
- ``journey_start`` → starts a guided setup journey session.
|
||||
- ``journey_message`` → continues a journey conversation.
|
||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||
- unknown types → logged, ignored.
|
||||
|
||||
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
||||
|
||||
On disconnect:
|
||||
- Unregisters from DeviceConnectionManager.
|
||||
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
||||
with message "device disconnected".
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from uuid import uuid4
|
||||
|
||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import update
|
||||
|
||||
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_runner import trigger_pending_runs
|
||||
from app.core.agent_session_buffer import session_buffer
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream
|
||||
from app.core.output_formatter import extract_canvas_block
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.core.output_formatter import StreamFormatter
|
||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||
from app.db import async_session
|
||||
from app.models import AgentRunLog
|
||||
from app.schemas import WsFrameType, WsStreamEnd
|
||||
from app.schemas.contextual import ContextualScope, render_scope_block
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||
|
||||
# ── v7 folder index session state ─────────────────────────────────────
|
||||
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
||||
_index_sessions: dict[str, dict] = {}
|
||||
|
||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||
|
||||
|
||||
@router.websocket("/device")
|
||||
async def device_ws(websocket: WebSocket) -> None:
|
||||
"""Persistent WebSocket endpoint for Electron device connections.
|
||||
|
||||
Authentication is via ``?token=<jwt>`` query parameter.
|
||||
"""
|
||||
# ── 1. Authenticate before accepting ─────────────────────────────
|
||||
token = websocket.query_params.get("token", "")
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
if not user_id:
|
||||
raise JWTError("missing sub")
|
||||
except JWTError:
|
||||
await websocket.close(code=1008) # Policy Violation
|
||||
return
|
||||
|
||||
await websocket.accept()
|
||||
|
||||
# ── 2. Await device_hello frame ───────────────────────────────────
|
||||
try:
|
||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
try:
|
||||
hello = json.loads(raw)
|
||||
if hello.get("type") != WsFrameType.device_hello:
|
||||
raise ValueError("expected device_hello as first frame")
|
||||
device_id: str = hello["device_id"]
|
||||
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
||||
await websocket.close(code=1008)
|
||||
return
|
||||
|
||||
# ── 3. Register connection ────────────────────────────────────────
|
||||
device_manager.register(user_id, device_id, websocket)
|
||||
logger.info(
|
||||
"device_ws: connected user=%s device=%s agents=%s",
|
||||
user_id,
|
||||
device_id,
|
||||
agent_ids,
|
||||
)
|
||||
|
||||
# Trigger any overdue agent runs now that the device is connected.
|
||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||
|
||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||
try:
|
||||
await asyncio.gather(
|
||||
_message_loop(websocket, user_id),
|
||||
_heartbeat_loop(websocket),
|
||||
)
|
||||
except WebSocketDisconnect:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
||||
finally:
|
||||
device_manager.unregister(user_id)
|
||||
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
||||
await _mark_runs_disconnected(user_id)
|
||||
|
||||
|
||||
# ── Message dispatch loop ─────────────────────────────────────────────
|
||||
|
||||
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
||||
async for raw in websocket.iter_text():
|
||||
try:
|
||||
frame: dict = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
||||
continue
|
||||
|
||||
frame_type = frame.get("type")
|
||||
|
||||
if frame_type == WsFrameType.tool_result:
|
||||
call_id = frame.get("id")
|
||||
if call_id:
|
||||
device_manager.resolve_pending_call(user_id, call_id, frame)
|
||||
else:
|
||||
logger.warning(
|
||||
"device_ws: tool_result missing id from user=%s", user_id
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.home_request:
|
||||
asyncio.create_task(
|
||||
_handle_home_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.brief_request:
|
||||
asyncio.create_task(
|
||||
_handle_brief_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.task_brief_request:
|
||||
asyncio.create_task(
|
||||
_handle_task_brief_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.journey_start:
|
||||
asyncio.create_task(
|
||||
_handle_journey_start(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.journey_message:
|
||||
asyncio.create_task(
|
||||
_handle_journey_message(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_start:
|
||||
asyncio.create_task(
|
||||
_handle_index_session_start(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_file_batch:
|
||||
asyncio.create_task(
|
||||
_handle_index_file_batch(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.index_session_cancel:
|
||||
await _handle_index_session_cancel(websocket, frame)
|
||||
|
||||
elif frame_type == WsFrameType.contextual_request:
|
||||
asyncio.create_task(
|
||||
_handle_contextual_request(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == WsFrameType.contextual_scope_update:
|
||||
asyncio.create_task(
|
||||
_handle_contextual_scope_update(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
|
||||
else:
|
||||
logger.debug(
|
||||
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
||||
)
|
||||
|
||||
|
||||
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||
|
||||
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||
async def _executor(payload: dict) -> dict:
|
||||
payload["type"] = WsFrameType.tool_call
|
||||
await websocket.send_text(json.dumps(payload))
|
||||
future = device_manager.create_pending_call(user_id, payload["id"])
|
||||
return await future
|
||||
return _executor
|
||||
|
||||
|
||||
async def _handle_home_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||
logger.info(
|
||||
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
project_id,
|
||||
message[:200],
|
||||
)
|
||||
|
||||
# ── Memory: enrich context before LLM call ────────────────────────
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
message,
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"conversation_history": frame.get("conversation_history", []),
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
# Collect text chunks to build the full response for episode storage
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: home_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# ── Memory: store episode after response ──────────────────────────
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||
)
|
||||
logger.info(
|
||||
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
len("".join(response_chunks)),
|
||||
)
|
||||
|
||||
|
||||
# ── v8 Contextual Sidebar Handlers ───────────────────────────────────
|
||||
|
||||
|
||||
def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"):
|
||||
"""Return a session-scoped buffer proxy for the given user+session.
|
||||
|
||||
Returns a _ContextualBufferProxy that exposes append_system_message().
|
||||
Defined at module level so tests can monkeypatch it.
|
||||
The channel kwarg is accepted for forward-compatibility.
|
||||
"""
|
||||
from app.core.agent_session_buffer import ContextualBufferProxy # noqa: PLC0415
|
||||
return ContextualBufferProxy(session_buffer, user_id, session_id)
|
||||
|
||||
|
||||
async def _handle_contextual_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a contextual_request frame — runs the contextual agent and streams frames."""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
message: str = frame.get("message", "")
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope_payload: dict = frame.get("scope", {})
|
||||
logger.info(
|
||||
"device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
message[:200],
|
||||
)
|
||||
|
||||
scope = ContextualScope.model_validate(scope_payload)
|
||||
|
||||
# Enrich context with memory before the LLM call.
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
message,
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"conversation_history": frame.get("conversation_history", []),
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
try:
|
||||
event_stream = run_contextual_stream(
|
||||
user_id=user_id,
|
||||
message=message,
|
||||
context=context,
|
||||
scope=scope,
|
||||
)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: contextual_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# Store episode so the contextual agent can recall prior turns.
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
await memory.store_episode(
|
||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
||||
)
|
||||
logger.info(
|
||||
"device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
|
||||
user_id,
|
||||
request_id,
|
||||
session_id,
|
||||
len("".join(response_chunks)),
|
||||
)
|
||||
|
||||
|
||||
async def _handle_contextual_scope_update(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a contextual_scope_update frame.
|
||||
|
||||
Injects a synthetic system message into the session buffer so the next
|
||||
agent turn knows the user navigated. No LLM call is made.
|
||||
"""
|
||||
session_id: str = frame.get("session_id") or str(uuid4())
|
||||
scope = ContextualScope.model_validate(frame.get("scope", {}))
|
||||
block = render_scope_block(scope)
|
||||
buf = get_session_buffer(user_id, session_id, channel="contextual")
|
||||
buf.append_system_message(
|
||||
f"User navigated to a new view. {block} Treat this as the new active context."
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.contextual_scope_ack,
|
||||
"session_id": session_id,
|
||||
}))
|
||||
logger.info(
|
||||
"device_ws: contextual_scope_update user=%s session=%s page=%s",
|
||||
user_id, session_id, scope.page,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_brief_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a brief_request frame — streams plain-text brief back on the socket.
|
||||
|
||||
No episode storage — briefs are not conversations.
|
||||
"""
|
||||
import uuid as _uuid
|
||||
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
session_id = frame.get("session_id") or str(uuid4())
|
||||
mode: str = frame.get("mode", "home")
|
||||
project_id: str | None = frame.get("project_id")
|
||||
|
||||
logger.info(
|
||||
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
|
||||
user_id, request_id, mode, project_id,
|
||||
)
|
||||
|
||||
# Validate project_id for project mode before touching LLM.
|
||||
if mode == "project":
|
||||
try:
|
||||
if not project_id:
|
||||
raise ValueError("project_id required for project mode")
|
||||
_uuid.UUID(project_id)
|
||||
except (ValueError, AttributeError) as exc:
|
||||
logger.warning(
|
||||
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
return
|
||||
|
||||
# Enrich context with memory (no user message — use empty string as probe).
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
"",
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
try:
|
||||
if mode == "project":
|
||||
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
|
||||
else:
|
||||
event_stream = run_home_brief(user_id, context)
|
||||
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: brief_request failed user=%s req=%s: %s",
|
||||
user_id, request_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
logger.info(
|
||||
"device_ws: brief_request_end user=%s req=%s mode=%s",
|
||||
user_id, request_id, mode,
|
||||
)
|
||||
|
||||
|
||||
# ── v6 Task Brief Handler ────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_task_brief_request(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
|
||||
|
||||
Streams the briefing markdown back to the client.
|
||||
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
|
||||
"""
|
||||
request_id = frame.get("request_id") or str(uuid4())
|
||||
session_id = frame.get("session_id") or str(uuid4())
|
||||
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
|
||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
||||
|
||||
logger.info(
|
||||
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
|
||||
user_id, request_id, task_id, project_id,
|
||||
)
|
||||
|
||||
if not task_id:
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
|
||||
)
|
||||
return
|
||||
|
||||
async with async_session() as db:
|
||||
memory = MemoryMiddleware(db)
|
||||
memory_context = await memory.enrich_context(
|
||||
user_id,
|
||||
f"task brief: {task_id}",
|
||||
trace_id=request_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
context: dict = {
|
||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
||||
"format_prefs": frame.get("format_prefs"),
|
||||
**memory_context,
|
||||
}
|
||||
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
response_chunks: list[str] = []
|
||||
|
||||
try:
|
||||
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
|
||||
formatter = StreamFormatter(request_id=request_id)
|
||||
async for ws_frame in formatter.format(event_stream):
|
||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
elif ws_frame.type == "stream_start":
|
||||
await websocket.send_text(ws_frame.model_dump_json())
|
||||
# stream_end is emitted below with mutations — skip formatter's version
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
|
||||
user_id, request_id, task_id, exc,
|
||||
)
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
||||
)
|
||||
return
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
# Extract canvas block then emit stream_end with optional mutations.
|
||||
full_response = "".join(response_chunks)
|
||||
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
|
||||
|
||||
mutations: list[dict] = []
|
||||
if canvas_content:
|
||||
mutations.append({
|
||||
"type": "canvas_draft",
|
||||
"content": canvas_content,
|
||||
"kind": canvas_kind,
|
||||
})
|
||||
|
||||
await websocket.send_text(
|
||||
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
|
||||
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
|
||||
)
|
||||
|
||||
|
||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_journey_start(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a journey_start frame — explores directory and sends first question."""
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
try:
|
||||
reply = await handle_journey_start(user_id, frame)
|
||||
await websocket.send_text(json.dumps(reply))
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "journey_reply",
|
||||
"session_id": frame.get("session_id", ""),
|
||||
"message": f"Failed to start journey: {exc}",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
}))
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
async def _handle_journey_message(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Handle a journey_message frame — continues the journey conversation."""
|
||||
executor = await _make_ws_executor(websocket, user_id)
|
||||
set_client_executor(executor)
|
||||
try:
|
||||
reply = await handle_journey_message(user_id, frame)
|
||||
await websocket.send_text(json.dumps(reply))
|
||||
except Exception as exc:
|
||||
session_id = frame.get("session_id", "")
|
||||
logger.error(
|
||||
"device_ws: journey_message failed user=%s session=%s: %s",
|
||||
user_id, session_id, exc,
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": f"Journey error: {exc}",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
}))
|
||||
finally:
|
||||
clear_client_executor()
|
||||
|
||||
|
||||
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
||||
|
||||
|
||||
async def _handle_index_session_start(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Register a new folder index session. No response sent — client is declaring intent."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
||||
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
|
||||
|
||||
if not session_id:
|
||||
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
||||
return
|
||||
|
||||
_index_sessions[session_id] = {
|
||||
"user_id": user_id,
|
||||
"project_id": project_id,
|
||||
"processed": 0,
|
||||
"total": total,
|
||||
"cancelled": False,
|
||||
}
|
||||
logger.info(
|
||||
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
||||
user_id, session_id, project_id, total,
|
||||
)
|
||||
|
||||
|
||||
async def _handle_index_session_cancel(
|
||||
websocket: WebSocket,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
session = _index_sessions.get(session_id)
|
||||
if session:
|
||||
session["cancelled"] = True
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "cancelled",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
||||
|
||||
|
||||
async def _handle_index_file_batch(
|
||||
websocket: WebSocket,
|
||||
user_id: str,
|
||||
frame: dict,
|
||||
) -> None:
|
||||
"""Process a batch of files for an index session, streaming results back."""
|
||||
# Lazy imports to avoid heavy load at module startup.
|
||||
from app.core.folder_indexer import ( # noqa: PLC0415
|
||||
summarize_image,
|
||||
summarize_pdf,
|
||||
summarize_docx,
|
||||
summarize_text,
|
||||
)
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.billing.quota import add_token_usage # noqa: PLC0415
|
||||
|
||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
||||
files: list[dict] = frame.get("files", [])
|
||||
|
||||
session = _index_sessions.get(session_id)
|
||||
if not session or session.get("cancelled"):
|
||||
return
|
||||
|
||||
async with async_session() as db:
|
||||
tier = await tier_manager.get_tier(user_id, db)
|
||||
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||
cap: int | None = None if raw_cap == -1 else raw_cap
|
||||
|
||||
for file_info in files:
|
||||
if session.get("cancelled"):
|
||||
return
|
||||
|
||||
# Electron's toSnakeCase converts payload keys, so accept both forms.
|
||||
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
|
||||
kind: str = file_info.get("kind") or "text"
|
||||
content: str = file_info.get("content") or ""
|
||||
ext: str = file_info.get("ext") or ""
|
||||
mime: str = file_info.get("mime") or "application/octet-stream"
|
||||
name: str = rel_path.split("/")[-1] or rel_path
|
||||
|
||||
try:
|
||||
if kind == "image":
|
||||
res = await summarize_image(image_b64=content, mime=mime)
|
||||
elif kind == "pdf":
|
||||
res = await summarize_pdf(pdf_b64=content, name=name)
|
||||
elif kind == "docx":
|
||||
res = await summarize_docx(docx_b64=content, name=name)
|
||||
else:
|
||||
res = await summarize_text(content=content, ext=ext, name=name)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
||||
session_id, rel_path, exc,
|
||||
)
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": None,
|
||||
"tokensUsed": 0,
|
||||
"error": str(exc),
|
||||
}))
|
||||
session["processed"] += 1
|
||||
continue
|
||||
|
||||
# Account for token usage and check cap.
|
||||
usage = await add_token_usage(
|
||||
user_id=user_id,
|
||||
feature="folder_index",
|
||||
tokens=res.tokens_used,
|
||||
db=db,
|
||||
cap=cap,
|
||||
)
|
||||
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_file_result,
|
||||
"sessionId": session_id,
|
||||
"relPath": rel_path,
|
||||
"summary": res.summary,
|
||||
"tokensUsed": res.tokens_used,
|
||||
}))
|
||||
session["processed"] += 1
|
||||
|
||||
if usage.exhausted:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "quota_exceeded",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session quota_exceeded user=%s session=%s",
|
||||
user_id, session_id,
|
||||
)
|
||||
return
|
||||
|
||||
# After processing the batch, emit progress.
|
||||
processed = session["processed"]
|
||||
total = session["total"]
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_progress,
|
||||
"sessionId": session_id,
|
||||
"processed": processed,
|
||||
"total": total,
|
||||
}))
|
||||
|
||||
if processed >= total:
|
||||
await websocket.send_text(json.dumps({
|
||||
"type": WsFrameType.index_session_done,
|
||||
"sessionId": session_id,
|
||||
"status": "completed",
|
||||
}))
|
||||
_index_sessions.pop(session_id, None)
|
||||
logger.info(
|
||||
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
||||
user_id, session_id, processed,
|
||||
)
|
||||
|
||||
|
||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||
|
||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||
"""Send a ping frame every 30 s to keep the connection alive."""
|
||||
while True:
|
||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||
|
||||
|
||||
# ── Disconnect cleanup ────────────────────────────────────────────────
|
||||
|
||||
async def _mark_runs_disconnected(user_id: str) -> None:
|
||||
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
||||
try:
|
||||
async with async_session() as db:
|
||||
await db.execute(
|
||||
update(AgentRunLog)
|
||||
.where(
|
||||
AgentRunLog.user_id == user_id,
|
||||
AgentRunLog.status == "running",
|
||||
)
|
||||
.values(
|
||||
status="error",
|
||||
errors=["device disconnected"],
|
||||
)
|
||||
)
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
@@ -1,225 +0,0 @@
|
||||
"""Memory management routes — view/edit/delete user memory tiers.
|
||||
|
||||
All routes require authentication. Data is always user-scoped.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Annotated
|
||||
|
||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.core.memory_middleware import MemoryMiddleware
|
||||
from app.db import get_session
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
MemoryRelation,
|
||||
)
|
||||
from app.schemas import UserProfile
|
||||
|
||||
router = APIRouter(prefix="/memory", tags=["memory"])
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_ALLOWED_PREDICATES = {
|
||||
"works_at",
|
||||
"reports_to",
|
||||
"stakeholder_of",
|
||||
"last_contacted_on",
|
||||
"owes_followup",
|
||||
"manages",
|
||||
"collaborates_with",
|
||||
"owns",
|
||||
"member_of",
|
||||
"custom",
|
||||
}
|
||||
|
||||
|
||||
# ── Response schemas ─────────────────────────────────────────────────────────
|
||||
|
||||
class RelationOut(BaseModel):
|
||||
id: str
|
||||
subject_label: str
|
||||
subject_type: str
|
||||
predicate: str
|
||||
object_label: str
|
||||
object_type: str
|
||||
confidence: float
|
||||
last_confirmed_at: int | None = None # epoch ms
|
||||
|
||||
|
||||
class RelationPatch(BaseModel):
|
||||
subject_label: str | None = None
|
||||
object_label: str | None = None
|
||||
predicate: str | None = None
|
||||
confidence: float | None = Field(None, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class CoreAddBody(BaseModel):
|
||||
key: str = Field(..., min_length=1, max_length=255)
|
||||
value: str = Field(..., min_length=1)
|
||||
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
||||
|
||||
def _relation_to_out(row: MemoryRelation) -> RelationOut:
|
||||
last_ms: int | None = None
|
||||
if row.last_confirmed_at is not None:
|
||||
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
|
||||
return RelationOut(
|
||||
id=row.id,
|
||||
subject_label=row.subject_label,
|
||||
subject_type=row.subject_type,
|
||||
predicate=row.predicate,
|
||||
object_label=row.object_label,
|
||||
object_type=row.object_type,
|
||||
confidence=row.confidence,
|
||||
last_confirmed_at=last_ms,
|
||||
)
|
||||
|
||||
|
||||
# ── Routes ───────────────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/core", response_model=dict[str, str])
|
||||
async def get_core_memory(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Return all core memory k/v pairs (plaintext) for the current user."""
|
||||
mw = MemoryMiddleware(db)
|
||||
blocks = await mw.list_core_blocks(current_user.id)
|
||||
return {b["label"]: b["value"] for b in blocks}
|
||||
|
||||
|
||||
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_core_key(
|
||||
key: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Delete a single core memory key (GDPR Art. 17)."""
|
||||
mw = MemoryMiddleware(db)
|
||||
deleted = await mw.delete_core(current_user.id, key)
|
||||
if not deleted:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
|
||||
|
||||
|
||||
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
|
||||
async def add_core_key(
|
||||
body: CoreAddBody,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, str]:
|
||||
"""Add or overwrite a core memory key/value pair."""
|
||||
mw = MemoryMiddleware(db)
|
||||
await mw.update_core(current_user.id, body.key, body.value)
|
||||
return {body.key: body.value}
|
||||
|
||||
|
||||
@router.get("/relational", response_model=list[RelationOut])
|
||||
async def get_relational_memory(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[RelationOut]:
|
||||
"""Return all relational memory rows for the current user."""
|
||||
mw = MemoryMiddleware(db)
|
||||
rows = await mw.query_relations(current_user.id, limit=200)
|
||||
return [_relation_to_out(r) for r in rows]
|
||||
|
||||
|
||||
@router.patch("/relational/{relation_id}", response_model=RelationOut)
|
||||
async def patch_relation(
|
||||
relation_id: str,
|
||||
body: RelationPatch,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> RelationOut:
|
||||
"""Edit a relation row's labels, predicate, or confidence."""
|
||||
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
|
||||
)
|
||||
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.id == relation_id,
|
||||
MemoryRelation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||
|
||||
if body.subject_label is not None:
|
||||
row.subject_label = body.subject_label
|
||||
if body.object_label is not None:
|
||||
row.object_label = body.object_label
|
||||
if body.predicate is not None:
|
||||
row.predicate = body.predicate
|
||||
if body.confidence is not None:
|
||||
row.confidence = body.confidence
|
||||
row.last_confirmed_at = datetime.now(timezone.utc)
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(row)
|
||||
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
|
||||
return _relation_to_out(row)
|
||||
|
||||
|
||||
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def delete_relation(
|
||||
relation_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Hard-delete a relation row (GDPR Art. 17)."""
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.id == relation_id,
|
||||
MemoryRelation.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
||||
await db.delete(row)
|
||||
await db.commit()
|
||||
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
|
||||
|
||||
|
||||
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def forget_all(
|
||||
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> None:
|
||||
"""Wipe all memory tiers for the current user (GDPR Art. 17).
|
||||
|
||||
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
|
||||
"""
|
||||
if x_confirm != "true":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
|
||||
)
|
||||
|
||||
uid = current_user.id
|
||||
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
|
||||
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
|
||||
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
|
||||
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
|
||||
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
|
||||
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
|
||||
await db.commit()
|
||||
logger.warning("memory: forget_all GDPR wipe user=%s", uid)
|
||||
@@ -1 +0,0 @@
|
||||
"OAuth provider abstractions and utilities."
|
||||
@@ -1,135 +0,0 @@
|
||||
"""OAuth 2.0 + PKCE provider abstractions.
|
||||
|
||||
Each provider implements a three-step flow designed for a desktop (public) client:
|
||||
|
||||
1. get_authorization_url(state, code_challenge) → str
|
||||
Build the provider's consent-screen URL. State and code_challenge are
|
||||
generated server-side; the client opens this URL in the system browser.
|
||||
|
||||
2. exchange_code(code, code_verifier, redirect_uri) → dict
|
||||
Exchange the short-lived authorization code for an access token.
|
||||
The code_verifier proves ownership of the PKCE challenge.
|
||||
|
||||
3. get_userinfo(access_token) → OAuthUserInfo
|
||||
Fetch the canonical user identity from the provider.
|
||||
|
||||
Currently supported providers:
|
||||
- GoogleOAuthProvider (scope: openid email profile)
|
||||
|
||||
Adding a new provider:
|
||||
- Implement the three methods above.
|
||||
- Register in _PROVIDERS inside routes/auth.py.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import hashlib
|
||||
import os
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
|
||||
import httpx
|
||||
|
||||
|
||||
# ── Data transfer objects ─────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class OAuthUserInfo:
|
||||
"""Normalized user identity returned by any provider."""
|
||||
|
||||
provider_user_id: str
|
||||
email: str
|
||||
email_verified: bool
|
||||
avatar_url: str | None
|
||||
name: str | None
|
||||
|
||||
|
||||
# ── PKCE helpers ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def generate_pkce_pair() -> tuple[str, str]:
|
||||
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
|
||||
|
||||
The code_verifier is a random 32-byte URL-safe base64 string.
|
||||
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
|
||||
"""
|
||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
|
||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||
return code_verifier, code_challenge
|
||||
|
||||
|
||||
# ── Google provider ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
class GoogleOAuthProvider:
|
||||
"""Google OAuth 2.0 provider (openid email profile scope).
|
||||
|
||||
Uses Google's standard authorization endpoint with PKCE S256.
|
||||
Does NOT use google-auth-oauthlib to keep the flow generic and async.
|
||||
"""
|
||||
|
||||
name = "google"
|
||||
|
||||
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self, state: str, code_challenge: str) -> str:
|
||||
"""Build the Google consent-screen URL."""
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": "openid email profile",
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "select_account",
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
async def exchange_code(
|
||||
self, code: str, code_verifier: str, redirect_uri: str
|
||||
) -> dict:
|
||||
"""Exchange authorization code for an access token."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self._TOKEN_URL,
|
||||
data={
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
"code": code,
|
||||
"code_verifier": code_verifier,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
|
||||
"""Fetch the authenticated user's identity from Google."""
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
self._USERINFO_URL,
|
||||
headers={"Authorization": f"Bearer {access_token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return OAuthUserInfo(
|
||||
provider_user_id=data["sub"],
|
||||
email=data["email"],
|
||||
email_verified=data.get("email_verified", False),
|
||||
avatar_url=data.get("picture"),
|
||||
name=data.get("name"),
|
||||
)
|
||||
@@ -1,4 +0,0 @@
|
||||
from app.billing.stripe_service import stripe_service
|
||||
from app.billing.tier_manager import tier_manager
|
||||
|
||||
__all__ = ["stripe_service", "tier_manager"]
|
||||
@@ -1,139 +0,0 @@
|
||||
"""Quota checks and atomic token-usage accounting for folder integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.billing.tier_manager import TierManager
|
||||
from app.models import MonthlyTokenUsage
|
||||
from app.schemas import BillingTier
|
||||
|
||||
|
||||
class QuotaExceeded(Exception):
|
||||
"""Raised when a folder operation cannot proceed under the user's tier."""
|
||||
|
||||
def __init__(self, reason: str, message: str) -> None:
|
||||
super().__init__(message)
|
||||
self.reason = reason # "max_files" | "monthly_tokens"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TokenUsageResult:
|
||||
tokens_used: int
|
||||
exhausted: bool
|
||||
|
||||
|
||||
def _current_year_month() -> str:
|
||||
return datetime.now(timezone.utc).strftime("%Y-%m")
|
||||
|
||||
|
||||
_tier_manager = TierManager()
|
||||
|
||||
|
||||
async def check_folder_quota(
|
||||
*,
|
||||
user_id: str,
|
||||
tier: BillingTier,
|
||||
estimated_files: int,
|
||||
db: AsyncSession,
|
||||
) -> None:
|
||||
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
|
||||
would be violated. -1 in either feature means unlimited."""
|
||||
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
|
||||
if max_files != -1 and estimated_files > max_files:
|
||||
raise QuotaExceeded(
|
||||
"max_files",
|
||||
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
|
||||
)
|
||||
|
||||
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
||||
if cap == -1:
|
||||
return
|
||||
ym = _current_year_month()
|
||||
row = (
|
||||
await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == user_id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == "folder_index",
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
used = row.tokens_used if row else 0
|
||||
if used >= cap:
|
||||
raise QuotaExceeded(
|
||||
"monthly_tokens",
|
||||
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
|
||||
)
|
||||
|
||||
|
||||
async def add_token_usage(
|
||||
*,
|
||||
user_id: str,
|
||||
feature: str,
|
||||
tokens: int,
|
||||
db: AsyncSession,
|
||||
cap: int | None = None,
|
||||
) -> TokenUsageResult:
|
||||
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
|
||||
|
||||
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
|
||||
back to a read-then-write on other engines (e.g. aiosqlite in tests).
|
||||
Returns post-update total and whether cap is exhausted.
|
||||
"""
|
||||
ym = _current_year_month()
|
||||
|
||||
# Detect dialect to choose between native upsert and portable fallback.
|
||||
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
|
||||
|
||||
if dialect_name == "postgresql":
|
||||
# Native atomic upsert — production path.
|
||||
stmt = (
|
||||
pg_insert(MonthlyTokenUsage)
|
||||
.values(
|
||||
user_id=user_id,
|
||||
year_month=ym,
|
||||
feature=feature,
|
||||
tokens_used=tokens,
|
||||
)
|
||||
.on_conflict_do_update(
|
||||
index_elements=["user_id", "year_month", "feature"],
|
||||
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
|
||||
)
|
||||
.returning(MonthlyTokenUsage.tokens_used)
|
||||
)
|
||||
used: int = (await db.execute(stmt)).scalar_one()
|
||||
await db.commit()
|
||||
else:
|
||||
# Portable fallback — used in tests (SQLite) and any non-PG engine.
|
||||
row = (
|
||||
await db.execute(
|
||||
select(MonthlyTokenUsage).where(
|
||||
MonthlyTokenUsage.user_id == user_id,
|
||||
MonthlyTokenUsage.year_month == ym,
|
||||
MonthlyTokenUsage.feature == feature,
|
||||
)
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
|
||||
if row is None:
|
||||
row = MonthlyTokenUsage(
|
||||
user_id=user_id,
|
||||
year_month=ym,
|
||||
feature=feature,
|
||||
tokens_used=tokens,
|
||||
)
|
||||
db.add(row)
|
||||
else:
|
||||
row.tokens_used += tokens
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(row)
|
||||
used = row.tokens_used
|
||||
|
||||
exhausted = cap is not None and cap != -1 and used >= cap
|
||||
return TokenUsageResult(tokens_used=used, exhausted=exhausted)
|
||||
@@ -1,149 +0,0 @@
|
||||
"""Tier manager: feature matrix and quota enforcement.
|
||||
|
||||
``TierManager`` is the single source of truth for what each billing tier
|
||||
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import HTTPException, status
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.schemas import BillingTier
|
||||
|
||||
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||
FEATURES: dict[str, dict[str, Any]] = {
|
||||
"free": {
|
||||
"agents": 3,
|
||||
"batch_active": 2,
|
||||
"batch_runs_per_day": 5,
|
||||
"providers": 1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": False, # keyword fallback only
|
||||
"realtime_extraction": False, # batch queue (Phase 2)
|
||||
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
"folder_max_files": 200,
|
||||
"folder_monthly_tokens": 100_000,
|
||||
},
|
||||
"pro": {
|
||||
"agents": -1, # unlimited
|
||||
"batch_active": 10,
|
||||
"batch_runs_per_day": 50,
|
||||
"providers": -1,
|
||||
"batch_builder": False,
|
||||
"sso": False,
|
||||
"real_embeddings": True, # pgvector cosine search
|
||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
||||
"relational_memory": True, # person/project predicates
|
||||
"proactive_mining": False, # Power+ only (Phase 5)
|
||||
"folder_max_files": 5000,
|
||||
"folder_monthly_tokens": 2_000_000,
|
||||
},
|
||||
"power": {
|
||||
"agents": -1,
|
||||
"batch_active": -1, # unlimited
|
||||
"batch_runs_per_day": -1, # unlimited
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"sso": False,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
"folder_max_files": -1, # unlimited
|
||||
"folder_monthly_tokens": -1, # unlimited
|
||||
},
|
||||
"team": {
|
||||
"agents": -1,
|
||||
"batch_active": -1,
|
||||
"batch_runs_per_day": -1, # unlimited
|
||||
"providers": -1,
|
||||
"batch_builder": True,
|
||||
"sso": True,
|
||||
"real_embeddings": True,
|
||||
"realtime_extraction": True,
|
||||
"relational_memory": True, # all predicates incl. custom
|
||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
||||
"folder_max_files": -1, # unlimited
|
||||
"folder_monthly_tokens": -1, # unlimited
|
||||
},
|
||||
}
|
||||
|
||||
# Requests-per-minute limit per tier.
|
||||
RATE_LIMITS: dict[str, int] = {
|
||||
"free": 20,
|
||||
"pro": 60,
|
||||
"power": 120,
|
||||
"team": 200,
|
||||
}
|
||||
|
||||
|
||||
class TierManager:
|
||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||
|
||||
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||
|
||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||
"""Return the current billing tier for ``user_id`` from the DB.
|
||||
|
||||
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
||||
when no subscription row exists.
|
||||
"""
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
tier: str | None = result.scalar_one_or_none()
|
||||
if tier is None or tier not in FEATURES:
|
||||
return "power" if settings.ENV == "dev" else "free"
|
||||
return tier # type: ignore[return-value]
|
||||
|
||||
# ── Feature access ───────────────────────────────────────────────────
|
||||
|
||||
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||
|
||||
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||
"""
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if value is None:
|
||||
return False
|
||||
if isinstance(value, bool):
|
||||
return value
|
||||
return value != 0
|
||||
|
||||
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||
if not self.check_feature(tier, feature):
|
||||
detail = (
|
||||
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||
if tier_name
|
||||
else f"Feature '{feature}' is not available on your current tier."
|
||||
)
|
||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
|
||||
"""Return integer feature value for tier. -1 means unlimited."""
|
||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||
if not isinstance(value, int):
|
||||
return 0
|
||||
return value
|
||||
|
||||
# ── Rate limiting ────────────────────────────────────────────────────
|
||||
|
||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||
"""Return the requests-per-minute limit for ``tier``."""
|
||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||
|
||||
|
||||
# Module-level singleton shared across the app.
|
||||
tier_manager = TierManager()
|
||||
@@ -1,85 +0,0 @@
|
||||
from typing import Literal
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
|
||||
JWT_SECRET: str = "change-me-in-production"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||
|
||||
STRIPE_SECRET_KEY: str = ""
|
||||
STRIPE_WEBHOOK_SECRET: str = ""
|
||||
|
||||
OPENAI_API_KEY: str = ""
|
||||
ANTHROPIC_API_KEY: str = ""
|
||||
GOOGLE_API_KEY: str = ""
|
||||
CEREBRAS_API_KEY: str = ""
|
||||
GROQ_API_KEY: str = ""
|
||||
DEEPSEEK_API_KEY: str = ""
|
||||
|
||||
LLM_MODEL: str = "gpt-4o"
|
||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||
|
||||
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
||||
LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use)
|
||||
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
|
||||
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
|
||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
||||
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
||||
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
|
||||
|
||||
# GitHub Copilot OAuth token storage directory.
|
||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||
GITHUB_COPILOT_TOKEN_DIR: str = ""
|
||||
|
||||
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
|
||||
GMAIL_CLIENT_ID: str = ""
|
||||
GMAIL_CLIENT_SECRET: str = ""
|
||||
MS_CLIENT_ID: str = ""
|
||||
MS_CLIENT_SECRET: str = ""
|
||||
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||
MS_TENANT_ID: str = "common"
|
||||
|
||||
# Google Login OAuth credentials — scope: openid email profile.
|
||||
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
|
||||
GOOGLE_AUTH_CLIENT_ID: str = ""
|
||||
GOOGLE_AUTH_CLIENT_SECRET: str = ""
|
||||
# The redirect URI registered in Google Cloud Console.
|
||||
# Google redirects here after consent; this backend route then bounces to
|
||||
# the adiuvai:// deep link so the Electron app receives the code.
|
||||
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
|
||||
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
||||
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
||||
|
||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||
OAUTH_ENCRYPTION_KEY: str = ""
|
||||
|
||||
CORS_ORIGINS: list[str] = [
|
||||
"app://.",
|
||||
"http://localhost:3000",
|
||||
"http://localhost:5173",
|
||||
"http://localhost:4173", # Vite preview (web SPA)
|
||||
"https://app.adiuvai.com", # Production web portal
|
||||
]
|
||||
|
||||
LANGFUSE_SECRET_KEY: str = ""
|
||||
LANGFUSE_PUBLIC_KEY: str = ""
|
||||
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
||||
|
||||
SCHEDULER_ENABLED: bool = True
|
||||
|
||||
ENV: Literal["dev", "prod"] = "dev"
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||
|
||||
|
||||
settings = Settings()
|
||||
@@ -1,30 +0,0 @@
|
||||
"""Minimal agent base types retained for compatibility with batch runners."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
|
||||
class BaseAgent(ABC):
|
||||
"""Common base for non-chat agents still using the old base contract."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
user_id: str = "",
|
||||
shared_memory: dict[str, Any] | None = None,
|
||||
vector_store_context: list[str] | None = None,
|
||||
) -> None:
|
||||
self.user_id = user_id
|
||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
||||
self.vector_store_context: list[str] = vector_store_context or []
|
||||
|
||||
@abstractmethod
|
||||
def get_name(self) -> str: ...
|
||||
|
||||
@abstractmethod
|
||||
def get_description(self) -> str: ...
|
||||
|
||||
@property
|
||||
def skills(self) -> list[str]:
|
||||
return []
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,96 +0,0 @@
|
||||
"""In-process TTL buffer for per-session LangChain message history.
|
||||
|
||||
Stores the full message list (including AIMessage with tool_calls and ToolMessage)
|
||||
keyed by (user_id, session_id), so agents can reconstruct tool-call context across
|
||||
conversation turns without it being lossy through the wire.
|
||||
|
||||
Single-process only. For multi-worker deployments, replace the _SessionBuffer
|
||||
implementation with one backed by Redis (serialize LangChain messages to dicts via
|
||||
message_to_dict / messages_from_dict from langchain_core.messages).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from threading import Lock
|
||||
|
||||
from langchain_core.messages import BaseMessage
|
||||
|
||||
SESSION_TTL_SECONDS = 1800 # 30-minute idle expiry
|
||||
MAX_MESSAGES_PER_SESSION = 80 # cap to avoid unbounded memory growth
|
||||
|
||||
|
||||
class _SessionBuffer:
|
||||
def __init__(self) -> None:
|
||||
self._store: dict[tuple[str, str], tuple[float, list[BaseMessage]]] = {}
|
||||
self._lock = Lock()
|
||||
|
||||
def _evict_stale(self) -> None:
|
||||
now = time.monotonic()
|
||||
stale = [k for k, (ts, _) in self._store.items() if now - ts > SESSION_TTL_SECONDS]
|
||||
for k in stale:
|
||||
del self._store[k]
|
||||
|
||||
def get(self, user_id: str, session_id: str) -> list[BaseMessage] | None:
|
||||
key = (user_id, session_id)
|
||||
with self._lock:
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
return None
|
||||
ts, msgs = entry
|
||||
if time.monotonic() - ts > SESSION_TTL_SECONDS:
|
||||
del self._store[key]
|
||||
return None
|
||||
self._store[key] = (time.monotonic(), msgs)
|
||||
return list(msgs)
|
||||
|
||||
def set(self, user_id: str, session_id: str, messages: list[BaseMessage]) -> None:
|
||||
key = (user_id, session_id)
|
||||
capped = messages[-MAX_MESSAGES_PER_SESSION:]
|
||||
with self._lock:
|
||||
self._evict_stale()
|
||||
self._store[key] = (time.monotonic(), capped)
|
||||
|
||||
def clear(self, user_id: str, session_id: str) -> None:
|
||||
with self._lock:
|
||||
self._store.pop((user_id, session_id), None)
|
||||
|
||||
def append_system_message(self, user_id: str, session_id: str, text: str) -> None:
|
||||
"""Append a synthetic system message to the buffer for the given session.
|
||||
|
||||
Creates the session slot if it does not yet exist. Used by the
|
||||
contextual_scope_update handler to inject navigation events without
|
||||
making an LLM call.
|
||||
"""
|
||||
from langchain_core.messages import SystemMessage # noqa: PLC0415
|
||||
|
||||
key = (user_id, session_id)
|
||||
with self._lock:
|
||||
entry = self._store.get(key)
|
||||
if entry is None:
|
||||
msgs: list[BaseMessage] = [SystemMessage(content=text)]
|
||||
else:
|
||||
_, existing = entry
|
||||
msgs = list(existing) + [SystemMessage(content=text)]
|
||||
capped = msgs[-MAX_MESSAGES_PER_SESSION:]
|
||||
self._store[key] = (time.monotonic(), capped)
|
||||
|
||||
|
||||
class ContextualBufferProxy:
|
||||
"""Thin wrapper around _SessionBuffer that closes over user_id + session_id.
|
||||
|
||||
Returned by get_session_buffer() so callers can call
|
||||
``proxy.append_system_message(text)`` without threading user_id/session_id
|
||||
through every call site.
|
||||
"""
|
||||
|
||||
def __init__(self, buf: "_SessionBuffer", user_id: str, session_id: str) -> None:
|
||||
self._buf = buf
|
||||
self._user_id = user_id
|
||||
self._session_id = session_id
|
||||
|
||||
def append_system_message(self, text: str) -> None:
|
||||
self._buf.append_system_message(self._user_id, self._session_id, text)
|
||||
|
||||
|
||||
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
|
||||
session_buffer = _SessionBuffer()
|
||||
@@ -1,228 +0,0 @@
|
||||
"""Brief agent — produces plain-text home and project status briefs.
|
||||
|
||||
Read-only tool subset only. Never calls _normalize_tagged_list_lines —
|
||||
the brief prompt forbids XML tags, so skipping post-processing is intentional.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import date
|
||||
from typing import Any
|
||||
|
||||
from app.agents.note_agent import NOTE_READ_TOOLS
|
||||
from app.agents.project_agent import PROJECT_READ_TOOLS
|
||||
from app.agents.task_agent import TASK_READ_TOOLS
|
||||
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
|
||||
from app.core.deep_agent import (
|
||||
_language_instruction,
|
||||
_proactive_hints_injection,
|
||||
_read_only_memory_tools,
|
||||
_relational_memory_injection,
|
||||
_run_single_agent_stream,
|
||||
_trace_id_from_context,
|
||||
build_brief_multi_project_manifest,
|
||||
)
|
||||
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
|
||||
|
||||
_LANGUAGE_NAMES: dict[str, str] = {
|
||||
"en": "English", "it": "Italian", "es": "Spanish",
|
||||
"fr": "French", "de": "German",
|
||||
"english": "English", "italian": "Italian", "italiano": "Italian",
|
||||
"spanish": "Spanish", "español": "Spanish",
|
||||
"french": "French", "français": "French",
|
||||
"german": "German", "deutsch": "German",
|
||||
}
|
||||
|
||||
_HOME_BRIEF_FALLBACK = """\
|
||||
You are the user's personal assistant producing a short daily brief.
|
||||
|
||||
ROLE
|
||||
Act like a calm, attentive secretary writing a stand-up note for your boss.
|
||||
Warm and human, never breezy. Never cheerful filler, never emojis, never
|
||||
"here is your brief" meta-text. The user is opening the app mid-workday and
|
||||
is probably stressed — your job is to lower cognitive load, not add noise.
|
||||
|
||||
TOOLS — always call before writing
|
||||
Pull fresh data every run. Do not invent counts or titles. Use at minimum:
|
||||
- list_tasks_due_today — tasks the user owes today
|
||||
- list_timelines_today — events starting or ending today
|
||||
- list_all_projects — projects currently in progress or at risk
|
||||
- memory_list_blocks / memory_get — personal context about people, clients,
|
||||
payment habits, working preferences
|
||||
If a tool returns nothing, simply omit that topic. Never report zeros.
|
||||
|
||||
WHAT TO INCLUDE
|
||||
1. Tasks due today (title + priority; group the 1-2 most important).
|
||||
2. Timeline events starting or ending today (and anything that starts/ends
|
||||
tomorrow if the user has a very light day).
|
||||
3. Active projects that need a nudge — stalled, blocked, or awaiting input.
|
||||
4. Memory-aware colour where it sharpens the brief. Examples:
|
||||
- "Client Rossi tends to pay late — the Acme invoice is 6 days out."
|
||||
- "You usually dislike meetings before 10:00 — the call at 09:30 is unusual."
|
||||
Only add a memory line when it changes what the user does. Do not pad.
|
||||
|
||||
WHAT TO OMIT
|
||||
- Zero-counts ("no overdue items", "0 meetings today").
|
||||
- Statistics ("2 active projects, 3 completed tasks").
|
||||
- Headers, titles, greetings, sign-offs, dates, emojis, slang.
|
||||
- Meta-phrases ("here is", "let me know if", "hope this helps").
|
||||
- XML/HTML tags of any kind. Plain prose only.
|
||||
|
||||
LIGHT-DAY CLAUSE
|
||||
If tasks + events + active-project-nudges together produce fewer than two
|
||||
sentences of content, also list 1-2 projects in status on_hold or waiting
|
||||
and ask a single, specific question about them — e.g. "Is the Bianchi
|
||||
redesign still paused, or ready to pick back up?" One question max, grounded
|
||||
in a real project name.
|
||||
|
||||
VOICE
|
||||
- Calm. Concise. Human. Short sentences.
|
||||
- Use **bold** sparingly for task titles, project names, and people's names.
|
||||
- No bullet lists. Flow as 2-4 sentences of prose.
|
||||
|
||||
LENGTH
|
||||
2-4 sentences total. Hard cap 4. If the day is truly empty, one sentence.
|
||||
|
||||
Respond in the user's language ({language}). Today is {today}.\
|
||||
"""
|
||||
|
||||
_PROJECT_BRIEF_FALLBACK = """\
|
||||
You are the project assistant producing a short status brief for ONE project.
|
||||
|
||||
ROLE
|
||||
A senior project manager summarising state-of-play for the owner. Factual,
|
||||
sharp, forward-looking. Never reassuring filler, never emojis.
|
||||
|
||||
SCOPE
|
||||
Work only with project_id = {project_id}. Do not mention or pull data from
|
||||
other projects. Use tools to fetch fresh data:
|
||||
- get_project — current status, dates, description
|
||||
- list_tasks(project_id) — open work, split by status
|
||||
- list_timelines(project_id) — milestones hit, upcoming, overdue
|
||||
- list_notes(project_id) — any recent decisions or blockers
|
||||
- memory_get — relevant context about the client, collaborators, constraints
|
||||
|
||||
STRUCTURE — follow exactly, one short paragraph per section, no headers
|
||||
1. **State.** One sentence: current phase, health (on track / at risk / blocked),
|
||||
and why. Cite the concrete signal (overdue milestone, stalled tasks, recent
|
||||
blocker note).
|
||||
2. **What's moving.** What was completed or progressed recently. Name specific
|
||||
tasks or milestones.
|
||||
3. **Next steps.** The 1-3 most important things the user should do next, in
|
||||
priority order. Be concrete — task name, who owns it, when due if known.
|
||||
If waiting on someone else, name them and what the ask is.
|
||||
4. **Risks / memory-flagged items.** One line max. Only include when there is
|
||||
a real risk or a relevant memory (e.g. late-paying client, tight deadline,
|
||||
scope change). Omit the section entirely if nothing to say.
|
||||
|
||||
WHAT TO OMIT
|
||||
- Zero-counts ("no overdue tasks").
|
||||
- Generic advice ("keep up the good work").
|
||||
- Greetings, headers, bullet lists, emojis, sign-offs, meta-phrases.
|
||||
- XML/HTML tags or bracketed id lists. Plain prose only.
|
||||
|
||||
VOICE
|
||||
- Direct. Factual. No fluff.
|
||||
- Use **bold** sparingly for task titles, milestone names, and the owner's name.
|
||||
- Short sentences. Prefer verbs over nouns ("Client review is blocking release"
|
||||
not "There is a blocker which is the client review").
|
||||
|
||||
LENGTH
|
||||
4-8 sentences total across the 3-4 sections. Hard cap 8.
|
||||
|
||||
Respond in the user's language ({language}). Today is {today}.\
|
||||
"""
|
||||
|
||||
|
||||
def _resolve_language(context: dict[str, Any]) -> str:
|
||||
core = context.get("core_memory") or {}
|
||||
raw = (core.get("language") or "en").strip().lower()
|
||||
return _LANGUAGE_NAMES.get(raw, raw.title()) or "English"
|
||||
|
||||
|
||||
def _build_read_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
||||
return [
|
||||
*TASK_READ_TOOLS,
|
||||
*PROJECT_READ_TOOLS,
|
||||
*TIMELINE_READ_TOOLS,
|
||||
*NOTE_READ_TOOLS,
|
||||
*_read_only_memory_tools(user_id, trace_id),
|
||||
]
|
||||
|
||||
|
||||
async def run_home_brief(
|
||||
user_id: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Stream a plain-text daily home brief.
|
||||
|
||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||
"""
|
||||
from app.agents.folder_agent import FOLDER_TOOLS
|
||||
|
||||
trace_id = _trace_id_from_context(context)
|
||||
today = date.today().isoformat()
|
||||
language = _resolve_language(context)
|
||||
|
||||
raw_template, langfuse_prompt = get_prompt_or_fallback("home_brief", _HOME_BRIEF_FALLBACK)
|
||||
system_prompt = compile_prompt(raw_template, langfuse_prompt, language=language, today=today)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
if today not in system_prompt:
|
||||
system_prompt += f"\nToday is {today}."
|
||||
|
||||
brief_manifest = await build_brief_multi_project_manifest()
|
||||
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
|
||||
|
||||
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
message="Generate the daily brief.",
|
||||
context=context,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
agent_name="brief-agent",
|
||||
tools=tools,
|
||||
):
|
||||
yield event
|
||||
|
||||
|
||||
async def run_project_brief(
|
||||
user_id: str,
|
||||
project_id: str,
|
||||
context: dict[str, Any],
|
||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||
"""Stream a plain-text project status brief for project_id.
|
||||
|
||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
||||
"""
|
||||
trace_id = _trace_id_from_context(context)
|
||||
today = date.today().isoformat()
|
||||
language = _resolve_language(context)
|
||||
|
||||
raw_template, langfuse_prompt = get_prompt_or_fallback("project_brief", _PROJECT_BRIEF_FALLBACK)
|
||||
system_prompt = compile_prompt(
|
||||
raw_template, langfuse_prompt,
|
||||
language=language, today=today, project_id=project_id,
|
||||
)
|
||||
system_prompt += _relational_memory_injection(context)
|
||||
system_prompt += _proactive_hints_injection(context)
|
||||
system_prompt += _language_instruction(context)
|
||||
if today not in system_prompt:
|
||||
system_prompt += f"\nToday is {today}."
|
||||
|
||||
tools = _build_read_tools(user_id, trace_id)
|
||||
async for event in _run_single_agent_stream(
|
||||
user_id=user_id,
|
||||
system_prompt=system_prompt,
|
||||
message=f"Generate the project status brief for project {project_id}.",
|
||||
context=context,
|
||||
langfuse_prompt=langfuse_prompt,
|
||||
agent_name="brief-agent",
|
||||
tools=tools,
|
||||
):
|
||||
yield event
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,151 +0,0 @@
|
||||
"""Device connection manager.
|
||||
|
||||
Maintains in-memory state for all active Electron → backend WebSocket
|
||||
connections. One connection per user (latest replaces previous).
|
||||
|
||||
The manager handles the **tool-call round-trip** pattern:
|
||||
- Backend sends ``tool_call`` frame → Electron executes the action →
|
||||
returns ``tool_result`` frame.
|
||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||
receive the result dict from Electron.
|
||||
|
||||
This pattern is used by all tools (CRUD, file-system, etc.) via
|
||||
``execute_on_client()`` in ``ws_context.py``.
|
||||
|
||||
The ``device_manager`` module-level singleton is imported by both the
|
||||
device WS route and the agent runner.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fastapi import WebSocket
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DeviceConnection:
|
||||
"""State for a single connected Electron device."""
|
||||
|
||||
ws: WebSocket
|
||||
device_id: str
|
||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||
|
||||
|
||||
class DeviceConnectionManager:
|
||||
"""Singleton registry of active Electron WebSocket connections.
|
||||
|
||||
Thread/task safety note: asyncio is single-threaded by design. All
|
||||
mutations happen inside await-points on the main event loop, so no
|
||||
locking is required for the in-memory dicts.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._connections: dict[str, DeviceConnection] = {}
|
||||
|
||||
# ── Registration ──────────────────────────────────────────────────
|
||||
|
||||
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
|
||||
"""Store the active connection for *user_id*, replacing any previous one."""
|
||||
if user_id in self._connections:
|
||||
old = self._connections[user_id]
|
||||
logger.info(
|
||||
"device_manager: replacing existing connection for user=%s device=%s",
|
||||
user_id,
|
||||
old.device_id,
|
||||
)
|
||||
# Cancel any futures that were waiting on the old connection.
|
||||
for fut in old.pending_calls.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
|
||||
logger.info(
|
||||
"device_manager: registered user=%s device=%s", user_id, device_id
|
||||
)
|
||||
|
||||
def unregister(self, user_id: str) -> None:
|
||||
"""Remove the connection for *user_id* and cancel any pending futures."""
|
||||
conn = self._connections.pop(user_id, None)
|
||||
if conn is None:
|
||||
return
|
||||
for fut in conn.pending_calls.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
logger.info("device_manager: unregistered user=%s", user_id)
|
||||
|
||||
# ── Presence queries ──────────────────────────────────────────────
|
||||
|
||||
def get_ws(self, user_id: str) -> WebSocket | None:
|
||||
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
|
||||
conn = self._connections.get(user_id)
|
||||
return conn.ws if conn else None
|
||||
|
||||
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
|
||||
"""Return ``True`` if the user has an active connection.
|
||||
|
||||
If *device_id* is provided also checks that it matches the connected device.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
return False
|
||||
if device_id is not None:
|
||||
return conn.device_id == device_id
|
||||
return True
|
||||
|
||||
# ── Frame sending ─────────────────────────────────────────────────
|
||||
|
||||
async def send_frame(self, user_id: str, frame: dict) -> None:
|
||||
"""Send *frame* as a JSON text message to the device.
|
||||
|
||||
Raises ``RuntimeError`` if the user is not connected.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
raise RuntimeError(
|
||||
f"send_frame: user {user_id!r} is not connected"
|
||||
)
|
||||
await conn.ws.send_text(json.dumps(frame))
|
||||
|
||||
# ── Tool-call round-trip ──────────────────────────────────────────
|
||||
|
||||
def create_pending_call(
|
||||
self, user_id: str, call_id: str
|
||||
) -> asyncio.Future[dict]:
|
||||
"""Register a Future that will be resolved when the tool_result arrives.
|
||||
|
||||
Raises ``RuntimeError`` if the user is not connected.
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
raise RuntimeError(
|
||||
f"create_pending_call: user {user_id!r} is not connected"
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
fut: asyncio.Future[dict] = loop.create_future()
|
||||
conn.pending_calls[call_id] = fut
|
||||
return fut
|
||||
|
||||
def resolve_pending_call(
|
||||
self, user_id: str, call_id: str, result: dict
|
||||
) -> None:
|
||||
"""Fulfil the Future registered under *call_id* with the Electron result.
|
||||
|
||||
No-ops if the call_id is unknown (already timed out or cancelled).
|
||||
"""
|
||||
conn = self._connections.get(user_id)
|
||||
if conn is None:
|
||||
return
|
||||
fut = conn.pending_calls.pop(call_id, None)
|
||||
if fut is not None and not fut.done():
|
||||
fut.set_result(result)
|
||||
|
||||
|
||||
# Module-level singleton — import this everywhere.
|
||||
device_manager = DeviceConnectionManager()
|
||||
@@ -1,34 +0,0 @@
|
||||
"""OpenAI embedding helper for associative memory tier.
|
||||
|
||||
Single public function: ``embed_text(text) -> list[float] | None``.
|
||||
Returns None on any failure — callers must implement a keyword fallback.
|
||||
Never raises; all exceptions are logged as warnings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_MAX_INPUT_CHARS = 8000
|
||||
_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||
|
||||
|
||||
async def embed_text(text: str) -> list[float] | None:
|
||||
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
|
||||
try:
|
||||
client = AsyncOpenAI()
|
||||
truncated = text[:_MAX_INPUT_CHARS]
|
||||
response = await client.embeddings.create(
|
||||
input=truncated,
|
||||
model=_EMBEDDING_MODEL,
|
||||
)
|
||||
result: list[float] = response.data[0].embedding
|
||||
logger.debug("embeddings: embed_text dims=%d", len(result))
|
||||
return result
|
||||
except Exception as exc:
|
||||
logger.warning("embeddings: embed_text failed: %s", exc)
|
||||
return None
|
||||
@@ -1,183 +0,0 @@
|
||||
"""Per-file summarisation for project folder integration."""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import io
|
||||
from dataclasses import dataclass
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from pypdf import PdfReader
|
||||
from docx import Document as DocxDocument
|
||||
|
||||
from app.core.langfuse_client import (
|
||||
compile_prompt,
|
||||
extract_usage,
|
||||
get_langfuse,
|
||||
get_prompt_or_fallback,
|
||||
)
|
||||
from app.core.llm import get_llm
|
||||
|
||||
_TEXT_FALLBACK = (
|
||||
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
|
||||
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
|
||||
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
|
||||
)
|
||||
_IMAGE_FALLBACK = (
|
||||
"You are summarising an image attached to a project folder.\n"
|
||||
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
|
||||
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
|
||||
)
|
||||
_MAX_INPUT_CHARS = 6000
|
||||
|
||||
|
||||
@dataclass
|
||||
class IndexResult:
|
||||
summary: str
|
||||
tokens_used: int
|
||||
|
||||
|
||||
async def _llm_text(messages: list) -> object:
|
||||
"""Make the LLM call for text summarisation.
|
||||
|
||||
Defined as a standalone async function so tests can patch it cleanly
|
||||
without needing to mock the LLM object itself.
|
||||
"""
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||
return await llm.ainvoke(messages)
|
||||
|
||||
|
||||
async def _llm_vision(messages: list) -> object:
|
||||
"""Make the LLM call for vision (image) summarisation.
|
||||
|
||||
Accepts the message list and returns the response directly, mirroring
|
||||
the ``_llm_text`` caller pattern so tests can patch it at the module level.
|
||||
"""
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
||||
return await llm.ainvoke(messages)
|
||||
|
||||
|
||||
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
|
||||
"""Return a compact summary of an image file using vision.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
image_b64:
|
||||
Base64-encoded image bytes.
|
||||
mime:
|
||||
MIME type of the image, e.g. ``"image/png"``.
|
||||
file_name:
|
||||
Optional file name, attached to the Langfuse trace as input metadata.
|
||||
"""
|
||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
|
||||
messages = [
|
||||
SystemMessage(content=template),
|
||||
HumanMessage(content=[
|
||||
{"type": "text", "text": "Summarise this image."},
|
||||
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
|
||||
]),
|
||||
]
|
||||
lf = get_langfuse()
|
||||
if lf is not None:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="folder-summarize-image",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input={"file_name": file_name, "mime": mime},
|
||||
) as gen:
|
||||
response = await _llm_vision(messages)
|
||||
usage = extract_usage(response)
|
||||
gen.update(output=response.content, usage_details=usage)
|
||||
else:
|
||||
response = await _llm_vision(messages)
|
||||
usage = extract_usage(response)
|
||||
summary = (response.content or "").strip()[:500]
|
||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||
|
||||
|
||||
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a text file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
content:
|
||||
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
|
||||
ext:
|
||||
File extension including the leading dot, e.g. ``".md"``.
|
||||
name:
|
||||
File name, e.g. ``"kickoff.md"``.
|
||||
"""
|
||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
|
||||
truncated = content[:_MAX_INPUT_CHARS]
|
||||
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
|
||||
messages = [
|
||||
SystemMessage(content=compiled),
|
||||
HumanMessage(content="Summarise this file."),
|
||||
]
|
||||
lf = get_langfuse()
|
||||
if lf is not None:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="folder-summarize-text",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
|
||||
) as gen:
|
||||
response = await _llm_text(messages)
|
||||
usage = extract_usage(response)
|
||||
gen.update(output=response.content, usage_details=usage)
|
||||
else:
|
||||
response = await _llm_text(messages)
|
||||
usage = extract_usage(response)
|
||||
summary = (response.content or "").strip()[:500]
|
||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
||||
|
||||
|
||||
def _extract_pdf_text(pdf_b64: str) -> str:
|
||||
buf = io.BytesIO(base64.b64decode(pdf_b64))
|
||||
reader = PdfReader(buf)
|
||||
parts: list[str] = []
|
||||
for page in reader.pages:
|
||||
try:
|
||||
parts.append(page.extract_text() or "")
|
||||
except Exception:
|
||||
continue
|
||||
return "\n".join(parts).strip()
|
||||
|
||||
|
||||
def _extract_docx_text(docx_b64: str) -> str:
|
||||
buf = io.BytesIO(base64.b64decode(docx_b64))
|
||||
doc = DocxDocument(buf)
|
||||
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
|
||||
|
||||
|
||||
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a PDF file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
pdf_b64:
|
||||
Base64-encoded PDF bytes.
|
||||
name:
|
||||
File name, e.g. ``"report.pdf"``.
|
||||
"""
|
||||
text = _extract_pdf_text(pdf_b64)
|
||||
if not text:
|
||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||
return await summarize_text(content=text, ext=".pdf", name=name)
|
||||
|
||||
|
||||
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
|
||||
"""Return a compact summary of a DOCX file.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
docx_b64:
|
||||
Base64-encoded DOCX bytes.
|
||||
name:
|
||||
File name, e.g. ``"spec.docx"``.
|
||||
"""
|
||||
text = _extract_docx_text(docx_b64)
|
||||
if not text:
|
||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
||||
return await summarize_text(content=text, ext=".docx", name=name)
|
||||
@@ -1,190 +0,0 @@
|
||||
"""Langfuse observability — singleton client and prompt helpers.
|
||||
|
||||
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
|
||||
all helpers are no-ops so the app works without Langfuse configured.
|
||||
|
||||
Usage
|
||||
-----
|
||||
Tracing::
|
||||
|
||||
from app.core.langfuse_client import get_langfuse
|
||||
|
||||
lf = get_langfuse()
|
||||
if lf:
|
||||
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
|
||||
span.update(input=user_message)
|
||||
# ... do work ...
|
||||
span.update(output=result)
|
||||
lf.flush()
|
||||
|
||||
Prompt management::
|
||||
|
||||
from app.core.langfuse_client import get_prompt_or_fallback
|
||||
|
||||
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
|
||||
# Use text as the system prompt; pass prompt_obj to generations for linking.
|
||||
|
||||
Linking a prompt to a generation::
|
||||
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="llm-call",
|
||||
model="gpt-4o",
|
||||
prompt=prompt_obj, # links generation → prompt version in the UI
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=_usage(response))
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_client: Any = None
|
||||
_initialized: bool = False
|
||||
|
||||
|
||||
def get_langfuse() -> Any | None:
|
||||
"""Return the Langfuse singleton, or ``None`` when not configured."""
|
||||
global _client, _initialized
|
||||
if _initialized:
|
||||
return _client
|
||||
_initialized = True
|
||||
|
||||
from app.config.settings import settings # local import to avoid circular deps
|
||||
|
||||
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
||||
logger.debug("langfuse: not configured — observability disabled")
|
||||
return None
|
||||
|
||||
try:
|
||||
from langfuse import Langfuse
|
||||
|
||||
_client = Langfuse(
|
||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||
host=settings.LANGFUSE_BASE_URL,
|
||||
)
|
||||
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_BASE_URL)
|
||||
except Exception as exc:
|
||||
logger.warning("langfuse: failed to initialize: %s", exc)
|
||||
_client = None
|
||||
|
||||
return _client
|
||||
|
||||
|
||||
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
||||
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
|
||||
|
||||
Returns ``(raw_template, prompt_obj_or_None)``.
|
||||
|
||||
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
|
||||
on it directly; use :func:`compile_prompt` instead so the correct variable
|
||||
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
|
||||
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
|
||||
unavailable / the fetch failed. Pass this to generation observations so
|
||||
Langfuse links the generation to the exact prompt version in the UI.
|
||||
"""
|
||||
lf = get_langfuse()
|
||||
if lf is None:
|
||||
return fallback, None
|
||||
|
||||
try:
|
||||
prompt = lf.get_prompt(name, label="production", fallback=fallback)
|
||||
# For text-type prompts .prompt holds the raw template string.
|
||||
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
|
||||
return raw, prompt
|
||||
except Exception as exc:
|
||||
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
|
||||
return fallback, None
|
||||
|
||||
|
||||
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
|
||||
"""Compile *template* with *variables*, choosing the right syntax.
|
||||
|
||||
* When *prompt_obj* is a real Langfuse prompt object, calls
|
||||
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
|
||||
substitution as defined in the Langfuse UI.
|
||||
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
|
||||
falls back to ``template.format(**variables)`` which handles the
|
||||
``{variable}`` syntax used in the hardcoded fallback strings.
|
||||
|
||||
This keeps callers oblivious to which syntax is in use.
|
||||
"""
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
compiled = prompt_obj.compile(**variables)
|
||||
# compile() returns a string for text prompts.
|
||||
if isinstance(compiled, str):
|
||||
return compiled
|
||||
# Chat prompts return a list of dicts — join text parts.
|
||||
if isinstance(compiled, list):
|
||||
return "\n".join(
|
||||
m.get("content", "") for m in compiled if isinstance(m, dict)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
|
||||
getattr(prompt_obj, "name", "?"),
|
||||
exc,
|
||||
)
|
||||
return template.format(**variables)
|
||||
|
||||
|
||||
def extract_usage(response: Any) -> dict[str, int]:
|
||||
"""Extract token usage from a LangChain AI message into Langfuse format."""
|
||||
meta = getattr(response, "usage_metadata", None)
|
||||
if not meta:
|
||||
return {}
|
||||
return {
|
||||
"input": int(meta.get("input_tokens", 0)),
|
||||
"output": int(meta.get("output_tokens", 0)),
|
||||
"total": int(meta.get("total_tokens", 0)),
|
||||
}
|
||||
|
||||
|
||||
def hash_user_id(user_id: str) -> str:
|
||||
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
|
||||
|
||||
This avoids sending raw database UUIDs to external observability services
|
||||
while still providing a stable, deterministic identifier for per-user
|
||||
metrics in the Langfuse dashboard.
|
||||
"""
|
||||
return hashlib.sha256(user_id.encode()).hexdigest()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def langfuse_context(
|
||||
user_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> Generator[None, None, None]:
|
||||
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
|
||||
|
||||
No-op when Langfuse is not configured or parameters are empty.
|
||||
"""
|
||||
lf = get_langfuse()
|
||||
if lf is None or (not user_id and not session_id):
|
||||
yield
|
||||
return
|
||||
|
||||
try:
|
||||
from langfuse import propagate_attributes
|
||||
except ImportError:
|
||||
logger.debug("langfuse: propagate_attributes not available — skipping context")
|
||||
yield
|
||||
return
|
||||
|
||||
attrs: dict[str, str] = {}
|
||||
if user_id:
|
||||
attrs["user_id"] = hash_user_id(user_id)
|
||||
if session_id:
|
||||
attrs["session_id"] = session_id
|
||||
|
||||
with propagate_attributes(**attrs):
|
||||
yield
|
||||
156
app/core/llm.py
156
app/core/llm.py
@@ -1,156 +0,0 @@
|
||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||
|
||||
Every agent and the orchestrator call ``get_llm()``
|
||||
instead of directly constructing a provider-specific class. The model string
|
||||
follows the `LiteLLM model naming convention
|
||||
<https://docs.litellm.ai/docs/providers>`_:
|
||||
|
||||
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
|
||||
* Anthropic: ``anthropic/claude-3.5-sonnet``
|
||||
* Google: ``gemini/gemini-pro``
|
||||
* Ollama: ``ollama/llama3``
|
||||
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||
|
||||
Switch providers by changing **LLM_MODEL** in ``.env``
|
||||
— no code changes required.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
import litellm
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
||||
|
||||
from app.config.settings import settings
|
||||
|
||||
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
|
||||
# Drop them silently instead of raising UnsupportedParamsError.
|
||||
litellm.drop_params = True
|
||||
|
||||
# Some provider responses include a plain dict in the `usage` field where a
|
||||
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def _api_key_for_model(model: str) -> str | None:
|
||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||
if model.startswith("anthropic/"):
|
||||
return settings.ANTHROPIC_API_KEY or None
|
||||
if model.startswith("gemini/") or model.startswith("google/"):
|
||||
return settings.GOOGLE_API_KEY or None
|
||||
if model.startswith("cerebras/"):
|
||||
return settings.CEREBRAS_API_KEY or None
|
||||
if model.startswith("groq/"):
|
||||
return settings.GROQ_API_KEY or None
|
||||
if model.startswith("deepseek/"):
|
||||
return settings.DEEPSEEK_API_KEY or None
|
||||
if model.startswith("github_copilot/"):
|
||||
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||
# No API key is required; returning None lets LiteLLM handle auth.
|
||||
return None
|
||||
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
||||
return settings.OPENAI_API_KEY or None
|
||||
|
||||
|
||||
def get_llm(
|
||||
*,
|
||||
model: str | None = None,
|
||||
temperature: float = 0,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
"""Return a LangChain chat model backed by LiteLLM.
|
||||
|
||||
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
|
||||
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
|
||||
``openai`` client transparently when the model string contains a provider
|
||||
prefix (``anthropic/…``, ``gemini/…``, etc.).
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model:
|
||||
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
|
||||
temperature:
|
||||
Sampling temperature. ``0`` = deterministic.
|
||||
"""
|
||||
model = model or settings.LLM_MODEL
|
||||
|
||||
# Point LiteLLM to the custom token directory when configured.
|
||||
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||
|
||||
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
|
||||
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
|
||||
if "/" in model:
|
||||
return ChatLiteLLM(model=model, temperature=temperature)
|
||||
|
||||
return ChatOpenAI(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=_api_key_for_model(model),
|
||||
)
|
||||
|
||||
|
||||
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
||||
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
|
||||
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
|
||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
||||
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
||||
"note-summarizer": lambda: "gpt-4o-mini",
|
||||
}
|
||||
|
||||
|
||||
def model_for_agent(agent_name: str) -> str:
|
||||
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
|
||||
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
|
||||
|
||||
|
||||
def get_agent_llm(
|
||||
agent_name: str,
|
||||
*,
|
||||
temperature: float = 0,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
|
||||
|
||||
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
|
||||
per-agent override is left empty in ``.env``.
|
||||
"""
|
||||
model = model_for_agent(agent_name)
|
||||
return get_llm(model=model, temperature=temperature)
|
||||
|
||||
|
||||
async def embed(text: str) -> list[float]:
|
||||
"""Return an embedding vector for *text*.
|
||||
|
||||
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
|
||||
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
|
||||
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
|
||||
model names to preserve existing behaviour.
|
||||
"""
|
||||
model = settings.LLM_EMBED_MODEL
|
||||
|
||||
if model.startswith("github_copilot/") or "/" in model:
|
||||
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
|
||||
# so the provider's auth mechanism is applied correctly.
|
||||
response = await litellm.aembedding(model=model, input=[text])
|
||||
return response.data[0]["embedding"]
|
||||
|
||||
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
|
||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
response = await client.embeddings.create(model=model, input=text)
|
||||
return response.data[0].embedding
|
||||
@@ -1,450 +0,0 @@
|
||||
"""Mem0-style Extract/Update pipeline — Phase 2.
|
||||
|
||||
Runs after every ``store_episode`` call to distil durable facts, preferences,
|
||||
routines, and relations from the latest conversation turn.
|
||||
|
||||
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
|
||||
|
||||
Design notes
|
||||
------------
|
||||
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
|
||||
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
|
||||
- Zero-trust: never logs decrypted user content; relation subject/object labels are
|
||||
treated as identifiers (safe to log per spec).
|
||||
- Must not raise into the request path — caller wraps in asyncio.create_task().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
|
||||
from app.core.llm import get_agent_llm, model_for_agent
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
|
||||
|
||||
_EXTRACTION_FALLBACK = (
|
||||
"You are a memory extractor for a personal AI secretary. Given the last conversation "
|
||||
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
|
||||
"preferences, routines, and person/project relations worth remembering.\n\n"
|
||||
"Output JSON matching this schema exactly:\n"
|
||||
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
|
||||
'"content": "<short canonical statement>", '
|
||||
'"target_tier": "<core|associative|relational|proactive>", '
|
||||
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
|
||||
"Rules:\n"
|
||||
"- Skip small talk, greetings, one-off questions.\n"
|
||||
"- Max 5 candidates per call.\n"
|
||||
"- Only extract durable information (still true next week).\n"
|
||||
"- For type=relation: subject/predicate/object required.\n"
|
||||
"- Default confidence=0.7.\n\n"
|
||||
"## Last turn\n{last_turn}\n\n"
|
||||
"## Core memory (current)\n{core_memory}\n\n"
|
||||
"## Recent episodes\n{recent_episodes}"
|
||||
)
|
||||
|
||||
_DECIDE_FALLBACK = (
|
||||
"You are a memory update decision engine. Given a new memory candidate and a list of "
|
||||
"existing memories from the same tier, decide what action to take.\n\n"
|
||||
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
|
||||
"- ADD: new information not in existing memories.\n"
|
||||
"- UPDATE: contradicts or supersedes an existing memory.\n"
|
||||
"- DELETE: states something is no longer true.\n"
|
||||
"- NOOP: already captured accurately.\n\n"
|
||||
"## New candidate\n{candidate}\n\n"
|
||||
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
|
||||
)
|
||||
|
||||
|
||||
# ── Pydantic schemas ───────────────────────────────────────────────────────────
|
||||
|
||||
class MemoryCandidate(BaseModel):
|
||||
type: Literal["fact", "preference", "relation", "routine"]
|
||||
content: str
|
||||
target_tier: Literal["core", "associative", "relational", "proactive"]
|
||||
subject: str | None = None
|
||||
predicate: str | None = None
|
||||
object: str | None = None
|
||||
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class ExtractionResult(BaseModel):
|
||||
candidates: list[MemoryCandidate] = Field(default_factory=list)
|
||||
|
||||
|
||||
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
|
||||
|
||||
async def extract_candidates(
|
||||
last_turn: str,
|
||||
core_memory: dict[str, str],
|
||||
recent_episodes: list[str],
|
||||
) -> ExtractionResult:
|
||||
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
|
||||
|
||||
Returns an ExtractionResult (may be empty on failure — never raises).
|
||||
"""
|
||||
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
|
||||
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
|
||||
|
||||
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: compile failed: %s", exc)
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
else:
|
||||
system_text = template.format(
|
||||
last_turn=last_turn,
|
||||
core_memory=core_str,
|
||||
recent_episodes=episodes_str,
|
||||
)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
# Bind JSON mode so the model always returns parseable output.
|
||||
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||
|
||||
lf = get_langfuse()
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Extract memory candidates as JSON."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-extraction",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
|
||||
raw = json.loads(response.content)
|
||||
result = ExtractionResult.model_validate(raw)
|
||||
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
|
||||
return result
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
|
||||
return ExtractionResult(candidates=[])
|
||||
|
||||
|
||||
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
|
||||
|
||||
async def decide_action(
|
||||
candidate: MemoryCandidate,
|
||||
existing: list[str],
|
||||
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
|
||||
"""Decide what to do with a candidate given existing memories in the same tier.
|
||||
|
||||
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
|
||||
Never raises.
|
||||
"""
|
||||
if not existing:
|
||||
return "ADD"
|
||||
|
||||
candidate_str = f"[{candidate.type}] {candidate.content}"
|
||||
existing_str = "\n".join(f"- {m}" for m in existing)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
|
||||
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(
|
||||
candidate=candidate_str,
|
||||
existing_memories=existing_str,
|
||||
)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide compile failed: %s", exc)
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
else:
|
||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
||||
|
||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
|
||||
try:
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Decide action."),
|
||||
]
|
||||
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-decide-action",
|
||||
model=model_for_agent("memory-extractor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
verb = response.content.strip().upper()
|
||||
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
|
||||
return verb # type: ignore[return-value]
|
||||
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
|
||||
return "ADD"
|
||||
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: decide_action failed: %s", exc)
|
||||
return "ADD"
|
||||
|
||||
|
||||
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
|
||||
|
||||
async def run_extraction(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Full Mem0-style extract/update pipeline for one conversation turn.
|
||||
|
||||
Steps:
|
||||
1. Load core memory + last 5 episodes.
|
||||
2. extract_candidates() → up to 5 MemoryCandidate objects.
|
||||
3. For each candidate: find top-3 neighbours → decide_action() → apply.
|
||||
4. Trace via Langfuse.
|
||||
|
||||
Never raises — wraps everything in try/except.
|
||||
"""
|
||||
try:
|
||||
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _run_extraction_inner(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||
|
||||
middleware = MemoryMiddleware(db)
|
||||
fernet = await middleware._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
|
||||
return
|
||||
|
||||
# 1. Load context
|
||||
core: dict[str, str] = await middleware._load_core(user_id, fernet)
|
||||
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
|
||||
|
||||
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
|
||||
|
||||
lf = get_langfuse()
|
||||
|
||||
async def _run(trace_id: str | None) -> dict[str, Any]:
|
||||
# 2. Extract candidates
|
||||
result = await extract_candidates(last_turn, core, episodes)
|
||||
if not result.candidates:
|
||||
logger.info("memory_extraction: no candidates user=%s", user_id)
|
||||
return {"candidates": 0, "applied": 0}
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: processing %d candidates user=%s trace=%s",
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
trace_id or "-",
|
||||
)
|
||||
|
||||
# 3. Apply each candidate
|
||||
applied = 0
|
||||
actions: list[str] = []
|
||||
for candidate in result.candidates:
|
||||
try:
|
||||
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
|
||||
applied += 1
|
||||
actions.append(f"{candidate.type}:{candidate.target_tier}")
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_extraction: apply failed candidate=%r user=%s: %s",
|
||||
candidate.content[:80],
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
logger.info(
|
||||
"memory_extraction: applied %d/%d candidates user=%s",
|
||||
applied,
|
||||
len(result.candidates),
|
||||
user_id,
|
||||
)
|
||||
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
|
||||
|
||||
with langfuse_context(user_id=user_id, session_id=session_id):
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
name="memory-extraction-pipeline",
|
||||
input={"last_turn_preview": last_turn[:200]},
|
||||
) as span:
|
||||
summary = await _run(trace_id=span.id)
|
||||
span.update(output=summary)
|
||||
try:
|
||||
lf.flush()
|
||||
except Exception:
|
||||
pass
|
||||
else:
|
||||
await _run(trace_id=None)
|
||||
|
||||
|
||||
async def _apply_candidate(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Any,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Fetch neighbours, decide action, apply to the appropriate tier."""
|
||||
|
||||
neighbours: list[str] = []
|
||||
|
||||
if candidate.target_tier == "core":
|
||||
# For core tier: neighbours are existing core block values for similar keys.
|
||||
blocks = await middleware.list_core_blocks(user_id)
|
||||
neighbours = [b["value"] for b in blocks[:3]]
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
|
||||
|
||||
elif candidate.target_tier == "relational":
|
||||
# Relation candidates handled specially — passed to upsert_relation directly.
|
||||
# Neighbours: search by subject label if available.
|
||||
neighbours = []
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
|
||||
|
||||
action = await decide_action(candidate, neighbours)
|
||||
logger.info(
|
||||
"memory_extraction: candidate type=%s tier=%s action=%s",
|
||||
candidate.type,
|
||||
candidate.target_tier,
|
||||
action,
|
||||
)
|
||||
|
||||
if action == "NOOP":
|
||||
return
|
||||
|
||||
if candidate.target_tier == "relational":
|
||||
# Always upsert relations — decide_action skipped (no neighbour search).
|
||||
if candidate.subject and candidate.predicate and candidate.object:
|
||||
await _upsert_relation(
|
||||
middleware, db, user_id, candidate, trace_id
|
||||
)
|
||||
return
|
||||
|
||||
if action in ("ADD", "UPDATE"):
|
||||
if candidate.target_tier == "core":
|
||||
# Derive a short key from the content (first 40 chars, snake_cased).
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
|
||||
|
||||
elif candidate.target_tier == "associative":
|
||||
await middleware.store_associative(user_id, candidate.content)
|
||||
|
||||
elif candidate.target_tier == "proactive":
|
||||
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
|
||||
|
||||
elif action == "DELETE":
|
||||
if candidate.target_tier == "core":
|
||||
key = _content_to_key(candidate.content)
|
||||
await middleware.delete_core(user_id, key)
|
||||
|
||||
|
||||
def _content_to_key(content: str) -> str:
|
||||
"""Derive a short snake_case key from a content string (first 40 chars)."""
|
||||
import re # noqa: PLC0415
|
||||
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
|
||||
return slug or "memory"
|
||||
|
||||
|
||||
async def _upsert_relation(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
trace_id: str | None,
|
||||
) -> None:
|
||||
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
|
||||
await middleware.upsert_relation(
|
||||
user_id=user_id,
|
||||
subject=candidate.subject or "unknown",
|
||||
subject_type="unknown",
|
||||
predicate=candidate.predicate or "related_to",
|
||||
object_=candidate.object or "unknown",
|
||||
object_type="unknown",
|
||||
confidence=candidate.confidence,
|
||||
)
|
||||
logger.info(
|
||||
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
|
||||
candidate.subject,
|
||||
candidate.predicate,
|
||||
candidate.object,
|
||||
)
|
||||
|
||||
|
||||
async def _store_proactive_stub(
|
||||
middleware: Any,
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
candidate: MemoryCandidate,
|
||||
fernet: Any,
|
||||
) -> None:
|
||||
"""Store a proactive pattern row directly (MemoryProactive model)."""
|
||||
import uuid # noqa: PLC0415
|
||||
from app.models import MemoryProactive # noqa: PLC0415
|
||||
from app.core.memory_middleware import _encrypt # noqa: PLC0415
|
||||
|
||||
encrypted = _encrypt(fernet, candidate.content)
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
pattern_encrypted=encrypted,
|
||||
confidence=candidate.confidence,
|
||||
source="inferred",
|
||||
)
|
||||
db.add(row)
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_extraction: store proactive failed: %s", exc)
|
||||
await db.rollback()
|
||||
@@ -1,581 +0,0 @@
|
||||
"""Memory maintenance jobs — Phase 3/5.
|
||||
|
||||
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
||||
|
||||
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
||||
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
|
||||
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
||||
|
||||
All are safe to call manually or from tests; they never raise.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Decay parameters for relations
|
||||
_DECAY_FACTOR = 0.95
|
||||
_DECAY_PERIOD_DAYS = 30
|
||||
_PRUNE_THRESHOLD = 0.2
|
||||
|
||||
# Proactive pattern decay: 10 % per 7 days since last sighting
|
||||
_PROACTIVE_DECAY_FACTOR = 0.9
|
||||
_PROACTIVE_DECAY_PERIOD_DAYS = 7
|
||||
_PROACTIVE_PRUNE_THRESHOLD = 0.2
|
||||
|
||||
# Mining: require at least this many episodes to attempt pattern extraction
|
||||
_MIN_EPISODES_FOR_MINING = 3
|
||||
_MINING_LOOKBACK_DAYS = 30
|
||||
|
||||
# Audit: caps to control token cost
|
||||
_AUDIT_MAX_FACTS = 50
|
||||
_AUDIT_MAX_LABELS = 100
|
||||
|
||||
|
||||
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
||||
"""Apply confidence decay to all relation rows for a user.
|
||||
|
||||
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
||||
Rows whose confidence falls below 0.2 are deleted.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _decay_relations_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted = 0
|
||||
decayed = 0
|
||||
|
||||
for row in rows:
|
||||
reference = row.last_confirmed_at or row.created_at
|
||||
if reference is None:
|
||||
continue
|
||||
if reference.tzinfo is None:
|
||||
reference = reference.replace(tzinfo=timezone.utc)
|
||||
|
||||
days_elapsed = (now - reference).days
|
||||
if days_elapsed < _DECAY_PERIOD_DAYS:
|
||||
continue
|
||||
|
||||
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
||||
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
||||
|
||||
if new_confidence < _PRUNE_THRESHOLD:
|
||||
await db.delete(row)
|
||||
deleted += 1
|
||||
logger.info(
|
||||
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
||||
"confidence=%.3f (below threshold)",
|
||||
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
||||
)
|
||||
else:
|
||||
row.confidence = new_confidence
|
||||
decayed += 1
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
||||
user_id, decayed, deleted,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
async def drain_extraction_queue(db: AsyncSession) -> None:
|
||||
"""Process pending ExtractionQueue rows for Free-tier users.
|
||||
|
||||
Each row corresponds to a stored episode that should be fed through the
|
||||
Mem0-style extraction pipeline. Rows are deleted after successful processing.
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _drain_extraction_queue_inner(db)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
||||
|
||||
|
||||
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
|
||||
from app.models import ExtractionQueue # noqa: PLC0415
|
||||
|
||||
result = await db.execute(select(ExtractionQueue))
|
||||
rows = result.scalars().all()
|
||||
|
||||
if not rows:
|
||||
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
|
||||
return
|
||||
|
||||
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
|
||||
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
|
||||
processed = 0
|
||||
for row in rows:
|
||||
try:
|
||||
await run_extraction(
|
||||
db=db,
|
||||
user_id=row.user_id,
|
||||
last_user_msg="",
|
||||
last_assistant_msg="",
|
||||
session_id=None,
|
||||
)
|
||||
await db.delete(row)
|
||||
await db.commit()
|
||||
processed += 1
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: drain failed row=%s user=%s: %s",
|
||||
row.id, row.user_id, exc,
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
|
||||
|
||||
|
||||
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
|
||||
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
|
||||
|
||||
Steps:
|
||||
1. Gate on proactive_mining tier feature.
|
||||
2. Load + decrypt last 30 days of episodic summaries.
|
||||
3. Call gpt-4o-mini to identify recurring patterns.
|
||||
4. Encrypt and store each pattern in memory_proactive.
|
||||
5. Apply decay to existing proactive rows.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _mine_proactive_patterns_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, db)
|
||||
if not tier_manager.check_feature(tier, "proactive_mining"):
|
||||
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
|
||||
return
|
||||
|
||||
# Load user Fernet key
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
|
||||
return
|
||||
|
||||
fernet = Fernet(user.encryption_key.encode())
|
||||
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
|
||||
|
||||
episodes_result = await db.execute(
|
||||
select(MemoryEpisodic)
|
||||
.where(
|
||||
MemoryEpisodic.user_id == user_id,
|
||||
MemoryEpisodic.created_at >= cutoff,
|
||||
)
|
||||
.order_by(MemoryEpisodic.created_at.asc())
|
||||
)
|
||||
episode_rows = episodes_result.scalars().all()
|
||||
|
||||
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
|
||||
logger.info(
|
||||
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
|
||||
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
|
||||
)
|
||||
return
|
||||
|
||||
summaries: list[str] = []
|
||||
for ep in episode_rows:
|
||||
try:
|
||||
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
|
||||
summaries.append(plaintext)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if not summaries:
|
||||
return
|
||||
|
||||
patterns = await _extract_proactive_patterns(summaries)
|
||||
if not patterns:
|
||||
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
|
||||
return
|
||||
|
||||
stored = 0
|
||||
for pattern_text in patterns:
|
||||
try:
|
||||
encrypted = fernet.encrypt(pattern_text.encode()).decode()
|
||||
row = MemoryProactive(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
pattern_encrypted=encrypted,
|
||||
confidence=0.7,
|
||||
source="inferred",
|
||||
)
|
||||
db.add(row)
|
||||
stored += 1
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
|
||||
user_id, stored,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
return
|
||||
|
||||
await _decay_proactive_patterns(db, user_id, fernet)
|
||||
|
||||
|
||||
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
|
||||
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
|
||||
from app.core.llm import get_agent_llm # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-miner", temperature=0)
|
||||
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
|
||||
prompt = (
|
||||
"You are analyzing conversation history for a personal AI secretary. "
|
||||
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
|
||||
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
|
||||
"Return each pattern as a plain, short English sentence on its own line. "
|
||||
"No numbering, no bullet points, no extra text.\n\n"
|
||||
f"Conversation history:\n{combined}"
|
||||
)
|
||||
try:
|
||||
response = await llm.ainvoke(prompt)
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
|
||||
return lines[:5]
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
|
||||
"""Decay confidence of existing proactive patterns; prune below threshold."""
|
||||
result = await db.execute(
|
||||
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
now = datetime.now(timezone.utc)
|
||||
deleted = 0
|
||||
decayed = 0
|
||||
|
||||
for row in rows:
|
||||
reference = row.created_at
|
||||
if reference is None:
|
||||
continue
|
||||
if reference.tzinfo is None:
|
||||
reference = reference.replace(tzinfo=timezone.utc)
|
||||
|
||||
days_elapsed = (now - reference).days
|
||||
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
|
||||
continue
|
||||
|
||||
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
|
||||
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
|
||||
|
||||
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
|
||||
await db.delete(row)
|
||||
deleted += 1
|
||||
else:
|
||||
row.confidence = new_confidence
|
||||
decayed += 1
|
||||
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
|
||||
user_id, decayed, deleted,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
||||
await db.rollback()
|
||||
|
||||
|
||||
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
|
||||
|
||||
_AUDIT_CONTRADICTIONS_FALLBACK = (
|
||||
"You are auditing a personal AI assistant's memory bank. "
|
||||
"Each fact has an ID in brackets. "
|
||||
"Find pairs that directly contradict each other "
|
||||
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
|
||||
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
|
||||
'Return ONLY a valid JSON array, no markdown fences: '
|
||||
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
|
||||
"If no contradictions, return [].\n\n"
|
||||
"Facts:\n{facts}"
|
||||
)
|
||||
|
||||
_AUDIT_CANONICALIZE_FALLBACK = (
|
||||
"You are auditing entity labels in a personal AI assistant's relational memory. "
|
||||
"These are names of people, companies, projects, or topics. "
|
||||
"Group labels that clearly refer to the same real-world entity "
|
||||
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
|
||||
"Return ONLY a valid JSON array, no markdown fences: "
|
||||
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
|
||||
"Only include groups with at least one variant. Singletons: omit.\n\n"
|
||||
"Labels:\n{labels}"
|
||||
)
|
||||
|
||||
|
||||
async def audit_memory(db: AsyncSession, user_id: str) -> None:
|
||||
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
|
||||
|
||||
Steps:
|
||||
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
|
||||
2. LLM flags rows to delete (direct contradictions); hard-delete them.
|
||||
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
|
||||
4. Rewrite variant labels to their canonical form in-place.
|
||||
|
||||
Never raises — wraps in try/except.
|
||||
"""
|
||||
try:
|
||||
await _audit_memory_inner(db, user_id)
|
||||
except Exception as exc:
|
||||
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
|
||||
|
||||
|
||||
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
|
||||
result = await db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
|
||||
return
|
||||
|
||||
fernet = Fernet(user.encryption_key.encode())
|
||||
await _scan_associative_contradictions(db, user_id, fernet)
|
||||
await _canonicalize_relation_labels(db, user_id)
|
||||
|
||||
|
||||
async def _scan_associative_contradictions(
|
||||
db: AsyncSession,
|
||||
user_id: str,
|
||||
fernet: Fernet,
|
||||
) -> None:
|
||||
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
|
||||
result = await db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(_AUDIT_MAX_FACTS)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if len(rows) < 2:
|
||||
return
|
||||
|
||||
id_to_text: dict[str, str] = {}
|
||||
for row in rows:
|
||||
try:
|
||||
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
|
||||
id_to_text[row.id] = plaintext
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if len(id_to_text) < 2:
|
||||
return
|
||||
|
||||
id_list = list(id_to_text.keys())
|
||||
numbered = "\n".join(
|
||||
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
|
||||
)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, facts=numbered)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Audit facts for contradictions."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-contradictions",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
deletions = json.loads(text.strip())
|
||||
if not isinstance(deletions, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
deleted = 0
|
||||
for item in deletions:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
rid = item.get("delete")
|
||||
if not rid or rid not in id_to_text:
|
||||
continue
|
||||
result2 = await db.execute(
|
||||
select(MemoryAssociative).where(
|
||||
MemoryAssociative.id == rid,
|
||||
MemoryAssociative.user_id == user_id,
|
||||
)
|
||||
)
|
||||
target = result2.scalar_one_or_none()
|
||||
if target:
|
||||
await db.delete(target)
|
||||
deleted += 1
|
||||
logger.info(
|
||||
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
|
||||
rid, user_id, item.get("reason", ""),
|
||||
)
|
||||
|
||||
if deleted:
|
||||
try:
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
|
||||
logger.info(
|
||||
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
|
||||
)
|
||||
|
||||
|
||||
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
|
||||
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
|
||||
result = await db.execute(
|
||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
if not rows:
|
||||
return
|
||||
|
||||
all_labels: set[str] = set()
|
||||
for row in rows:
|
||||
all_labels.add(row.subject_label)
|
||||
all_labels.add(row.object_label)
|
||||
|
||||
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
|
||||
if len(labels_list) < 2:
|
||||
return
|
||||
|
||||
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
|
||||
template, prompt_obj = get_prompt_or_fallback(
|
||||
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
|
||||
)
|
||||
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
|
||||
|
||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
||||
lf = get_langfuse()
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Canonicalize entity labels."),
|
||||
]
|
||||
try:
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="memory-audit-canonicalize",
|
||||
model=model_for_agent("memory-auditor"),
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm.ainvoke(messages)
|
||||
|
||||
text = response.content if hasattr(response, "content") else str(response)
|
||||
groups = json.loads(text.strip())
|
||||
if not isinstance(groups, list):
|
||||
return
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
|
||||
user_id, exc,
|
||||
)
|
||||
return
|
||||
|
||||
# Build variant → canonical map
|
||||
remap: dict[str, str] = {}
|
||||
for group in groups:
|
||||
if not isinstance(group, dict):
|
||||
continue
|
||||
canonical = group.get("canonical", "")
|
||||
variants = group.get("variants") or []
|
||||
if not canonical:
|
||||
continue
|
||||
for v in variants:
|
||||
if isinstance(v, str) and v != canonical:
|
||||
remap[v] = canonical
|
||||
|
||||
if not remap:
|
||||
return
|
||||
|
||||
updated = 0
|
||||
for row in rows:
|
||||
changed = False
|
||||
if row.subject_label in remap:
|
||||
row.subject_label = remap[row.subject_label]
|
||||
changed = True
|
||||
if row.object_label in remap:
|
||||
row.object_label = remap[row.object_label]
|
||||
changed = True
|
||||
if changed:
|
||||
updated += 1
|
||||
|
||||
if updated:
|
||||
try:
|
||||
await db.commit()
|
||||
logger.info(
|
||||
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
|
||||
user_id, updated,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await db.rollback()
|
||||
@@ -1,733 +0,0 @@
|
||||
"""Memory Middleware — enrich requests with memory context and store interactions.
|
||||
|
||||
Four-tier memory model (MemGPT-style):
|
||||
core — persistent key/value user preferences, always injected
|
||||
associative — semantic similarity search via pgvector (top-k)
|
||||
episodic — recent session summaries (last N)
|
||||
proactive — behavioral patterns above confidence threshold
|
||||
|
||||
All memory content is encrypted at rest using the per-user Fernet key
|
||||
stored in User.encryption_key. Decryption happens in-memory only.
|
||||
|
||||
Usage:
|
||||
memory = MemoryMiddleware(db_session)
|
||||
context = await memory.enrich_context(user_id, message)
|
||||
# ... run agent ...
|
||||
await memory.store_episode(user_id, session_id, message, response)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.models import (
|
||||
ExtractionQueue,
|
||||
MemoryAssociative,
|
||||
MemoryCore,
|
||||
MemoryEpisodic,
|
||||
MemoryProactive,
|
||||
MemoryRelation,
|
||||
User,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _now() -> datetime:
|
||||
return datetime.now(timezone.utc)
|
||||
|
||||
|
||||
# Tuning constants
|
||||
_ASSOCIATIVE_TOP_K = 5
|
||||
_EPISODIC_RECENT_N = 10
|
||||
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||
|
||||
|
||||
class MemoryMiddleware:
|
||||
"""Enrich orchestrator context with memory and persist interactions after."""
|
||||
|
||||
def __init__(self, db: AsyncSession) -> None:
|
||||
self._db = db
|
||||
|
||||
# ── Public API ────────────────────────────────────────────────────────────
|
||||
|
||||
async def enrich_context(
|
||||
self,
|
||||
user_id: str,
|
||||
message: str,
|
||||
trace_id: str | None = None,
|
||||
session_id: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
||||
|
||||
Returns a dict with keys:
|
||||
core_memory — {key: plaintext_value, ...}
|
||||
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return {}
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier: str = user_dbg.get("tier") or "free"
|
||||
|
||||
core = await self._load_core(user_id, fernet)
|
||||
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||
proactive = await self._load_proactive(user_id, fernet)
|
||||
relational = await self._load_relational(user_id, user_tier=user_tier)
|
||||
|
||||
logger.info(
|
||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_tier,
|
||||
len(core),
|
||||
len(associative),
|
||||
len(episodic),
|
||||
len(proactive),
|
||||
len(relational),
|
||||
)
|
||||
|
||||
return {
|
||||
"core_memory": core,
|
||||
"associative_memory": associative,
|
||||
"episodic_memory": episodic,
|
||||
"proactive_hints": proactive,
|
||||
"relational_memory": relational,
|
||||
}
|
||||
|
||||
async def store_episode(
|
||||
self,
|
||||
user_id: str,
|
||||
session_id: str,
|
||||
message: str,
|
||||
response: str,
|
||||
trace_id: str | None = None,
|
||||
) -> None:
|
||||
"""Summarise and store a completed interaction in episodic memory.
|
||||
|
||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||
latency low. After committing the episode row, dispatches the Mem0-style
|
||||
extraction pipeline:
|
||||
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
||||
- Free → enqueue an ExtractionQueue row for the daily cron.
|
||||
"""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||
encrypted = _encrypt(fernet, summary)
|
||||
|
||||
episode = MemoryEpisodic(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
summary_encrypted=encrypted,
|
||||
session_id=session_id,
|
||||
)
|
||||
self._db.add(episode)
|
||||
episode_id: str = episode.id
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
tier = user_dbg.get("tier") or "free"
|
||||
logger.info(
|
||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
tier,
|
||||
session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
return
|
||||
|
||||
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
||||
await self._dispatch_extraction(
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
last_user_msg=message,
|
||||
last_assistant_msg=response,
|
||||
session_id=session_id,
|
||||
)
|
||||
|
||||
async def _dispatch_extraction(
|
||||
self,
|
||||
user_id: str,
|
||||
episode_id: str,
|
||||
last_user_msg: str,
|
||||
last_assistant_msg: str,
|
||||
session_id: str | None,
|
||||
) -> None:
|
||||
"""Route extraction to realtime task or batch queue based on user tier."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
tier = await tier_manager.get_tier(user_id, self._db)
|
||||
|
||||
if tier_manager.check_feature(tier, "realtime_extraction"):
|
||||
# Pro/Power/Team: fire-and-forget in the background.
|
||||
# Must open a fresh session — request session closes after handler returns.
|
||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
|
||||
async def _task() -> None:
|
||||
try:
|
||||
async with async_session() as fresh_db:
|
||||
await run_extraction(
|
||||
db=fresh_db,
|
||||
user_id=user_id,
|
||||
last_user_msg=last_user_msg,
|
||||
last_assistant_msg=last_assistant_msg,
|
||||
session_id=session_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction task failed user=%s: %s", user_id, exc
|
||||
)
|
||||
|
||||
asyncio.create_task(_task())
|
||||
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
||||
else:
|
||||
# Free tier: enqueue for daily batch cron.
|
||||
queue_row = ExtractionQueue(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
episode_id=episode_id,
|
||||
)
|
||||
self._db.add(queue_row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: extraction enqueued (batch) user=%s episode=%s",
|
||||
user_id,
|
||||
episode_id,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
||||
)
|
||||
await self._db.rollback()
|
||||
|
||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
||||
"""Upsert a core memory key/value for a user."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, value)
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(
|
||||
MemoryCore.user_id == user_id,
|
||||
MemoryCore.key == key,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
if existing is not None:
|
||||
existing.value_encrypted = encrypted
|
||||
else:
|
||||
self._db.add(MemoryCore(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
key=key,
|
||||
value_encrypted=encrypted,
|
||||
))
|
||||
try:
|
||||
await self._db.commit()
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
logger.info(
|
||||
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
||||
trace_id or "-",
|
||||
user_id,
|
||||
user_dbg.get("tier") or "-",
|
||||
key,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
||||
"""Return core memory as editable blocks (label/value)."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore)
|
||||
.where(MemoryCore.user_id == user_id)
|
||||
.order_by(MemoryCore.key.asc())
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[dict[str, str]] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append({"label": row.key, "value": plaintext})
|
||||
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
||||
return out
|
||||
|
||||
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
||||
"""Return a single core memory block value by label."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return None
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(
|
||||
MemoryCore.user_id == user_id,
|
||||
MemoryCore.key == label,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
||||
return None
|
||||
value = _safe_decrypt(fernet, row.value_encrypted)
|
||||
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
||||
return value
|
||||
|
||||
async def delete_core(self, user_id: str, label: str) -> bool:
|
||||
"""Delete a core memory block by label. Returns True if deleted."""
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(
|
||||
MemoryCore.user_id == user_id,
|
||||
MemoryCore.key == label,
|
||||
)
|
||||
)
|
||||
row = result.scalar_one_or_none()
|
||||
if row is None:
|
||||
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
||||
return False
|
||||
|
||||
await self._db.delete(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
||||
return True
|
||||
except Exception as exc:
|
||||
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
||||
await self._db.rollback()
|
||||
return False
|
||||
|
||||
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
||||
"""Append content to a core block, creating it if missing."""
|
||||
current = await self.get_core_block(user_id, label)
|
||||
if current is None:
|
||||
await self.update_core(user_id, label, content)
|
||||
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
||||
return
|
||||
await self.update_core(user_id, label, f"{current}\n{content}")
|
||||
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
||||
|
||||
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
||||
"""Replace one exact string inside a core block. Returns False if not found."""
|
||||
current = await self.get_core_block(user_id, label)
|
||||
if current is None or old not in current:
|
||||
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
||||
return False
|
||||
await self.update_core(user_id, label, current.replace(old, new, 1))
|
||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||
return True
|
||||
|
||||
async def store_associative(
|
||||
self,
|
||||
user_id: str,
|
||||
content: str,
|
||||
entity_type: str | None = None,
|
||||
entity_id: str | None = None,
|
||||
) -> None:
|
||||
"""Store associative memory; embed if user tier has real_embeddings."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, content)
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier = user_dbg.get("tier") or "free"
|
||||
|
||||
embedding: list[float] | None = None
|
||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||
embedding = await embed_text(content)
|
||||
|
||||
row = MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=encrypted,
|
||||
embedding=embedding,
|
||||
entity_type=entity_type,
|
||||
entity_id=entity_id,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: store_associative user=%s embedded=%s",
|
||||
user_id,
|
||||
embedding is not None,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def upsert_relation(
|
||||
self,
|
||||
user_id: str,
|
||||
subject: str,
|
||||
subject_type: str,
|
||||
predicate: str,
|
||||
object_: str,
|
||||
object_type: str,
|
||||
*,
|
||||
confidence: float = 0.7,
|
||||
source_episode_id: str | None = None,
|
||||
notes: str | None = None,
|
||||
) -> None:
|
||||
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
|
||||
|
||||
subject_label / object_label are plaintext entity identifiers — not encrypted.
|
||||
notes is optional; encrypted with user Fernet if provided.
|
||||
"""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
user_dbg = await self._get_user_debug(user_id)
|
||||
user_tier = user_dbg.get("tier") or "free"
|
||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
|
||||
return
|
||||
|
||||
notes_encrypted: bytes | None = None
|
||||
if notes:
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet:
|
||||
notes_encrypted = fernet.encrypt(notes.encode())
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryRelation).where(
|
||||
MemoryRelation.user_id == user_id,
|
||||
MemoryRelation.subject_label == subject,
|
||||
MemoryRelation.predicate == predicate,
|
||||
MemoryRelation.object_label == object_,
|
||||
)
|
||||
)
|
||||
existing = result.scalar_one_or_none()
|
||||
|
||||
if existing is not None:
|
||||
existing.subject_type = subject_type
|
||||
existing.object_type = object_type
|
||||
existing.confidence = confidence
|
||||
existing.last_confirmed_at = _now()
|
||||
if notes_encrypted is not None:
|
||||
existing.notes_encrypted = notes_encrypted
|
||||
else:
|
||||
self._db.add(MemoryRelation(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
subject_label=subject,
|
||||
subject_type=subject_type,
|
||||
predicate=predicate,
|
||||
object_label=object_,
|
||||
object_type=object_type,
|
||||
confidence=confidence,
|
||||
source_episode_id=source_episode_id,
|
||||
notes_encrypted=notes_encrypted,
|
||||
))
|
||||
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info(
|
||||
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
|
||||
user_id, subject, predicate, object_,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def query_relations(
|
||||
self,
|
||||
user_id: str,
|
||||
subject: str | None = None,
|
||||
predicate: str | None = None,
|
||||
object_: str | None = None,
|
||||
limit: int = 20,
|
||||
) -> list[MemoryRelation]:
|
||||
"""Query relation rows for a user with optional filters."""
|
||||
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
||||
if subject is not None:
|
||||
q = q.where(MemoryRelation.subject_label == subject)
|
||||
if predicate is not None:
|
||||
q = q.where(MemoryRelation.predicate == predicate)
|
||||
if object_ is not None:
|
||||
q = q.where(MemoryRelation.object_label == object_)
|
||||
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
|
||||
result = await self._db.execute(q)
|
||||
return list(result.scalars().all())
|
||||
|
||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||
"""Insert a long-term archival memory entry."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return
|
||||
|
||||
encrypted = _encrypt(fernet, content)
|
||||
row = MemoryAssociative(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
content_encrypted=encrypted,
|
||||
embedding=None,
|
||||
entity_type=source,
|
||||
entity_id=None,
|
||||
)
|
||||
self._db.add(row)
|
||||
try:
|
||||
await self._db.commit()
|
||||
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
||||
except Exception as exc:
|
||||
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
||||
await self._db.rollback()
|
||||
|
||||
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(100)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
needle = query.strip().lower()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is None:
|
||||
continue
|
||||
if not needle or needle in plaintext.lower():
|
||||
out.append(plaintext)
|
||||
if len(out) >= max(top_k, 1):
|
||||
break
|
||||
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||
return out
|
||||
|
||||
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
||||
"""Search recall memory (episodic summaries) by keyword."""
|
||||
fernet = await self._get_fernet(user_id)
|
||||
if fernet is None:
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryEpisodic)
|
||||
.where(MemoryEpisodic.user_id == user_id)
|
||||
.order_by(MemoryEpisodic.created_at.desc())
|
||||
.limit(100)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
needle = query.strip().lower()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||
if plaintext is None:
|
||||
continue
|
||||
if not needle or needle in plaintext.lower():
|
||||
out.append(plaintext)
|
||||
if len(out) >= max(top_k, 1):
|
||||
break
|
||||
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
||||
return out
|
||||
|
||||
# ── Private helpers ───────────────────────────────────────────────────────
|
||||
|
||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||
"""Load the user's Fernet key from DB. Returns None if missing."""
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not user.encryption_key:
|
||||
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||
return None
|
||||
return Fernet(user.encryption_key.encode())
|
||||
|
||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
||||
"""Load lightweight user debug fields for trace logs."""
|
||||
from app.config.settings import settings # noqa: PLC0415
|
||||
from app.models import Subscription # noqa: PLC0415
|
||||
|
||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None:
|
||||
return {"tier": None}
|
||||
|
||||
sub_result = await self._db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
sub_tier: str | None = sub_result.scalar_one_or_none()
|
||||
if sub_tier:
|
||||
tier = sub_tier
|
||||
elif settings.ENV == "dev":
|
||||
tier = "power"
|
||||
else:
|
||||
tier = user.tier or "free"
|
||||
|
||||
return {"tier": tier}
|
||||
|
||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: dict[str, str] = {}
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||
if plaintext is not None:
|
||||
out[row.key] = plaintext
|
||||
return out
|
||||
|
||||
async def _load_associative(
|
||||
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
||||
) -> list[str]:
|
||||
"""Load top-k associative memories.
|
||||
|
||||
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
||||
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
||||
"""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||
|
||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||
vec = await embed_text(message)
|
||||
if vec is not None:
|
||||
try:
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(
|
||||
MemoryAssociative.user_id == user_id,
|
||||
MemoryAssociative.embedding.isnot(None),
|
||||
)
|
||||
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
||||
.limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
logger.info(
|
||||
"memory: _load_associative user=%s mode=vector hits=%d",
|
||||
user_id,
|
||||
len(out),
|
||||
)
|
||||
return out
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"memory: vector search failed user=%s, falling back to keyword: %s",
|
||||
user_id,
|
||||
exc,
|
||||
)
|
||||
|
||||
# Keyword fallback: most recent rows
|
||||
result = await self._db.execute(
|
||||
select(MemoryAssociative)
|
||||
.where(MemoryAssociative.user_id == user_id)
|
||||
.order_by(MemoryAssociative.updated_at.desc())
|
||||
.limit(_ASSOCIATIVE_TOP_K)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_episodic(
|
||||
self,
|
||||
user_id: str,
|
||||
fernet: Fernet,
|
||||
session_id: str | None = None,
|
||||
) -> list[str]:
|
||||
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
||||
if session_id:
|
||||
query = query.where(MemoryEpisodic.session_id == session_id)
|
||||
result = await self._db.execute(
|
||||
query
|
||||
.order_by(MemoryEpisodic.created_at.desc())
|
||||
.limit(_EPISODIC_RECENT_N)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
|
||||
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
|
||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
||||
return []
|
||||
|
||||
result = await self._db.execute(
|
||||
select(MemoryRelation)
|
||||
.where(MemoryRelation.user_id == user_id)
|
||||
.order_by(MemoryRelation.confidence.desc())
|
||||
.limit(10)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out = [
|
||||
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
|
||||
for r in rows
|
||||
]
|
||||
return out
|
||||
|
||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||
result = await self._db.execute(
|
||||
select(MemoryProactive)
|
||||
.where(
|
||||
MemoryProactive.user_id == user_id,
|
||||
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||
)
|
||||
.order_by(MemoryProactive.confidence.desc())
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
out: list[str] = []
|
||||
for row in rows:
|
||||
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||
if plaintext is not None:
|
||||
out.append(plaintext)
|
||||
return out
|
||||
|
||||
|
||||
# ── Encryption helpers ────────────────────────────────────────────────────────
|
||||
|
||||
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||
return fernet.encrypt(plaintext.encode()).decode()
|
||||
|
||||
|
||||
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
||||
try:
|
||||
return fernet.decrypt(ciphertext.encode()).decode()
|
||||
except (InvalidToken, Exception) as exc:
|
||||
logger.warning("memory: decrypt failed: %s", exc)
|
||||
return None
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Note summarizer — generates a compact AI summary for a note.
|
||||
|
||||
Called fire-and-forget from create_note / update_note tools so the
|
||||
``notes.ai_summary`` column stays current without blocking the agent loop.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
|
||||
from app.core.langfuse_client import get_prompt_or_fallback
|
||||
from app.core.llm import get_agent_llm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FALLBACK_PROMPT = """\
|
||||
Summarize this note in <=250 characters. Be terse and dense.
|
||||
Keep proper nouns, dates, decisions, and action items.
|
||||
Do not start with "This note".
|
||||
Respond with the summary text only — no intro, no labels.
|
||||
|
||||
Title: {title}
|
||||
Content: {content}"""
|
||||
|
||||
_MAX_CONTENT_CHARS = 4000
|
||||
|
||||
|
||||
async def generate_note_summary(title: str, content: str) -> str:
|
||||
"""Return a <=250-char summary of *title* + *content*.
|
||||
|
||||
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
|
||||
fallback. Truncates *content* to 4000 chars before sending to avoid
|
||||
token waste on large notes.
|
||||
"""
|
||||
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
|
||||
trimmed = content[:_MAX_CONTENT_CHARS]
|
||||
system_prompt = template.format(title=title, content=trimmed)
|
||||
|
||||
try:
|
||||
llm = get_agent_llm("note-summarizer")
|
||||
response = await llm.ainvoke([
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content="Generate the summary."),
|
||||
])
|
||||
text = response.content if isinstance(response.content, str) else ""
|
||||
return text.strip()[:250]
|
||||
except Exception as exc:
|
||||
logger.warning("note_summarizer: failed to generate summary: %s", exc)
|
||||
return ""
|
||||
@@ -1,63 +0,0 @@
|
||||
"""Output formatter for deep-agent stream events."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any
|
||||
|
||||
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
|
||||
|
||||
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
|
||||
_CANVAS_BLOCK_RE = re.compile(
|
||||
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
|
||||
re.DOTALL | re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
|
||||
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
|
||||
|
||||
Returns ``(visible_text, canvas_content, canvas_kind)``.
|
||||
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
|
||||
"""
|
||||
match = _CANVAS_BLOCK_RE.search(text)
|
||||
if not match:
|
||||
return text, None, None
|
||||
|
||||
canvas_kind = match.group(1).strip()
|
||||
canvas_content = match.group(2).strip()
|
||||
visible = text[: match.start()] + text[match.end() :]
|
||||
visible = visible.strip()
|
||||
return visible, canvas_content, canvas_kind
|
||||
|
||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd
|
||||
|
||||
|
||||
class StreamFormatter:
|
||||
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
||||
|
||||
def __init__(self, request_id: str) -> None:
|
||||
self.request_id = request_id
|
||||
|
||||
async def format(
|
||||
self,
|
||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||
) -> AsyncGenerator[WsFrame, None]:
|
||||
started = False
|
||||
|
||||
async for event_type, data in event_stream:
|
||||
if event_type != "token":
|
||||
continue
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
started = True
|
||||
|
||||
text = str(data or "")
|
||||
if text:
|
||||
yield WsStreamText(request_id=self.request_id, chunk=text)
|
||||
|
||||
if not started:
|
||||
yield WsStreamStart(request_id=self.request_id)
|
||||
yield WsStreamEnd(request_id=self.request_id)
|
||||
@@ -1,104 +0,0 @@
|
||||
"""Preprocessor registry: detect content type and dispatch to handlers.
|
||||
|
||||
Public API
|
||||
----------
|
||||
detect_content_type(filename, raw_content) -> str
|
||||
Heuristic detection based on file extension and content patterns.
|
||||
|
||||
preprocess(content_type, raw_content) -> PreprocessResult
|
||||
Dispatch to the appropriate handler.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from app.core.preprocessors.base import PreprocessResult
|
||||
|
||||
# ── Heuristics ────────────────────────────────────────────────────────
|
||||
|
||||
# Patterns that strongly suggest an email HTML file
|
||||
_EMAIL_SIGNALS = re.compile(
|
||||
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
# Patterns that suggest a generic HTML page (not an email)
|
||||
_GENERIC_HTML_SIGNALS = re.compile(
|
||||
r"<(nav|main|header|footer|article|section)\b",
|
||||
re.IGNORECASE,
|
||||
)
|
||||
|
||||
|
||||
def detect_content_type(filename: str, raw_content: str) -> str:
|
||||
"""Return a content-type string for the given file.
|
||||
|
||||
Supported types: ``"email_html"``, ``"generic_html"``,
|
||||
``"plain_text"``, ``"unknown"``.
|
||||
"""
|
||||
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
||||
|
||||
if ext == "txt":
|
||||
return "plain_text"
|
||||
|
||||
if ext in ("html", "htm", "eml", "mhtml", "mht"):
|
||||
# Prefer email detection over generic HTML
|
||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||
return "email_html"
|
||||
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
|
||||
return "generic_html"
|
||||
# .html without clear signals — check for any email header
|
||||
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
|
||||
return "email_html"
|
||||
return "generic_html"
|
||||
|
||||
# Plain text files with email headers
|
||||
if ext in ("", "txt") or not ext:
|
||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||
return "email_html"
|
||||
|
||||
# Detect binary content
|
||||
try:
|
||||
raw_content.encode("utf-8")
|
||||
except (UnicodeEncodeError, AttributeError):
|
||||
return "unknown"
|
||||
|
||||
# Non-text bytes heuristic: high ratio of non-printable chars
|
||||
sample = raw_content[:512]
|
||||
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
|
||||
if len(sample) > 0 and non_printable / len(sample) > 0.1:
|
||||
return "unknown"
|
||||
|
||||
return "unknown"
|
||||
|
||||
|
||||
# ── Generic fallback handler ──────────────────────────────────────────
|
||||
|
||||
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
|
||||
"""Strip HTML tags if present, return text as-is."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup
|
||||
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
|
||||
except ImportError:
|
||||
# No BeautifulSoup — strip tags with a simple regex
|
||||
text = re.sub(r"<[^>]+>", "", raw_content)
|
||||
|
||||
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
||||
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
|
||||
|
||||
|
||||
# ── Dispatch ──────────────────────────────────────────────────────────
|
||||
|
||||
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
|
||||
"""Dispatch *raw_content* to the handler registered for *content_type*.
|
||||
|
||||
Falls back to the generic handler for unknown types.
|
||||
"""
|
||||
if content_type == "email_html":
|
||||
from app.core.preprocessors.email_html import preprocess_email_html
|
||||
return preprocess_email_html(raw_content)
|
||||
|
||||
return _preprocess_generic(raw_content, content_type)
|
||||
|
||||
|
||||
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]
|
||||
@@ -1,25 +0,0 @@
|
||||
"""Base types for the preprocessor system."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreprocessResult:
|
||||
"""Output of a preprocessor handler.
|
||||
|
||||
Attributes
|
||||
----------
|
||||
content_type:
|
||||
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
|
||||
clean_text:
|
||||
Human-readable text stripped of markup/binary noise.
|
||||
metadata:
|
||||
Dict of extracted metadata (keys vary by handler).
|
||||
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
|
||||
"""
|
||||
|
||||
content_type: str
|
||||
clean_text: str
|
||||
metadata: dict = field(default_factory=dict)
|
||||
@@ -1,111 +0,0 @@
|
||||
"""Preprocessor for email HTML files.
|
||||
|
||||
Handles:
|
||||
- HTML stripping via BeautifulSoup
|
||||
- Metadata extraction (Subject, From, To, Date)
|
||||
- Thread splitting — isolates the latest reply
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from app.core.preprocessors.base import PreprocessResult
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
# ── Thread split markers ──────────────────────────────────────────────
|
||||
|
||||
# Matches patterns like:
|
||||
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
|
||||
# "-----Original Message-----"
|
||||
# "> " (plain-text quote prefix)
|
||||
_THREAD_PATTERNS = [
|
||||
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
|
||||
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
|
||||
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
|
||||
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
|
||||
]
|
||||
|
||||
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
|
||||
|
||||
_META_PATTERNS: dict[str, list[re.Pattern]] = {
|
||||
"subject": [
|
||||
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
|
||||
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
|
||||
],
|
||||
"from": [
|
||||
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||
re.compile(r"From:\s*(.+)", re.IGNORECASE),
|
||||
],
|
||||
"to": [
|
||||
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||
re.compile(r"To:\s*(.+)", re.IGNORECASE),
|
||||
],
|
||||
"date": [
|
||||
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
|
||||
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _extract_metadata(raw_html: str, text: str) -> dict:
|
||||
"""Extract Subject/From/To/Date from raw HTML or plain text."""
|
||||
metadata: dict[str, str] = {}
|
||||
for field, patterns in _META_PATTERNS.items():
|
||||
for pat in patterns:
|
||||
m = pat.search(raw_html) or pat.search(text)
|
||||
if m:
|
||||
metadata[field] = m.group(1).strip()
|
||||
break
|
||||
return metadata
|
||||
|
||||
|
||||
def _split_thread(text: str) -> str:
|
||||
"""Return only the latest message in a threaded email."""
|
||||
earliest_pos: int | None = None
|
||||
for pat in _THREAD_PATTERNS:
|
||||
m = pat.search(text)
|
||||
if m and (earliest_pos is None or m.start() < earliest_pos):
|
||||
earliest_pos = m.start()
|
||||
|
||||
if earliest_pos is not None and earliest_pos > 0:
|
||||
return text[:earliest_pos].strip()
|
||||
return text.strip()
|
||||
|
||||
|
||||
def preprocess_email_html(raw_content: str) -> PreprocessResult:
|
||||
"""Strip HTML, extract metadata, split thread from an email HTML file."""
|
||||
try:
|
||||
from bs4 import BeautifulSoup # lazy import — optional dep
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"beautifulsoup4 is required for email_html preprocessing. "
|
||||
"Install it with: pip install beautifulsoup4"
|
||||
) from exc
|
||||
|
||||
# Parse with lxml if available, fall back to html.parser
|
||||
try:
|
||||
soup = BeautifulSoup(raw_content, "lxml")
|
||||
except Exception:
|
||||
soup = BeautifulSoup(raw_content, "html.parser")
|
||||
|
||||
# Remove noise tags
|
||||
for tag in soup(["style", "script", "head", "noscript"]):
|
||||
tag.decompose()
|
||||
|
||||
clean_text = soup.get_text(separator="\n")
|
||||
# Collapse excessive blank lines
|
||||
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
|
||||
|
||||
metadata = _extract_metadata(raw_content, clean_text)
|
||||
latest_message = _split_thread(clean_text)
|
||||
|
||||
return PreprocessResult(
|
||||
content_type="email_html",
|
||||
clean_text=latest_message,
|
||||
metadata=metadata,
|
||||
)
|
||||
@@ -1,115 +0,0 @@
|
||||
"""WebSocket client executor context.
|
||||
|
||||
Holds a per-request async callback that tools call to execute CRUD
|
||||
operations on the Electron client's local SQLite / LanceDB databases.
|
||||
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Callable, Coroutine
|
||||
from uuid import uuid4
|
||||
|
||||
_SNAKE_TO_CAMEL_RE = re.compile(r"_([a-z])")
|
||||
|
||||
|
||||
def _key_to_camel(key: str) -> str:
|
||||
return _SNAKE_TO_CAMEL_RE.sub(lambda m: m.group(1).upper(), key)
|
||||
|
||||
|
||||
def _keys_to_camel(obj: Any) -> Any:
|
||||
"""Recursively convert dict keys from snake_case to camelCase.
|
||||
|
||||
Mirrors the JS-side ``toCamelCase`` applied to incoming WS frames in
|
||||
``adiuvAI/src/main/api/backend-client.ts``. The Electron executor wraps
|
||||
tool_result payloads in ``toSnakeCase`` before sending; this restores the
|
||||
camelCase schema property names that the tool code expects to read.
|
||||
"""
|
||||
if isinstance(obj, dict):
|
||||
return {_key_to_camel(k): _keys_to_camel(v) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_keys_to_camel(v) for v in obj]
|
||||
return obj
|
||||
|
||||
# Holds the execute callback for the current WS session.
|
||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||
"_client_executor"
|
||||
)
|
||||
|
||||
# Optional collector that captures raw execute_on_client results.
|
||||
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||
"_tool_result_collector", default=None
|
||||
)
|
||||
|
||||
|
||||
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||
"""Register *lst* as the collector for this async context."""
|
||||
_tool_result_collector.set(lst)
|
||||
|
||||
|
||||
def clear_tool_result_collector() -> None:
|
||||
"""Clear the collector (best-effort)."""
|
||||
_tool_result_collector.set(None)
|
||||
|
||||
|
||||
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||
_client_executor.set(fn)
|
||||
|
||||
|
||||
def clear_client_executor() -> None:
|
||||
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
|
||||
try:
|
||||
_client_executor.set(None) # type: ignore[arg-type]
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def execute_on_client(
|
||||
action: str,
|
||||
table: str | None = None,
|
||||
data: dict[str, Any] | None = None,
|
||||
filters: dict[str, Any] | None = None,
|
||||
vector: list[float] | None = None,
|
||||
limit: int | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Send a CRUD/vector operation to the Electron client and return the result.
|
||||
|
||||
Builds a ``tool_call`` payload, invokes the per-session WS callback,
|
||||
and returns the ``tool_result`` dict from Electron.
|
||||
|
||||
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
|
||||
"""
|
||||
callback = _client_executor.get(None)
|
||||
if callback is None:
|
||||
raise RuntimeError(
|
||||
"execute_on_client() called outside a WebSocket session — "
|
||||
"no client executor is set."
|
||||
)
|
||||
|
||||
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
|
||||
if table is not None:
|
||||
payload["table"] = table
|
||||
if data is not None:
|
||||
payload["data"] = data
|
||||
if filters is not None:
|
||||
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||
if vector is not None:
|
||||
payload["vector"] = vector
|
||||
if limit is not None:
|
||||
payload["limit"] = limit
|
||||
|
||||
result = await callback(payload)
|
||||
result = _keys_to_camel(result)
|
||||
collector = _tool_result_collector.get(None)
|
||||
if collector is not None:
|
||||
collector.append({
|
||||
"action": action,
|
||||
"table": table,
|
||||
"data": result,
|
||||
})
|
||||
return result
|
||||
143
app/main.py
143
app/main.py
@@ -1,143 +0,0 @@
|
||||
from contextlib import asynccontextmanager
|
||||
import logging
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||
from app.config.settings import settings
|
||||
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||
)
|
||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||
|
||||
|
||||
async def _memory_audit_cron_tick() -> None:
|
||||
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
|
||||
import logging # noqa: PLC0415
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("memory audit cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
|
||||
from app.models import User # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(User.id))
|
||||
user_ids: list[str] = list(result.scalars().all())
|
||||
|
||||
for uid in user_ids:
|
||||
try:
|
||||
async with async_session() as db:
|
||||
await audit_memory(db, uid)
|
||||
except Exception as exc:
|
||||
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
|
||||
|
||||
_log.info("memory audit cron tick: done users=%d", len(user_ids))
|
||||
except Exception as exc:
|
||||
_log.warning("memory audit cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
async def _memory_cron_tick() -> None:
|
||||
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
|
||||
import logging # noqa: PLC0415
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("memory cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.core.memory_maintenance import drain_extraction_queue, mine_proactive_patterns # noqa: PLC0415
|
||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||
from app.models import User # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as db:
|
||||
await drain_extraction_queue(db)
|
||||
|
||||
# mine proactive patterns for every Power+ user
|
||||
async with async_session() as db:
|
||||
result = await db.execute(select(User.id))
|
||||
user_ids: list[str] = list(result.scalars().all())
|
||||
|
||||
for uid in user_ids:
|
||||
try:
|
||||
async with async_session() as db:
|
||||
tier = await tier_manager.get_tier(uid, db)
|
||||
if tier_manager.check_feature(tier, "proactive_mining"):
|
||||
await mine_proactive_patterns(db, uid)
|
||||
except Exception as exc:
|
||||
_log.warning("memory cron tick: mine_proactive_patterns failed user=%s: %s", uid, exc)
|
||||
|
||||
_log.info("memory cron tick: done users=%d", len(user_ids))
|
||||
except Exception as exc:
|
||||
_log.warning("memory cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup: ensure agent tool modules are loaded.
|
||||
import app.agents # noqa: F401
|
||||
|
||||
scheduler = None
|
||||
if settings.SCHEDULER_ENABLED:
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # noqa: PLC0415
|
||||
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
||||
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
||||
scheduler.start()
|
||||
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
||||
|
||||
yield
|
||||
|
||||
if scheduler is not None:
|
||||
scheduler.shutdown(wait=False)
|
||||
|
||||
# Shutdown: dispose SQLAlchemy connection pool
|
||||
from app.db import engine
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="AdiuvAI Cloud API",
|
||||
version="0.1.0",
|
||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
|
||||
# Request flow: TierRateLimit → Sanitizer → CORS → Router
|
||||
# Response flow: Router → CORS → Sanitizer → TierRateLimit
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import agents, auth, billing, chat, device_ws, memory
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(billing.router, prefix="/api/v1")
|
||||
app.include_router(agents.router, prefix="/api/v1")
|
||||
app.include_router(device_ws.router, prefix="/api/v1")
|
||||
app.include_router(memory.router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "version": app.version}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
@@ -1,73 +0,0 @@
|
||||
"""Contextual sidebar scope schema and prompt block renderer.
|
||||
|
||||
ContextualScope mirrors the TypeScript ContextualScope type sent by the
|
||||
Electron renderer when the user opens the side chat anchored to a specific
|
||||
view. The renderer ships camelCase keys; Pydantic's alias_generator maps
|
||||
them to snake_case Python attributes automatically.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from pydantic.alias_generators import to_camel
|
||||
|
||||
|
||||
PageType = Literal[
|
||||
"timeline",
|
||||
"tasks",
|
||||
"projects-list",
|
||||
"project",
|
||||
"note",
|
||||
]
|
||||
|
||||
EntityType = Literal["project", "note", "task", "timeline_event"]
|
||||
|
||||
|
||||
class ContextualScope(BaseModel):
|
||||
"""Scope payload sent by the Electron renderer for contextual chat.
|
||||
|
||||
The renderer ships camelCase keys (entityType, entityId, ...). Pydantic's
|
||||
alias generator maps them to snake_case Python attrs.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
|
||||
|
||||
page: PageType
|
||||
entity_type: Optional[EntityType] = None
|
||||
entity_id: Optional[str] = None
|
||||
entity_name: Optional[str] = None
|
||||
project_id: Optional[str] = None
|
||||
char_count: Optional[int] = None
|
||||
counts: Optional[dict[str, int]] = None
|
||||
filters: Optional[dict] = None
|
||||
|
||||
|
||||
def render_scope_block(scope: ContextualScope) -> str:
|
||||
"""Produce a single-paragraph human-readable summary of the current view
|
||||
for injection into the contextual agent system prompt.
|
||||
|
||||
Never emits internal ids — only names. The LLM is told to use names in
|
||||
prose; ids travel through tool calls.
|
||||
"""
|
||||
if scope.entity_type == "project":
|
||||
c = scope.counts or {}
|
||||
return (
|
||||
f"User is viewing the project {scope.entity_name!r}. "
|
||||
f"{c.get('tasks', 0)} tasks, "
|
||||
f"{c.get('notes', 0)} notes, "
|
||||
f"{c.get('milestones', 0)} milestones."
|
||||
)
|
||||
if scope.entity_type == "note":
|
||||
return (
|
||||
f"User is viewing the note {scope.entity_name!r} "
|
||||
f"({scope.char_count or 0} characters)."
|
||||
)
|
||||
if scope.page == "tasks":
|
||||
return "User is viewing the global Tasks list (all projects)."
|
||||
if scope.page == "timeline":
|
||||
return "User is viewing the global Timeline view."
|
||||
if scope.page == "projects-list":
|
||||
return "User is viewing the Projects list."
|
||||
return f"User is on page {scope.page}."
|
||||
@@ -1,27 +1,34 @@
|
||||
# ── Adiuva Microservices ─────────────────────────────────────────────
|
||||
# docker compose up --build
|
||||
# docker compose up --build auth ws-gateway chat # subset
|
||||
|
||||
services:
|
||||
app:
|
||||
build: .
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
# Infrastructure
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
|
||||
traefik:
|
||||
image: traefik:v3.1
|
||||
ports:
|
||||
- "8080:8000"
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
- "8080:8080" # dashboard (dev only)
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
|
||||
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
||||
CF_DNS_API_TOKEN: ${CF_DNS_API_TOKEN:-}
|
||||
volumes:
|
||||
- copilot_tokens:/root/.config/litellm/github_copilot
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- ./traefik/traefik.yml:/etc/traefik/traefik.yml:ro
|
||||
- ./traefik/dynamic:/etc/traefik/dynamic:ro
|
||||
- traefik_acme:/etc/traefik/acme
|
||||
restart: unless-stopped
|
||||
|
||||
db:
|
||||
image: pgvector/pgvector:pg16
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: adiuvai
|
||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-postgres}
|
||||
POSTGRES_DB: ${POSTGRES_DB:-adiuva}
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
@@ -31,11 +38,161 @@ services:
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
|
||||
# Optional Redis for future rate-limit or caching needs
|
||||
# redis:
|
||||
# image: redis:7-alpine
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
|
||||
# ── Optional infrastructure (uncomment as needed) ────────────────
|
||||
|
||||
# minio:
|
||||
# image: minio/minio:latest
|
||||
# command: server /data --console-address ":9001"
|
||||
# ports:
|
||||
# - "9000:9000"
|
||||
# - "9001:9001"
|
||||
# environment:
|
||||
# MINIO_ROOT_USER: minioadmin
|
||||
# MINIO_ROOT_PASSWORD: minioadmin
|
||||
# volumes:
|
||||
# - minio_data:/data
|
||||
# healthcheck:
|
||||
# test: ["CMD", "mc", "ready", "local"]
|
||||
# interval: 5s
|
||||
# timeout: 5s
|
||||
# retries: 5
|
||||
# restart: unless-stopped
|
||||
|
||||
# qdrant:
|
||||
# image: qdrant/qdrant:latest
|
||||
# ports:
|
||||
# - "6333:6333"
|
||||
# - "6334:6334"
|
||||
# volumes:
|
||||
# - qdrant_data:/qdrant/storage
|
||||
# restart: unless-stopped
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
# Migrations (run once, then exit)
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
|
||||
migrate:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
command: ["python", "-m", "alembic", "upgrade", "head"]
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
restart: "no"
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
# Application Services
|
||||
# ═══════════════════════════════════════════════════════════════════
|
||||
|
||||
auth:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: services/auth/Dockerfile
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||
REDIS_URL: redis://redis:6379/0
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
restart: unless-stopped
|
||||
|
||||
ws-gateway:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: services/ws-gateway/Dockerfile
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||
REDIS_URL: redis://redis:6379/0
|
||||
depends_on:
|
||||
redis:
|
||||
condition: service_healthy
|
||||
auth:
|
||||
condition: service_started
|
||||
restart: unless-stopped
|
||||
|
||||
chat:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: services/chat/Dockerfile
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||
REDIS_URL: redis://redis:6379/0
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
restart: unless-stopped
|
||||
|
||||
batch-agent:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: services/batch-agent/Dockerfile
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||
REDIS_URL: redis://redis:6379/0
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
restart: unless-stopped
|
||||
|
||||
billing:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: services/billing/Dockerfile
|
||||
env_file:
|
||||
- path: .env
|
||||
required: false
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
migrate:
|
||||
condition: service_completed_successfully
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
copilot_tokens:
|
||||
redis_data:
|
||||
traefik_acme:
|
||||
# minio_data:
|
||||
# qdrant_data:
|
||||
|
||||
941
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
941
docs/MICROSERVICES_ARCHITECTURE.md
Normal file
@@ -0,0 +1,941 @@
|
||||
# Adiuva — Architettura Microservizi (MVP)
|
||||
|
||||
## Panoramica
|
||||
|
||||
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
|
||||
|
||||
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
|
||||
|
||||
```
|
||||
┌──────────────┐
|
||||
│ Cloudflare │
|
||||
│ (DNS + CDN) │
|
||||
└──────┬───────┘
|
||||
│ HTTPS / WSS
|
||||
┌──────▼───────┐
|
||||
│ Traefik │
|
||||
│ API Gateway │
|
||||
│ (routing, │
|
||||
│ TLS, rate │
|
||||
│ limiting) │
|
||||
└──────┬───────┘
|
||||
│
|
||||
┌──────────┬───────────┼───────────┐
|
||||
│ │ │ │
|
||||
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
|
||||
│ Auth │ │ Chat │ │ Agent │ │Billing │
|
||||
│ Service │ │Service│ │ Service │ │Service │
|
||||
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
|
||||
│ │ │ │
|
||||
┌─────▼──────────▼──────────▼───────────▼────┐
|
||||
│ Infrastruttura │
|
||||
│ PostgreSQL │ Redis │ Qdrant │
|
||||
└─────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 1. Suddivisione dei Servizi
|
||||
|
||||
### 1.1 Auth Service (`auth-service`)
|
||||
|
||||
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
|
||||
|
||||
| Endpoint originale | Metodo |
|
||||
|---|---|
|
||||
| `/api/v1/auth/register` | POST |
|
||||
| `/api/v1/auth/login` | POST |
|
||||
| `/api/v1/auth/refresh` | POST |
|
||||
| `/api/v1/auth/me` | GET / PUT |
|
||||
|
||||
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
|
||||
|
||||
**Modifica chiave — JWT con RS256**:
|
||||
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
|
||||
- L'Auth Service firma i JWT con la **chiave privata**.
|
||||
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
|
||||
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
|
||||
|
||||
```python
|
||||
# auth-service/app/auth/jwt.py
|
||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
||||
from jose import jwt
|
||||
|
||||
PRIVATE_KEY = ... # Da env/secret
|
||||
PUBLIC_KEY = ... # Derivata o da env
|
||||
|
||||
def create_access_token(user_id: str, tier: str) -> str:
|
||||
return jwt.encode(
|
||||
{"sub": user_id, "tier": tier, "exp": ...},
|
||||
PRIVATE_KEY,
|
||||
algorithm="RS256",
|
||||
)
|
||||
```
|
||||
|
||||
```python
|
||||
# shared/auth.py (usato da tutti gli altri servizi)
|
||||
from jose import jwt
|
||||
|
||||
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
|
||||
|
||||
def verify_token(token: str) -> dict:
|
||||
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
||||
```
|
||||
|
||||
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
|
||||
|
||||
---
|
||||
|
||||
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
|
||||
|
||||
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
|
||||
|
||||
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
|
||||
|
||||
| Endpoint | Tipo |
|
||||
|---|---|
|
||||
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
|
||||
| `/api/v1/chat` | POST (REST fallback) |
|
||||
|
||||
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
|
||||
|
||||
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
|
||||
|
||||
**Scaling**: 2–N repliche. Sticky cookies per le WS + Redis per cross-instance.
|
||||
|
||||
---
|
||||
|
||||
### 1.3 Agent Service (`agent-service`) ⭐ Batch
|
||||
|
||||
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
|
||||
|
||||
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
|
||||
|
||||
| Endpoint | Tipo |
|
||||
|---|---|
|
||||
| `/api/v1/agents/catalog` | GET |
|
||||
| `/api/v1/agents/can-create` | POST |
|
||||
| `/api/v1/agents/trigger` | POST |
|
||||
| `/api/v1/agents/journey/start` | POST (o WS relay) |
|
||||
| `/api/v1/agents/journey/message` | POST (o WS relay) |
|
||||
|
||||
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
|
||||
|
||||
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
|
||||
|
||||
```
|
||||
┌──────────────┐ ┌──────────────┐ ┌──────────┐
|
||||
│ Agent Service│ │ Redis │ │ Chat │
|
||||
│ (batch run) │ │ │ │ Service │
|
||||
│ │ │ │ │ (ha WS) │
|
||||
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
|
||||
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
|
||||
│ from │ │ │ │ al │
|
||||
│ device │ │ │ │ device│
|
||||
│ │ │ │ │ via WS│
|
||||
│ │ SUBSCRIBE │ │ PUBLISH │ │
|
||||
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
|
||||
│ risultato │ │ │ │ reply │
|
||||
└──────────────┘ └──────────────┘ └──────────┘
|
||||
```
|
||||
|
||||
**Scaling**: 1–N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
|
||||
|
||||
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
|
||||
|
||||
---
|
||||
|
||||
### 1.4 Billing Service (`billing-service`)
|
||||
|
||||
**Responsabilità**: Stripe checkout, webhook, subscription management.
|
||||
|
||||
| Endpoint originale | Metodo |
|
||||
|---|---|
|
||||
| `/api/v1/billing/checkout` | POST |
|
||||
| `/api/v1/billing/webhook` | POST |
|
||||
| `/api/v1/billing/subscription` | GET / DELETE |
|
||||
|
||||
**Database**: Tabelle `subscriptions` (schema `billing`).
|
||||
|
||||
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
|
||||
|
||||
**Scaling**: 1 replica sufficiente. Basso traffico.
|
||||
|
||||
---
|
||||
|
||||
### 1.5 Servizi esclusi dall'MVP
|
||||
|
||||
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
|
||||
|
||||
| Servizio | Responsabilità | Note |
|
||||
|---|---|---|
|
||||
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
|
||||
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
|
||||
|
||||
---
|
||||
|
||||
## 2. Tier Check — Dove e Come
|
||||
|
||||
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
|
||||
|
||||
### Strategia: Tier nel JWT
|
||||
|
||||
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
|
||||
|
||||
```json
|
||||
{
|
||||
"sub": "user_123",
|
||||
"tier": "pro",
|
||||
"exp": 1742515200,
|
||||
"iat": 1742511600
|
||||
}
|
||||
```
|
||||
|
||||
Ogni servizio:
|
||||
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
|
||||
2. Legge `payload["tier"]` — **zero chiamate extra**
|
||||
3. Applica le sue regole di enforcement localmente
|
||||
|
||||
```python
|
||||
# shared/auth.py — dependency FastAPI condivisa
|
||||
from fastapi import Depends, HTTPException, Request
|
||||
from jose import jwt
|
||||
|
||||
PUBLIC_KEY = ...
|
||||
|
||||
class CurrentUser:
|
||||
def __init__(self, user_id: str, tier: str):
|
||||
self.user_id = user_id
|
||||
self.tier = tier
|
||||
|
||||
async def get_current_user(request: Request) -> CurrentUser:
|
||||
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
|
||||
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
||||
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
|
||||
|
||||
def require_tier(*allowed_tiers: str):
|
||||
"""Dependency che blocca se il tier non è tra quelli ammessi."""
|
||||
async def check(user: CurrentUser = Depends(get_current_user)):
|
||||
if user.tier not in allowed_tiers:
|
||||
raise HTTPException(403, "Tier insufficient")
|
||||
return user
|
||||
return check
|
||||
```
|
||||
|
||||
### Cosa succede quando il tier cambia (upgrade/downgrade)?
|
||||
|
||||
```
|
||||
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
|
||||
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
|
||||
│ │ │ Service │ (Redis pub/sub) │ Service │
|
||||
└──────────┘ └──────────┘ └────┬─────┘
|
||||
│
|
||||
UPDATE users
|
||||
SET tier = 'power'
|
||||
│
|
||||
Al prossimo /refresh
|
||||
il JWT conterrà tier='power'
|
||||
```
|
||||
|
||||
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 15–30 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
|
||||
|
||||
### Dove si applica in ciascun servizio
|
||||
|
||||
| Servizio | Enforcement |
|
||||
|---|---|
|
||||
| **Auth Service** | Nessuno (è lui che scrive il tier) |
|
||||
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
|
||||
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
|
||||
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
|
||||
|
||||
### Rate-limit distribuito via Redis
|
||||
|
||||
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
|
||||
|
||||
```python
|
||||
# shared/middleware/rate_limit.py
|
||||
import redis.asyncio as aioredis
|
||||
|
||||
class DistributedRateLimiter:
|
||||
def __init__(self, redis: aioredis.Redis):
|
||||
self._redis = redis
|
||||
|
||||
async def check(self, user_id: str, tier: str, service: str) -> bool:
|
||||
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
|
||||
max_req = limits.get(tier, 20)
|
||||
key = f"rate:{service}:{user_id}"
|
||||
|
||||
pipe = self._redis.pipeline()
|
||||
pipe.incr(key)
|
||||
pipe.expire(key, 60)
|
||||
count, _ = await pipe.execute()
|
||||
|
||||
return count <= max_req
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
|
||||
|
||||
`DeviceConnectionManager` è un **singleton in-memory**:
|
||||
|
||||
```python
|
||||
class DeviceConnectionManager:
|
||||
def __init__(self):
|
||||
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
|
||||
```
|
||||
|
||||
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
|
||||
|
||||
### La soluzione: Redis Pub/Sub + Registry
|
||||
|
||||
```
|
||||
┌──────────────────────────────────────────────────────────────┐
|
||||
│ Redis │
|
||||
│ │
|
||||
│ Hash: ws:connections │
|
||||
│ user_123 → instance_A │
|
||||
│ user_456 → instance_B │
|
||||
│ │
|
||||
│ Pub/Sub channels: │
|
||||
│ tool_call:{user_id} → tool call payloads │
|
||||
│ tool_result:{call_id} → tool result payloads │
|
||||
│ stream:{user_id} → text_chunk streaming │
|
||||
└──────────────────────────────────────────────────────────────┘
|
||||
|
||||
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
|
||||
┌───────────────────────┐ ┌───────────────────────┐
|
||||
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
|
||||
│ tool_call:user_123│ │ → user_123 è su A │
|
||||
│ │ │ │
|
||||
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
|
||||
│ da Redis channel │ │ tool_call:user_123 │
|
||||
│ │ │ {id, action, ...} │
|
||||
│ 3. Invia al device │ │ │
|
||||
│ via WS │ │ 4. SUBSCRIBE │
|
||||
│ │ │ tool_result:{id} │
|
||||
│ 4. Device risponde │ │ │
|
||||
│ tool_result │──────────│► 5. Riceve risultato │
|
||||
│ │ │ │
|
||||
│ 5. PUBLISH │ │ │
|
||||
│ tool_result:{id} │ │ │
|
||||
└───────────────────────┘ └───────────────────────┘
|
||||
```
|
||||
|
||||
### Implementazione: `RedisDeviceManager`
|
||||
|
||||
```python
|
||||
# chat-service/app/core/device_manager.py
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import redis.asyncio as aioredis
|
||||
from dataclasses import dataclass, field
|
||||
from fastapi import WebSocket
|
||||
|
||||
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
|
||||
|
||||
@dataclass
|
||||
class LocalConnection:
|
||||
ws: WebSocket
|
||||
device_id: str
|
||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||
|
||||
|
||||
class RedisDeviceManager:
|
||||
"""Device manager backed by Redis for cross-instance communication."""
|
||||
|
||||
def __init__(self, redis_url: str = "redis://redis:6379"):
|
||||
self._redis = aioredis.from_url(redis_url)
|
||||
self._pubsub = self._redis.pubsub()
|
||||
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
|
||||
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
|
||||
|
||||
async def start(self):
|
||||
"""Avvia il listener Redis per tool_call in arrivo."""
|
||||
asyncio.create_task(self._listen_tool_calls())
|
||||
|
||||
# ── Registrazione ──
|
||||
|
||||
async def register(self, user_id: str, device_id: str, ws: WebSocket):
|
||||
# Registra localmente
|
||||
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
|
||||
# Registra in Redis quale istanza ha la connessione
|
||||
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
|
||||
# Sottoscrivi ai tool_call per questo utente
|
||||
await self._pubsub.subscribe(f"tool_call:{user_id}")
|
||||
|
||||
async def unregister(self, user_id: str):
|
||||
conn = self._local.pop(user_id, None)
|
||||
if conn:
|
||||
for fut in conn.pending_calls.values():
|
||||
if not fut.done():
|
||||
fut.cancel()
|
||||
await self._redis.hdel("ws:connections", user_id)
|
||||
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
|
||||
|
||||
# ── Presenza ──
|
||||
|
||||
async def is_online(self, user_id: str) -> bool:
|
||||
return await self._redis.hexists("ws:connections", user_id)
|
||||
|
||||
# ── Tool-call round-trip (cross-instance) ──
|
||||
|
||||
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
|
||||
"""
|
||||
Invia un tool_call al device dell'utente.
|
||||
Funziona sia che la WS sia locale che su un'altra istanza.
|
||||
"""
|
||||
call_id = payload["id"]
|
||||
|
||||
# Caso 1: connessione locale → invio diretto
|
||||
if user_id in self._local:
|
||||
conn = self._local[user_id]
|
||||
loop = asyncio.get_event_loop()
|
||||
fut: asyncio.Future[dict] = loop.create_future()
|
||||
conn.pending_calls[call_id] = fut
|
||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
|
||||
return await asyncio.wait_for(fut, timeout=30.0)
|
||||
|
||||
# Caso 2: connessione remota → Redis pub/sub
|
||||
loop = asyncio.get_event_loop()
|
||||
fut = loop.create_future()
|
||||
self._remote_futures[call_id] = fut
|
||||
|
||||
# Sottoscrivi al canale di risposta
|
||||
result_channel = f"tool_result:{call_id}"
|
||||
await self._pubsub.subscribe(result_channel)
|
||||
|
||||
# Pubblica il tool_call
|
||||
await self._redis.publish(
|
||||
f"tool_call:{user_id}",
|
||||
json.dumps(payload),
|
||||
)
|
||||
|
||||
try:
|
||||
return await asyncio.wait_for(fut, timeout=30.0)
|
||||
finally:
|
||||
self._remote_futures.pop(call_id, None)
|
||||
await self._pubsub.unsubscribe(result_channel)
|
||||
|
||||
# ── Risoluzione tool_result (da WS locale) ──
|
||||
|
||||
def resolve_local(self, user_id: str, call_id: str, result: dict):
|
||||
conn = self._local.get(user_id)
|
||||
if conn:
|
||||
fut = conn.pending_calls.pop(call_id, None)
|
||||
if fut and not fut.done():
|
||||
fut.set_result(result)
|
||||
|
||||
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
|
||||
"""Chiamato quando il device locale invia un tool_result."""
|
||||
self.resolve_local(user_id, call_id, result)
|
||||
# Pubblica anche su Redis per l'istanza remota che aspetta
|
||||
await self._redis.publish(
|
||||
f"tool_result:{call_id}",
|
||||
json.dumps(result),
|
||||
)
|
||||
|
||||
# ── Listener Redis ──
|
||||
|
||||
async def _listen_tool_calls(self):
|
||||
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
|
||||
async for message in self._pubsub.listen():
|
||||
if message["type"] != "message":
|
||||
continue
|
||||
channel = message["channel"]
|
||||
if isinstance(channel, bytes):
|
||||
channel = channel.decode()
|
||||
|
||||
data = json.loads(message["data"])
|
||||
|
||||
if channel.startswith("tool_call:"):
|
||||
# Un'altra istanza vuole che inviamo un tool_call al nostro device
|
||||
user_id = channel.split(":", 1)[1]
|
||||
conn = self._local.get(user_id)
|
||||
if conn:
|
||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
|
||||
|
||||
elif channel.startswith("tool_result:"):
|
||||
# Risposta a un tool_call che abbiamo inviato tramite Redis
|
||||
call_id = channel.split(":", 1)[1]
|
||||
fut = self._remote_futures.pop(call_id, None)
|
||||
if fut and not fut.done():
|
||||
fut.set_result(data)
|
||||
|
||||
# ── Stream cross-instance ──
|
||||
|
||||
async def publish_stream_chunk(self, user_id: str, chunk: dict):
|
||||
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
|
||||
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 4. Struttura Directory Proposta (MVP)
|
||||
|
||||
```
|
||||
adiuva-api/
|
||||
├── docker-compose.yml # Orchestrazione completa
|
||||
├── docker-compose.dev.yml # Override per sviluppo locale
|
||||
├── shared/ # Codice condiviso (montato come volume)
|
||||
│ ├── auth.py # JWT verification (chiave pubblica)
|
||||
│ ├── schemas.py # Pydantic schemas condivisi
|
||||
│ ├── middleware/
|
||||
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
|
||||
│ │ └── sanitizer.py
|
||||
│ └── models/
|
||||
│ └── base.py # SQLAlchemy base condivisa
|
||||
│
|
||||
├── auth-service/
|
||||
│ ├── Dockerfile
|
||||
│ ├── requirements.txt
|
||||
│ └── app/
|
||||
│ ├── main.py
|
||||
│ ├── config.py
|
||||
│ ├── db.py
|
||||
│ ├── models.py # users, refresh_tokens
|
||||
│ ├── routes/
|
||||
│ │ └── auth.py
|
||||
│ └── services/
|
||||
│ ├── jwt_service.py # RS256 signing
|
||||
│ └── user_service.py
|
||||
│
|
||||
├── chat-service/
|
||||
│ ├── Dockerfile
|
||||
│ ├── requirements.txt
|
||||
│ └── app/
|
||||
│ ├── main.py
|
||||
│ ├── config.py
|
||||
│ ├── db.py
|
||||
│ ├── models.py # memory_*
|
||||
│ ├── routes/
|
||||
│ │ ├── device_ws.py # WS connection owner
|
||||
│ │ └── chat.py # REST fallback
|
||||
│ ├── core/
|
||||
│ │ ├── device_manager.py # RedisDeviceManager
|
||||
│ │ ├── deep_agent.py # Home + floating chat
|
||||
│ │ ├── memory_middleware.py
|
||||
│ │ ├── ws_context.py
|
||||
│ │ ├── output_formatter.py
|
||||
│ │ └── llm.py
|
||||
│ └── agents/ # Tool definitions (used by deep_agent)
|
||||
│ ├── task_agent.py
|
||||
│ ├── project_agent.py
|
||||
│ ├── note_agent.py
|
||||
│ └── timeline_agent.py
|
||||
│
|
||||
├── agent-service/
|
||||
│ ├── Dockerfile
|
||||
│ ├── requirements.txt
|
||||
│ └── app/
|
||||
│ ├── main.py
|
||||
│ ├── config.py
|
||||
│ ├── db.py
|
||||
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
|
||||
│ ├── routes/
|
||||
│ │ ├── agents.py # catalog, can-create, trigger
|
||||
│ │ └── agent_setup.py # journey start/message
|
||||
│ ├── core/
|
||||
│ │ ├── agent_runner.py # Batch classify → process
|
||||
│ │ ├── agent_registry.py
|
||||
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
|
||||
│ │ └── llm.py
|
||||
│ └── agents/
|
||||
│ ├── task_agent.py # Tool definitions (batch context)
|
||||
│ ├── project_agent.py
|
||||
│ ├── note_agent.py
|
||||
│ ├── timeline_agent.py
|
||||
│ └── filesystem_agent.py
|
||||
│
|
||||
├── billing-service/
|
||||
│ ├── Dockerfile
|
||||
│ ├── requirements.txt
|
||||
│ └── app/
|
||||
│ ├── main.py
|
||||
│ ├── config.py
|
||||
│ ├── db.py
|
||||
│ ├── models.py # subscriptions
|
||||
│ ├── routes/
|
||||
│ │ └── billing.py
|
||||
│ └── services/
|
||||
│ ├── stripe_service.py
|
||||
│ └── tier_manager.py
|
||||
│
|
||||
└── infra/
|
||||
├── traefik/
|
||||
│ └── traefik.yml
|
||||
├── keys/
|
||||
│ ├── jwt_private.pem # Solo auth-service
|
||||
│ └── jwt_public.pem # Tutti i servizi
|
||||
└── alembic/ # Migrazioni condivise o per-servizio
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 5. Docker Compose — Configurazione MVP
|
||||
|
||||
```yaml
|
||||
# docker-compose.yml
|
||||
|
||||
services:
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# API Gateway
|
||||
# ══════════════════════════════════════════════════════════
|
||||
traefik:
|
||||
image: traefik:v3.2
|
||||
command:
|
||||
- "--api.insecure=true"
|
||||
- "--providers.docker=true"
|
||||
- "--providers.docker.exposedbydefault=false"
|
||||
- "--entrypoints.web.address=:80"
|
||||
- "--entrypoints.websecure.address=:443"
|
||||
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
||||
ports:
|
||||
- "80:80"
|
||||
- "443:443"
|
||||
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
|
||||
volumes:
|
||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
||||
- ./infra/certs:/certs:ro
|
||||
restart: unless-stopped
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# Auth Service (2 repliche)
|
||||
# ══════════════════════════════════════════════════════════
|
||||
auth-service:
|
||||
build: ./auth-service
|
||||
deploy:
|
||||
replicas: 2
|
||||
env_file: .env
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||
REDIS_URL: redis://redis:6379
|
||||
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
|
||||
SERVICE_NAME: auth
|
||||
secrets:
|
||||
- jwt_private_key
|
||||
- jwt_public_key
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
|
||||
- "traefik.http.services.auth.loadbalancer.server.port=8000"
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# Chat Service — Real-time WS + Chat (scalabile)
|
||||
# ══════════════════════════════════════════════════════════
|
||||
chat-service:
|
||||
build: ./chat-service
|
||||
deploy:
|
||||
replicas: 2
|
||||
env_file: .env
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||
REDIS_URL: redis://redis:6379
|
||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||
SERVICE_NAME: chat
|
||||
secrets:
|
||||
- jwt_public_key
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
# REST chat endpoint
|
||||
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
|
||||
- "traefik.http.services.chat.loadbalancer.server.port=8000"
|
||||
# WebSocket route con sticky session
|
||||
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
|
||||
- "traefik.http.routers.ws.service=chat-ws"
|
||||
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
|
||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
|
||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# Agent Service — Batch processing (scalabile indipendentemente)
|
||||
# ══════════════════════════════════════════════════════════
|
||||
agent-service:
|
||||
build: ./agent-service
|
||||
deploy:
|
||||
replicas: 2
|
||||
env_file: .env
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||
REDIS_URL: redis://redis:6379
|
||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||
SERVICE_NAME: agent
|
||||
secrets:
|
||||
- jwt_public_key
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
|
||||
- "traefik.http.services.agents.loadbalancer.server.port=8000"
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# Billing Service (1 replica)
|
||||
# ══════════════════════════════════════════════════════════
|
||||
billing-service:
|
||||
build: ./billing-service
|
||||
deploy:
|
||||
replicas: 1
|
||||
env_file: .env
|
||||
environment:
|
||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||
REDIS_URL: redis://redis:6379
|
||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
||||
SERVICE_NAME: billing
|
||||
secrets:
|
||||
- jwt_public_key
|
||||
labels:
|
||||
- "traefik.enable=true"
|
||||
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
|
||||
- "traefik.http.services.billing.loadbalancer.server.port=8000"
|
||||
depends_on:
|
||||
db:
|
||||
condition: service_healthy
|
||||
redis:
|
||||
condition: service_healthy
|
||||
|
||||
# ══════════════════════════════════════════════════════════
|
||||
# Infrastruttura
|
||||
# ══════════════════════════════════════════════════════════
|
||||
db:
|
||||
image: pgvector/pgvector:pg16
|
||||
environment:
|
||||
POSTGRES_USER: postgres
|
||||
POSTGRES_PASSWORD: postgres
|
||||
POSTGRES_DB: adiuva
|
||||
volumes:
|
||||
- postgres_data:/var/lib/postgresql/data
|
||||
healthcheck:
|
||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
||||
volumes:
|
||||
- redis_data:/data
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 5s
|
||||
timeout: 3s
|
||||
retries: 5
|
||||
restart: unless-stopped
|
||||
|
||||
qdrant:
|
||||
image: qdrant/qdrant:latest
|
||||
volumes:
|
||||
- qdrant_data:/qdrant/storage
|
||||
restart: unless-stopped
|
||||
|
||||
secrets:
|
||||
jwt_private_key:
|
||||
file: ./infra/keys/jwt_private.pem
|
||||
jwt_public_key:
|
||||
file: ./infra/keys/jwt_public.pem
|
||||
|
||||
volumes:
|
||||
postgres_data:
|
||||
redis_data:
|
||||
qdrant_data:
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 6. Configurazione Cloudflare + VPS
|
||||
|
||||
### 6.1 DNS
|
||||
|
||||
```
|
||||
api.tuodominio.com → A record → IP del VPS
|
||||
→ Proxy: ON (orange cloud)
|
||||
```
|
||||
|
||||
### 6.2 Cloudflare Settings
|
||||
|
||||
| Setting | Valore | Motivo |
|
||||
|---------|--------|--------|
|
||||
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
|
||||
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
|
||||
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
|
||||
| Under Attack Mode | Off (attivare se necessario) | |
|
||||
|
||||
### 6.3 TLS sul VPS
|
||||
|
||||
Due opzioni:
|
||||
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
|
||||
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
|
||||
|
||||
```yaml
|
||||
# traefik.yml — con Cloudflare Origin Certificate
|
||||
entryPoints:
|
||||
websecure:
|
||||
address: ":443"
|
||||
|
||||
tls:
|
||||
certificates:
|
||||
- certFile: /certs/origin.pem
|
||||
keyFile: /certs/origin-key.pem
|
||||
```
|
||||
|
||||
### 6.4 Rete VPS
|
||||
|
||||
```bash
|
||||
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
|
||||
# https://www.cloudflare.com/ips/
|
||||
ufw default deny incoming
|
||||
ufw allow from 173.245.48.0/20 to any port 443
|
||||
ufw allow from 103.21.244.0/22 to any port 443
|
||||
# ... (tutti gli IP range di Cloudflare)
|
||||
ufw allow ssh
|
||||
ufw enable
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 7. Comunicazione Inter-Servizio
|
||||
|
||||
### 7.1 Redis Pub/Sub — Event Bus
|
||||
|
||||
```
|
||||
┌──────────┐ tier_changed:user_123 ┌──────────┐
|
||||
│ Billing │ ────────────────────────► │ Auth │
|
||||
│ Service │ │ Service │
|
||||
└──────────┘ └──────────┘
|
||||
|
||||
┌──────────┐ tool_call:user_123 ┌──────────┐
|
||||
│ Agent │ ────────────────────────► │ Chat │
|
||||
│ Service │ │ Service │
|
||||
│ (batch) │ ◄────────────────────────│ (ha WS) │
|
||||
└──────────┘ tool_result:{call_id} └──────────┘
|
||||
```
|
||||
|
||||
### 7.2 Health Checks e Service Discovery
|
||||
|
||||
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
|
||||
- **Redis pub/sub** (tool-call cross-instance, tier events)
|
||||
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
|
||||
- **PostgreSQL** (dati persistenti condivisi)
|
||||
|
||||
---
|
||||
|
||||
## 8. Piano di Migrazione Incrementale (MVP)
|
||||
|
||||
### Fase 1 — Preparazione (nel monolite attuale)
|
||||
1. Aggiungere Redis al `docker-compose.yml` attuale
|
||||
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
|
||||
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
|
||||
4. Estrarre `shared/` con auth verification, schemas, middleware
|
||||
|
||||
### Fase 2 — Auth Service (primo split)
|
||||
1. Estrarre `auth.py` routes + models in `auth-service/`
|
||||
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
|
||||
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
|
||||
4. Il monolite continua a servire tutto il resto
|
||||
|
||||
### Fase 3 — Billing Service
|
||||
1. Estrarre billing routes, Stripe service, tier manager
|
||||
2. Configurare Redis pub/sub per `tier_changed` events
|
||||
3. Routare via Traefik
|
||||
|
||||
### Fase 4 — Split Chat + Agent (il più delicato)
|
||||
1. Il monolite residuo contiene WS + chat + agents
|
||||
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
|
||||
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
|
||||
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
|
||||
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
|
||||
|
||||
### Fase 5 — Scaling test
|
||||
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
|
||||
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
|
||||
3. Monitoring (Prometheus + Grafana) per ogni servizio
|
||||
|
||||
---
|
||||
|
||||
## 9. Monitoraggio e Logging
|
||||
|
||||
```yaml
|
||||
# Aggiungere al docker-compose.yml
|
||||
|
||||
prometheus:
|
||||
image: prom/prometheus:latest
|
||||
volumes:
|
||||
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
|
||||
restart: unless-stopped
|
||||
|
||||
grafana:
|
||||
image: grafana/grafana:latest
|
||||
ports:
|
||||
- "3000:3000"
|
||||
volumes:
|
||||
- grafana_data:/var/lib/grafana
|
||||
restart: unless-stopped
|
||||
|
||||
loki:
|
||||
image: grafana/loki:latest
|
||||
restart: unless-stopped
|
||||
```
|
||||
|
||||
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
|
||||
|
||||
---
|
||||
|
||||
## 10. Sizing VPS Minimo Consigliato (MVP)
|
||||
|
||||
| Componente | CPU | RAM | Note |
|
||||
|---|---|---|---|
|
||||
| Traefik | 0.25 | 128MB | |
|
||||
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
|
||||
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
|
||||
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
|
||||
| Billing Service | 0.25 | 128MB | |
|
||||
| PostgreSQL | 1.0 | 1GB | |
|
||||
| Redis | 0.25 | 256MB | |
|
||||
| Qdrant | 0.5 | 512MB | |
|
||||
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
|
||||
|
||||
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
|
||||
|
||||
---
|
||||
|
||||
## Riepilogo Architettura MVP
|
||||
|
||||
| Servizio | Repliche | Proprietario di |
|
||||
|---|---|---|
|
||||
| **Traefik** | 1 | Routing, TLS, sticky sessions |
|
||||
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
|
||||
| **Chat Service** | 2–N | WebSocket, home/floating chat, streaming |
|
||||
| **Agent Service** | 2–N | Batch processing, directory scan, agent setup |
|
||||
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
|
||||
|
||||
| Decisione | Scelta | Motivazione |
|
||||
|---|---|---|
|
||||
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
|
||||
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
|
||||
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
|
||||
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
|
||||
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
|
||||
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
|
||||
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
|
||||
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
|
||||
| TLS | Cloudflare Origin Certificate | Zero maintenance |
|
||||
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
|
||||
| Storage / Plugin | Post-MVP | Non critici per il lancio |
|
||||
@@ -1,43 +0,0 @@
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.34.0
|
||||
gunicorn>=22.0.0
|
||||
langchain>=0.3.0
|
||||
langchain-openai>=0.3.0
|
||||
langchain-litellm>=0.1.0
|
||||
litellm>=1.50.0
|
||||
pydantic>=2.10.0
|
||||
pydantic-settings>=2.7.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
stripe>=11.0.0
|
||||
boto3>=1.35.0
|
||||
slowapi>=0.1.9
|
||||
sqlalchemy>=2.0.0
|
||||
asyncpg>=0.30.0
|
||||
alembic>=1.14.0
|
||||
bcrypt>=4.2.0
|
||||
python-dotenv>=1.0.0
|
||||
httpx>=0.28.0
|
||||
websockets>=14.0
|
||||
psycopg2-binary>=2.9.0
|
||||
pytest>=8.0.0
|
||||
pytest-asyncio>=0.24.0
|
||||
aiosqlite>=0.20.0
|
||||
moto[s3]>=5.0.0
|
||||
pinecone>=5.0.0
|
||||
qdrant-client>=1.7.0
|
||||
croniter>=3.0.0
|
||||
google-api-python-client>=2.130.0
|
||||
google-auth>=2.29.0
|
||||
google-auth-oauthlib>=1.2.0
|
||||
google-auth-httplib2>=0.2.0
|
||||
msal>=1.28.0
|
||||
cryptography>=42.0.0
|
||||
pgvector>=0.2.5
|
||||
langfuse>=3.3.1
|
||||
beautifulsoup4>=4.12.0
|
||||
lxml>=5.0.0
|
||||
PyYAML>=6.0.0
|
||||
apscheduler>=3.10.0
|
||||
ruff>=0.8.0
|
||||
pypdf>=4.0
|
||||
python-docx>=1.1
|
||||
File diff suppressed because one or more lines are too long
19
services/auth/.env.example
Normal file
19
services/auth/.env.example
Normal file
@@ -0,0 +1,19 @@
|
||||
# ── Auth Service ──────────────────────────────────────────────────────────────
|
||||
# This file contains env vars specific to the Auth Service.
|
||||
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
|
||||
# or from docker-compose environment.
|
||||
|
||||
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
|
||||
# Generate keypair:
|
||||
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||
# openssl rsa -in private.pem -pubout -out public.pem
|
||||
#
|
||||
# Paste PEM content with literal \n for newlines:
|
||||
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
|
||||
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
|
||||
|
||||
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
|
||||
JWT_PRIVATE_KEY=
|
||||
|
||||
# PUBLIC KEY — used to VERIFY JWTs.
|
||||
JWT_PUBLIC_KEY=
|
||||
36
services/auth/Dockerfile
Normal file
36
services/auth/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
||||
# ── builder ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
# Install shared + service deps in one layer
|
||||
COPY services/auth/requirements.txt ./requirements.txt
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||
|
||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
# Copy shared module (available to all services)
|
||||
COPY shared/ shared/
|
||||
|
||||
# Copy service source
|
||||
COPY services/auth/app/ app/
|
||||
|
||||
RUN chown -R appuser:appgroup /app
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
CMD ["gunicorn", "app.main:app", \
|
||||
"-k", "uvicorn.workers.UvicornWorker", \
|
||||
"--bind", "0.0.0.0:8000", \
|
||||
"--workers", "2", \
|
||||
"--timeout", "30"]
|
||||
16
services/auth/README.md
Normal file
16
services/auth/README.md
Normal file
@@ -0,0 +1,16 @@
|
||||
# Auth Service
|
||||
|
||||
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
|
||||
|
||||
## Tables owned
|
||||
- `users`
|
||||
- `refresh_tokens`
|
||||
- `subscriptions` (read; Billing Service writes)
|
||||
|
||||
## Endpoints
|
||||
- `POST /auth/register`
|
||||
- `POST /auth/login`
|
||||
- `POST /auth/refresh`
|
||||
- `GET /auth/me`
|
||||
- `PUT /auth/me`
|
||||
- `GET /auth/verify` (ForwardAuth for Traefik)
|
||||
34
services/auth/app/config.py
Normal file
34
services/auth/app/config.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Auth Service — local configuration.
|
||||
|
||||
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
|
||||
These are NOT in shared/config.py to prevent other services from accessing them.
|
||||
"""
|
||||
|
||||
from pydantic import field_validator
|
||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||
|
||||
|
||||
class AuthSettings(BaseSettings):
|
||||
# RS256 private key (PEM format). Used to SIGN JWTs.
|
||||
# Only the Auth Service has this. Generate with:
|
||||
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
||||
# Then set the env var (newlines as \n):
|
||||
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
|
||||
JWT_PRIVATE_KEY: str = ""
|
||||
|
||||
# RS256 public key (PEM format). Used to VERIFY JWTs.
|
||||
# Derived from the private key:
|
||||
# openssl rsa -in private.pem -pubout -out public.pem
|
||||
JWT_PUBLIC_KEY: str = ""
|
||||
|
||||
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
|
||||
@classmethod
|
||||
def _expand_pem_newlines(cls, v: str) -> str:
|
||||
if isinstance(v, str) and r"\n" in v:
|
||||
return v.replace(r"\n", "\n")
|
||||
return v
|
||||
|
||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||
|
||||
|
||||
auth_settings = AuthSettings()
|
||||
69
services/auth/app/deps.py
Normal file
69
services/auth/app/deps.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Auth dependencies — JWT validation for the Auth Service.
|
||||
|
||||
This is the canonical get_current_user used by protected endpoints
|
||||
within the Auth Service itself (/me, /me PUT).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.config import settings
|
||||
from shared.db import get_session
|
||||
from shared.models import Subscription, User
|
||||
from shared.schemas import UserProfile
|
||||
|
||||
from app.config import auth_settings
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
token: str = Depends(oauth2_scheme),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Validate a Bearer JWT and return the authenticated user.
|
||||
|
||||
The JWT is used for identity and expiry. Tier is fetched live from the
|
||||
subscriptions table so upgrades/downgrades take effect immediately.
|
||||
"""
|
||||
credentials_exc = HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Could not validate credentials",
|
||||
headers={"WWW-Authenticate": "Bearer"},
|
||||
)
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
if not user_id or not email:
|
||||
raise credentials_exc
|
||||
except JWTError:
|
||||
raise credentials_exc
|
||||
|
||||
# Live tier lookup
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
# Fetch name/surname
|
||||
user_result = await db.execute(
|
||||
select(User.name, User.surname).where(User.id == user_id)
|
||||
)
|
||||
user_row = user_result.one_or_none()
|
||||
|
||||
return UserProfile(
|
||||
id=user_id,
|
||||
email=email,
|
||||
name=user_row.name if user_row else None,
|
||||
surname=user_row.surname if user_row else None,
|
||||
tier=tier,
|
||||
) # type: ignore[arg-type]
|
||||
62
services/auth/app/main.py
Normal file
62
services/auth/app/main.py
Normal file
@@ -0,0 +1,62 @@
|
||||
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
|
||||
|
||||
Standalone FastAPI service extracted from the adiuva-api monolith.
|
||||
Owns: users, refresh_tokens, subscriptions (read).
|
||||
"""
|
||||
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the repo root is on sys.path so "shared" is importable.
|
||||
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
|
||||
# In local dev, we need to add the repo root (two levels up from this file).
|
||||
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
yield
|
||||
from shared.db import engine
|
||||
|
||||
await engine.dispose()
|
||||
|
||||
|
||||
def create_app() -> FastAPI:
|
||||
app = FastAPI(
|
||||
title="Adiuva Auth Service",
|
||||
version="0.1.0",
|
||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||
redoc_url=None,
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.CORS_ORIGINS,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
from app.routes import router
|
||||
from app.verify import router as verify_router
|
||||
|
||||
app.include_router(router, prefix="/api/v1")
|
||||
app.include_router(verify_router, prefix="/api/v1")
|
||||
|
||||
@app.get("/api/v1/health", tags=["health"])
|
||||
async def health() -> dict:
|
||||
return {"status": "ok", "service": "auth", "version": app.version}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
249
services/auth/app/routes.py
Normal file
249
services/auth/app/routes.py
Normal file
@@ -0,0 +1,249 @@
|
||||
"""Auth routes: register, login, refresh, me.
|
||||
|
||||
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import bcrypt
|
||||
from cryptography.fernet import Fernet
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.config import settings
|
||||
from shared.db import get_session
|
||||
from shared.models import RefreshToken, Subscription, User
|
||||
from shared.schemas import AuthTokens, UserProfile
|
||||
|
||||
from app.config import auth_settings
|
||||
from app.deps import get_current_user
|
||||
|
||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||
|
||||
|
||||
# ── Internal helpers ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _hash_password(password: str) -> str:
|
||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
||||
|
||||
|
||||
def _verify_password(password: str, hashed: str) -> bool:
|
||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
||||
|
||||
|
||||
def _hash_token(plain_token: str) -> str:
|
||||
"""SHA-256 of the plain refresh token string."""
|
||||
return hashlib.sha256(plain_token.encode()).hexdigest()
|
||||
|
||||
|
||||
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||
"""Return (RS256-signed JWT, expires_at_ms)."""
|
||||
now = int(time.time())
|
||||
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||
payload = {
|
||||
"sub": user_id,
|
||||
"email": email,
|
||||
"tier": tier,
|
||||
"exp": exp,
|
||||
"iat": now,
|
||||
}
|
||||
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
|
||||
return token, exp * 1000 # ms for client
|
||||
|
||||
|
||||
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
|
||||
"""Fetch authoritative tier from subscriptions table."""
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
return result.scalar_one_or_none() or default_tier
|
||||
|
||||
|
||||
# ── Request bodies ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class _RegisterRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
|
||||
|
||||
class _LoginRequest(BaseModel):
|
||||
email: str
|
||||
password: str
|
||||
|
||||
|
||||
class _RefreshRequest(BaseModel):
|
||||
refresh_token: str
|
||||
|
||||
|
||||
class _UpdateProfileRequest(BaseModel):
|
||||
name: str | None = None
|
||||
surname: str | None = None
|
||||
|
||||
|
||||
# ── Routes ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
||||
async def register(
|
||||
body: _RegisterRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Create a new account and return JWT tokens."""
|
||||
existing = await db.execute(select(User).where(User.email == body.email))
|
||||
if existing.scalar_one_or_none() is not None:
|
||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
||||
|
||||
user = User(
|
||||
id=str(uuid.uuid4()),
|
||||
email=body.email,
|
||||
name=body.name,
|
||||
surname=body.surname,
|
||||
password_hash=_hash_password(body.password),
|
||||
tier="free",
|
||||
encryption_key=Fernet.generate_key().decode(),
|
||||
)
|
||||
db.add(user)
|
||||
await db.flush()
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login", response_model=AuthTokens)
|
||||
async def login(
|
||||
body: _LoginRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Validate credentials and return JWT tokens."""
|
||||
result = await db.execute(select(User).where(User.email == body.email))
|
||||
user = result.scalar_one_or_none()
|
||||
if user is None or not _verify_password(body.password, user.password_hash):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||
|
||||
# Fetch live tier for the JWT claim
|
||||
tier = await _get_live_tier(db, user.id)
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||
)
|
||||
rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=expires_at,
|
||||
)
|
||||
db.add(rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/refresh", response_model=AuthTokens)
|
||||
async def refresh(
|
||||
body: _RefreshRequest,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AuthTokens:
|
||||
"""Rotate a refresh token and return a new token pair."""
|
||||
token_hash = _hash_token(body.refresh_token)
|
||||
result = await db.execute(
|
||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
||||
)
|
||||
rt = result.scalar_one_or_none()
|
||||
|
||||
now = datetime.now(timezone.utc)
|
||||
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||
|
||||
await db.delete(rt)
|
||||
|
||||
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||
user = user_result.scalar_one_or_none()
|
||||
if user is None:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||
|
||||
# Fetch live tier for the new JWT
|
||||
tier = await _get_live_tier(db, user.id)
|
||||
|
||||
plain_token = str(uuid.uuid4())
|
||||
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
new_rt = RefreshToken(
|
||||
user_id=user.id,
|
||||
token_hash=_hash_token(plain_token),
|
||||
expires_at=new_expires,
|
||||
)
|
||||
db.add(new_rt)
|
||||
await db.commit()
|
||||
|
||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
||||
return AuthTokens(
|
||||
access_token=access_token,
|
||||
refresh_token=plain_token,
|
||||
expires_at=expires_at_ms,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/me", response_model=UserProfile)
|
||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||
"""Return the profile for the authenticated user."""
|
||||
return current_user
|
||||
|
||||
|
||||
@router.put("/me", response_model=UserProfile)
|
||||
async def update_profile(
|
||||
body: _UpdateProfileRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> UserProfile:
|
||||
"""Update the authenticated user's name and surname."""
|
||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||
user = result.scalar_one()
|
||||
|
||||
if body.name is not None:
|
||||
user.name = body.name
|
||||
if body.surname is not None:
|
||||
user.surname = body.surname
|
||||
|
||||
await db.commit()
|
||||
await db.refresh(user)
|
||||
|
||||
return UserProfile(
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
name=user.name,
|
||||
surname=user.surname,
|
||||
tier=current_user.tier,
|
||||
)
|
||||
66
services/auth/app/verify.py
Normal file
66
services/auth/app/verify.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""ForwardAuth verification endpoint for Traefik.
|
||||
|
||||
Traefik calls GET /api/v1/auth/verify on every request to a protected
|
||||
service. This endpoint validates the JWT from the Authorization header
|
||||
and returns identity headers that Traefik injects into downstream requests.
|
||||
|
||||
Downstream services NEVER validate JWTs themselves — they trust the
|
||||
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter, Request, Response
|
||||
from fastapi import status as http_status
|
||||
from jose import JWTError, jwt
|
||||
from sqlalchemy import select
|
||||
|
||||
from shared.config import settings
|
||||
from shared.db import async_session
|
||||
from shared.models import Subscription
|
||||
|
||||
from app.config import auth_settings
|
||||
|
||||
router = APIRouter(tags=["auth"])
|
||||
|
||||
|
||||
@router.get("/auth/verify")
|
||||
async def verify(request: Request) -> Response:
|
||||
"""Validate JWT and return identity headers for Traefik ForwardAuth.
|
||||
|
||||
Returns 200 with X-User-* headers on success, 401 on failure.
|
||||
Traefik copies response headers to the downstream request.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization", "")
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
token = auth_header[7:] # strip "Bearer "
|
||||
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
||||
)
|
||||
user_id: str | None = payload.get("sub")
|
||||
email: str | None = payload.get("email")
|
||||
if not user_id or not email:
|
||||
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||
except JWTError:
|
||||
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
||||
|
||||
# Live tier lookup from subscriptions table
|
||||
async with async_session() as db:
|
||||
result = await db.execute(
|
||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||
)
|
||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||
tier: str = result.scalar_one_or_none() or default_tier
|
||||
|
||||
return Response(
|
||||
status_code=http_status.HTTP_200_OK,
|
||||
headers={
|
||||
"X-User-Id": user_id,
|
||||
"X-User-Email": email,
|
||||
"X-User-Tier": tier,
|
||||
},
|
||||
)
|
||||
11
services/auth/requirements.txt
Normal file
11
services/auth/requirements.txt
Normal file
@@ -0,0 +1,11 @@
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.34.0
|
||||
gunicorn>=22.0.0
|
||||
pydantic>=2.10.0
|
||||
pydantic-settings>=2.7.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
sqlalchemy>=2.0.0
|
||||
asyncpg>=0.30.0
|
||||
bcrypt>=4.2.0
|
||||
cryptography>=42.0.0
|
||||
python-dotenv>=1.0.0
|
||||
36
services/batch-agent/Dockerfile
Normal file
36
services/batch-agent/Dockerfile
Normal file
@@ -0,0 +1,36 @@
|
||||
# ── builder ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /build
|
||||
|
||||
COPY services/batch-agent/requirements.txt ./requirements.txt
|
||||
RUN pip install --upgrade pip && \
|
||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||
|
||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||
FROM python:3.12-slim AS runtime
|
||||
|
||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
COPY --from=builder /install /usr/local
|
||||
|
||||
# Shared module
|
||||
COPY shared/ shared/
|
||||
|
||||
# Service source
|
||||
COPY services/batch-agent/app/ app/
|
||||
|
||||
RUN chown -R appuser:appgroup /app
|
||||
|
||||
USER appuser
|
||||
|
||||
EXPOSE 8000
|
||||
|
||||
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
|
||||
CMD ["gunicorn", "app.main:app", \
|
||||
"-k", "uvicorn.workers.UvicornWorker", \
|
||||
"--bind", "0.0.0.0:8000", \
|
||||
"--workers", "2", \
|
||||
"--timeout", "300"]
|
||||
23
services/batch-agent/README.md
Normal file
23
services/batch-agent/README.md
Normal file
@@ -0,0 +1,23 @@
|
||||
# Batch Agent Service
|
||||
|
||||
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
|
||||
|
||||
## Tables owned
|
||||
- `local_agent_configs`
|
||||
- `cloud_agent_configs`
|
||||
- `agent_run_logs`
|
||||
|
||||
## Endpoints
|
||||
- `GET /agents/catalog`
|
||||
- `POST /agents/can-create`
|
||||
- `POST /agents/trigger`
|
||||
- `GET /agents/{id}/history`
|
||||
|
||||
## Redis channels
|
||||
- Subscribe: `batch:request:{user_id}`
|
||||
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
|
||||
- BRPOP: `tool:result:{call_id}` (30s timeout)
|
||||
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
|
||||
|
||||
## TODO
|
||||
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.
|
||||
910
services/batch-agent/app/agent_runner.py
Normal file
910
services/batch-agent/app/agent_runner.py
Normal file
@@ -0,0 +1,910 @@
|
||||
"""Agent run orchestrator — adapted for Batch Agent Service.
|
||||
|
||||
Key changes from monolith app/core/agent_runner.py:
|
||||
- No DeviceConnectionManager — tool calls go through Redis ws_context.
|
||||
- set_current_user / clear_current_user replace set_client_executor.
|
||||
- run_local_agent accepts a serialized dict (from Redis / REST) instead
|
||||
of SQLAlchemy model objects.
|
||||
- _finalize_run writes to PostgreSQL via shared.db.async_session.
|
||||
- Cloud agent import path changed to app.integrations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||
from shared.agents.note_agent import NOTE_TOOLS
|
||||
from shared.agents.project_agent import PROJECT_TOOLS
|
||||
from shared.agents.task_agent import TASK_TOOLS
|
||||
from shared.agents.timeline_agent import TIMELINE_TOOLS
|
||||
from shared.llm import get_llm
|
||||
from shared.ws_context import execute_on_client, set_current_user, clear_current_user
|
||||
import app.tracing as tracing
|
||||
from shared.db import async_session
|
||||
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||
from shared.redis import redis_client, ws_out_channel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Concurrency guard ─────────────────────────────────────────────────────
|
||||
_running_agents: set[str] = set()
|
||||
|
||||
|
||||
def is_agent_running(agent_id: str) -> bool:
|
||||
return agent_id in _running_agents
|
||||
|
||||
|
||||
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||
_TOOL_CALL_TIMEOUT: int = 30
|
||||
_MAX_PROCESSING_STEPS: int = 12
|
||||
_MAX_SCAN_DEPTH: int = 5
|
||||
|
||||
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||
"tasks": TASK_TOOLS,
|
||||
"notes": NOTE_TOOLS,
|
||||
"timelines": TIMELINE_TOOLS,
|
||||
}
|
||||
|
||||
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
||||
|
||||
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
||||
"tasks": (
|
||||
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
||||
"assigned to someone, or tracked with a due date or status."
|
||||
),
|
||||
"notes": (
|
||||
"Documentation, meeting notes, summaries, reference material — "
|
||||
"written content meant to be read and referenced rather than acted on."
|
||||
),
|
||||
"timelines": (
|
||||
"Project milestones, deadlines, scheduled events — "
|
||||
"specific dates that mark a point in the progress of a project."
|
||||
),
|
||||
"projects": (
|
||||
"High-level project entities — only relevant if the file clearly introduces "
|
||||
"a new project or updates the scope of an existing one."
|
||||
),
|
||||
}
|
||||
|
||||
_STEP1_SYSTEM_PROMPT = """\
|
||||
You are a file classifier for a freelance project management tool.
|
||||
|
||||
Your job is to match a file to an existing project and identify which data domains to extract.
|
||||
|
||||
## Project matching rules (STRICT — follow in order)
|
||||
|
||||
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
||||
that overlaps with the existing projects listed below.
|
||||
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
||||
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
||||
when the file has zero meaningful connection to any listed project.
|
||||
4. When in doubt, pick the closest match from the list.
|
||||
|
||||
## Response format
|
||||
|
||||
Respond ONLY with a JSON object — no markdown, no explanation:
|
||||
|
||||
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
||||
|
||||
## Domain definitions (only consider domains in the allowed list)
|
||||
|
||||
{domain_definitions}
|
||||
|
||||
## Existing projects
|
||||
|
||||
{projects_list}
|
||||
"""
|
||||
|
||||
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
||||
|
||||
_PROCESSING_SYSTEM_PROMPT = """\
|
||||
You are a data extraction assistant for a freelance project management tool.
|
||||
|
||||
Your task: extract structured data from the file content and persist it using the available tools.
|
||||
|
||||
## Mandatory process — follow this order for EVERY item you extract
|
||||
|
||||
1. READ the existing records listed below for the relevant domain.
|
||||
2. SEARCH for a match by title, topic, or semantic similarity.
|
||||
3. If a match exists → call the update_* tool with the existing record's id.
|
||||
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
||||
|
||||
NEVER call create_* without first checking the existing records.
|
||||
NEVER duplicate a record that already exists under a different wording.
|
||||
|
||||
## Existing records (source of truth)
|
||||
|
||||
{existing_context}
|
||||
|
||||
## Context
|
||||
|
||||
Project: {project_context}
|
||||
Domains to extract: {data_types}
|
||||
|
||||
{custom_prompt_section}
|
||||
"""
|
||||
|
||||
# ── Cloud processing prompt ───────────────────────────────────────────────
|
||||
|
||||
_CLOUD_PROCESSING_PROMPT = """\
|
||||
You are a data extraction and management assistant for a freelance project
|
||||
management tool.
|
||||
|
||||
Available tools:
|
||||
Filesystem : read_file_content, list_directory, get_file_metadata
|
||||
Tasks : list_tasks, create_task, update_task, add_task_comment
|
||||
Notes : list_notes, get_note, create_note, update_note
|
||||
Timelines : list_timelines, create_timeline, update_timeline
|
||||
Projects : list_all_projects, get_project, create_project, update_project
|
||||
|
||||
Your task:
|
||||
1. Read the full content of each file below using read_file_content.
|
||||
2. For each piece of information found, ALWAYS try to match and update an
|
||||
existing record before creating a new one.
|
||||
3. ONLY act on these entity types: {data_types}.
|
||||
4. Do NOT invent data. Only extract what is clearly present in the files.
|
||||
5. If a file contains no relevant data for the target entity types, skip it.
|
||||
|
||||
{project_context}
|
||||
|
||||
Files to process:
|
||||
{file_list}
|
||||
|
||||
{custom_prompt_section}
|
||||
|
||||
After processing all files, respond with a brief summary of what you updated
|
||||
and what you created.
|
||||
"""
|
||||
|
||||
|
||||
# ── LLM tool-calling loop ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
async def _run_agent_with_tools(
|
||||
*,
|
||||
system_prompt: str,
|
||||
user_message: str,
|
||||
tools: list[Any],
|
||||
max_steps: int,
|
||||
langfuse_handler: Any | None = None,
|
||||
) -> str:
|
||||
"""Run an LLM agent with tool-calling, returning the final text response."""
|
||||
callbacks = [langfuse_handler] if langfuse_handler else None
|
||||
llm = get_llm(callbacks=callbacks)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
messages: list[Any] = [
|
||||
SystemMessage(content=system_prompt),
|
||||
HumanMessage(content=user_message),
|
||||
]
|
||||
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
|
||||
for _ in range(max_steps):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
return _as_text(response.content)
|
||||
|
||||
for call in response.tool_calls:
|
||||
call_id = str(call.get("id", ""))
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
"agent_runner: tool_call name=%s args=%s",
|
||||
call_name,
|
||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||
)
|
||||
|
||||
tool_fn = tool_map.get(call_name)
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call_name}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call_args)
|
||||
|
||||
logger.info(
|
||||
"agent_runner: tool_result name=%s output=%s",
|
||||
call_name,
|
||||
str(tool_output)[:200],
|
||||
)
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
final = await llm.ainvoke(messages)
|
||||
return _as_text(final.content)
|
||||
|
||||
|
||||
# ── Tool list builder ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
||||
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
||||
for dt in data_types:
|
||||
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
||||
if dt_tools:
|
||||
tools.extend(dt_tools)
|
||||
return tools
|
||||
|
||||
|
||||
# ── Code-based directory scanner ─────────────────────────────────────────
|
||||
|
||||
|
||||
async def _scan_directories(
|
||||
paths: list[str],
|
||||
extensions: list[str],
|
||||
last_run_at: datetime | None,
|
||||
) -> list[str]:
|
||||
all_files: list[str] = []
|
||||
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
|
||||
|
||||
async def _walk(path: str, depth: int) -> None:
|
||||
if depth > _MAX_SCAN_DEPTH:
|
||||
return
|
||||
try:
|
||||
result = await execute_on_client(action="list_directory", data={"path": path})
|
||||
except Exception as exc:
|
||||
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
|
||||
return
|
||||
for entry in result.get("entries", []):
|
||||
entry_path = entry.get("path", "")
|
||||
if not entry_path:
|
||||
continue
|
||||
if entry.get("type") == "directory":
|
||||
await _walk(entry_path, depth + 1)
|
||||
elif entry.get("type") == "file":
|
||||
if ext_set:
|
||||
dot_pos = entry_path.rfind(".")
|
||||
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
|
||||
if file_ext not in ext_set:
|
||||
continue
|
||||
all_files.append(entry_path)
|
||||
|
||||
for root in paths:
|
||||
await _walk(root, depth=0)
|
||||
|
||||
if last_run_at is None:
|
||||
return all_files
|
||||
|
||||
last_run_ms = int(last_run_at.timestamp() * 1000)
|
||||
filtered: list[str] = []
|
||||
for file_path in all_files:
|
||||
try:
|
||||
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
||||
modified_at = meta.get("modifiedAt")
|
||||
if modified_at is None:
|
||||
filtered.append(file_path)
|
||||
continue
|
||||
if isinstance(modified_at, (int, float)):
|
||||
mod_ms = int(modified_at)
|
||||
else:
|
||||
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
|
||||
if mod_ms > last_run_ms:
|
||||
filtered.append(file_path)
|
||||
except Exception:
|
||||
filtered.append(file_path)
|
||||
|
||||
return filtered
|
||||
|
||||
|
||||
# ── Code-based entity fetchers ────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _fetch_projects() -> list[dict]:
|
||||
try:
|
||||
result = await execute_on_client(action="select", table="projects")
|
||||
return result.get("rows", [])
|
||||
except Exception as exc:
|
||||
logger.warning("agent_runner: failed to fetch projects: %s", exc)
|
||||
return []
|
||||
|
||||
|
||||
_DOMAIN_TABLE: dict[str, str] = {
|
||||
"tasks": "tasks",
|
||||
"notes": "notes",
|
||||
"timelines": "timelines",
|
||||
"projects": "projects",
|
||||
}
|
||||
|
||||
|
||||
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
|
||||
table = _DOMAIN_TABLE.get(domain)
|
||||
if not table:
|
||||
return []
|
||||
filters: dict[str, Any] = {}
|
||||
if project_id != "standalone" and domain != "projects":
|
||||
filters["projectId"] = project_id
|
||||
try:
|
||||
result = await execute_on_client(
|
||||
action="select",
|
||||
table=table,
|
||||
filters=filters if filters else None,
|
||||
)
|
||||
return result.get("rows", [])
|
||||
except Exception as exc:
|
||||
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
|
||||
return []
|
||||
|
||||
|
||||
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
||||
if not rows:
|
||||
return f"No existing {domain}."
|
||||
lines: list[str] = []
|
||||
for r in rows:
|
||||
if domain == "tasks":
|
||||
desc = r.get("description") or ""
|
||||
desc_part = f" — {desc[:120]}" if desc else ""
|
||||
assignee = r.get("assignee") or r.get("assignees") or ""
|
||||
due = r.get("dueDate") or r.get("due_date") or ""
|
||||
meta = ", ".join(filter(None, [
|
||||
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
|
||||
f"assignee: {assignee}" if assignee else "",
|
||||
f"due: {due}" if due else "",
|
||||
]))
|
||||
lines.append(
|
||||
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
|
||||
f" ({meta}, id: {r['id']})"
|
||||
)
|
||||
elif domain == "notes":
|
||||
snippet = (r.get("content") or "")[:200].replace("\n", " ")
|
||||
snippet_part = f"\n Preview: {snippet}" if snippet else ""
|
||||
lines.append(
|
||||
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
|
||||
)
|
||||
elif domain == "timelines":
|
||||
lines.append(
|
||||
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
|
||||
)
|
||||
elif domain == "projects":
|
||||
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
|
||||
summary_part = f" — {summary}" if summary else ""
|
||||
lines.append(
|
||||
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
|
||||
f" (id: {r['id']})"
|
||||
)
|
||||
return f"Existing {domain}:\n" + "\n".join(lines)
|
||||
|
||||
|
||||
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
||||
|
||||
|
||||
async def _classify_file(
|
||||
file_path: str,
|
||||
file_content: str,
|
||||
projects: list[dict],
|
||||
config_data_types: list[str],
|
||||
langfuse_handler: Any | None = None,
|
||||
custom_system_prompt: str | None = None,
|
||||
) -> tuple[str, list[str], str | None]:
|
||||
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
||||
|
||||
if not file_content.strip():
|
||||
return fallback
|
||||
|
||||
valid_project_ids = {p["id"] for p in projects}
|
||||
|
||||
def _fmt_project(p: dict) -> str:
|
||||
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
||||
summary_part = f" — {summary[:100]}" if summary else ""
|
||||
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
||||
|
||||
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
||||
|
||||
domain_definitions = "\n".join(
|
||||
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
||||
for d in config_data_types
|
||||
if d in _DOMAIN_DESCRIPTIONS
|
||||
)
|
||||
|
||||
if custom_system_prompt:
|
||||
# Fixture-provided prompt takes absolute priority
|
||||
system = custom_system_prompt.format_map(
|
||||
{"domain_definitions": domain_definitions, "projects_list": projects_list}
|
||||
)
|
||||
else:
|
||||
system = tracing.compile_prompt(
|
||||
"batch_file_classifier",
|
||||
fallback=_STEP1_SYSTEM_PROMPT,
|
||||
variables={
|
||||
"domain_definitions": domain_definitions,
|
||||
"projects_list": projects_list,
|
||||
},
|
||||
)
|
||||
|
||||
llm = get_llm(callbacks=[langfuse_handler] if langfuse_handler else None)
|
||||
try:
|
||||
response = await llm.ainvoke([
|
||||
SystemMessage(content=system),
|
||||
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
||||
])
|
||||
raw = _as_text(response.content).strip()
|
||||
if raw.startswith("```"):
|
||||
raw = raw.split("```")[1]
|
||||
if raw.startswith("json"):
|
||||
raw = raw[4:]
|
||||
parsed = json.loads(raw.strip())
|
||||
raw_project_id: str = str(parsed.get("project_id") or "new")
|
||||
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
||||
new_project_name: str | None = (
|
||||
str(parsed["new_project_name"]).strip() or None
|
||||
if project_id == "new" and parsed.get("new_project_name")
|
||||
else None
|
||||
)
|
||||
domains: list[str] = [
|
||||
d for d in parsed.get("domains", [])
|
||||
if d in config_data_types
|
||||
]
|
||||
if not domains:
|
||||
domains = list(config_data_types)
|
||||
return project_id, domains, new_project_name
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
||||
)
|
||||
return fallback
|
||||
|
||||
|
||||
# ── Local agent runner (two-step per file) ────────────────────────────────
|
||||
|
||||
|
||||
async def run_local_agent(user_id: str, trigger_data: dict[str, Any], *, langfuse_handler: Any | None = None) -> None:
|
||||
"""Execute a local directory agent run.
|
||||
|
||||
In the microservice world, trigger_data is a serialized dict from
|
||||
the REST route (forwarded via Redis), containing the agent config
|
||||
fields and run_context.
|
||||
|
||||
set_current_user() must be called BEFORE this function.
|
||||
"""
|
||||
run_context: dict = trigger_data.get("run_context", {})
|
||||
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
|
||||
run_id = run_context.get("run_id")
|
||||
|
||||
_running_agents.add(agent_id)
|
||||
|
||||
# Extract config from trigger payload
|
||||
directory_paths: list[str] = trigger_data.get("directory_paths", [])
|
||||
if not directory_paths:
|
||||
directory = trigger_data.get("directory", "")
|
||||
if directory:
|
||||
directory_paths = [directory]
|
||||
|
||||
data_types: list[str] = trigger_data.get("data_types", [])
|
||||
file_extensions: list[str] = trigger_data.get("file_extensions", [])
|
||||
prompt_template: str = trigger_data.get("prompt_template", "")
|
||||
last_run_at_raw = trigger_data.get("last_run_at")
|
||||
last_run_at: datetime | None = None
|
||||
if last_run_at_raw:
|
||||
if isinstance(last_run_at_raw, str):
|
||||
last_run_at = datetime.fromisoformat(last_run_at_raw)
|
||||
elif isinstance(last_run_at_raw, (int, float)):
|
||||
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
|
||||
|
||||
errors: list[str] = []
|
||||
items_processed = 0
|
||||
items_created = 0
|
||||
|
||||
custom_section = (
|
||||
f"User instructions:\n{prompt_template}"
|
||||
if prompt_template
|
||||
else ""
|
||||
)
|
||||
|
||||
# Create or load run log
|
||||
run_log_id = run_id
|
||||
if not run_log_id:
|
||||
async with async_session() as db:
|
||||
run_log = AgentRunLog(
|
||||
agent_id=agent_id,
|
||||
agent_type="local",
|
||||
user_id=user_id,
|
||||
status="running",
|
||||
)
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
run_log_id = run_log.id
|
||||
|
||||
try:
|
||||
# ── Scan directories ─────────────────────────────────────────
|
||||
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
|
||||
file_paths = await _scan_directories(
|
||||
paths=directory_paths,
|
||||
extensions=file_extensions,
|
||||
last_run_at=last_run_at,
|
||||
)
|
||||
logger.info(
|
||||
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
|
||||
)
|
||||
|
||||
if not file_paths:
|
||||
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
|
||||
return
|
||||
|
||||
# ── Fetch all projects once ──────────────────────────────────
|
||||
projects = await _fetch_projects()
|
||||
|
||||
for file_path in file_paths:
|
||||
try:
|
||||
file_result = await execute_on_client(
|
||||
action="read_file_content", data={"path": file_path}
|
||||
)
|
||||
file_content: str = file_result.get("content", "")
|
||||
if not file_content:
|
||||
continue
|
||||
|
||||
items_processed += 1
|
||||
|
||||
# Step 1 — classify file
|
||||
project_id, domains, new_project_name = await _classify_file(
|
||||
file_path=file_path,
|
||||
file_content=file_content,
|
||||
projects=projects,
|
||||
config_data_types=data_types,
|
||||
langfuse_handler=langfuse_handler,
|
||||
)
|
||||
|
||||
# Step 2 — resolve project_id, fetch entities, process
|
||||
if project_id == "new":
|
||||
proj_name = new_project_name or "Untitled Project"
|
||||
try:
|
||||
proj_result = await execute_on_client(
|
||||
action="insert",
|
||||
table="projects",
|
||||
data={"name": proj_name, "clientId": None},
|
||||
)
|
||||
created = proj_result.get("row", {})
|
||||
effective_project_id = created.get("id", "standalone")
|
||||
if "id" in created:
|
||||
projects.append(created)
|
||||
except Exception as exc:
|
||||
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
|
||||
effective_project_id = "standalone"
|
||||
proj_name = "unknown"
|
||||
project_context = (
|
||||
f"Project: {proj_name} (id: {effective_project_id}). "
|
||||
"Always set projectId to this id on every record you create."
|
||||
)
|
||||
else:
|
||||
effective_project_id = project_id
|
||||
proj = next((p for p in projects if p["id"] == project_id), None)
|
||||
proj_name = proj.get("name", project_id) if proj else project_id
|
||||
project_context = (
|
||||
f"Project: {proj_name} (id: {project_id}). "
|
||||
"Always set projectId to this id on every record you create."
|
||||
)
|
||||
|
||||
domains = [d for d in domains if d != "projects"]
|
||||
|
||||
existing_blocks: list[str] = []
|
||||
for domain in domains:
|
||||
rows = await _fetch_domain_entities(domain, effective_project_id)
|
||||
existing_blocks.append(_format_entities_for_context(domain, rows))
|
||||
|
||||
existing_context = "\n\n".join(existing_blocks)
|
||||
|
||||
system_prompt = tracing.compile_prompt(
|
||||
"batch_processing",
|
||||
fallback=_PROCESSING_SYSTEM_PROMPT,
|
||||
variables={
|
||||
"existing_context": existing_context,
|
||||
"project_context": project_context,
|
||||
"data_types": ", ".join(domains),
|
||||
"custom_prompt_section": custom_section,
|
||||
},
|
||||
)
|
||||
|
||||
processing_tools = _build_processing_tools(domains)
|
||||
|
||||
result_text = await _run_agent_with_tools(
|
||||
system_prompt=system_prompt,
|
||||
user_message=(
|
||||
f"Process this file and extract relevant information.\n\n"
|
||||
f"File: {file_path}\n\nContent:\n{file_content}"
|
||||
),
|
||||
tools=processing_tools,
|
||||
max_steps=_MAX_PROCESSING_STEPS,
|
||||
langfuse_handler=langfuse_handler,
|
||||
)
|
||||
logger.info(
|
||||
"agent_runner: run=%s file=%r result=%s",
|
||||
run_log_id, file_path, result_text[:200],
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
errors.append(f"Error processing '{file_path}': {exc}")
|
||||
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
|
||||
|
||||
except Exception as exc:
|
||||
errors.append(f"Agent run failed: {exc}")
|
||||
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
|
||||
finally:
|
||||
_running_agents.discard(agent_id)
|
||||
|
||||
# ── Finalise ────────────────────────────────────────────────────
|
||||
if errors and items_processed == 0:
|
||||
final_status = "error"
|
||||
elif errors:
|
||||
final_status = "partial"
|
||||
else:
|
||||
final_status = "success"
|
||||
|
||||
await _finalize_run(
|
||||
run_log_id,
|
||||
status=final_status,
|
||||
items_processed=items_processed,
|
||||
items_created=items_created,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
# Notify Electron that the run is complete via Redis
|
||||
if run_context:
|
||||
try:
|
||||
channel = ws_out_channel(user_id)
|
||||
await redis_client.publish(channel, json.dumps({
|
||||
"type": "run_complete",
|
||||
"run_context": run_context,
|
||||
"status": final_status,
|
||||
}))
|
||||
except Exception as exc:
|
||||
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
|
||||
|
||||
|
||||
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||
|
||||
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||||
|
||||
|
||||
async def run_cloud_agent(user_id: str, config_id: str, *, langfuse_handler: Any | None = None) -> None:
|
||||
"""Execute a cloud connector agent run.
|
||||
|
||||
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
|
||||
messages from the provider, and runs LLM extraction.
|
||||
|
||||
set_current_user() must be called BEFORE this function.
|
||||
"""
|
||||
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||||
|
||||
async with async_session() as db:
|
||||
result = await db.execute(
|
||||
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||
)
|
||||
config = result.scalar_one_or_none()
|
||||
if config is None:
|
||||
logger.error("agent_runner: cloud config %s not found", config_id)
|
||||
return
|
||||
|
||||
# Create run log
|
||||
run_log = AgentRunLog(
|
||||
agent_id=config.id,
|
||||
agent_type="cloud",
|
||||
user_id=user_id,
|
||||
status="running",
|
||||
)
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
run_log_id = run_log.id
|
||||
|
||||
# ── Decrypt OAuth token ────────────────────────────────────────
|
||||
if not config.oauth_token_encrypted:
|
||||
await _finalize_run(
|
||||
run_log_id,
|
||||
status="error",
|
||||
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||
except ValueError as exc:
|
||||
await _finalize_run(
|
||||
run_log_id,
|
||||
status="error",
|
||||
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||||
)
|
||||
return
|
||||
|
||||
# ── Instantiate provider ──────────────────────────────────────
|
||||
try:
|
||||
provider = get_provider(config.provider, credentials_info)
|
||||
except ValueError as exc:
|
||||
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
|
||||
return
|
||||
|
||||
# ── Fetch messages ────────────────────────────────────────────
|
||||
since: datetime | None = config.last_run_at
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||||
if since.tzinfo is None:
|
||||
since = since.replace(tzinfo=timezone.utc)
|
||||
|
||||
errors: list[str] = []
|
||||
items_processed = 0
|
||||
|
||||
try:
|
||||
if config.provider == "gmail":
|
||||
raw_messages = await provider.fetch_messages(
|
||||
filter_config=config.filter_config,
|
||||
since=since,
|
||||
)
|
||||
elif config.provider == "outlook":
|
||||
raw_messages = await provider.fetch_emails(
|
||||
filter_config=config.filter_config,
|
||||
since=since,
|
||||
)
|
||||
elif config.provider == "teams":
|
||||
raw_messages = await provider.fetch_messages(
|
||||
filter_config=config.filter_config,
|
||||
since=since,
|
||||
)
|
||||
else:
|
||||
raw_messages = []
|
||||
except RuntimeError as exc:
|
||||
await _finalize_run(
|
||||
run_log_id,
|
||||
status="error",
|
||||
errors=[f"Provider fetch failed: {exc}"],
|
||||
update_config_last_run=True,
|
||||
config_id=config.id,
|
||||
config_type="cloud",
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"agent_runner: cloud agent %s fetched %d item(s) from %s",
|
||||
config.id, len(raw_messages), config.provider,
|
||||
)
|
||||
|
||||
# ── Extract + insert via LLM ─────────────────────────────────
|
||||
try:
|
||||
processing_tools = _build_processing_tools(config.data_types)
|
||||
custom_section = (
|
||||
f"User instructions:\n{config.prompt_template}"
|
||||
if config.prompt_template
|
||||
else ""
|
||||
)
|
||||
|
||||
for msg in raw_messages:
|
||||
content_text = msg.as_text
|
||||
if not content_text:
|
||||
continue
|
||||
items_processed += 1
|
||||
|
||||
processing_prompt = tracing.compile_prompt(
|
||||
"batch_cloud_processing",
|
||||
fallback=_CLOUD_PROCESSING_PROMPT,
|
||||
variables={
|
||||
"data_types": ", ".join(config.data_types),
|
||||
"project_context": "Determine the appropriate project from the message context.",
|
||||
"file_list": f"Message from {config.provider} (id: {msg.id})",
|
||||
"custom_prompt_section": custom_section,
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
await _run_agent_with_tools(
|
||||
system_prompt=processing_prompt,
|
||||
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
||||
tools=processing_tools,
|
||||
max_steps=_MAX_PROCESSING_STEPS,
|
||||
langfuse_handler=langfuse_handler,
|
||||
)
|
||||
except Exception as exc:
|
||||
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
||||
except Exception as exc:
|
||||
errors.append(f"Agent run failed: {exc}")
|
||||
|
||||
# ── Persist refreshed token ───────────────────────────────────
|
||||
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||
if refreshed:
|
||||
try:
|
||||
new_encrypted = encrypt_token(refreshed)
|
||||
async with async_session() as db:
|
||||
cfg_result = await db.execute(
|
||||
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||||
)
|
||||
cfg_row = cfg_result.scalar_one_or_none()
|
||||
if cfg_row:
|
||||
cfg_row.oauth_token_encrypted = new_encrypted
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
|
||||
|
||||
# ── Finalise ──────────────────────────────────────────────────
|
||||
if errors and items_processed == 0:
|
||||
final_status = "error"
|
||||
elif errors:
|
||||
final_status = "partial"
|
||||
else:
|
||||
final_status = "success"
|
||||
|
||||
await _finalize_run(
|
||||
run_log_id,
|
||||
status=final_status,
|
||||
items_processed=items_processed,
|
||||
items_created=0,
|
||||
errors=errors,
|
||||
update_config_last_run=True,
|
||||
config_id=config.id,
|
||||
config_type="cloud",
|
||||
)
|
||||
|
||||
|
||||
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _finalize_run(
|
||||
run_log_id: int | str,
|
||||
*,
|
||||
status: str,
|
||||
items_processed: int = 0,
|
||||
items_created: int = 0,
|
||||
errors: list[str] | None = None,
|
||||
update_config_last_run: bool = False,
|
||||
config_id: str | None = None,
|
||||
config_type: str | None = None,
|
||||
) -> None:
|
||||
"""Persist the run outcome and optionally update last_run_at on the config."""
|
||||
now = datetime.now(timezone.utc)
|
||||
try:
|
||||
async with async_session() as db:
|
||||
result = await db.execute(
|
||||
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
|
||||
)
|
||||
managed = result.scalar_one_or_none()
|
||||
if managed is None:
|
||||
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
|
||||
return
|
||||
|
||||
managed.status = status
|
||||
managed.items_processed = items_processed
|
||||
managed.items_created = items_created
|
||||
managed.errors = errors or []
|
||||
managed.completed_at = now
|
||||
|
||||
if update_config_last_run and config_id:
|
||||
if config_type == "local":
|
||||
cfg_result = await db.execute(
|
||||
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||
)
|
||||
cfg = cfg_result.scalar_one_or_none()
|
||||
if cfg:
|
||||
cfg.last_run_at = now
|
||||
elif config_type == "cloud":
|
||||
cfg_result = await db.execute(
|
||||
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||
)
|
||||
cfg = cfg_result.scalar_one_or_none()
|
||||
if cfg:
|
||||
cfg.last_run_at = now
|
||||
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)
|
||||
1
services/batch-agent/app/agents/__init__.py
Normal file
1
services/batch-agent/app/agents/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Batch Agent Service domain agents and filesystem tools."""
|
||||
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
83
services/batch-agent/app/agents/filesystem_agent.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Filesystem agent — tools for reading local directories and files on Electron.
|
||||
|
||||
Adapted for Batch Agent Service: import from app.ws_context.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.tools import tool
|
||||
|
||||
from shared.ws_context import execute_on_client
|
||||
|
||||
|
||||
@tool
|
||||
async def list_directory(path: str) -> str:
|
||||
"""List files and folders in a local directory on the user's device.
|
||||
|
||||
Returns a formatted listing of entries with name, type (file/directory),
|
||||
and full path.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="list_directory",
|
||||
data={"path": path},
|
||||
)
|
||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||
if not entries:
|
||||
return f"Directory '{path}' is empty or does not exist."
|
||||
lines: list[str] = []
|
||||
for entry in entries:
|
||||
entry_type = entry.get("type", "unknown")
|
||||
entry_name = entry.get("name", "")
|
||||
entry_path = entry.get("path", "")
|
||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||
|
||||
|
||||
@tool
|
||||
async def read_file_content(path: str) -> str:
|
||||
"""Read the text content of a local file on the user's device.
|
||||
|
||||
Returns the file content as a string. Large files may be truncated
|
||||
by the Electron client.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="read_file_content",
|
||||
data={"path": path},
|
||||
)
|
||||
content: str = result.get("content", "")
|
||||
if not content:
|
||||
return f"File '{path}' is empty or could not be read."
|
||||
return content
|
||||
|
||||
|
||||
@tool
|
||||
async def get_file_metadata(path: str) -> str:
|
||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||
|
||||
Returns a formatted summary of the file's metadata.
|
||||
"""
|
||||
result = await execute_on_client(
|
||||
action="get_file_metadata",
|
||||
data={"path": path},
|
||||
)
|
||||
size = result.get("size", "unknown")
|
||||
created = result.get("createdAt", "unknown")
|
||||
modified = result.get("modifiedAt", "unknown")
|
||||
extension = result.get("extension", "unknown")
|
||||
name = result.get("name", path)
|
||||
return (
|
||||
f"File: {name}\n"
|
||||
f" Extension: {extension}\n"
|
||||
f" Size: {size} bytes\n"
|
||||
f" Created: {created}\n"
|
||||
f" Modified: {modified}"
|
||||
)
|
||||
|
||||
|
||||
FILESYSTEM_TOOLS: list[Any] = [
|
||||
list_directory,
|
||||
read_file_content,
|
||||
get_file_metadata,
|
||||
]
|
||||
@@ -1,20 +1,11 @@
|
||||
"""Cloud provider integration utilities.
|
||||
|
||||
Provides:
|
||||
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
|
||||
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
|
||||
* ``get_provider()`` — factory that returns the correct client given a
|
||||
provider name and decrypted OAuth credentials dict.
|
||||
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
|
||||
encryption for OAuth tokens stored in ``cloud_agent_configs``.
|
||||
Adapted for Batch Agent Service: import from shared.config instead of app.config.
|
||||
|
||||
Encryption rationale
|
||||
--------------------
|
||||
Unlike user content (which is E2E-encrypted client-side and **never**
|
||||
decrypted server-side), OAuth tokens *must* be decrypted server-side
|
||||
because the backend makes provider API calls on behalf of the user.
|
||||
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
|
||||
is never returned to clients.
|
||||
Provides:
|
||||
* Shared message dataclasses (EmailMessage, ChatMessage)
|
||||
* get_provider() — factory for Gmail/MS Graph clients
|
||||
* encrypt_token() / decrypt_token() — Fernet-based OAuth token encryption
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -27,7 +18,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from cryptography.fernet import Fernet, InvalidToken
|
||||
|
||||
from app.config.settings import settings
|
||||
from shared.config import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.integrations.gmail import GmailClient
|
||||
@@ -35,13 +26,9 @@ if TYPE_CHECKING:
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Shared message types ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class EmailMessage:
|
||||
"""A single email message fetched from Gmail or Outlook."""
|
||||
|
||||
id: str
|
||||
subject: str
|
||||
sender: str
|
||||
@@ -51,7 +38,6 @@ class EmailMessage:
|
||||
|
||||
@property
|
||||
def as_text(self) -> str:
|
||||
"""Return a human-readable text representation for LLM extraction."""
|
||||
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||
return (
|
||||
@@ -64,8 +50,6 @@ class EmailMessage:
|
||||
|
||||
@dataclass
|
||||
class ChatMessage:
|
||||
"""A single Teams chat or channel message fetched from MS Graph."""
|
||||
|
||||
id: str
|
||||
content: str
|
||||
sender: str
|
||||
@@ -74,7 +58,6 @@ class ChatMessage:
|
||||
|
||||
@property
|
||||
def as_text(self) -> str:
|
||||
"""Return a human-readable text representation for LLM extraction."""
|
||||
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||
return (
|
||||
@@ -84,15 +67,7 @@ class ChatMessage:
|
||||
)
|
||||
|
||||
|
||||
# ── Fernet helpers ────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _get_fernet() -> Fernet:
|
||||
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
|
||||
|
||||
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
|
||||
must ensure this is configured before persisting OAuth tokens.
|
||||
"""
|
||||
key = settings.OAUTH_ENCRYPTION_KEY
|
||||
if not key:
|
||||
raise RuntimeError(
|
||||
@@ -103,15 +78,6 @@ def _get_fernet() -> Fernet:
|
||||
|
||||
|
||||
def encrypt_token(token_info: dict) -> str:
|
||||
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
|
||||
|
||||
Stores the full ``{access_token, refresh_token, token_uri, client_id,
|
||||
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
|
||||
|
||||
Raises:
|
||||
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||
ValueError: ``token_info`` is not a non-empty dict.
|
||||
"""
|
||||
if not isinstance(token_info, dict) or not token_info:
|
||||
raise ValueError("token_info must be a non-empty dict")
|
||||
plaintext = json.dumps(token_info).encode("utf-8")
|
||||
@@ -119,13 +85,6 @@ def encrypt_token(token_info: dict) -> str:
|
||||
|
||||
|
||||
def decrypt_token(encrypted: str) -> dict:
|
||||
"""Decrypt a Fernet-encrypted token string and return the credential dict.
|
||||
|
||||
Raises:
|
||||
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||
ValueError: The encrypted string is invalid or was encrypted with a
|
||||
different key.
|
||||
"""
|
||||
try:
|
||||
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||
return json.loads(plaintext)
|
||||
@@ -133,25 +92,10 @@ def decrypt_token(encrypted: str) -> dict:
|
||||
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||
|
||||
|
||||
# ── Provider factory ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_provider(
|
||||
provider: str,
|
||||
credentials_info: dict,
|
||||
) -> "GmailClient | MSGraphClient":
|
||||
"""Return the correct provider client for *provider*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
provider:
|
||||
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
|
||||
credentials_info:
|
||||
Decrypted OAuth credential dict (Google or Microsoft shape).
|
||||
|
||||
Raises:
|
||||
ValueError: Unknown provider name.
|
||||
"""
|
||||
if provider == "gmail":
|
||||
from app.integrations.gmail import GmailClient
|
||||
return GmailClient(credentials_info)
|
||||
@@ -1,26 +1,7 @@
|
||||
"""Gmail API client for cloud agent integration.
|
||||
|
||||
Wraps the Google Gmail REST API to fetch email messages matching a
|
||||
``filter_config`` dict. Uses the official ``google-api-python-client``
|
||||
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
|
||||
blocking the event loop.
|
||||
|
||||
Token refresh is handled transparently: when the stored access token has
|
||||
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
||||
token to obtain a fresh one. The caller is responsible for persisting
|
||||
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
|
||||
(see ``agent_runner.run_cloud_agent``).
|
||||
|
||||
Credential dict shape (Google OAuth2):
|
||||
{
|
||||
"token": "<access_token>",
|
||||
"refresh_token": "<refresh_token>",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"client_id": "<client_id>",
|
||||
"client_secret": "<client_secret>",
|
||||
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
|
||||
}
|
||||
Adapted for Batch Agent Service: import from app.integrations instead of
|
||||
app.integrations (same relative path within the service).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -38,13 +19,8 @@ from app.integrations import EmailMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Gmail search date format — e.g. "after:2025/01/01"
|
||||
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||
|
||||
# Maximum characters of body text forwarded to the LLM.
|
||||
_BODY_TRUNCATE = 8_000
|
||||
|
||||
# Maximum messages retrieved per run (prevents runaway quota usage).
|
||||
_MAX_MESSAGES = 200
|
||||
|
||||
|
||||
@@ -52,20 +28,9 @@ def _build_gmail_query(
|
||||
filter_config: dict[str, Any] | None,
|
||||
since: datetime | None,
|
||||
) -> str:
|
||||
"""Build a Gmail search query string from *filter_config* and *since*.
|
||||
|
||||
Supported ``filter_config`` keys:
|
||||
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
|
||||
senders (list[str]): Sender addresses or domains to include
|
||||
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
|
||||
|
||||
A hard ``since`` date (from last run) always overrides ``date_range.from``
|
||||
when it is earlier.
|
||||
"""
|
||||
parts: list[str] = []
|
||||
cfg = filter_config or {}
|
||||
|
||||
# Labels — joined with OR when multiple given.
|
||||
labels: list[str] = cfg.get("labels", [])
|
||||
if labels:
|
||||
if len(labels) == 1:
|
||||
@@ -74,17 +39,14 @@ def _build_gmail_query(
|
||||
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||
parts.append(f"({label_expr})")
|
||||
|
||||
# Senders — each prefixed with "from:".
|
||||
senders: list[str] = cfg.get("senders", [])
|
||||
for sender in senders:
|
||||
parts.append(f"from:{sender}")
|
||||
|
||||
# Date range.
|
||||
date_range: dict = cfg.get("date_range", {})
|
||||
from_str: str | None = date_range.get("from")
|
||||
to_str: str | None = date_range.get("to")
|
||||
|
||||
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
|
||||
effective_since: datetime | None = since
|
||||
if from_str:
|
||||
try:
|
||||
@@ -110,18 +72,12 @@ def _build_gmail_query(
|
||||
|
||||
|
||||
def _strip_html(raw_html: str) -> str:
|
||||
"""Remove HTML tags and decode entities to get plain text."""
|
||||
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||
decoded = html.unescape(no_tags)
|
||||
return re.sub(r"\s+", " ", decoded).strip()
|
||||
|
||||
|
||||
def _parse_body(payload: dict[str, Any]) -> str:
|
||||
"""Recursively extract the plain-text body from a Gmail message payload.
|
||||
|
||||
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
|
||||
Returns an empty string if no body can be extracted.
|
||||
"""
|
||||
mime_type: str = payload.get("mimeType", "")
|
||||
body: dict = payload.get("body", {})
|
||||
parts: list[dict] = payload.get("parts", [])
|
||||
@@ -139,7 +95,6 @@ def _parse_body(payload: dict[str, Any]) -> str:
|
||||
return _strip_html(raw)
|
||||
return ""
|
||||
|
||||
# Multipart — prefer text/plain part, fall back to text/html.
|
||||
plain_fallback = ""
|
||||
for part in parts:
|
||||
part_mime = part.get("mimeType", "")
|
||||
@@ -155,7 +110,6 @@ def _parse_body(payload: dict[str, Any]) -> str:
|
||||
|
||||
|
||||
def _parse_date(raw: str) -> datetime:
|
||||
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
|
||||
try:
|
||||
parsed = email.utils.parsedate_to_datetime(raw)
|
||||
if parsed.tzinfo is None:
|
||||
@@ -166,16 +120,6 @@ def _parse_date(raw: str) -> datetime:
|
||||
|
||||
|
||||
class GmailClient:
|
||||
"""Fetch email messages from a Gmail account via the Gmail REST API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
credentials_info:
|
||||
Decrypted OAuth2 credential dict. Must contain at minimum
|
||||
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
|
||||
``client_id`` + ``client_secret``.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
@@ -200,38 +144,20 @@ class GmailClient:
|
||||
expiry=expiry,
|
||||
)
|
||||
|
||||
# ── Public API ─────────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_messages(
|
||||
self,
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[EmailMessage]:
|
||||
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
|
||||
|
||||
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
|
||||
to avoid blocking the async event loop.
|
||||
|
||||
Token refresh is performed automatically when the access token has
|
||||
expired. After the call, ``self.refreshed_credentials`` may be
|
||||
consulted to detect whether new credentials should be persisted.
|
||||
"""
|
||||
query = _build_gmail_query(filter_config, since)
|
||||
logger.debug("gmail: executing search query %r", query)
|
||||
return await asyncio.to_thread(self._fetch_sync, query)
|
||||
|
||||
@property
|
||||
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||
"""Return updated credential dict if the access token was refreshed.
|
||||
|
||||
If the credentials were refreshed during ``fetch_messages()``, returns
|
||||
a new dict that should be re-encrypted and written back to the DB.
|
||||
Returns ``None`` if no refresh occurred.
|
||||
"""
|
||||
creds = self._credentials
|
||||
if not creds.valid and creds.expired:
|
||||
return None
|
||||
# Check whether the token changed from what was stored.
|
||||
if creds.token != self._credentials_info.get("token"):
|
||||
result = {
|
||||
"token": creds.token,
|
||||
@@ -246,15 +172,11 @@ class GmailClient:
|
||||
return result
|
||||
return None
|
||||
|
||||
# ── Internal sync worker ───────────────────────────────────────────────
|
||||
|
||||
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
|
||||
import googleapiclient.discovery
|
||||
import googleapiclient.errors
|
||||
from google.auth.transport.requests import Request
|
||||
|
||||
# Refresh token if needed before building the service.
|
||||
if self._credentials.expired and self._credentials.refresh_token:
|
||||
try:
|
||||
self._credentials.refresh(Request())
|
||||
@@ -264,9 +186,8 @@ class GmailClient:
|
||||
service = googleapiclient.discovery.build(
|
||||
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||
)
|
||||
user_api = service.users() # type: ignore[attr-defined]
|
||||
user_api = service.users()
|
||||
|
||||
# ── List matching message IDs ──────────────────────────────────────
|
||||
ids: list[str] = []
|
||||
page_token: str | None = None
|
||||
while len(ids) < _MAX_MESSAGES:
|
||||
@@ -293,12 +214,10 @@ class GmailClient:
|
||||
break
|
||||
|
||||
if not ids:
|
||||
logger.debug("gmail: no messages matched query %r", query)
|
||||
return []
|
||||
|
||||
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||
|
||||
# ── Fetch individual message details ──────────────────────────────
|
||||
messages: list[EmailMessage] = []
|
||||
for msg_id in ids:
|
||||
try:
|
||||
@@ -326,10 +245,8 @@ class GmailClient:
|
||||
date=date,
|
||||
labels=labels,
|
||||
))
|
||||
except googleapiclient.errors.HttpError as exc:
|
||||
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
|
||||
except Exception as exc:
|
||||
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
|
||||
logger.warning("gmail: skipping message %s: %s", msg_id, exc)
|
||||
|
||||
logger.info("gmail: returned %d message(s)", len(messages))
|
||||
return messages
|
||||
@@ -1,24 +1,6 @@
|
||||
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
|
||||
"""Microsoft Graph API client for Outlook and Teams.
|
||||
|
||||
Handles two data sources:
|
||||
|
||||
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
|
||||
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
|
||||
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
|
||||
``/me/chats/getAllMessages`` filtered by date.
|
||||
|
||||
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
|
||||
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
|
||||
dependency) is used for all API calls.
|
||||
|
||||
Credential dict shape (Microsoft OAuth2 / MSAL):
|
||||
{
|
||||
"access_token": "<access_token>",
|
||||
"refresh_token": "<refresh_token>",
|
||||
"token_type": "Bearer",
|
||||
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
|
||||
"expires_in": 3600
|
||||
}
|
||||
Adapted for Batch Agent Service: import settings from shared.config.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -30,23 +12,19 @@ from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from app.config.settings import settings
|
||||
from shared.config import settings
|
||||
from app.integrations import ChatMessage, EmailMessage
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||
|
||||
# Max items fetched per run.
|
||||
_MAX_EMAILS = 200
|
||||
_MAX_MESSAGES = 200
|
||||
|
||||
# Max characters of body forwarded to the LLM.
|
||||
_BODY_TRUNCATE = 8_000
|
||||
|
||||
|
||||
def _strip_html(raw: str) -> str:
|
||||
"""Strip HTML tags and collapse whitespace."""
|
||||
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||
import html as _html
|
||||
decoded = _html.unescape(no_tags)
|
||||
@@ -54,7 +32,6 @@ def _strip_html(raw: str) -> str:
|
||||
|
||||
|
||||
def _odata_datetime(dt: datetime) -> str:
|
||||
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
|
||||
utc = dt.astimezone(timezone.utc)
|
||||
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
||||
@@ -63,29 +40,14 @@ def _build_email_filter(
|
||||
filter_config: dict[str, Any] | None,
|
||||
since: datetime | None,
|
||||
) -> str:
|
||||
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
|
||||
|
||||
Supported ``filter_config`` keys:
|
||||
senders (list[str]): Sender email addresses.
|
||||
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
|
||||
folders (list[str]): Folder display names (not directly filterable
|
||||
via OData, so ignored here — callers iterate
|
||||
folder IDs separately if needed; listed for
|
||||
completeness).
|
||||
|
||||
A hard ``since`` date always overrides ``date_range.from`` when it is
|
||||
earlier.
|
||||
"""
|
||||
clauses: list[str] = []
|
||||
cfg = filter_config or {}
|
||||
|
||||
# Senders.
|
||||
senders: list[str] = cfg.get("senders", [])
|
||||
if senders:
|
||||
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||
|
||||
# Date range.
|
||||
date_range: dict = cfg.get("date_range", {})
|
||||
from_str: str | None = date_range.get("from")
|
||||
|
||||
@@ -117,33 +79,16 @@ def _build_email_filter(
|
||||
|
||||
|
||||
class MSGraphClient:
|
||||
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
credentials_info:
|
||||
Decrypted MSAL credential dict.
|
||||
"""
|
||||
|
||||
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||
self._credentials_info = credentials_info
|
||||
self._access_token: str = credentials_info.get("access_token", "")
|
||||
self._original_access_token: str = self._access_token
|
||||
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||
|
||||
# ── Token management ───────────────────────────────────────────────────
|
||||
|
||||
def _auth_headers(self) -> dict[str, str]:
|
||||
return {"Authorization": f"Bearer {self._access_token}"}
|
||||
|
||||
async def _refresh_access_token(self) -> None:
|
||||
"""Use MSAL to exchange the refresh token for a fresh access token.
|
||||
|
||||
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
|
||||
|
||||
Raises:
|
||||
RuntimeError: MSAL reports an auth error.
|
||||
"""
|
||||
import msal
|
||||
|
||||
app = msal.ConfidentialClientApplication(
|
||||
@@ -164,7 +109,6 @@ class MSGraphClient:
|
||||
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||
|
||||
self._access_token = result["access_token"]
|
||||
# MSAL may issue a new refresh token.
|
||||
if "refresh_token" in result:
|
||||
self._refresh_token = result["refresh_token"]
|
||||
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||
@@ -172,16 +116,10 @@ class MSGraphClient:
|
||||
|
||||
@property
|
||||
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||
"""Return updated credential dict if the access token was refreshed.
|
||||
|
||||
Returns ``None`` if no change was made.
|
||||
"""
|
||||
if self._access_token != self._original_access_token:
|
||||
return {**self._credentials_info, "access_token": self._access_token}
|
||||
return None
|
||||
|
||||
# ── HTTP helpers ───────────────────────────────────────────────────────
|
||||
|
||||
async def _get(
|
||||
self,
|
||||
client: httpx.AsyncClient,
|
||||
@@ -190,10 +128,8 @@ class MSGraphClient:
|
||||
*,
|
||||
retry_on_401: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""GET *url* with auth; refresh token on 401 and retry once."""
|
||||
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||
logger.debug("ms_graph: 401 on %s — refreshing token", url)
|
||||
await self._refresh_access_token()
|
||||
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||
if resp.status_code == 429:
|
||||
@@ -201,22 +137,11 @@ class MSGraphClient:
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
|
||||
# ── Public API ─────────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_emails(
|
||||
self,
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[EmailMessage]:
|
||||
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
filter_config:
|
||||
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
|
||||
since:
|
||||
Hard lower-bound on email date (from last agent run).
|
||||
"""
|
||||
odata_filter = _build_email_filter(filter_config, since)
|
||||
params: dict[str, Any] = {
|
||||
"$top": 50,
|
||||
@@ -237,7 +162,7 @@ class MSGraphClient:
|
||||
if len(emails) >= _MAX_EMAILS:
|
||||
break
|
||||
url = data.get("@odata.nextLink", "")
|
||||
params = {} # nextLink already contains encoded params.
|
||||
params = {}
|
||||
|
||||
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||
return emails
|
||||
@@ -247,13 +172,6 @@ class MSGraphClient:
|
||||
filter_config: dict[str, Any] | None = None,
|
||||
since: datetime | None = None,
|
||||
) -> list[ChatMessage]:
|
||||
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
|
||||
|
||||
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
|
||||
The ``filter_config.channels`` key is checked as a text-filter on
|
||||
the channel name post-fetch (the API doesn't support channel OData
|
||||
filter directly on ``getAllMessages``).
|
||||
"""
|
||||
cfg = filter_config or {}
|
||||
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||
params: dict[str, Any] = {"$top": 50}
|
||||
@@ -268,11 +186,9 @@ class MSGraphClient:
|
||||
try:
|
||||
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||
except httpx.HTTPStatusError as exc:
|
||||
# getAllMessages requires specific licensing; degrade gracefully.
|
||||
if exc.response.status_code in (403, 404):
|
||||
logger.warning(
|
||||
"ms_graph: /me/chats/getAllMessages not available (%d) — "
|
||||
"check Teams license or permissions",
|
||||
"ms_graph: /me/chats/getAllMessages not available (%d)",
|
||||
exc.response.status_code,
|
||||
)
|
||||
break
|
||||
@@ -292,8 +208,6 @@ class MSGraphClient:
|
||||
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||
return messages
|
||||
|
||||
# ── Parsers ────────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||
395
services/batch-agent/app/journey.py
Normal file
395
services/batch-agent/app/journey.py
Normal file
@@ -0,0 +1,395 @@
|
||||
"""Chatbot Journey — guided conversation to build an agent prompt_template.
|
||||
|
||||
Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
|
||||
and app.llm instead of monolith paths. Session state is in-memory (could
|
||||
be moved to Redis for horizontal scaling in the future).
|
||||
|
||||
Journey flow:
|
||||
1. Redis consumer dispatches ``journey_start`` with basic agent config.
|
||||
2. Server creates an in-memory session, runs the setup LLM with
|
||||
file-system tools to explore the directory, returns first question.
|
||||
3. ``journey_message`` frames drive the conversation.
|
||||
4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
|
||||
5. Server parses the block and returns ``journey_reply`` with ``done=True``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||
|
||||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
||||
from shared.llm import get_llm
|
||||
import app.tracing as tracing
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||
|
||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||
|
||||
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||
|
||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||
_MAX_TURNS: int = 15
|
||||
_MAX_TOOL_STEPS: int = 6
|
||||
|
||||
# ── In-memory session store ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class JourneySession:
|
||||
session_id: str
|
||||
user_id: str
|
||||
agent_type: str # "local" | "cloud"
|
||||
directory: str
|
||||
data_types: list[str]
|
||||
history: list[dict[str, Any]] = field(default_factory=list)
|
||||
system_prompt: str = ""
|
||||
created_at: float = field(default_factory=time.monotonic)
|
||||
|
||||
def is_expired(self) -> bool:
|
||||
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||
|
||||
|
||||
# session_id → session
|
||||
_sessions: dict[str, JourneySession] = {}
|
||||
|
||||
|
||||
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
||||
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
||||
s = _sessions.get(session_id)
|
||||
if s is None or s.is_expired():
|
||||
_sessions.pop(session_id, None)
|
||||
return None
|
||||
if s.user_id != user_id:
|
||||
return None
|
||||
return s
|
||||
|
||||
|
||||
# ── System prompt builder ─────────────────────────────────────────────────
|
||||
|
||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||
Your job is to understand exactly what data the user wants to extract from their
|
||||
local directory and produce a concise prompt_template that a separate AI will use
|
||||
as its instruction set.
|
||||
|
||||
You have access to file-system tools to explore the user's directory:
|
||||
- list_directory: to see folder structure
|
||||
- read_file_content: to peek at file contents
|
||||
- get_file_metadata: to check file info
|
||||
|
||||
The user's configured directory is: {directory}
|
||||
Target data types: {data_types}
|
||||
|
||||
IMPORTANT — project assignment is handled automatically. You MUST NOT ask the user
|
||||
about projects, projectId, or how to link records to projects. Never include
|
||||
projectId logic or project creation instructions in the generated prompt_template.
|
||||
|
||||
Start by exploring the directory to understand its structure. Then ask concise,
|
||||
focused questions one at a time. Cover only the topics relevant to the target
|
||||
data types listed above:
|
||||
|
||||
1. Content type and format — confirmed by your exploration.
|
||||
2. For TASKS (if in scope): field mapping for title, status, priority, content,
|
||||
dueDate (where is the date found? what's the fallback when absent?),
|
||||
and assignee (is there a person name to assign?).
|
||||
3. For NOTES when TASKS are also in scope: note vs task distinction —
|
||||
what makes something a note rather than a task?
|
||||
4. For TIMELINES (if in scope): the date source — what marks a milestone or event?
|
||||
5. Exclusions and special handling applicable to the target data types.
|
||||
|
||||
Keep asking focused questions until you are at least 90% confident. Then stop and
|
||||
output the final prompt_template immediately, wrapped between these exact markers
|
||||
on their own lines:
|
||||
|
||||
{template_start}
|
||||
<the complete extraction prompt here>
|
||||
{template_end}
|
||||
|
||||
The prompt_template must be concise (bullet points, ~15–25 lines maximum).
|
||||
Specify only:
|
||||
- Scope: what files/content qualify and what entity types to create.
|
||||
- Field mapping rules per entity type (camelCase fields: title, status, priority,
|
||||
dueDate, content, assignee, etc.).
|
||||
- dueDate rule (if tasks in scope): source and fallback behaviour.
|
||||
- Note vs task rule (if both in scope): the criterion that separates them.
|
||||
- Timeline date rule (if timelines in scope): what constitutes a timeline event.
|
||||
- Exclusion/filtering rules.
|
||||
- 2–3 concrete mapping examples based on what you discovered.
|
||||
|
||||
{existing_section}Begin by exploring the directory, then ask your first question.\
|
||||
"""
|
||||
|
||||
|
||||
def _build_system_prompt(
|
||||
directory: str,
|
||||
data_types: list[str],
|
||||
existing_template: str | None = None,
|
||||
) -> str:
|
||||
existing_section = (
|
||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||
f"---\n{existing_template}\n---\n"
|
||||
if existing_template
|
||||
else ""
|
||||
)
|
||||
# Use Langfuse compile_prompt ({{variable}} syntax) with Python .format() fallback
|
||||
return tracing.compile_prompt(
|
||||
"journey_system",
|
||||
fallback=_SYSTEM_PROMPT_TEMPLATE,
|
||||
variables={
|
||||
"directory": directory,
|
||||
"data_types": ", ".join(data_types),
|
||||
"existing_section": existing_section,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# ── Template extraction ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _extract_template(text: str) -> str | None:
|
||||
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||
return None
|
||||
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||
end_idx = text.index(_TEMPLATE_END)
|
||||
return text[start_idx:end_idx].strip() or None
|
||||
|
||||
|
||||
# ── LLM call with tool support ───────────────────────────────────────────
|
||||
|
||||
|
||||
def _as_text(content: Any) -> str:
|
||||
if content is None:
|
||||
return ""
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts: list[str] = []
|
||||
for item in content:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
text = item.get("text")
|
||||
if isinstance(text, str):
|
||||
parts.append(text)
|
||||
return "".join(parts)
|
||||
return str(content)
|
||||
|
||||
|
||||
async def _call_llm_with_tools(
|
||||
system_prompt: str,
|
||||
history: list[dict[str, Any]],
|
||||
tools: list[Any],
|
||||
langfuse_handler: Any | None = None,
|
||||
) -> str:
|
||||
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||
|
||||
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||
continue until a final text response is produced.
|
||||
"""
|
||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||
for turn in history:
|
||||
if turn["role"] == "user":
|
||||
messages.append(HumanMessage(content=turn["content"]))
|
||||
else:
|
||||
messages.append(AIMessage(content=turn["content"]))
|
||||
|
||||
callbacks = [langfuse_handler] if langfuse_handler else None
|
||||
llm = get_llm(model=None, temperature=0.4, callbacks=callbacks)
|
||||
llm_with_tools = llm.bind_tools(tools)
|
||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||
|
||||
for _ in range(_MAX_TOOL_STEPS):
|
||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||
messages.append(response)
|
||||
|
||||
if not response.tool_calls:
|
||||
return _as_text(response.content)
|
||||
|
||||
for call in response.tool_calls:
|
||||
call_name = str(call.get("name", ""))
|
||||
call_args = call.get("args", {})
|
||||
logger.info(
|
||||
"journey: tool_call name=%s args=%s",
|
||||
call_name,
|
||||
json.dumps(call_args, ensure_ascii=True)[:500],
|
||||
)
|
||||
|
||||
tool_fn = tool_map.get(call_name)
|
||||
if tool_fn is None:
|
||||
tool_output = f"Unknown tool: {call_name}"
|
||||
else:
|
||||
tool_output = await tool_fn.ainvoke(call_args)
|
||||
|
||||
logger.info(
|
||||
"journey: tool_result name=%s output=%s",
|
||||
call_name,
|
||||
str(tool_output)[:800],
|
||||
)
|
||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||
|
||||
# Fallback: exceeded max tool steps.
|
||||
final = await llm.ainvoke(messages)
|
||||
return _as_text(final.content)
|
||||
|
||||
|
||||
# ── Journey handlers (called from redis_consumer) ────────────────────────
|
||||
|
||||
|
||||
async def handle_journey_start(
|
||||
user_id: str,
|
||||
frame: dict[str, Any],
|
||||
*,
|
||||
langfuse_handler: Any | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Handle a ``journey_start`` request.
|
||||
|
||||
Creates a session, runs the setup LLM with directory exploration,
|
||||
and returns the ``journey_reply`` payload.
|
||||
"""
|
||||
agent_type = frame.get("agent_type", "local")
|
||||
directory = frame.get("directory", "")
|
||||
data_types = frame.get("data_types", [])
|
||||
existing_template = frame.get("existing_template")
|
||||
|
||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
||||
|
||||
session = JourneySession(
|
||||
session_id=session_id,
|
||||
user_id=user_id,
|
||||
agent_type=agent_type,
|
||||
directory=directory,
|
||||
data_types=data_types,
|
||||
system_prompt=system_prompt,
|
||||
)
|
||||
|
||||
seed_history: list[dict[str, Any]] = [
|
||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||
]
|
||||
ai_reply = await _call_llm_with_tools(
|
||||
system_prompt=system_prompt,
|
||||
history=seed_history,
|
||||
tools=list(FILESYSTEM_TOOLS),
|
||||
langfuse_handler=langfuse_handler,
|
||||
)
|
||||
|
||||
session.history.extend(seed_history)
|
||||
session.history.append({"role": "assistant", "content": ai_reply})
|
||||
_sessions[session_id] = session
|
||||
|
||||
logger.info(
|
||||
"journey: session %s started for user %s (directory=%s)",
|
||||
session_id,
|
||||
user_id,
|
||||
directory,
|
||||
)
|
||||
|
||||
prompt_template = _extract_template(ai_reply)
|
||||
done = prompt_template is not None
|
||||
|
||||
display_message = ai_reply
|
||||
if done:
|
||||
display_message = (
|
||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||
or "Here is your agent configuration. You can save it or continue refining."
|
||||
)
|
||||
_sessions.pop(session_id, None)
|
||||
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": display_message,
|
||||
"done": done,
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
|
||||
|
||||
async def handle_journey_message(
|
||||
user_id: str,
|
||||
frame: dict[str, Any],
|
||||
*,
|
||||
langfuse_handler: Any | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Handle a ``journey_message`` request.
|
||||
|
||||
Appends the user message, calls the LLM, and returns the
|
||||
``journey_reply`` payload.
|
||||
"""
|
||||
session_id = frame.get("session_id", "")
|
||||
message = frame.get("message", "")
|
||||
|
||||
session = get_journey_session(session_id, user_id)
|
||||
if session is None:
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": "Journey session not found or expired. Please start a new setup.",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
}
|
||||
|
||||
session.history.append({"role": "user", "content": message})
|
||||
|
||||
ai_reply = await _call_llm_with_tools(
|
||||
system_prompt=session.system_prompt,
|
||||
history=session.history,
|
||||
tools=list(FILESYSTEM_TOOLS),
|
||||
langfuse_handler=langfuse_handler,
|
||||
)
|
||||
|
||||
session.history.append({"role": "assistant", "content": ai_reply})
|
||||
|
||||
prompt_template = _extract_template(ai_reply)
|
||||
done = prompt_template is not None
|
||||
|
||||
if not done:
|
||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||
if turns >= _MAX_TURNS:
|
||||
nudge_content = (
|
||||
"[System: You have enough information. Please generate the final "
|
||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||
)
|
||||
session.history.append({"role": "user", "content": nudge_content})
|
||||
|
||||
nudge_reply = await _call_llm_with_tools(
|
||||
system_prompt=session.system_prompt,
|
||||
history=session.history,
|
||||
tools=list(FILESYSTEM_TOOLS),
|
||||
langfuse_handler=langfuse_handler,
|
||||
)
|
||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||
|
||||
prompt_template = _extract_template(nudge_reply)
|
||||
if prompt_template is not None:
|
||||
done = True
|
||||
ai_reply = nudge_reply
|
||||
|
||||
display_message = ai_reply
|
||||
if done:
|
||||
display_message = (
|
||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||
if _TEMPLATE_START in ai_reply
|
||||
else "Here is your agent configuration. You can save it or continue refining."
|
||||
)
|
||||
_sessions.pop(session_id, None)
|
||||
logger.info("journey: session %s completed for user %s", session_id, user_id)
|
||||
|
||||
return {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": display_message,
|
||||
"done": done,
|
||||
"prompt_template": prompt_template,
|
||||
}
|
||||
76
services/batch-agent/app/llm.py
Normal file
76
services/batch-agent/app/llm.py
Normal file
@@ -0,0 +1,76 @@
|
||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||
|
||||
Identical to services/chat/app/llm.py. Uses shared.config.settings.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import warnings
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
import litellm
|
||||
|
||||
from langchain_openai import ChatOpenAI
|
||||
from langchain_litellm import ChatLiteLLM
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
||||
category=UserWarning,
|
||||
)
|
||||
|
||||
|
||||
def _api_key_for_model(model: str) -> str | None:
|
||||
if model.startswith("anthropic/"):
|
||||
return settings.ANTHROPIC_API_KEY or None
|
||||
if model.startswith("gemini/") or model.startswith("google/"):
|
||||
return settings.GOOGLE_API_KEY or None
|
||||
if model.startswith("cerebras/"):
|
||||
return settings.CEREBRAS_API_KEY or None
|
||||
if model.startswith("github/"):
|
||||
return settings.GITHUB_TOKEN or None
|
||||
if model.startswith("github_copilot/"):
|
||||
return None
|
||||
return settings.OPENAI_API_KEY or None
|
||||
|
||||
|
||||
def get_llm(
|
||||
*,
|
||||
model: str | None = None,
|
||||
temperature: float = 0,
|
||||
callbacks: list | None = None,
|
||||
) -> ChatOpenAI | ChatLiteLLM:
|
||||
model = model or settings.LLM_MODEL
|
||||
|
||||
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||
|
||||
if settings.GITHUB_TOKEN:
|
||||
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
|
||||
|
||||
if "/" in model:
|
||||
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
|
||||
|
||||
return ChatOpenAI(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=_api_key_for_model(model),
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
|
||||
async def embed(text: str) -> list[float]:
|
||||
model = settings.LLM_EMBED_MODEL
|
||||
|
||||
if model.startswith("github_copilot/") or "/" in model:
|
||||
response = await litellm.aembedding(model=model, input=[text])
|
||||
return response.data[0]["embedding"]
|
||||
|
||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||
response = await client.embeddings.create(model=model, input=text)
|
||||
return response.data[0].embedding
|
||||
79
services/batch-agent/app/main.py
Normal file
79
services/batch-agent/app/main.py
Normal file
@@ -0,0 +1,79 @@
|
||||
"""Batch Agent Service — FastAPI application.
|
||||
|
||||
Owns: agent_runner (local directory + cloud connectors), journey builder,
|
||||
filesystem_agent, integrations (Gmail, MS Graph).
|
||||
|
||||
Communicates with WS Gateway via Redis:
|
||||
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
|
||||
- Publishes to ws:out:{user_id} (journey replies + tool calls)
|
||||
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
|
||||
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the repo root is on sys.path so ``shared`` is importable when
|
||||
# running locally (in Docker the COPY already places it at /app/shared/).
|
||||
_repo_root = str(Path(__file__).resolve().parents[3])
|
||||
if _repo_root not in sys.path:
|
||||
sys.path.insert(0, _repo_root)
|
||||
|
||||
from contextlib import asynccontextmanager
|
||||
from typing import AsyncGenerator
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.redis_consumer import start_consumer
|
||||
from app.routes import router
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
# Initialise Langfuse tracing (no-op if keys are missing)
|
||||
from app.tracing import init_langfuse
|
||||
init_langfuse()
|
||||
|
||||
logger.info("batch-agent: starting Redis consumer")
|
||||
task = asyncio.create_task(start_consumer())
|
||||
yield
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
from app.tracing import shutdown as shutdown_langfuse
|
||||
shutdown_langfuse()
|
||||
|
||||
from shared.db import engine
|
||||
await engine.dispose()
|
||||
|
||||
from shared.redis import redis_client
|
||||
await redis_client.aclose()
|
||||
|
||||
logger.info("batch-agent: Redis consumer stopped")
|
||||
|
||||
|
||||
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
allow_methods=["GET", "POST"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
app.include_router(router)
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
async def health() -> dict[str, str]:
|
||||
return {"status": "ok", "service": "batch-agent"}
|
||||
183
services/batch-agent/app/redis_consumer.py
Normal file
183
services/batch-agent/app/redis_consumer.py
Normal file
@@ -0,0 +1,183 @@
|
||||
"""Redis consumer for the Batch Agent Service.
|
||||
|
||||
Subscribes to batch:request:* (pattern) and dispatches:
|
||||
- journey_start → handle_journey_start
|
||||
- journey_message → handle_journey_message
|
||||
- agent_trigger → run_local_agent / run_cloud_agent
|
||||
|
||||
Results are published back to ws:out:{user_id} via Redis.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from shared.redis import redis_client, batch_request_channel, ws_out_channel
|
||||
|
||||
import app.tracing as tracing
|
||||
from shared.ws_context import set_current_user, clear_current_user
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
|
||||
"""Publish a frame to the user's WS outbound channel."""
|
||||
channel = ws_out_channel(user_id)
|
||||
await redis_client.publish(channel, json.dumps(payload))
|
||||
|
||||
|
||||
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
|
||||
"""Handle a journey_start request from WS Gateway."""
|
||||
from app.journey import handle_journey_start
|
||||
|
||||
session_id = data.get("session_id", "")
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
with tracing.trace_span(
|
||||
name="journey_start",
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
input=data.get("directory", ""),
|
||||
metadata={"data_types": data.get("data_types", [])},
|
||||
tags=["journey"],
|
||||
) as span:
|
||||
langfuse_handler = tracing.get_langfuse_callback()
|
||||
reply = await handle_journey_start(user_id, data, langfuse_handler=langfuse_handler)
|
||||
tracing.link_prompt_to_trace(span, "journey_system")
|
||||
span.update(output=reply.get("message", "")[:500])
|
||||
await _publish_to_user(user_id, reply)
|
||||
tracing.flush()
|
||||
except Exception as exc:
|
||||
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
|
||||
await _publish_to_user(user_id, {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": f"Journey setup failed: {exc}",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
})
|
||||
finally:
|
||||
clear_current_user()
|
||||
|
||||
|
||||
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
|
||||
"""Handle a journey_message from WS Gateway."""
|
||||
from app.journey import handle_journey_message
|
||||
|
||||
session_id = data.get("session_id", "")
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
with tracing.trace_span(
|
||||
name="journey_message",
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
input=data.get("message", "")[:200],
|
||||
tags=["journey"],
|
||||
) as span:
|
||||
langfuse_handler = tracing.get_langfuse_callback()
|
||||
reply = await handle_journey_message(user_id, data, langfuse_handler=langfuse_handler)
|
||||
tracing.link_prompt_to_trace(span, "journey_system")
|
||||
span.update(output=reply.get("message", "")[:500])
|
||||
await _publish_to_user(user_id, reply)
|
||||
tracing.flush()
|
||||
except Exception as exc:
|
||||
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
|
||||
await _publish_to_user(user_id, {
|
||||
"type": "journey_reply",
|
||||
"session_id": session_id,
|
||||
"message": f"Journey processing failed: {exc}",
|
||||
"done": True,
|
||||
"prompt_template": None,
|
||||
})
|
||||
finally:
|
||||
clear_current_user()
|
||||
|
||||
|
||||
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
|
||||
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
|
||||
from app.agent_runner import run_local_agent
|
||||
|
||||
run_context = data.get("run_context", {})
|
||||
agent_id = run_context.get("agent_id", "")
|
||||
set_current_user(user_id)
|
||||
try:
|
||||
with tracing.trace_span(
|
||||
name="agent_trigger",
|
||||
user_id=user_id,
|
||||
trace_id=run_context.get("run_id"),
|
||||
input={"agent_id": agent_id, "directory": data.get("directory", "")},
|
||||
metadata={"data_types": data.get("data_types", [])},
|
||||
tags=["batch", "agent_run"],
|
||||
) as span:
|
||||
langfuse_handler = tracing.get_langfuse_callback()
|
||||
await run_local_agent(user_id, data, langfuse_handler=langfuse_handler)
|
||||
tracing.link_prompt_to_trace(span, "batch_processing")
|
||||
span.update(output={"status": "completed"})
|
||||
tracing.flush()
|
||||
except Exception as exc:
|
||||
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
|
||||
await _publish_to_user(user_id, {
|
||||
"type": "run_complete",
|
||||
"status": "error",
|
||||
"run_context": run_context,
|
||||
})
|
||||
finally:
|
||||
clear_current_user()
|
||||
|
||||
|
||||
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
|
||||
"""Route a batch request to the correct handler."""
|
||||
msg_type = message_data.get("type", "")
|
||||
|
||||
if msg_type == "journey_start":
|
||||
await _handle_journey_start(user_id, message_data)
|
||||
elif msg_type == "journey_message":
|
||||
await _handle_journey_message(user_id, message_data)
|
||||
elif msg_type == "agent_trigger":
|
||||
await _handle_agent_trigger(user_id, message_data)
|
||||
elif msg_type == "device_online":
|
||||
logger.info("batch-agent: device_online user=%s device=%s", user_id, message_data.get("device_id", "?"))
|
||||
else:
|
||||
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
|
||||
|
||||
|
||||
async def start_consumer() -> None:
|
||||
"""Subscribe to batch:request:* and dispatch incoming frames."""
|
||||
pubsub = redis_client.pubsub()
|
||||
await pubsub.psubscribe("batch:request:*")
|
||||
logger.info("batch-agent: subscribed to batch:request:*")
|
||||
|
||||
try:
|
||||
async for message in pubsub.listen():
|
||||
if message["type"] != "pmessage":
|
||||
continue
|
||||
|
||||
channel: str = message["channel"]
|
||||
if isinstance(channel, bytes):
|
||||
channel = channel.decode()
|
||||
|
||||
# Extract user_id from channel: batch:request:{user_id}
|
||||
parts = channel.split(":", 2)
|
||||
if len(parts) < 3:
|
||||
continue
|
||||
user_id = parts[2]
|
||||
|
||||
raw = message["data"]
|
||||
if isinstance(raw, bytes):
|
||||
raw = raw.decode()
|
||||
|
||||
try:
|
||||
data = json.loads(raw)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning("batch-agent: invalid JSON on channel %s", channel)
|
||||
continue
|
||||
|
||||
# Dispatch in a separate task to avoid blocking the consumer
|
||||
asyncio.create_task(_dispatch(user_id, data))
|
||||
except asyncio.CancelledError:
|
||||
logger.info("batch-agent: consumer shutting down")
|
||||
finally:
|
||||
await pubsub.punsubscribe("batch:request:*")
|
||||
208
services/batch-agent/app/routes.py
Normal file
208
services/batch-agent/app/routes.py
Normal file
@@ -0,0 +1,208 @@
|
||||
"""Agent REST routes — catalog, billing checks, trigger.
|
||||
|
||||
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
|
||||
Agent trigger dispatches via Redis to the consumer instead of spawning
|
||||
an in-process background task.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, status
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from shared.db import async_session
|
||||
from shared.models import AgentRunLog
|
||||
from shared.redis import redis_client, batch_request_channel
|
||||
|
||||
from app.agent_runner import is_agent_running
|
||||
|
||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||
|
||||
# ── Tier feature limits ───────────────────────────────────────────────
|
||||
# Mirrors app/billing/tier_manager.py FEATURES dict.
|
||||
FEATURES: dict[str, dict] = {
|
||||
"free": {"batch_active": 1, "batch_runs_per_day": 3},
|
||||
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
|
||||
"power": {"batch_active": 20, "batch_runs_per_day": 100},
|
||||
"team": {"batch_active": -1, "batch_runs_per_day": -1},
|
||||
}
|
||||
|
||||
|
||||
def _dt_ms(dt: datetime) -> int:
|
||||
return int(dt.timestamp() * 1000)
|
||||
|
||||
|
||||
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
def _to_data_types(values: list[str]) -> list[str]:
|
||||
normalize = {
|
||||
"task": "tasks", "tasks": "tasks",
|
||||
"note": "notes", "notes": "notes",
|
||||
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
||||
"project": "projects", "projects": "projects",
|
||||
}
|
||||
seen: set[str] = set()
|
||||
result: list[str] = []
|
||||
for v in values:
|
||||
mapped = normalize.get(v)
|
||||
if mapped and mapped not in seen:
|
||||
seen.add(mapped)
|
||||
result.append(mapped)
|
||||
return result
|
||||
|
||||
|
||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||
if limit != -1 and current_count >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||
)
|
||||
return limit
|
||||
|
||||
|
||||
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
|
||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
||||
if limit == -1:
|
||||
return
|
||||
today_start = datetime.now(timezone.utc).replace(
|
||||
hour=0, minute=0, second=0, microsecond=0
|
||||
)
|
||||
async with async_session() as db:
|
||||
result = await db.execute(
|
||||
select(func.count(AgentRunLog.id)).where(
|
||||
AgentRunLog.user_id == user_id,
|
||||
AgentRunLog.started_at >= today_start,
|
||||
)
|
||||
)
|
||||
runs_today: int = result.scalar_one()
|
||||
|
||||
if runs_today >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||
detail=f"Daily batch run limit ({limit}) reached for your tier.",
|
||||
)
|
||||
|
||||
|
||||
# ── Catalog ───────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/catalog")
|
||||
async def get_agent_catalog(
|
||||
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||
) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"type": "local_directory",
|
||||
"name": "Local Directory Monitor",
|
||||
"description": "Watches local directories, extracts data from files using AI",
|
||||
},
|
||||
{
|
||||
"type": "gmail",
|
||||
"name": "Gmail Connector",
|
||||
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
|
||||
},
|
||||
{
|
||||
"type": "teams",
|
||||
"name": "Microsoft Teams Connector",
|
||||
"description": "Monitors Teams messages, extracts action items",
|
||||
},
|
||||
{
|
||||
"type": "outlook",
|
||||
"name": "Outlook Connector",
|
||||
"description": "Scans Outlook inbox, extracts tasks/notes",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
# ── Can-create check ─────────────────────────────────────────────────
|
||||
|
||||
@router.post("/can-create")
|
||||
async def can_create_agent(
|
||||
body: dict,
|
||||
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||
) -> dict:
|
||||
active_agents = body.get("active_agents", 0)
|
||||
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
|
||||
allowed = limit == -1 or active_agents < limit
|
||||
return {
|
||||
"allowed": allowed,
|
||||
"tier": x_user_tier,
|
||||
"active_agents": active_agents,
|
||||
"limit": limit,
|
||||
}
|
||||
|
||||
|
||||
# ── Trigger ──────────────────────────────────────────────────────────
|
||||
|
||||
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
|
||||
async def trigger_agent_run(
|
||||
body: dict,
|
||||
x_user_id: str = Header(..., alias="X-User-Id"),
|
||||
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
||||
) -> dict:
|
||||
"""Trigger a local agent run — creates run log and dispatches via Redis."""
|
||||
active_agents = body.get("active_agents", 0)
|
||||
_enforce_agent_limit(x_user_tier, active_agents)
|
||||
await _enforce_run_frequency(x_user_tier, x_user_id)
|
||||
|
||||
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
|
||||
|
||||
if is_agent_running(stable_agent_id):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Agent is already running.",
|
||||
)
|
||||
|
||||
# Create run log in DB
|
||||
async with async_session() as db:
|
||||
run_log = AgentRunLog(
|
||||
agent_id=stable_agent_id,
|
||||
agent_type="local",
|
||||
user_id=x_user_id,
|
||||
status="running",
|
||||
)
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
run_log_id = run_log.id
|
||||
|
||||
run_context = {
|
||||
"type": "agent_batch",
|
||||
"run_id": run_log_id,
|
||||
"agent_id": stable_agent_id,
|
||||
}
|
||||
|
||||
# Dispatch to the Redis consumer for processing
|
||||
trigger_data = {
|
||||
"type": "agent_trigger",
|
||||
"directory": body.get("directory", ""),
|
||||
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
|
||||
"data_types": _to_data_types(body.get("what_to_extract", [])),
|
||||
"file_extensions": body.get("file_extensions", []),
|
||||
"prompt_template": body.get("custom_agent_prompt", ""),
|
||||
"device_id": body.get("device_id", ""),
|
||||
"run_context": run_context,
|
||||
}
|
||||
|
||||
channel = batch_request_channel(x_user_id)
|
||||
await redis_client.publish(channel, json.dumps(trigger_data))
|
||||
|
||||
return {
|
||||
"id": run_log_id,
|
||||
"agent_id": stable_agent_id,
|
||||
"agent_type": "local",
|
||||
"status": "running",
|
||||
"items_processed": 0,
|
||||
"items_created": 0,
|
||||
"errors": [],
|
||||
"started_at": _dt_ms(run_log.started_at),
|
||||
"completed_at": None,
|
||||
}
|
||||
336
services/batch-agent/app/tracing.py
Normal file
336
services/batch-agent/app/tracing.py
Normal file
@@ -0,0 +1,336 @@
|
||||
"""Langfuse tracing & prompt management for the Batch Agent Service (v4 SDK).
|
||||
|
||||
Provides:
|
||||
- ``init_langfuse()`` — initialise the singleton client at startup
|
||||
- ``trace_span()`` — context manager that creates a trace + span
|
||||
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
|
||||
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
|
||||
- ``flush()`` / ``shutdown()`` — lifecycle management
|
||||
|
||||
All functions gracefully degrade to no-ops when Langfuse is not configured,
|
||||
so the service works identically with or without observability keys.
|
||||
|
||||
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
|
||||
from shared.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── State ────────────────────────────────────────────────────────────────
|
||||
|
||||
_initialised: bool = False
|
||||
_disabled: bool = False
|
||||
|
||||
|
||||
def _is_configured() -> bool:
|
||||
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
|
||||
|
||||
|
||||
def init_langfuse() -> None:
|
||||
"""Initialise the Langfuse singleton. Call once at startup."""
|
||||
global _initialised, _disabled
|
||||
|
||||
if _initialised or _disabled:
|
||||
return
|
||||
|
||||
if not _is_configured():
|
||||
_disabled = True
|
||||
logger.info("tracing: Langfuse keys not set — tracing disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
from langfuse import Langfuse
|
||||
|
||||
Langfuse(
|
||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||
host=settings.LANGFUSE_HOST,
|
||||
)
|
||||
_initialised = True
|
||||
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
|
||||
except Exception as exc:
|
||||
_disabled = True
|
||||
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
|
||||
|
||||
|
||||
def _get_client() -> Any | None:
|
||||
"""Return the singleton Langfuse client, or *None* if disabled."""
|
||||
if _disabled:
|
||||
return None
|
||||
if not _initialised:
|
||||
init_langfuse()
|
||||
if _disabled:
|
||||
return None
|
||||
try:
|
||||
from langfuse import get_client
|
||||
return get_client()
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
|
||||
|
||||
|
||||
class _NullSpan:
|
||||
"""Drop-in replacement when Langfuse is disabled."""
|
||||
|
||||
def update(self, **_: Any) -> None: ...
|
||||
def set_trace_io(self, **_: Any) -> None: ...
|
||||
def score_trace(self, **_: Any) -> None: ...
|
||||
|
||||
|
||||
# ── Trace context manager ───────────────────────────────────────────────
|
||||
|
||||
|
||||
@contextmanager
|
||||
def trace_span(
|
||||
*,
|
||||
name: str,
|
||||
user_id: str,
|
||||
session_id: str | None = None,
|
||||
trace_id: str | None = None,
|
||||
input: Any = None,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
tags: list[str] | None = None,
|
||||
):
|
||||
"""Context manager that creates a Langfuse trace/span.
|
||||
|
||||
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
|
||||
A ``CallbackHandler`` created inside this block auto-inherits the trace
|
||||
context, so there is no need to pass trace IDs manually.
|
||||
"""
|
||||
lf = _get_client()
|
||||
if lf is None:
|
||||
yield _NullSpan()
|
||||
return
|
||||
|
||||
try:
|
||||
from langfuse import Langfuse, propagate_attributes
|
||||
|
||||
trace_ctx: dict[str, str] = {}
|
||||
if trace_id is not None:
|
||||
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
|
||||
|
||||
with lf.start_as_current_observation(
|
||||
as_type="span",
|
||||
name=name,
|
||||
input=input,
|
||||
metadata=metadata or {},
|
||||
**({"trace_context": trace_ctx} if trace_ctx else {}),
|
||||
) as span:
|
||||
with propagate_attributes(
|
||||
user_id=user_id,
|
||||
session_id=session_id,
|
||||
tags=tags or [],
|
||||
):
|
||||
yield span
|
||||
except Exception as exc:
|
||||
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
|
||||
yield _NullSpan()
|
||||
|
||||
|
||||
# ── LangChain callback handler ──────────────────────────────────────────
|
||||
|
||||
|
||||
def get_langfuse_callback() -> Any | None:
|
||||
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
|
||||
|
||||
Must be called inside a ``trace_span()`` block for proper linking.
|
||||
Returns *None* when Langfuse is disabled.
|
||||
"""
|
||||
if _disabled and not _initialised:
|
||||
return None
|
||||
|
||||
try:
|
||||
from langfuse.langchain import CallbackHandler
|
||||
return CallbackHandler()
|
||||
except Exception as exc:
|
||||
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
|
||||
return None
|
||||
|
||||
|
||||
# ── Prompt management ────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_prompt(
|
||||
name: str,
|
||||
*,
|
||||
version: int | None = None,
|
||||
label: str | None = None,
|
||||
fallback: str | None = None,
|
||||
cache_ttl_seconds: int = 300,
|
||||
) -> str | None:
|
||||
"""Fetch a managed prompt from Langfuse by name (without variable compilation).
|
||||
|
||||
Returns the raw prompt string, or *fallback* if the prompt is not
|
||||
found or Langfuse is disabled.
|
||||
"""
|
||||
lf = _get_client()
|
||||
if lf is None:
|
||||
return fallback
|
||||
|
||||
try:
|
||||
kwargs: dict[str, Any] = {
|
||||
"name": name,
|
||||
"cache_ttl_seconds": cache_ttl_seconds,
|
||||
}
|
||||
if version is not None:
|
||||
kwargs["version"] = version
|
||||
if label is not None:
|
||||
kwargs["label"] = label
|
||||
prompt = lf.get_prompt(**kwargs)
|
||||
return prompt.prompt
|
||||
except Exception as exc:
|
||||
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
|
||||
return fallback
|
||||
|
||||
|
||||
def compile_prompt(
|
||||
name: str,
|
||||
*,
|
||||
fallback: str,
|
||||
variables: dict[str, str],
|
||||
version: int | None = None,
|
||||
label: str | None = None,
|
||||
cache_ttl_seconds: int = 300,
|
||||
) -> str:
|
||||
"""Fetch a managed prompt from Langfuse and compile it with ``{{variables}}``.
|
||||
|
||||
If the prompt exists in Langfuse, uses the SDK's ``.compile(**variables)``
|
||||
which replaces ``{{key}}`` placeholders. If Langfuse is disabled or the
|
||||
prompt is not found, falls back to ``fallback.format(**variables)`` (Python
|
||||
``{key}`` placeholders).
|
||||
|
||||
This means:
|
||||
- Langfuse prompts use ``{{variable}}`` syntax.
|
||||
- Hardcoded fallback strings use Python ``{variable}`` syntax.
|
||||
"""
|
||||
lf = _get_client()
|
||||
if lf is None:
|
||||
return fallback.format(**variables)
|
||||
|
||||
try:
|
||||
kwargs: dict[str, Any] = {
|
||||
"name": name,
|
||||
"cache_ttl_seconds": cache_ttl_seconds,
|
||||
}
|
||||
if version is not None:
|
||||
kwargs["version"] = version
|
||||
if label is not None:
|
||||
kwargs["label"] = label
|
||||
prompt = lf.get_prompt(**kwargs)
|
||||
return prompt.compile(**variables)
|
||||
except Exception as exc:
|
||||
logger.warning("tracing: compile_prompt(%s) failed, using fallback: %s", name, exc)
|
||||
return fallback.format(**variables)
|
||||
|
||||
|
||||
def get_prompt_object(
|
||||
name: str,
|
||||
*,
|
||||
version: int | None = None,
|
||||
label: str | None = None,
|
||||
cache_ttl_seconds: int = 300,
|
||||
) -> Any | None:
|
||||
"""Fetch the raw Langfuse prompt *object* (not the compiled string).
|
||||
|
||||
Returns ``None`` when Langfuse is disabled or the prompt is not found.
|
||||
Use this when you need to pass the prompt to ``start_observation(prompt=...)``
|
||||
for linking the prompt to a trace in the Langfuse UI.
|
||||
"""
|
||||
lf = _get_client()
|
||||
if lf is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
kwargs: dict[str, Any] = {
|
||||
"name": name,
|
||||
"cache_ttl_seconds": cache_ttl_seconds,
|
||||
}
|
||||
if version is not None:
|
||||
kwargs["version"] = version
|
||||
if label is not None:
|
||||
kwargs["label"] = label
|
||||
return lf.get_prompt(**kwargs)
|
||||
except Exception as exc:
|
||||
logger.warning("tracing: get_prompt_object(%s) failed: %s", name, exc)
|
||||
return None
|
||||
|
||||
|
||||
def link_prompt_to_trace(
|
||||
span: Any,
|
||||
prompt_name: str,
|
||||
*,
|
||||
version: int | None = None,
|
||||
label: str | None = None,
|
||||
) -> None:
|
||||
"""Link a Langfuse managed prompt to a span/observation.
|
||||
|
||||
Uses the SDK v4 ``prompt=`` parameter so that the prompt version
|
||||
appears linked in the Langfuse UI with metrics tracking.
|
||||
"""
|
||||
lf = _get_client()
|
||||
if lf is None or isinstance(span, _NullSpan):
|
||||
return
|
||||
|
||||
try:
|
||||
prompt = get_prompt_object(prompt_name, version=version, label=label)
|
||||
if prompt is not None:
|
||||
span.update(prompt=prompt)
|
||||
except Exception as exc:
|
||||
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
|
||||
|
||||
|
||||
# ── Scoring helper ───────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def score_trace(
|
||||
trace_id: str,
|
||||
name: str,
|
||||
value: float,
|
||||
*,
|
||||
comment: str | None = None,
|
||||
) -> None:
|
||||
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
|
||||
lf = _get_client()
|
||||
if lf is None:
|
||||
return
|
||||
|
||||
try:
|
||||
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
|
||||
except Exception as exc:
|
||||
logger.warning("tracing: score_trace failed: %s", exc)
|
||||
|
||||
|
||||
# ── Shutdown ─────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def flush() -> None:
|
||||
"""Flush pending Langfuse events."""
|
||||
lf = _get_client()
|
||||
if lf is not None:
|
||||
try:
|
||||
lf.flush()
|
||||
except Exception as exc:
|
||||
logger.warning("tracing: flush failed: %s", exc)
|
||||
|
||||
|
||||
def shutdown() -> None:
|
||||
"""Flush and close the Langfuse client."""
|
||||
global _initialised, _disabled
|
||||
lf = _get_client()
|
||||
if lf is not None:
|
||||
try:
|
||||
lf.flush()
|
||||
lf.shutdown()
|
||||
except Exception as exc:
|
||||
logger.warning("tracing: shutdown failed: %s", exc)
|
||||
_initialised = False
|
||||
_disabled = False
|
||||
1
services/batch-agent/eval/__init__.py
Normal file
1
services/batch-agent/eval/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Batch Agent E2E evaluation harness."""
|
||||
5
services/batch-agent/eval/__main__.py
Normal file
5
services/batch-agent/eval/__main__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Allow running the eval package as ``python -m eval``."""
|
||||
|
||||
from eval.cli import main
|
||||
|
||||
main()
|
||||
285
services/batch-agent/eval/cli.py
Normal file
285
services/batch-agent/eval/cli.py
Normal file
@@ -0,0 +1,285 @@
|
||||
"""CLI entry point for the batch agent evaluation harness.
|
||||
|
||||
Usage::
|
||||
|
||||
# From services/batch-agent/:
|
||||
python -m eval run # all agent fixtures, default model
|
||||
python -m eval run --fixture=classify-invoices # single fixture
|
||||
python -m eval run --models=gpt-4o,gpt-5.3-codex # multiple models
|
||||
python -m eval run --mode=step1 # only step1 fixtures
|
||||
python -m eval run --no-judge # skip LLM judge scoring
|
||||
|
||||
python -m eval interactive # interactive journey session
|
||||
python -m eval interactive --fixture=journey-invoice-setup
|
||||
python -m eval interactive --model=gpt-4o
|
||||
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
|
||||
|
||||
python -m eval list # list all fixtures
|
||||
python -m eval sync # sync fixtures to Langfuse datasets
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Ensure the service root and repo root are in sys.path.
|
||||
# Service root must come BEFORE repo root so its ``app/`` package
|
||||
# shadows the monolith ``app/`` in the repo root.
|
||||
_SERVICE_ROOT = Path(__file__).resolve().parent.parent
|
||||
_REPO_ROOT = _SERVICE_ROOT.parent.parent
|
||||
_sr = str(_SERVICE_ROOT)
|
||||
_rr = str(_REPO_ROOT)
|
||||
if _rr not in sys.path:
|
||||
sys.path.insert(0, _rr)
|
||||
# Always force service root to position 0 (python -m may have already
|
||||
# added CWD further down the list, which loses to repo root).
|
||||
if _sr in sys.path:
|
||||
sys.path.remove(_sr)
|
||||
sys.path.insert(0, _sr)
|
||||
|
||||
from eval.config import discover_fixtures, discover_journey_fixtures
|
||||
from eval.runner import run_fixture_eval, print_results
|
||||
from eval.interactive import run_interactive
|
||||
from eval import langfuse_eval
|
||||
|
||||
|
||||
def _setup_logging(verbose: bool) -> None:
|
||||
level = logging.DEBUG if verbose else logging.INFO
|
||||
logging.basicConfig(
|
||||
level=level,
|
||||
format="%(asctime)s %(name)-20s %(levelname)-5s %(message)s",
|
||||
datefmt="%H:%M:%S",
|
||||
)
|
||||
# Quiet noisy libraries
|
||||
for name in ("httpx", "httpcore", "openai", "litellm", "urllib3"):
|
||||
logging.getLogger(name).setLevel(logging.WARNING)
|
||||
|
||||
|
||||
def _parse_args() -> argparse.Namespace:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Batch Agent E2E evaluation harness",
|
||||
prog="python -m eval",
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# ── run ───────────────────────────────────────────────────────
|
||||
run_cmd = sub.add_parser("run", help="Run evaluations")
|
||||
run_cmd.add_argument(
|
||||
"--fixture", "-f",
|
||||
help="Run only the named fixture (default: all)",
|
||||
)
|
||||
run_cmd.add_argument(
|
||||
"--models", "-m",
|
||||
default="github_copilot/gpt-5.3-codex",
|
||||
help="Comma-separated list of models to test (default: github_copilot/gpt-5.3-codex)",
|
||||
)
|
||||
run_cmd.add_argument(
|
||||
"--mode",
|
||||
default=None,
|
||||
choices=["step1", "step2", "full"],
|
||||
help="Only run fixtures with this mode (default: all)",
|
||||
)
|
||||
run_cmd.add_argument(
|
||||
"--no-judge",
|
||||
action="store_true",
|
||||
help="Skip LLM-as-judge scoring",
|
||||
)
|
||||
run_cmd.add_argument(
|
||||
"--judge-model",
|
||||
default="gpt-4o",
|
||||
help="Model for LLM judge (default: gpt-4o)",
|
||||
)
|
||||
run_cmd.add_argument(
|
||||
"--fixtures-dir",
|
||||
default=None,
|
||||
help="Path to fixtures directory (default: eval/fixtures/)",
|
||||
)
|
||||
run_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||
|
||||
# ── list ──────────────────────────────────────────────────────
|
||||
list_cmd = sub.add_parser("list", help="List available fixtures")
|
||||
list_cmd.add_argument("--fixtures-dir", default=None)
|
||||
list_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||
|
||||
# ── sync ──────────────────────────────────────────────────────
|
||||
sync_cmd = sub.add_parser("sync", help="Sync fixtures to Langfuse datasets")
|
||||
sync_cmd.add_argument("--fixture", "-f", default=None, help="Sync only the named fixture")
|
||||
sync_cmd.add_argument("--fixtures-dir", default=None)
|
||||
sync_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||
|
||||
# ── interactive ───────────────────────────────────────────────
|
||||
inter_cmd = sub.add_parser("interactive", help="Interactive journey session (human-in-the-loop)")
|
||||
inter_cmd.add_argument(
|
||||
"--fixture", "-f",
|
||||
help="Journey fixture to use (default: pick interactively)",
|
||||
)
|
||||
inter_cmd.add_argument(
|
||||
"--model", "-m",
|
||||
default="github_copilot/gpt-5.3-codex",
|
||||
help="Model for the journey AI (default: github_copilot/gpt-5.3-codex)",
|
||||
)
|
||||
inter_cmd.add_argument(
|
||||
"--judge-model",
|
||||
default="gpt-4o",
|
||||
help="Model for LLM judge (default: gpt-4o)",
|
||||
)
|
||||
inter_cmd.add_argument(
|
||||
"--fixtures-dir",
|
||||
default=None,
|
||||
help="Path to fixtures directory (default: eval/fixtures/)",
|
||||
)
|
||||
inter_cmd.add_argument(
|
||||
"--data-dir",
|
||||
default=None,
|
||||
help="Override sample data directory (e.g. path to private test files not in git)",
|
||||
)
|
||||
inter_cmd.add_argument("-v", "--verbose", action="store_true")
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def _fixtures_dir(arg: str | None) -> Path | None:
|
||||
if arg:
|
||||
return Path(arg)
|
||||
return None
|
||||
|
||||
|
||||
async def _cmd_run(args: argparse.Namespace) -> None:
|
||||
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||
if not fixtures:
|
||||
print("No fixtures found. Create YAML files in eval/fixtures/.")
|
||||
return
|
||||
|
||||
if args.fixture:
|
||||
fixtures = [f for f in fixtures if f.name == args.fixture]
|
||||
if not fixtures:
|
||||
print(f"Fixture '{args.fixture}' not found.")
|
||||
return
|
||||
|
||||
models = [m.strip() for m in args.models.split(",")]
|
||||
|
||||
all_results = []
|
||||
for fixture in fixtures:
|
||||
if args.mode and fixture.mode != args.mode:
|
||||
continue
|
||||
results = await run_fixture_eval(
|
||||
fixture,
|
||||
models=models,
|
||||
use_llm_judge=not args.no_judge,
|
||||
judge_model=args.judge_model,
|
||||
)
|
||||
all_results.extend(results)
|
||||
|
||||
print_results(all_results)
|
||||
|
||||
|
||||
def _cmd_list(args: argparse.Namespace) -> None:
|
||||
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||
|
||||
if not fixtures and not journey_fixtures:
|
||||
print("No fixtures found.")
|
||||
return
|
||||
|
||||
if fixtures:
|
||||
print(f"\n{'[Agent Fixtures]'}")
|
||||
print(f"{'Name':<30} {'Mode':<6} {'Types':<25} {'Expected'}")
|
||||
print("-" * 90)
|
||||
for f in fixtures:
|
||||
types = ", ".join(f.data_types)
|
||||
n_expected = len(f.expected) + len(f.expected_classification)
|
||||
print(f"{f.name:<30} {f.mode:<6} {types:<25} {n_expected}")
|
||||
|
||||
if journey_fixtures:
|
||||
print(f"\n{'[Journey Fixtures]'}")
|
||||
print(f"{'Name':<30} {'Types':<25} {'Messages':<10} {'Criteria'}")
|
||||
print("-" * 90)
|
||||
for f in journey_fixtures:
|
||||
types = ", ".join(f.data_types)
|
||||
print(f"{f.name:<30} {types:<25} {len(f.user_messages):<10} {len(f.expected_template_criteria)}")
|
||||
|
||||
print()
|
||||
|
||||
|
||||
def _cmd_sync(args: argparse.Namespace) -> None:
|
||||
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||
|
||||
if args.fixture:
|
||||
fixtures = [f for f in fixtures if f.name == args.fixture]
|
||||
journey_fixtures = [f for f in journey_fixtures if f.name == args.fixture]
|
||||
|
||||
if not fixtures and not journey_fixtures:
|
||||
print("No fixtures to sync.")
|
||||
return
|
||||
|
||||
for fixture in fixtures:
|
||||
name = langfuse_eval.sync_fixture_to_dataset(fixture)
|
||||
if name:
|
||||
print(f"Synced: {fixture.name} → {name}")
|
||||
else:
|
||||
print(f"Skipped: {fixture.name} (Langfuse not configured)")
|
||||
|
||||
for fixture in journey_fixtures:
|
||||
name = langfuse_eval.sync_journey_fixture_to_dataset(fixture)
|
||||
if name:
|
||||
print(f"Synced: {fixture.name} → {name}")
|
||||
else:
|
||||
print(f"Skipped: {fixture.name} (Langfuse not configured)")
|
||||
|
||||
|
||||
async def _cmd_interactive(args: argparse.Namespace) -> None:
|
||||
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
||||
if not journey_fixtures:
|
||||
print("No journey fixtures found. Create YAML files with type: journey in eval/fixtures/.")
|
||||
return
|
||||
|
||||
if args.fixture:
|
||||
fixtures = [f for f in journey_fixtures if f.name == args.fixture]
|
||||
if not fixtures:
|
||||
print(f"Journey fixture '{args.fixture}' not found.")
|
||||
return
|
||||
fixture = fixtures[0]
|
||||
elif len(journey_fixtures) == 1:
|
||||
fixture = journey_fixtures[0]
|
||||
else:
|
||||
# Let user pick
|
||||
print("\nAvailable journey fixtures:")
|
||||
for i, f in enumerate(journey_fixtures, 1):
|
||||
print(f" {i}. {f.name} — {f.description[:60]}")
|
||||
print()
|
||||
try:
|
||||
choice = int(input("Pick a fixture number: ").strip()) - 1
|
||||
fixture = journey_fixtures[choice]
|
||||
except (ValueError, IndexError, EOFError, KeyboardInterrupt):
|
||||
print("Invalid choice.")
|
||||
return
|
||||
|
||||
await run_interactive(
|
||||
fixture,
|
||||
model=args.model,
|
||||
judge_model=args.judge_model,
|
||||
data_dir=Path(args.data_dir).resolve() if args.data_dir else None,
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
args = _parse_args()
|
||||
_setup_logging(args.verbose)
|
||||
|
||||
if args.command == "run":
|
||||
asyncio.run(_cmd_run(args))
|
||||
elif args.command == "interactive":
|
||||
asyncio.run(_cmd_interactive(args))
|
||||
elif args.command == "list":
|
||||
_cmd_list(args)
|
||||
elif args.command == "sync":
|
||||
_cmd_sync(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
220
services/batch-agent/eval/config.py
Normal file
220
services/batch-agent/eval/config.py
Normal file
@@ -0,0 +1,220 @@
|
||||
"""Eval configuration — YAML fixture loader and dataclasses.
|
||||
|
||||
Fixtures come in two families:
|
||||
|
||||
1. **Agent fixtures** — test the batch agent pipeline.
|
||||
Three modes controlled by ``mode``:
|
||||
|
||||
``step1`` — classification prompt only.
|
||||
``step2`` — processing prompt only.
|
||||
``full`` — both steps in sequence.
|
||||
|
||||
2. **Journey fixtures** — test the prompt-template builder conversation
|
||||
(unchanged).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
EvalMode = Literal["step1", "step2", "full"]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpectedRecord:
|
||||
"""A single expected extraction result.
|
||||
|
||||
Only the fields specified are checked — unspecified fields are ignored.
|
||||
"""
|
||||
|
||||
table: str # tasks | notes | timelines | projects
|
||||
fields: dict[str, Any] # field_name → expected_value
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExpectedClassification:
|
||||
"""Expected output of step-1 classification for one file."""
|
||||
|
||||
file: str # relative path to the sample file
|
||||
project_id: str # expected matched project id, or "new"
|
||||
domains: list[str] # expected domain list
|
||||
new_project_name: str | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalFixture:
|
||||
"""A complete test scenario loaded from YAML.
|
||||
|
||||
``mode`` determines which pipeline steps are exercised:
|
||||
|
||||
- **step1**: only ``_classify_file``
|
||||
- **step2**: only the processing LLM + tool loop
|
||||
- **full**: both steps in sequence (``run_local_agent``)
|
||||
"""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
mode: EvalMode
|
||||
directory: str # relative path to sample files
|
||||
data_types: list[str]
|
||||
file_extensions: list[str]
|
||||
models: list[str] # if empty, use CLI default
|
||||
fixture_path: Path = field(default_factory=lambda: Path("."))
|
||||
|
||||
# ── Step-1 inputs (classification) ───────────────────────────
|
||||
domain_definitions: str = ""
|
||||
projects_list: list[dict[str, Any]] = field(default_factory=list)
|
||||
custom_step1_prompt: str = ""
|
||||
|
||||
# ── Step-2 inputs (processing) ───────────────────────────────
|
||||
existing_context: str = ""
|
||||
project_context: str = ""
|
||||
custom_prompt_section: str = ""
|
||||
|
||||
# ── Seed records for mock executor ───────────────────────────
|
||||
seed_records: dict[str, list[dict]] = field(default_factory=dict)
|
||||
|
||||
# ── Expected outputs ─────────────────────────────────────────
|
||||
expected_classification: list[ExpectedClassification] = field(default_factory=list)
|
||||
expected: list[ExpectedRecord] = field(default_factory=list)
|
||||
|
||||
@property
|
||||
def fixture_dir(self) -> Path:
|
||||
"""Absolute path to the sample files directory."""
|
||||
return self.fixture_path.parent / self.directory
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: Path) -> "EvalFixture":
|
||||
"""Load a fixture from a YAML file."""
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||
|
||||
mode: EvalMode = raw.get("mode", "full")
|
||||
|
||||
# Parse expected records (step2/full)
|
||||
expected: list[ExpectedRecord] = []
|
||||
for table, records in (raw.get("expected") or {}).items():
|
||||
for rec in records:
|
||||
expected.append(ExpectedRecord(table=table, fields=rec))
|
||||
|
||||
# Parse expected classification (step1/full)
|
||||
expected_classification: list[ExpectedClassification] = []
|
||||
for item in raw.get("expected_classification") or []:
|
||||
expected_classification.append(ExpectedClassification(
|
||||
file=item["file"],
|
||||
project_id=item["project_id"],
|
||||
domains=item.get("domains", []),
|
||||
new_project_name=item.get("new_project_name"),
|
||||
))
|
||||
|
||||
return cls(
|
||||
name=raw["name"],
|
||||
description=raw.get("description", ""),
|
||||
mode=mode,
|
||||
directory=raw.get("directory", "sample_files"),
|
||||
data_types=raw.get("data_types", ["tasks"]),
|
||||
file_extensions=raw.get("file_extensions", []),
|
||||
models=raw.get("models", []),
|
||||
fixture_path=path,
|
||||
# Step-1 inputs
|
||||
domain_definitions=raw.get("domain_definitions", ""),
|
||||
projects_list=raw.get("projects_list", []),
|
||||
# Step-2 inputs
|
||||
existing_context=raw.get("existing_context", ""),
|
||||
project_context=raw.get("project_context", ""),
|
||||
custom_prompt_section=raw.get("custom_prompt_section", ""),
|
||||
# Shared
|
||||
seed_records=raw.get("seed_records", {}),
|
||||
expected_classification=expected_classification,
|
||||
expected=expected,
|
||||
)
|
||||
|
||||
|
||||
def discover_fixtures(fixtures_dir: Path | None = None) -> list[EvalFixture]:
|
||||
"""Find and load all YAML fixtures in the fixtures directory."""
|
||||
if fixtures_dir is None:
|
||||
fixtures_dir = Path(__file__).parent / "fixtures"
|
||||
|
||||
fixtures: list[EvalFixture] = []
|
||||
if not fixtures_dir.is_dir():
|
||||
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
|
||||
return fixtures
|
||||
|
||||
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
|
||||
try:
|
||||
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
||||
if raw.get("type") == "journey":
|
||||
continue # Skip journey fixtures
|
||||
fixtures.append(EvalFixture.from_yaml(yaml_path))
|
||||
logger.info("eval: loaded fixture %s from %s", fixtures[-1].name, yaml_path.name)
|
||||
except Exception as exc:
|
||||
logger.error("eval: failed to load fixture %s: %s", yaml_path.name, exc)
|
||||
|
||||
return fixtures
|
||||
|
||||
|
||||
# ── Journey fixtures ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@dataclass
|
||||
class JourneyFixture:
|
||||
"""A journey test scenario — tests the prompt_template builder conversation."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
directory: str # relative path to sample files
|
||||
data_types: list[str]
|
||||
expected_template_criteria: list[str] # what the template should contain/satisfy
|
||||
user_messages: list[str] = field(default_factory=list) # for automated journey runs (unused in interactive mode)
|
||||
models: list[str] = field(default_factory=list)
|
||||
fixture_path: Path = field(default_factory=lambda: Path("."))
|
||||
|
||||
@property
|
||||
def fixture_dir(self) -> Path:
|
||||
"""Absolute path to the sample files directory."""
|
||||
return self.fixture_path.parent / self.directory
|
||||
|
||||
@classmethod
|
||||
def from_yaml(cls, path: Path) -> "JourneyFixture":
|
||||
"""Load a journey fixture from a YAML file."""
|
||||
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
|
||||
|
||||
return cls(
|
||||
name=raw["name"],
|
||||
description=raw.get("description", ""),
|
||||
directory=raw.get("directory", "sample_files"),
|
||||
data_types=raw.get("data_types", ["tasks"]),
|
||||
user_messages=raw.get("user_messages", []),
|
||||
expected_template_criteria=raw.get("expected_template_criteria", []),
|
||||
models=raw.get("models", []),
|
||||
fixture_path=path,
|
||||
)
|
||||
|
||||
|
||||
def discover_journey_fixtures(fixtures_dir: Path | None = None) -> list[JourneyFixture]:
|
||||
"""Find and load all journey YAML fixtures in the fixtures directory."""
|
||||
if fixtures_dir is None:
|
||||
fixtures_dir = Path(__file__).parent / "fixtures"
|
||||
|
||||
fixtures: list[JourneyFixture] = []
|
||||
if not fixtures_dir.is_dir():
|
||||
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
|
||||
return fixtures
|
||||
|
||||
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
|
||||
try:
|
||||
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
||||
if raw.get("type") != "journey":
|
||||
continue
|
||||
fixtures.append(JourneyFixture.from_yaml(yaml_path))
|
||||
logger.info("eval: loaded journey fixture %s from %s", fixtures[-1].name, yaml_path.name)
|
||||
except Exception as exc:
|
||||
logger.error("eval: failed to load journey fixture %s: %s", yaml_path.name, exc)
|
||||
|
||||
return fixtures
|
||||
40
services/batch-agent/eval/fixtures/classify_invoices.yaml
Normal file
40
services/batch-agent/eval/fixtures/classify_invoices.yaml
Normal file
@@ -0,0 +1,40 @@
|
||||
# Fixture: classify-invoices (step1)
|
||||
# Tests _STEP1_SYSTEM_PROMPT — file classification and project matching.
|
||||
# Verifies that the LLM correctly matches files to existing projects
|
||||
# and identifies the right data domains.
|
||||
|
||||
name: classify-invoices
|
||||
mode: step1
|
||||
description: >
|
||||
Test file classification on Italian freelance invoices and meeting notes.
|
||||
Verifies project matching and domain identification.
|
||||
|
||||
directory: sample_files/invoices
|
||||
data_types: [tasks, notes, timelines]
|
||||
file_extensions: [txt, md]
|
||||
|
||||
# ── Step-1 prompt variables ──────────────────────────────────────
|
||||
domain_definitions: |
|
||||
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
|
||||
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
|
||||
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
|
||||
|
||||
projects_list:
|
||||
- id: "proj-web-redesign"
|
||||
name: "Redesign Sito Web Corporate"
|
||||
status: "active"
|
||||
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
||||
- id: "proj-ecommerce"
|
||||
name: "E-Commerce FashionStore"
|
||||
status: "active"
|
||||
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
||||
|
||||
# ── Expected classification results ─────────────────────────────
|
||||
expected_classification:
|
||||
- file: "sample_files/invoices/fattura_042.txt"
|
||||
project_id: "proj-web-redesign"
|
||||
domains: [tasks, notes, timelines]
|
||||
|
||||
- file: "sample_files/invoices/meeting_ecommerce.md"
|
||||
project_id: "proj-ecommerce"
|
||||
domains: [tasks, notes, timelines]
|
||||
108
services/batch-agent/eval/fixtures/full_invoices.yaml
Normal file
108
services/batch-agent/eval/fixtures/full_invoices.yaml
Normal file
@@ -0,0 +1,108 @@
|
||||
# Fixture: full-invoices (full)
|
||||
# Tests both _STEP1_SYSTEM_PROMPT and _PROCESSING_SYSTEM_PROMPT in sequence
|
||||
# via run_local_agent(). Verifies end-to-end classification + extraction.
|
||||
|
||||
name: full-invoices
|
||||
mode: full
|
||||
description: >
|
||||
End-to-end test: classify Italian invoices/meeting notes into the
|
||||
correct project, then extract tasks, notes, and timeline events.
|
||||
|
||||
directory: sample_files/invoices
|
||||
data_types: [tasks, notes, timelines]
|
||||
file_extensions: [txt, md]
|
||||
|
||||
# ── Step-1 prompt variables ──────────────────────────────────────
|
||||
domain_definitions: |
|
||||
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
|
||||
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
|
||||
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
|
||||
|
||||
projects_list:
|
||||
- id: "proj-web-redesign"
|
||||
name: "Redesign Sito Web Corporate"
|
||||
status: "active"
|
||||
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
||||
- id: "proj-ecommerce"
|
||||
name: "E-Commerce FashionStore"
|
||||
status: "active"
|
||||
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
||||
|
||||
# ── Step-2 prompt variables ──────────────────────────────────────
|
||||
existing_context: |
|
||||
Existing tasks:
|
||||
(none)
|
||||
|
||||
Existing notes:
|
||||
(none)
|
||||
|
||||
Existing timelines:
|
||||
(none)
|
||||
|
||||
project_context: ""
|
||||
|
||||
custom_prompt_section: |
|
||||
User instructions:
|
||||
Estrai i dati dai file come segue:
|
||||
- TASK: ogni azione da fare, deliverable, o item con scadenza.
|
||||
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
|
||||
Mappa "media priorità" → priority: medium.
|
||||
Mappa "bassa priorità" → priority: low.
|
||||
Se un item è marcato come "completato" o [x], impostalo status: done.
|
||||
Altrimenti status: todo.
|
||||
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
|
||||
- TIMELINE: date di scadenza, milestone, meeting futuri.
|
||||
Imposta sempre isAiSuggested=1.
|
||||
|
||||
# ── Seed records (pre-existing DB state) ─────────────────────────
|
||||
seed_records:
|
||||
projects:
|
||||
- id: "proj-web-redesign"
|
||||
name: "Redesign Sito Web Corporate"
|
||||
status: "active"
|
||||
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
||||
- id: "proj-ecommerce"
|
||||
name: "E-Commerce FashionStore"
|
||||
status: "active"
|
||||
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
||||
tasks: []
|
||||
notes: []
|
||||
timelines: []
|
||||
|
||||
# ── Expected classification (step 1) ─────────────────────────────
|
||||
expected_classification:
|
||||
- file: "sample_files/invoices/fattura_042.txt"
|
||||
project_id: "proj-web-redesign"
|
||||
domains: [tasks, notes, timelines]
|
||||
|
||||
- file: "sample_files/invoices/meeting_ecommerce.md"
|
||||
project_id: "proj-ecommerce"
|
||||
domains: [tasks, notes, timelines]
|
||||
|
||||
# ── Expected extractions (step 2) ────────────────────────────────
|
||||
expected:
|
||||
tasks:
|
||||
- title: "Sviluppo frontend React"
|
||||
priority: "high"
|
||||
status: "todo"
|
||||
- title: "Integrazione API backend"
|
||||
priority: "medium"
|
||||
status: "todo"
|
||||
- title: "Testing cross-browser e fix bug responsive"
|
||||
status: "todo"
|
||||
- title: "Preparare wireframe homepage"
|
||||
priority: "high"
|
||||
status: "todo"
|
||||
- title: "Setup progetto Next.js e configurare CI/CD"
|
||||
priority: "medium"
|
||||
status: "todo"
|
||||
- title: "Ricerca plugin Stripe per gestione abbonamenti"
|
||||
priority: "low"
|
||||
status: "todo"
|
||||
|
||||
notes:
|
||||
- title: "Meeting Kickoff Progetto E-Commerce"
|
||||
|
||||
timelines:
|
||||
- title: "MVP E-Commerce pronto"
|
||||
- title: "Meeting di revisione"
|
||||
@@ -0,0 +1,28 @@
|
||||
# Journey Fixture: journey-invoice-setup
|
||||
# Used by `python -m eval interactive` for human-in-the-loop testing
|
||||
# of the journey chatbot's prompt-building conversation.
|
||||
|
||||
type: journey
|
||||
name: journey-invoice-setup
|
||||
description: >
|
||||
Interactive test for the journey chatbot — explore a directory of
|
||||
Italian invoices and meeting notes, answer the chatbot's questions,
|
||||
and verify it produces a well-structured prompt_template for data
|
||||
extraction.
|
||||
|
||||
directory: sample_files/invoices
|
||||
data_types: [tasks, notes, timelines, projects]
|
||||
|
||||
# Criteria the generated prompt_template must satisfy
|
||||
# Each is scored 0-1 by an LLM judge
|
||||
expected_template_criteria:
|
||||
- "Mentions creating tasks from action items and work descriptions"
|
||||
- "Mentions creating notes from meeting summaries"
|
||||
- "Mentions extracting timeline events from deadlines and meeting dates"
|
||||
- "Mentions creating projects from relevant information"
|
||||
- "Sets isAiSuggested=1 on all created records"
|
||||
- "Does NOT include projectId assignment logic"
|
||||
- "Uses camelCase field names (title, status, priority, dueDate, content)"
|
||||
|
||||
# Models to test (empty = use CLI --models default)
|
||||
models: []
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user