Compare commits
29 Commits
5753f8def9
...
2d8abb6311
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2d8abb6311 | ||
|
|
e668e3fd20 | ||
|
|
7ccdad431f | ||
|
|
4073863dc6 | ||
|
|
a85f8fde29 | ||
|
|
90500a3462 | ||
|
|
c1a8ac7669 | ||
|
|
c510cbaae5 | ||
|
|
ce139bbac3 | ||
|
|
3cf067faea | ||
|
|
7253f6fe72 | ||
|
|
41db3a7089 | ||
|
|
cc94194fd1 | ||
|
|
96c91e386d | ||
|
|
c0aef71141 | ||
|
|
467abc8d42 | ||
|
|
e672b58b6f | ||
|
|
d8add7e8cb | ||
|
|
c6c4578f9a | ||
|
|
3aa0b36a6c | ||
|
|
fa231a3642 | ||
|
|
d91c98f86d | ||
|
|
c0619f5c4d | ||
|
|
da282229ff | ||
|
|
7fa6ad5760 | ||
|
|
dcd14220ca | ||
|
|
3cc32569d9 | ||
|
|
bf445ac2ce | ||
|
|
a2d6d689e4 |
46
.env.example
46
.env.example
@@ -2,7 +2,7 @@
|
|||||||
ENV=dev
|
ENV=dev
|
||||||
|
|
||||||
# ── Database ──────────────────────────────────────────────────────────────────
|
# ── Database ──────────────────────────────────────────────────────────────────
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
|
||||||
|
|
||||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||||
JWT_SECRET=replace-with-a-long-random-secret
|
JWT_SECRET=replace-with-a-long-random-secret
|
||||||
@@ -13,11 +13,45 @@ JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
|||||||
# ── LLM ───────────────────────────────────────────────────────────────────────
|
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||||
# LiteLLM model identifiers — change to swap providers without code changes.
|
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||||
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||||
|
#
|
||||||
|
# API keys — only the key(s) matching your chosen provider(s) are required.
|
||||||
|
# The correct key is picked automatically from the model prefix (e.g.
|
||||||
|
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
LLM_MODEL=gpt-4o
|
CEREBRAS_API_KEY=
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
# Default model used by any agent that does not have a specific override below.
|
||||||
|
LLM_MODEL=gpt-5-mini
|
||||||
|
LLM_EMBED_MODEL=text-embedding-3-small
|
||||||
|
|
||||||
|
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
|
||||||
|
# In Docker, point this to a named-volume path so tokens survive restarts.
|
||||||
|
# GITHUB_COPILOT_TOKEN_DIR=
|
||||||
|
|
||||||
|
# ── Per-agent model overrides ─────────────────────────────────────────────────
|
||||||
|
# Leave a value empty to fall back to LLM_MODEL.
|
||||||
|
# Each agent resolves its API key from the model prefix automatically.
|
||||||
|
#
|
||||||
|
# Intent classifier — routes user messages to the right domain agent.
|
||||||
|
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
|
||||||
|
LLM_MODEL_CLASSIFIER=
|
||||||
|
|
||||||
|
# Home-agent — handles chat from the home screen (all tools available).
|
||||||
|
LLM_MODEL_HOME_AGENT=
|
||||||
|
|
||||||
|
# Floating-agent — handles contextual chat triggered from a task/project/note.
|
||||||
|
LLM_MODEL_FLOATING_AGENT=
|
||||||
|
|
||||||
|
# Unified-processor — processes local directory files (local agent runner).
|
||||||
|
LLM_MODEL_UNIFIED_PROCESSOR=
|
||||||
|
|
||||||
|
# Cloud-processor — fetches and processes data from cloud connectors.
|
||||||
|
LLM_MODEL_CLOUD_PROCESSOR=
|
||||||
|
|
||||||
|
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||||
|
LLM_MODEL_SETUP_AGENT=
|
||||||
|
|
||||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
@@ -27,9 +61,9 @@ STRIPE_WEBHOOK_SECRET=
|
|||||||
# ── Langfuse (leave empty to disable observability) ───────────────────────────
|
# ── Langfuse (leave empty to disable observability) ───────────────────────────
|
||||||
LANGFUSE_SECRET_KEY=
|
LANGFUSE_SECRET_KEY=
|
||||||
LANGFUSE_PUBLIC_KEY=
|
LANGFUSE_PUBLIC_KEY=
|
||||||
# LANGFUSE_HOST=https://cloud.langfuse.com # EU (default)
|
# LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default)
|
||||||
# LANGFUSE_HOST=https://us.cloud.langfuse.com # US
|
# LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US
|
||||||
# LANGFUSE_HOST=http://localhost:3000 # Self-hosted
|
# LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted
|
||||||
|
|
||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
|
|||||||
@@ -48,23 +48,23 @@ jobs:
|
|||||||
key: ${{ secrets.SSH_KEY }}
|
key: ${{ secrets.SSH_KEY }}
|
||||||
script: |
|
script: |
|
||||||
set -e
|
set -e
|
||||||
DEPLOY_DIR="/opt/adiuva-api"
|
DEPLOY_DIR="/opt/adiuvai-api"
|
||||||
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||||
TAG="${{ gitea.ref_name }}"
|
TAG="${{ gitea.ref_name }}"
|
||||||
|
|
||||||
# ── Pull latest code ──
|
# ── Pull latest code ──
|
||||||
cd /tmp && rm -rf adiuva-api-deploy
|
cd /tmp && rm -rf adiuvai-api-deploy
|
||||||
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
|
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
|
||||||
|
|
||||||
# ── Sync source (preserve .env) ──
|
# ── Sync source (preserve .env) ──
|
||||||
cp -rf /tmp/adiuva-api-deploy/app/ \
|
cp -rf /tmp/adiuvai-api-deploy/app/ \
|
||||||
/tmp/adiuva-api-deploy/alembic/ \
|
/tmp/adiuvai-api-deploy/alembic/ \
|
||||||
/tmp/adiuva-api-deploy/alembic.ini \
|
/tmp/adiuvai-api-deploy/alembic.ini \
|
||||||
/tmp/adiuva-api-deploy/Dockerfile \
|
/tmp/adiuvai-api-deploy/Dockerfile \
|
||||||
/tmp/adiuva-api-deploy/docker-compose.yml \
|
/tmp/adiuvai-api-deploy/docker-compose.yml \
|
||||||
/tmp/adiuva-api-deploy/requirements.txt \
|
/tmp/adiuvai-api-deploy/requirements.txt \
|
||||||
"$DEPLOY_DIR/"
|
"$DEPLOY_DIR/"
|
||||||
rm -rf /tmp/adiuva-api-deploy
|
rm -rf /tmp/adiuvai-api-deploy
|
||||||
|
|
||||||
# ── Verify .env ──
|
# ── Verify .env ──
|
||||||
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||||
|
|||||||
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Build image
|
- name: Build image
|
||||||
run: docker build -t adiuva-api:ci .
|
run: docker build -t adiuvai-api:ci .
|
||||||
|
|
||||||
- name: Verify gunicorn installed
|
- name: Verify gunicorn installed
|
||||||
run: docker run --rm adiuva-api:ci gunicorn --version
|
run: docker run --rm adiuvai-api:ci gunicorn --version
|
||||||
|
|||||||
591
README.md
591
README.md
@@ -1,591 +0,0 @@
|
|||||||
# Adiuva Cloud API
|
|
||||||
|
|
||||||
**AI-powered project management backend with LLM orchestration and subscription billing.**
|
|
||||||
|
|
||||||
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Table of Contents
|
|
||||||
|
|
||||||
- [Overview](#overview)
|
|
||||||
- [Architecture](#architecture)
|
|
||||||
- [Key Features](#key-features)
|
|
||||||
- [Tech Stack](#tech-stack)
|
|
||||||
- [Getting Started](#getting-started)
|
|
||||||
- [Docker Deployment](#docker-deployment)
|
|
||||||
- [Environment Variables](#environment-variables)
|
|
||||||
- [API Reference](#api-reference)
|
|
||||||
- [Data Model](#data-model)
|
|
||||||
- [AI Agent System](#ai-agent-system)
|
|
||||||
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
|
||||||
- [Middleware](#middleware)
|
|
||||||
- [Billing & Tiers](#billing--tiers)
|
|
||||||
- [Testing](#testing)
|
|
||||||
- [Project Structure](#project-structure)
|
|
||||||
- [License](#license)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, text embedding generation, and Stripe-based subscription billing across four tiers.
|
|
||||||
|
|
||||||
### Design Principles
|
|
||||||
|
|
||||||
1. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
|
||||||
2. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
|
||||||
3. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
|
|
||||||
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
|
|
||||||
│ Desktop App │────▶│ │
|
|
||||||
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
|
|
||||||
└──────────────┘ │ │
|
|
||||||
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
|
||||||
│ │ Auth Routes │ │ Chat Routes │ │
|
|
||||||
│ │ Billing Routes │ │ ↓ │ │
|
|
||||||
│ │ Agent Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
|
||||||
│ │ Device WS │ │ ↓ classify intent │ │
|
|
||||||
│ └──────────────────┘ │ Agent Registry │ │
|
|
||||||
│ │ ↓ │ │
|
|
||||||
│ │ TaskAgent | ProjectAgent │ │
|
|
||||||
│ │ NoteAgent | CheckptAgent │ │
|
|
||||||
│ │ (GPT-4o + LangChain) │ │
|
|
||||||
│ └────────────────────────────┘ │
|
|
||||||
└────────────────────────────────────────────────────────┘
|
|
||||||
│
|
|
||||||
┌────────▼───┐
|
|
||||||
│ PostgreSQL │
|
|
||||||
│ (Auth, │
|
|
||||||
│ Billing, │
|
|
||||||
│ Agents) │
|
|
||||||
└────────────┘
|
|
||||||
│
|
|
||||||
┌────────▼───┐
|
|
||||||
│ Stripe │
|
|
||||||
│ (Billing) │
|
|
||||||
└────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Key Features
|
|
||||||
|
|
||||||
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
|
||||||
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
|
||||||
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
|
||||||
4. **Text embeddings** — Generates text-embedding-3-small vectors for local client-side note search.
|
|
||||||
5. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
|
||||||
6. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
|
||||||
7. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
|
||||||
8. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
|
||||||
9. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
|
||||||
10. **Alembic migrations** — Versioned schema management.
|
|
||||||
11. **Comprehensive test suite** — In-memory SQLite, per-tier test fixtures, and full API coverage without external dependencies.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tech Stack
|
|
||||||
|
|
||||||
| Package | Version | Purpose |
|
|
||||||
|---|---|---|
|
|
||||||
| `fastapi` | ≥ 0.115.0 | Web framework |
|
|
||||||
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
|
|
||||||
| `gunicorn` | ≥ 22.0.0 | Production process manager |
|
|
||||||
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
|
|
||||||
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
|
|
||||||
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
|
|
||||||
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
|
|
||||||
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
|
||||||
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
|
||||||
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
|
||||||
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
|
||||||
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
|
||||||
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
|
||||||
| `alembic` | ≥ 1.14.0 | Database migration management |
|
|
||||||
| `bcrypt` | ≥ 4.2.0 | Password hashing |
|
|
||||||
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
|
|
||||||
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
|
||||||
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
|
||||||
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
|
||||||
| `pytest` | ≥ 8.0.0 | Test framework |
|
|
||||||
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
|
||||||
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
|
||||||
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Getting Started
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
- Python 3.12+
|
|
||||||
- PostgreSQL 16+
|
|
||||||
- An OpenAI API key (for LLM features)
|
|
||||||
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Clone the repository
|
|
||||||
git clone <repo-url> && cd adiuva-api
|
|
||||||
|
|
||||||
# Create a virtual environment
|
|
||||||
python -m venv .venv && source .venv/bin/activate
|
|
||||||
|
|
||||||
# Install dependencies
|
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
# Configure environment
|
|
||||||
cp .env.example .env
|
|
||||||
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
|
|
||||||
```
|
|
||||||
|
|
||||||
### Database Setup
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Start PostgreSQL (or use the Docker Compose database)
|
|
||||||
docker compose up db -d
|
|
||||||
|
|
||||||
# Run migrations
|
|
||||||
alembic upgrade head
|
|
||||||
```
|
|
||||||
|
|
||||||
### Run the Development Server
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
|
||||||
```
|
|
||||||
|
|
||||||
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Docker Deployment
|
|
||||||
|
|
||||||
### Quick Start
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose up --build
|
|
||||||
```
|
|
||||||
|
|
||||||
This starts two services:
|
|
||||||
|
|
||||||
- **app** — FastAPI server on port `8000`
|
|
||||||
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
|
||||||
|
|
||||||
### Dockerfile Details
|
|
||||||
|
|
||||||
The Dockerfile uses a multi-stage build:
|
|
||||||
|
|
||||||
1. **Builder stage** — Installs Python dependencies into a virtual environment.
|
|
||||||
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
|
|
||||||
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Production command (run by the container)
|
|
||||||
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Homelab / Self-Hosted Deployment
|
|
||||||
|
|
||||||
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**.
|
|
||||||
|
|
||||||
### 1. Start all services
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
This starts PostgreSQL alongside the app.
|
|
||||||
|
|
||||||
### 2. Configure your `.env`
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Database (uses the compose PostgreSQL)
|
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
|
|
||||||
# Billing — leave empty to stub (no Stripe needed)
|
|
||||||
STRIPE_SECRET_KEY=
|
|
||||||
STRIPE_WEBHOOK_SECRET=
|
|
||||||
|
|
||||||
# LLM — the only external service
|
|
||||||
OPENAI_API_KEY=sk-...
|
|
||||||
LLM_MODEL=gpt-4o
|
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
|
||||||
# Auth
|
|
||||||
JWT_SECRET=your-secret-here
|
|
||||||
ENV=dev
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Run migrations
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose exec app alembic upgrade head
|
|
||||||
```
|
|
||||||
|
|
||||||
### What runs where
|
|
||||||
|
|
||||||
| Service | Runs on | Port | Notes |
|
|
||||||
|---|---|---|---|
|
|
||||||
| FastAPI app | Docker | 8000 | API server |
|
|
||||||
| PostgreSQL | Docker | 5432 | Auth, billing, agents |
|
|
||||||
| Stripe | — | — | Stubbed when keys are empty |
|
|
||||||
| OpenAI / LLM | Cloud | — | Only external dependency |
|
|
||||||
|
|
||||||
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Environment Variables
|
|
||||||
|
|
||||||
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
|
|
||||||
|
|
||||||
| Variable | Type | Default | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
|
|
||||||
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
|
|
||||||
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
|
|
||||||
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
|
||||||
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
|
||||||
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
|
||||||
| `STRIPE_WEBHOOK_SECRET` | `str` | `\"\"` | Stripe webhook signature secret |\n| `OPENAI_API_KEY` | `str` | `\"\"` | OpenAI key for LLM agent calls |
|
|
||||||
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
|
||||||
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
|
||||||
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
|
||||||
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API Reference
|
|
||||||
|
|
||||||
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
|
|
||||||
|
|
||||||
### Health
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
|
|
||||||
|
|
||||||
### Auth
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
|
|
||||||
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
|
|
||||||
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
|
|
||||||
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
|
|
||||||
|
|
||||||
### Chat
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
|
||||||
| `POST` | `/api/v1/chat/embed` | JWT | Generate a 1536-dim text embedding vector (`text-embedding-3-small`). Used by Electron for local note search. |
|
|
||||||
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
|
||||||
|
|
||||||
### Plans
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
|
||||||
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
|
||||||
|
|
||||||
### Billing
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
|
|
||||||
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
|
|
||||||
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
|
|
||||||
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Data Model
|
|
||||||
|
|
||||||
3 tables managed by Alembic migrations. Source: `app/models.py`
|
|
||||||
|
|
||||||
### Tables
|
|
||||||
|
|
||||||
| Table | Primary Key | Key Columns | Purpose |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
|
||||||
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
|
||||||
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
|
||||||
|
|
||||||
### Enum Types
|
|
||||||
|
|
||||||
| Enum | Values |
|
|
||||||
|---|---|
|
|
||||||
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
|
||||||
|
|
||||||
### Migrations
|
|
||||||
|
|
||||||
| Version | Description |
|
|
||||||
|---|---|
|
|
||||||
| `001_initial_schema` | Creates core auth and billing tables with indexes and foreign key constraints |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## AI Agent System
|
|
||||||
|
|
||||||
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
- **`BaseAgent`** — Abstract base with `user_id` and `shared_memory`.
|
|
||||||
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
|
||||||
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
|
||||||
|
|
||||||
### Registered Agents
|
|
||||||
|
|
||||||
| Agent | Registry Name | Tools | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
|
||||||
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
|
||||||
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
|
|
||||||
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
|
||||||
|
|
||||||
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
|
||||||
|
|
||||||
### Switching LLM Providers
|
|
||||||
|
|
||||||
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# OpenAI (default)
|
|
||||||
LLM_MODEL=gpt-4o
|
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
|
||||||
# Anthropic
|
|
||||||
LLM_MODEL=anthropic/claude-3.5-sonnet
|
|
||||||
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
|
|
||||||
|
|
||||||
# Google Gemini
|
|
||||||
LLM_MODEL=gemini/gemini-pro
|
|
||||||
LLM_ROUTER_MODEL=gemini/gemini-flash
|
|
||||||
|
|
||||||
# Local Ollama
|
|
||||||
LLM_MODEL=ollama/llama3
|
|
||||||
LLM_ROUTER_MODEL=ollama/llama3
|
|
||||||
|
|
||||||
# AWS Bedrock
|
|
||||||
LLM_MODEL=bedrock/anthropic.claude-v2
|
|
||||||
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
|
|
||||||
```
|
|
||||||
|
|
||||||
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Orchestration & Execution Plans
|
|
||||||
|
|
||||||
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
|
|
||||||
|
|
||||||
### Orchestrator
|
|
||||||
|
|
||||||
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
|
|
||||||
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
|
|
||||||
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
|
|
||||||
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
|
|
||||||
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
|
|
||||||
|
|
||||||
### Execution Plans
|
|
||||||
|
|
||||||
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
|
|
||||||
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
|
|
||||||
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
|
|
||||||
|
|
||||||
### Built-in Templates (6)
|
|
||||||
|
|
||||||
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
|
||||||
|
|
||||||
### Built-in Playbooks (2)
|
|
||||||
|
|
||||||
| Playbook | Description |
|
|
||||||
|---|---|
|
|
||||||
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
|
|
||||||
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Middleware
|
|
||||||
|
|
||||||
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
|
|
||||||
|
|
||||||
### JWT Authentication
|
|
||||||
|
|
||||||
Source: `app/api/middleware/auth.py`
|
|
||||||
|
|
||||||
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
|
|
||||||
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
|
|
||||||
- Falls back to `free` when no subscription row exists.
|
|
||||||
- Raises `401 Unauthorized` on invalid or expired tokens.
|
|
||||||
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
|
||||||
|
|
||||||
### Tier-Based Rate Limiter
|
|
||||||
|
|
||||||
Source: `app/api/middleware/rate_limit.py`
|
|
||||||
|
|
||||||
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
|
|
||||||
- Per-user 60-second window sized by subscription tier:
|
|
||||||
|
|
||||||
| Tier | Requests / Minute |
|
|
||||||
|---|---|
|
|
||||||
| Free | 20 |
|
|
||||||
| Pro | 60 |
|
|
||||||
| Power | 120 |
|
|
||||||
| Team | 200 |
|
|
||||||
|
|
||||||
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
|
|
||||||
- **Exempt paths:** register, login, webhook, health
|
|
||||||
|
|
||||||
### Response Sanitizer
|
|
||||||
|
|
||||||
Source: `app/api/middleware/sanitizer.py`
|
|
||||||
|
|
||||||
- Runs only on `/api/v1/chat` endpoints.
|
|
||||||
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
|
||||||
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
|
||||||
- Logs sanitization events as `WARNING`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Billing & Tiers
|
|
||||||
|
|
||||||
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
|
||||||
|
|
||||||
### Feature Matrix
|
|
||||||
|
|
||||||
| Feature | Free | Pro | Power | Team |
|
|
||||||
|---|---|---|---|---|
|
|
||||||
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
|
||||||
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
|
||||||
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
|
||||||
| Batch Builder | — | — | ✓ | ✓ |
|
|
||||||
| SSO | — | — | — | ✓ |
|
|
||||||
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
|
||||||
|
|
||||||
### Stripe Integration
|
|
||||||
|
|
||||||
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
|
|
||||||
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
|
|
||||||
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
|
|
||||||
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
|
|
||||||
|
|
||||||
### Tier Manager
|
|
||||||
|
|
||||||
- `get_tier(user_id)` — Returns the user's current billing tier.
|
|
||||||
- `check_feature(tier, feature)` — Boolean feature gate check.
|
|
||||||
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
### Running Tests
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run all tests
|
|
||||||
pytest
|
|
||||||
|
|
||||||
# Run a specific test file
|
|
||||||
pytest tests/test_auth.py
|
|
||||||
|
|
||||||
# Run with verbose output
|
|
||||||
pytest -v
|
|
||||||
```
|
|
||||||
|
|
||||||
### Test Infrastructure
|
|
||||||
|
|
||||||
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
|
||||||
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
|
||||||
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
|
||||||
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
|
||||||
- **No external dependencies** — all tests run fully offline.
|
|
||||||
|
|
||||||
### Test Coverage
|
|
||||||
|
|
||||||
| File | Coverage |
|
|
||||||
|---|---|
|
|
||||||
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
|
||||||
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── alembic.ini # Alembic configuration
|
|
||||||
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
|
||||||
├── Dockerfile # Multi-stage production build
|
|
||||||
├── requirements.txt # Python dependencies
|
|
||||||
│
|
|
||||||
├── alembic/ # Database migrations
|
|
||||||
│ ├── env.py # Alembic environment config
|
|
||||||
│ ├── script.py.mako # Migration template
|
|
||||||
│ └── versions/
|
|
||||||
│ └── 001_initial_schema.py # Tables, indexes, FKs
|
|
||||||
│
|
|
||||||
├── app/ # Application source
|
|
||||||
│ ├── main.py # FastAPI app factory, middleware, routes
|
|
||||||
│ ├── db.py # Async SQLAlchemy engine & session
|
|
||||||
│ ├── models.py # SQLAlchemy ORM models
|
|
||||||
│ ├── schemas.py # Pydantic request/response schemas
|
|
||||||
│ │
|
|
||||||
│ ├── config/
|
|
||||||
│ │ └── settings.py # Pydantic Settings (env vars)
|
|
||||||
│ │
|
|
||||||
│ ├── agents/ # LLM-powered domain agents
|
|
||||||
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
|
||||||
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
|
||||||
│ │ ├── timeline_agent.py # Milestones (4 tools)
|
|
||||||
│ │ └── note_agent.py # Markdown notes (5 tools)
|
|
||||||
│ │
|
|
||||||
│ ├── core/ # Orchestration engine
|
|
||||||
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
|
||||||
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
|
||||||
│ │ └── deep_agent.py # Deep agent orchestration
|
|
||||||
│ │
|
|
||||||
│ ├── api/ # HTTP layer
|
|
||||||
│ │ ├── deps.py # Shared FastAPI dependencies
|
|
||||||
│ │ ├── middleware/
|
|
||||||
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
|
||||||
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
|
||||||
│ │ └── routes/
|
|
||||||
│ │ ├── auth.py # Register, login, refresh, me
|
|
||||||
│ │ ├── chat.py # Chat + embed endpoint
|
|
||||||
│ │ ├── billing.py # Stripe checkout, webhooks, subscription
|
|
||||||
│ │ ├── agents.py # Agent catalog, config, runs
|
|
||||||
│ │ └── device_ws.py # Persistent device WebSocket
|
|
||||||
│ │
|
|
||||||
│ └── billing/
|
|
||||||
│ ├── stripe_service.py # Stripe API wrapper
|
|
||||||
│ └── tier_manager.py # Feature matrix, rate limits
|
|
||||||
│
|
|
||||||
└── tests/ # Test suite
|
|
||||||
├── conftest.py # Fixtures: DB, auth, seeds
|
|
||||||
├── test_auth.py
|
|
||||||
├── test_orchestrator.py
|
|
||||||
├── test_agents.py
|
|
||||||
├── test_agent_registry.py
|
|
||||||
├── test_execution_plan.py
|
|
||||||
└── test_middleware.py
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
*To be determined.*
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import re
|
|||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sqlalchemy import engine_from_config, pool
|
from sqlalchemy import pool
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
# Alembic Config object (gives access to alembic.ini values).
|
# Alembic Config object (gives access to alembic.ini values).
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from alembic import op
|
|||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
revision: str = "003"
|
revision: str = "003"
|
||||||
down_revision: Union[str, None] = "002"
|
down_revision: Union[str, None] = "001"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|||||||
54
alembic/versions/005_associative_pgvector.py
Normal file
54
alembic/versions/005_associative_pgvector.py
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
"""Phase 1 — confirm pgvector activation on memory_associative.
|
||||||
|
|
||||||
|
Migration 004 created the embedding column as vector(1536) and added the
|
||||||
|
IVFFlat index. This migration is the Phase-1 checkpoint:
|
||||||
|
1. Ensures the pgvector extension is enabled (idempotent).
|
||||||
|
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
|
||||||
|
memory_associative_embedding_idx (creates it only if absent).
|
||||||
|
|
||||||
|
Revision ID: 005
|
||||||
|
Revises: 9a1f2d0b6c7e
|
||||||
|
Create Date: 2026-04-15
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "005"
|
||||||
|
down_revision: Union[str, None] = "e04100e88ace"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Ensure pgvector extension is enabled (also done in 004, idempotent).
|
||||||
|
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
||||||
|
|
||||||
|
# Ensure the canonical Phase-1 IVFFlat index exists.
|
||||||
|
# 004 may have created ix_memory_associative_embedding; this adds the
|
||||||
|
# Phase-1 name memory_associative_embedding_idx if it is missing.
|
||||||
|
op.execute(
|
||||||
|
"""
|
||||||
|
DO $$
|
||||||
|
BEGIN
|
||||||
|
IF NOT EXISTS (
|
||||||
|
SELECT 1
|
||||||
|
FROM pg_indexes
|
||||||
|
WHERE tablename = 'memory_associative'
|
||||||
|
AND indexname = 'memory_associative_embedding_idx'
|
||||||
|
) THEN
|
||||||
|
CREATE INDEX memory_associative_embedding_idx
|
||||||
|
ON memory_associative
|
||||||
|
USING ivfflat (embedding vector_cosine_ops)
|
||||||
|
WITH (lists = 100);
|
||||||
|
END IF;
|
||||||
|
END $$;
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")
|
||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""Restore agent config tables and add agent_config column.
|
||||||
|
|
||||||
|
9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both
|
||||||
|
ORM models are still active. This migration recreates them with agent_config
|
||||||
|
added to local_agent_configs.
|
||||||
|
|
||||||
|
Revision ID: a3b9c0d1e2f3
|
||||||
|
Revises: 9a1f2d0b6c7e
|
||||||
|
Create Date: 2026-04-07 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "a3b9c0d1e2f3"
|
||||||
|
down_revision: Union[str, None] = "9a1f2d0b6c7e"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Recreate enum types (idempotent — they may already exist from migration 003)
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = sa.inspect(bind)
|
||||||
|
existing = set(inspector.get_table_names())
|
||||||
|
|
||||||
|
# ── local_agent_configs (with agent_config column) ────────────────────
|
||||||
|
if "local_agent_configs" not in existing:
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("agent_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
# ── cloud_agent_configs ───────────────────────────────────────────────
|
||||||
|
if "cloud_agent_configs" not in existing:
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
56
alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
56
alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
|
||||||
|
|
||||||
|
Revision ID: b4c0d1e2f3a4
|
||||||
|
Revises: a3b9c0d1e2f3
|
||||||
|
Create Date: 2026-04-10 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "b4c0d1e2f3a4"
|
||||||
|
down_revision: Union[str, None] = "a3b9c0d1e2f3"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── users: make password_hash nullable (social users have no password) ──
|
||||||
|
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
|
||||||
|
|
||||||
|
# ── users: add avatar_url ─────────────────────────────────────────────
|
||||||
|
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
|
||||||
|
|
||||||
|
# ── oauth_accounts ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"oauth_accounts",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("provider", sa.String(50), nullable=False),
|
||||||
|
sa.Column("provider_user_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("provider_email", sa.String(255), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
|
||||||
|
op.drop_table("oauth_accounts")
|
||||||
|
op.drop_column("users", "avatar_url")
|
||||||
|
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)
|
||||||
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Add onboarding_completed_at column to users table.
|
||||||
|
|
||||||
|
Revision ID: c5d1e2f3a4b5
|
||||||
|
Revises: b4c0d1e2f3a4
|
||||||
|
Create Date: 2026-04-11 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "c5d1e2f3a4b5"
|
||||||
|
down_revision: Union[str, None] = "b4c0d1e2f3a4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("users", "onboarding_completed_at")
|
||||||
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""avatar_url_varchar_to_text
|
||||||
|
|
||||||
|
Revision ID: e04100e88ace
|
||||||
|
Revises: c5d1e2f3a4b5
|
||||||
|
Create Date: 2026-04-13 09:13:06.733674
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'e04100e88ace'
|
||||||
|
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column('users', 'avatar_url',
|
||||||
|
existing_type=sa.VARCHAR(length=2048),
|
||||||
|
type_=sa.Text(),
|
||||||
|
existing_nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column('users', 'avatar_url',
|
||||||
|
existing_type=sa.Text(),
|
||||||
|
type_=sa.VARCHAR(length=2048),
|
||||||
|
existing_nullable=True)
|
||||||
@@ -7,12 +7,31 @@ handles actual disk I/O and responds with ``tool_result`` frames.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
# Max characters returned by read_file_content in journey (exploration) tools.
|
||||||
|
# The journey only needs to understand file structure, not full content.
|
||||||
|
_JOURNEY_READ_MAX_CHARS: int = 4000
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(path: str, base: str) -> str:
|
||||||
|
"""Resolve *path* against *base* when *path* is relative.
|
||||||
|
|
||||||
|
The LLM often passes ``"."`` meaning "the configured directory".
|
||||||
|
Without this, Electron resolves ``"."`` relative to its own CWD instead
|
||||||
|
of the user's chosen directory.
|
||||||
|
"""
|
||||||
|
if os.path.isabs(path):
|
||||||
|
return path
|
||||||
|
return str(Path(base) / path)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_directory(path: str) -> str:
|
async def list_directory(path: str) -> str:
|
||||||
@@ -83,3 +102,93 @@ FILESYSTEM_TOOLS: list[Any] = [
|
|||||||
read_file_content,
|
read_file_content,
|
||||||
get_file_metadata,
|
get_file_metadata,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def make_directory_tools(base_directory: str) -> list[Any]:
|
||||||
|
"""Return filesystem tools that resolve relative paths against *base_directory*.
|
||||||
|
|
||||||
|
Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target
|
||||||
|
directory upfront (e.g., journey setup sessions). Relative paths like ``"."``
|
||||||
|
from the LLM are resolved to the correct absolute path before being sent to
|
||||||
|
the Electron client, preventing it from falling back to its own CWD.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _compact_for_journey(raw: str) -> str:
|
||||||
|
"""Strip HTML noise and truncate for journey exploration.
|
||||||
|
|
||||||
|
The journey LLM only needs to understand file structure (headers,
|
||||||
|
first paragraphs). Full CSS/style blocks are pure noise that eat
|
||||||
|
up context window budget.
|
||||||
|
"""
|
||||||
|
text = re.sub(r"<style[^>]*>.*?</style>", "", raw, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
text = re.sub(r"<script[^>]*>.*?</script>", "", text, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)
|
||||||
|
if len(text) > _JOURNEY_READ_MAX_CHARS:
|
||||||
|
text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]"
|
||||||
|
return text
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str: # noqa: F811
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_path(path, base_directory)
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": resolved},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{resolved}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str: # noqa: F811
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_path(path, base_directory)
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": resolved},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{resolved}' is empty or could not be read."
|
||||||
|
return _compact_for_journey(content)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str: # noqa: F811
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_path(path, base_directory)
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": resolved},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", resolved)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [list_directory, read_file_content, get_file_metadata]
|
||||||
|
|||||||
@@ -18,21 +18,6 @@ _UUID_RE = re.compile(
|
|||||||
def _is_uuid(value: str) -> bool:
|
def _is_uuid(value: str) -> bool:
|
||||||
return bool(_UUID_RE.match(value))
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
NOTE_SYSTEM_PROMPT = (
|
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
|
|||||||
@@ -8,22 +8,6 @@ from langchain_core.tools import tool
|
|||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
PROJECT_SYSTEM_PROMPT = (
|
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
|
||||||
"update, and archive projects in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: active, archived\n"
|
|
||||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
|
||||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
|
||||||
" derive it from context data — do not fabricate content\n"
|
|
||||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
|
||||||
" user wants a complete cross-client view including archived projects\n"
|
|
||||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
|
||||||
" list_projects if you only have a project name\n"
|
|
||||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
|
||||||
" only call delete_project when the user explicitly confirms deletion."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_projects(
|
async def list_projects(
|
||||||
|
|||||||
@@ -18,23 +18,6 @@ _UUID_RE = re.compile(
|
|||||||
def _is_uuid(value: str) -> bool:
|
def _is_uuid(value: str) -> bool:
|
||||||
return bool(_UUID_RE.match(value))
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
TASK_SYSTEM_PROMPT = (
|
|
||||||
"You are a task management assistant for a project workspace.\n"
|
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: todo, in_progress, done\n"
|
|
||||||
" - priority must be one of: high, medium, low\n"
|
|
||||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
|
||||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
|
||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
|
||||||
" did not explicitly request; 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -17,20 +17,6 @@ _UUID_RE = re.compile(
|
|||||||
def _is_uuid(value: str) -> bool:
|
def _is_uuid(value: str) -> bool:
|
||||||
return bool(_UUID_RE.match(value))
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
TIMELINE_SYSTEM_PROMPT = (
|
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all timelines across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
|
|||||||
@@ -65,16 +65,39 @@ async def get_current_user(
|
|||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
# Fetch name/surname from user row.
|
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
select(User.name, User.surname).where(User.id == user_id)
|
select(
|
||||||
|
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||||
|
User.password_hash,
|
||||||
|
).where(User.id == user_id)
|
||||||
)
|
)
|
||||||
user_row = user_result.one_or_none()
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
# Convert onboarding_completed_at to epoch ms (int) or None.
|
||||||
|
onboarding_ms: int | None = None
|
||||||
|
if user_row and user_row.onboarding_completed_at is not None:
|
||||||
|
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||||
|
|
||||||
|
# Load decrypted core memory.
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||||
|
|
||||||
|
memory_dict: dict[str, str] = {}
|
||||||
|
try:
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
blocks = await mw.list_core_blocks(user_id)
|
||||||
|
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||||
|
except Exception:
|
||||||
|
pass # Non-critical — return empty memory on failure
|
||||||
|
|
||||||
return UserProfile(
|
return UserProfile(
|
||||||
id=user_id,
|
id=user_id,
|
||||||
email=email,
|
email=email,
|
||||||
name=user_row.name if user_row else None,
|
name=user_row.name if user_row else None,
|
||||||
surname=user_row.surname if user_row else None,
|
surname=user_row.surname if user_row else None,
|
||||||
|
avatar_url=user_row.avatar_url if user_row else None,
|
||||||
|
has_password=bool(user_row.password_hash) if user_row else False,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
|
onboarding_completed_at=onboarding_ms,
|
||||||
|
memory=memory_dict,
|
||||||
) # type: ignore[arg-type]
|
) # type: ignore[arg-type]
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
|
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
|
||||||
|
|
||||||
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||||
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||||
frames to the functions exported here.
|
frames to the functions exported here.
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. FE sends ``journey_start`` frame with basic agent config (directory,
|
1. FE sends ``journey_start`` frame with basic agent info (directory,
|
||||||
data_types, schedule).
|
data_types, schedule).
|
||||||
2. Server creates an in-memory session, sets up a WS executor so the
|
2. Server creates an in-memory session, sets up a WS executor so the
|
||||||
setup LLM can use file-system tools, does a first directory scrape,
|
setup LLM can use file-system tools, does a first directory scrape,
|
||||||
@@ -13,10 +13,11 @@ Journey flow:
|
|||||||
3. FE sends ``journey_message`` frames for each user reply.
|
3. FE sends ``journey_message`` frames for each user reply.
|
||||||
4. Server appends the user message, calls the LLM (which may read files
|
4. Server appends the user message, calls the LLM (which may read files
|
||||||
via tools), and sends back a ``journey_reply``.
|
via tools), and sends back a ``journey_reply``.
|
||||||
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
|
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
|
||||||
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
||||||
6. Server parses the block, sends ``journey_reply`` with ``done=True``
|
6. Server parses and validates the JSON with Pydantic, sends
|
||||||
and the template. FE stores it locally.
|
``journey_reply`` with ``done=True`` and the serialised config.
|
||||||
|
FE stores it locally.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -30,10 +31,10 @@ from typing import Any
|
|||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
|
||||||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
from app.agents.filesystem_agent import make_directory_tools
|
||||||
from app.config.settings import settings
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
from app.core.llm import get_llm
|
from app.schemas import AgentConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -41,9 +42,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
|
||||||
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
_CONFIG_START = "AGENT_CONFIG_START"
|
||||||
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
_CONFIG_END = "AGENT_CONFIG_END"
|
||||||
|
|
||||||
# Minimum turns before we consider nudging the LLM to wrap up.
|
# Minimum turns before we consider nudging the LLM to wrap up.
|
||||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||||
@@ -86,61 +87,76 @@ def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt builder ─────────────────────────────────────────────────
|
# ── System prompt ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_JOURNEY_SYSTEM_PROMPT = """\
|
_JOURNEY_SYSTEM_PROMPT = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand exactly what data the user wants to extract from their
|
Your job is to understand what files the user has in their directory and produce a
|
||||||
local directory and produce a detailed prompt_template that a separate AI will use
|
structured AgentConfig JSON that the extraction agent will use as its instruction set.
|
||||||
as its instruction set.
|
|
||||||
|
|
||||||
The extraction agent already has this base behaviour built in:
|
|
||||||
- Reads each file using file-system tools.
|
|
||||||
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
|
|
||||||
- Sets isAiSuggested=1 on every new record.
|
|
||||||
- Only extracts data explicitly present in the files — it never invents information.
|
|
||||||
The user's custom prompt is appended AFTER this base behaviour, so focus on
|
|
||||||
what to look for and how to map it — not on the general extraction mechanics.
|
|
||||||
|
|
||||||
You have access to file-system tools to explore the user's directory:
|
You have access to file-system tools to explore the user's directory:
|
||||||
- list_directory: to see folder structure
|
- list_directory: see folder structure and file names
|
||||||
- read_file_content: to peek at file contents
|
- read_file_content: peek at a file's content
|
||||||
- get_file_metadata: to check file info
|
- get_file_metadata: check file size, extension, dates
|
||||||
|
|
||||||
The user's configured directory is: {directory}
|
The user's configured directory is: {directory}
|
||||||
Target data types: {data_types}
|
Target data types: {data_types}
|
||||||
|
|
||||||
IMPORTANT — project assignment is handled automatically by the main agent runner
|
## Your process
|
||||||
before the custom prompt is ever used. You MUST NOT ask the user about projects,
|
|
||||||
projectId, or how to link records to projects. Never include projectId logic or
|
|
||||||
project creation instructions in the generated prompt_template.
|
|
||||||
|
|
||||||
Start by exploring the directory to understand its structure. Then ask concise,
|
### Step 1 — Explore the directory
|
||||||
focused questions one at a time. Cover these topics (not necessarily in this order):
|
Use list_directory and read_file_content to understand what types of files are present
|
||||||
1. The type and format of the source content (confirmed by your exploration).
|
(HTML emails, plain-text documents, CSVs, etc.).
|
||||||
2. How fields should be mapped (e.g. filename → task title).
|
|
||||||
3. Priority or status rules (e.g. "urgent" keyword → high priority).
|
|
||||||
4. Any special handling, date extraction, or exclusions.
|
|
||||||
|
|
||||||
Once you reach 90% confidence, output the final prompt_template between these exact
|
### Step 2 — Identify content types
|
||||||
markers on their own lines:
|
For each distinct file type found, decide:
|
||||||
|
- A short id (e.g. "email_html", "plain_text", "csv")
|
||||||
|
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
|
||||||
|
- A human-readable label and optional detection_hint
|
||||||
|
|
||||||
{template_start}
|
### Step 3 — Ask focused questions (one at a time)
|
||||||
<the complete extraction prompt here>
|
Cover these topics based on what you discovered:
|
||||||
{template_end}
|
1. How to map content to entity types (task / note / timeline entry)
|
||||||
|
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
|
||||||
|
3. Priority or status rules (e.g. "urgent" in subject → high priority)
|
||||||
|
4. Date extraction (e.g. "by Friday" → dueDate)
|
||||||
|
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
||||||
|
|
||||||
The prompt_template must be a self-contained instruction for an AI that reads files
|
### Step 4 — Produce the AgentConfig JSON
|
||||||
and must perform CRUD operations using tools to create records. It should specify:
|
Once you are ≥ 90% confident, output the final config between these exact markers
|
||||||
- What entity types to create (tasks, notes, timelines) — never projects.
|
(each on its own line):
|
||||||
- How to map file content to record fields (camelCase: title, status, priority,
|
|
||||||
dueDate, content, etc.) — never include projectId.
|
{config_start}
|
||||||
- That isAiSuggested must be set to 1 on every new record.
|
{{
|
||||||
- Concrete examples of mappings based on what you discovered in the directory.
|
"content_types": [
|
||||||
|
{{
|
||||||
|
"id": "email_html",
|
||||||
|
"label": "Email HTML",
|
||||||
|
"detection_hint": "HTML file with From/To/Subject headers",
|
||||||
|
"preprocessing": "email_html",
|
||||||
|
"extraction_prompt": "Detailed extraction instructions for this content type..."
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
"global_rules": [
|
||||||
|
"If the file cannot be matched to any project, do not create any entity."
|
||||||
|
],
|
||||||
|
"data_types": {data_types_json}
|
||||||
|
}}
|
||||||
|
{config_end}
|
||||||
|
|
||||||
|
## Rules for the extraction_prompt field
|
||||||
|
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
|
||||||
|
- Include field mapping rules based on what you found in the directory
|
||||||
|
- Include priority/status/date rules if applicable
|
||||||
|
- Do NOT include projectId logic — the runner handles project assignment automatically
|
||||||
|
- Do NOT mention isAiSuggested — the runner always sets it to 1
|
||||||
|
|
||||||
|
## Constraints
|
||||||
|
- Never ask about projects, projectId, or how to link records to projects
|
||||||
|
- Never include projectId or project creation logic in the generated config
|
||||||
|
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
|
||||||
|
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Keep asking clarifying questions until you are at least 90% confident you have
|
|
||||||
enough information to generate an accurate prompt_template. Once you reach that
|
|
||||||
confidence level, stop asking and produce the final template immediately.
|
|
||||||
Begin by exploring the directory, then ask your first question.\
|
Begin by exploring the directory, then ask your first question.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -148,38 +164,53 @@ Begin by exploring the directory, then ask your first question.\
|
|||||||
def _build_system_prompt(
|
def _build_system_prompt(
|
||||||
directory: str,
|
directory: str,
|
||||||
data_types: list[str],
|
data_types: list[str],
|
||||||
existing_template: str | None = None,
|
existing_config: str | None = None,
|
||||||
) -> tuple[str, Any]:
|
) -> tuple[str, Any]:
|
||||||
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
||||||
existing_section = (
|
existing_section = (
|
||||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
|
||||||
f"---\n{existing_template}\n---\n"
|
f"```json\n{existing_config}\n```\n"
|
||||||
if existing_template
|
if existing_config
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
template, prompt_obj = get_prompt_or_fallback(
|
template, prompt_obj = get_prompt_or_fallback(
|
||||||
"journey_system", _JOURNEY_SYSTEM_PROMPT
|
"journey_system", _JOURNEY_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
compiled = template.format(
|
compiled = compile_prompt(
|
||||||
|
template,
|
||||||
|
prompt_obj,
|
||||||
directory=directory,
|
directory=directory,
|
||||||
data_types=", ".join(data_types),
|
data_types=", ".join(data_types),
|
||||||
template_start=_TEMPLATE_START,
|
data_types_json=json.dumps(data_types),
|
||||||
template_end=_TEMPLATE_END,
|
config_start=_CONFIG_START,
|
||||||
|
config_end=_CONFIG_END,
|
||||||
existing_section=existing_section,
|
existing_section=existing_section,
|
||||||
)
|
)
|
||||||
return compiled, prompt_obj
|
return compiled, prompt_obj
|
||||||
|
|
||||||
|
|
||||||
# ── Template extraction ───────────────────────────────────────────────────
|
# ── AgentConfig extraction ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _extract_template(text: str) -> str | None:
|
def _extract_agent_config(text: str) -> str | None:
|
||||||
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
"""Return validated AgentConfig JSON string from between markers, or None.
|
||||||
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
|
||||||
|
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
||||||
|
returning. Returns None if markers are absent or JSON is invalid.
|
||||||
|
"""
|
||||||
|
if _CONFIG_START not in text or _CONFIG_END not in text:
|
||||||
|
return None
|
||||||
|
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
|
||||||
|
end_idx = text.index(_CONFIG_END)
|
||||||
|
raw = text[start_idx:end_idx].strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = AgentConfig.model_validate_json(raw)
|
||||||
|
return parsed.model_dump_json()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
|
||||||
return None
|
return None
|
||||||
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
|
||||||
end_idx = text.index(_TEMPLATE_END)
|
|
||||||
return text[start_idx:end_idx].strip() or None
|
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call with tool support ───────────────────────────────────────────
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
@@ -225,16 +256,17 @@ async def _call_llm_with_tools(
|
|||||||
else:
|
else:
|
||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_llm(model=None, temperature=0.4)
|
llm = get_agent_llm("setup", temperature=0.4)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
||||||
|
_lf_ctx.__enter__()
|
||||||
|
|
||||||
_span_ctx = (
|
_span_ctx = (
|
||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="span",
|
as_type="span",
|
||||||
name="journey-setup",
|
name="journey-setup",
|
||||||
user_id=user_id or None,
|
|
||||||
session_id=session_id or None,
|
|
||||||
input=history[-1]["content"] if history else "",
|
input=history[-1]["content"] if history else "",
|
||||||
)
|
)
|
||||||
if lf else None
|
if lf else None
|
||||||
@@ -242,12 +274,12 @@ async def _call_llm_with_tools(
|
|||||||
_span = _span_ctx.__enter__() if _span_ctx else None
|
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _ in range(_MAX_TOOL_STEPS):
|
for step in range(_MAX_TOOL_STEPS):
|
||||||
_gen_ctx = (
|
_gen_ctx = (
|
||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="generation",
|
as_type="generation",
|
||||||
name="journey-setup-llm",
|
name="journey-setup-llm",
|
||||||
model=settings.LLM_MODEL,
|
model=model_for_agent("setup"),
|
||||||
prompt=langfuse_prompt,
|
prompt=langfuse_prompt,
|
||||||
input=messages,
|
input=messages,
|
||||||
)
|
)
|
||||||
@@ -256,15 +288,27 @@ async def _call_llm_with_tools(
|
|||||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
if _gen_ctx:
|
if _gen_ctx:
|
||||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
_gen_ctx.__exit__(None, None, None)
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
resp_text = _as_text(response.content)
|
||||||
|
|
||||||
|
# Guard against empty responses (e.g. model returned finish_reason
|
||||||
|
# 'error' which LiteLLM maps to 'stop' with empty content).
|
||||||
|
if not response.tool_calls and not resp_text.strip():
|
||||||
|
logger.warning(
|
||||||
|
"agent_setup: journey LLM returned empty response at step %d — retrying",
|
||||||
|
step,
|
||||||
|
)
|
||||||
|
# Drop the empty AIMessage so we don't pollute history, and retry.
|
||||||
|
continue
|
||||||
|
|
||||||
messages.append(response)
|
messages.append(response)
|
||||||
|
|
||||||
if not response.tool_calls:
|
if not response.tool_calls:
|
||||||
if _span:
|
if _span:
|
||||||
_span.update(output=_as_text(response.content))
|
_span.update(output=resp_text)
|
||||||
return _as_text(response.content)
|
return resp_text
|
||||||
|
|
||||||
for call in response.tool_calls:
|
for call in response.tool_calls:
|
||||||
call_name = str(call.get("name", ""))
|
call_name = str(call.get("name", ""))
|
||||||
@@ -293,10 +337,14 @@ async def _call_llm_with_tools(
|
|||||||
final_text = _as_text(final.content)
|
final_text = _as_text(final.content)
|
||||||
if _span:
|
if _span:
|
||||||
_span.update(output=final_text)
|
_span.update(output=final_text)
|
||||||
return final_text
|
return final_text or (
|
||||||
|
"Sorry, I had trouble processing the files. "
|
||||||
|
"Could you try again? If the issue persists, the files might be too large for me to analyse."
|
||||||
|
)
|
||||||
finally:
|
finally:
|
||||||
if _span_ctx:
|
if _span_ctx:
|
||||||
_span_ctx.__exit__(None, None, None)
|
_span_ctx.__exit__(None, None, None)
|
||||||
|
_lf_ctx.__exit__(None, None, None)
|
||||||
if lf:
|
if lf:
|
||||||
lf.flush()
|
lf.flush()
|
||||||
|
|
||||||
@@ -316,12 +364,12 @@ async def handle_journey_start(
|
|||||||
agent_type = frame.get("agent_type", "local")
|
agent_type = frame.get("agent_type", "local")
|
||||||
directory = frame.get("directory", "")
|
directory = frame.get("directory", "")
|
||||||
data_types = frame.get("data_types", [])
|
data_types = frame.get("data_types", [])
|
||||||
existing_template = frame.get("existing_template")
|
existing_config = frame.get("existing_config")
|
||||||
|
|
||||||
# Use the session_id provided by the FE so the reply matches the
|
# Use the session_id provided by the FE so the reply matches the
|
||||||
# listener key; fall back to a generated one if absent.
|
# listener key; fall back to a generated one if absent.
|
||||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_template)
|
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
|
||||||
|
|
||||||
session = JourneySession(
|
session = JourneySession(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -333,17 +381,15 @@ async def handle_journey_start(
|
|||||||
langfuse_prompt=langfuse_prompt,
|
langfuse_prompt=langfuse_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
|
# Seed with an initial user message — some providers require at least one
|
||||||
# ws_context executor (already set by the WS handler before calling us).
|
# user/input message to be present.
|
||||||
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
|
|
||||||
# require at least one user/input message to be present.
|
|
||||||
seed_history: list[dict[str, Any]] = [
|
seed_history: list[dict[str, Any]] = [
|
||||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
]
|
]
|
||||||
ai_reply = await _call_llm_with_tools(
|
ai_reply = await _call_llm_with_tools(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history=seed_history,
|
history=seed_history,
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
tools=make_directory_tools(directory),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
langfuse_prompt=langfuse_prompt,
|
langfuse_prompt=langfuse_prompt,
|
||||||
@@ -360,14 +406,14 @@ async def handle_journey_start(
|
|||||||
directory,
|
directory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the LLM produced the template on the first turn (unlikely but possible).
|
# Check if the LLM produced the config on the first turn (unlikely but possible).
|
||||||
prompt_template = _extract_template(ai_reply)
|
agent_config = _extract_agent_config(ai_reply)
|
||||||
done = prompt_template is not None
|
done = agent_config is not None
|
||||||
|
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
)
|
)
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
@@ -377,7 +423,7 @@ async def handle_journey_start(
|
|||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"message": display_message,
|
"message": display_message,
|
||||||
"done": done,
|
"done": done,
|
||||||
"prompt_template": prompt_template,
|
"agent_config": agent_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -400,17 +446,18 @@ async def handle_journey_message(
|
|||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"message": "Journey session not found or expired. Please start a new setup.",
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
"done": True,
|
"done": True,
|
||||||
"prompt_template": None,
|
"agent_config": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Append user turn.
|
# Append user turn.
|
||||||
session.history.append({"role": "user", "content": message})
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
# Call the LLM with tools.
|
# Call the LLM with tools.
|
||||||
|
session_tools = make_directory_tools(session.directory)
|
||||||
ai_reply = await _call_llm_with_tools(
|
ai_reply = await _call_llm_with_tools(
|
||||||
system_prompt=session.system_prompt,
|
system_prompt=session.system_prompt,
|
||||||
history=session.history,
|
history=session.history,
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
tools=session_tools,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
langfuse_prompt=session.langfuse_prompt,
|
langfuse_prompt=session.langfuse_prompt,
|
||||||
@@ -418,41 +465,40 @@ async def handle_journey_message(
|
|||||||
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
# Check if the LLM produced the final template.
|
# Check if the LLM produced the final config.
|
||||||
prompt_template = _extract_template(ai_reply)
|
agent_config = _extract_agent_config(ai_reply)
|
||||||
done = prompt_template is not None
|
done = agent_config is not None
|
||||||
|
|
||||||
# If the LLM didn't produce a template, nudge it once it has asked enough
|
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
|
||||||
# questions (>= _MIN_TURNS_BEFORE_NUDGE) or hits the hard safety cap.
|
|
||||||
if not done:
|
if not done:
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
if turns >= _MAX_TURNS:
|
if turns >= _MAX_TURNS:
|
||||||
nudge_content = (
|
nudge_content = (
|
||||||
"[System: You have enough information. Please generate the final "
|
"[System: You have enough information. Please generate the final "
|
||||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
||||||
)
|
)
|
||||||
session.history.append({"role": "user", "content": nudge_content})
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
nudge_reply = await _call_llm_with_tools(
|
nudge_reply = await _call_llm_with_tools(
|
||||||
system_prompt=session.system_prompt,
|
system_prompt=session.system_prompt,
|
||||||
history=session.history,
|
history=session.history,
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
tools=session_tools,
|
||||||
user_id=session.user_id,
|
user_id=session.user_id,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
langfuse_prompt=session.langfuse_prompt,
|
langfuse_prompt=session.langfuse_prompt,
|
||||||
)
|
)
|
||||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
prompt_template = _extract_template(nudge_reply)
|
agent_config = _extract_agent_config(nudge_reply)
|
||||||
if prompt_template is not None:
|
if agent_config is not None:
|
||||||
done = True
|
done = True
|
||||||
ai_reply = nudge_reply
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||||
if _TEMPLATE_START in ai_reply
|
if _CONFIG_START in ai_reply
|
||||||
else "Here is your agent configuration. You can save it or continue refining."
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
)
|
)
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
@@ -463,5 +509,5 @@ async def handle_journey_message(
|
|||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"message": display_message,
|
"message": display_message,
|
||||||
"done": done,
|
"done": done,
|
||||||
"prompt_template": prompt_template,
|
"agent_config": agent_config,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,11 @@ in backend agent-config tables.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
@@ -177,6 +180,11 @@ async def trigger_agent_run(
|
|||||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||||
|
|
||||||
|
last_run_dt = (
|
||||||
|
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
|
||||||
|
if body.last_run_at
|
||||||
|
else None
|
||||||
|
)
|
||||||
config = LocalAgentConfig(
|
config = LocalAgentConfig(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -184,10 +192,12 @@ async def trigger_agent_run(
|
|||||||
name="Local Directory Monitor",
|
name="Local Directory Monitor",
|
||||||
directory_paths=[body.directory],
|
directory_paths=[body.directory],
|
||||||
data_types=_to_data_types(body.what_to_extract),
|
data_types=_to_data_types(body.what_to_extract),
|
||||||
prompt_template=body.custom_agent_prompt,
|
prompt_template=body.custom_agent_prompt or "",
|
||||||
|
agent_config=body.agent_config,
|
||||||
file_extensions=[],
|
file_extensions=[],
|
||||||
schedule_cron=body.batch_interval,
|
schedule_cron=body.batch_interval,
|
||||||
enabled=True,
|
enabled=True,
|
||||||
|
last_run_at=last_run_dt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||||
|
|||||||
@@ -1,34 +1,68 @@
|
|||||||
"""Auth routes: register, login, refresh, me.
|
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
|
||||||
|
|
||||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||||
SHA-256 hashes so plaintext never reaches the DB.
|
SHA-256 hashes so plaintext never reaches the DB.
|
||||||
|
|
||||||
|
OAuth (Google):
|
||||||
|
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
|
||||||
|
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
|
import urllib.parse
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import RefreshToken, User
|
from app.models import OAuthAccount, RefreshToken, User
|
||||||
from app.schemas import AuthTokens, UserProfile
|
from app.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth provider registry ───────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_google_provider() -> GoogleOAuthProvider:
|
||||||
|
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
"Google login is not configured on this server",
|
||||||
|
)
|
||||||
|
return GoogleOAuthProvider(
|
||||||
|
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
|
||||||
|
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||||
|
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PROVIDERS = {"google": _get_google_provider}
|
||||||
|
|
||||||
|
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
|
||||||
|
# Production note: replace with Redis for multi-process deployments.
|
||||||
|
_pending_states: dict[str, tuple[str, float]] = {}
|
||||||
|
_STATE_TTL_SECONDS = 600 # 10 minutes
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ─────────────────────────────────────────────────
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -231,5 +265,531 @@ async def update_profile(
|
|||||||
email=user.email,
|
email=user.email,
|
||||||
name=user.name,
|
name=user.name,
|
||||||
surname=user.surname,
|
surname=user.surname,
|
||||||
|
avatar_url=user.avatar_url,
|
||||||
tier=current_user.tier,
|
tier=current_user.tier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth helpers ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
|
||||||
|
"""Create a refresh token row and return (plain_token, AuthTokens)."""
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return plain_token, AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth request/response schemas ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _OAuthAuthorizeResponse(BaseModel):
|
||||||
|
url: str
|
||||||
|
state: str
|
||||||
|
|
||||||
|
|
||||||
|
class _OAuthCallbackRequest(BaseModel):
|
||||||
|
code: str
|
||||||
|
state: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth routes ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/oauth/{provider}/web-callback",
|
||||||
|
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def oauth_web_callback(
|
||||||
|
provider: Literal["google"],
|
||||||
|
code: str,
|
||||||
|
state: str,
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Google redirects here after user consent.
|
||||||
|
|
||||||
|
This endpoint immediately redirects to the Electron deep-link URI so the
|
||||||
|
desktop app receives the authorization code. It is intentionally simple —
|
||||||
|
no state validation here (the Electron app + backend callback do that).
|
||||||
|
|
||||||
|
Registered in Google Cloud Console as:
|
||||||
|
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
|
||||||
|
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
|
||||||
|
"""
|
||||||
|
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
|
||||||
|
deep_link = f"adiuvai://oauth/callback?{params}"
|
||||||
|
return RedirectResponse(url=deep_link, status_code=302)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/oauth/{provider}/authorize",
|
||||||
|
response_model=_OAuthAuthorizeResponse,
|
||||||
|
summary="Start OAuth flow — returns the provider consent-screen URL",
|
||||||
|
)
|
||||||
|
async def oauth_authorize(
|
||||||
|
provider: Literal["google"],
|
||||||
|
) -> _OAuthAuthorizeResponse:
|
||||||
|
"""Generate a PKCE state + code_challenge and return the authorization URL.
|
||||||
|
|
||||||
|
The client opens this URL in the system browser. After the user grants
|
||||||
|
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
|
||||||
|
with ``code`` and ``state`` query params. The client then calls
|
||||||
|
``POST /auth/oauth/{provider}/callback`` with those values.
|
||||||
|
"""
|
||||||
|
provider_factory = _PROVIDERS.get(provider)
|
||||||
|
if provider_factory is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
oauth_provider = provider_factory()
|
||||||
|
state = str(uuid.uuid4())
|
||||||
|
code_verifier, code_challenge = generate_pkce_pair()
|
||||||
|
|
||||||
|
# Purge expired states to prevent unbounded growth.
|
||||||
|
now = time.time()
|
||||||
|
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
|
||||||
|
for s in expired:
|
||||||
|
del _pending_states[s]
|
||||||
|
|
||||||
|
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
|
||||||
|
|
||||||
|
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
|
||||||
|
return _OAuthAuthorizeResponse(url=url, state=state)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/oauth/{provider}/callback",
|
||||||
|
response_model=AuthTokens,
|
||||||
|
summary="Complete OAuth flow — exchange code and issue JWT tokens",
|
||||||
|
)
|
||||||
|
async def oauth_callback(
|
||||||
|
provider: Literal["google"],
|
||||||
|
body: _OAuthCallbackRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Validate state, exchange the authorization code, and sign in (or register) the user.
|
||||||
|
|
||||||
|
Resolution order:
|
||||||
|
1. ``oauth_accounts`` row match → existing user, log in.
|
||||||
|
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
|
||||||
|
3. No match → create new user (password_hash=None, avatar from provider).
|
||||||
|
"""
|
||||||
|
provider_factory = _PROVIDERS.get(provider)
|
||||||
|
if provider_factory is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
# Validate state (CSRF protection).
|
||||||
|
now = time.time()
|
||||||
|
entry = _pending_states.pop(body.state, None)
|
||||||
|
if entry is None or entry[1] < now:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
||||||
|
|
||||||
|
code_verifier, _ = entry
|
||||||
|
|
||||||
|
oauth_provider = provider_factory()
|
||||||
|
|
||||||
|
# Exchange code for tokens.
|
||||||
|
try:
|
||||||
|
token_data = await oauth_provider.exchange_code(
|
||||||
|
code=body.code,
|
||||||
|
code_verifier=code_verifier,
|
||||||
|
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token_google = token_data.get("access_token")
|
||||||
|
if not access_token_google:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
|
||||||
|
|
||||||
|
# Fetch user identity.
|
||||||
|
try:
|
||||||
|
userinfo = await oauth_provider.get_userinfo(access_token_google)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
|
||||||
|
|
||||||
|
# ── Resolution order ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
# 1. Existing OAuth link?
|
||||||
|
oauth_result = await db.execute(
|
||||||
|
select(OAuthAccount).where(
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
OAuthAccount.provider_user_id == userinfo.provider_user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
oauth_account = oauth_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if oauth_account is not None:
|
||||||
|
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
|
||||||
|
user = user_result.scalar_one()
|
||||||
|
# Backfill avatar if the user doesn't have one yet.
|
||||||
|
if user.avatar_url is None and userinfo.avatar_url:
|
||||||
|
user.avatar_url = userinfo.avatar_url
|
||||||
|
await db.commit()
|
||||||
|
plain_token, tokens = await _issue_refresh_token(user, db)
|
||||||
|
await db.commit()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# 2. Email match with a verified Google email → link accounts.
|
||||||
|
if userinfo.email_verified:
|
||||||
|
email_result = await db.execute(select(User).where(User.email == userinfo.email))
|
||||||
|
existing_user = email_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing_user is not None:
|
||||||
|
new_link = OAuthAccount(
|
||||||
|
user_id=existing_user.id,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=userinfo.provider_user_id,
|
||||||
|
provider_email=userinfo.email,
|
||||||
|
)
|
||||||
|
db.add(new_link)
|
||||||
|
if existing_user.avatar_url is None and userinfo.avatar_url:
|
||||||
|
existing_user.avatar_url = userinfo.avatar_url
|
||||||
|
plain_token, tokens = await _issue_refresh_token(existing_user, db)
|
||||||
|
await db.commit()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# Guard: if the email is already taken but we couldn't auto-link (e.g.
|
||||||
|
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
|
||||||
|
if not userinfo.email_verified:
|
||||||
|
conflict = await db.execute(select(User).where(User.email == userinfo.email))
|
||||||
|
if conflict.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_409_CONFLICT,
|
||||||
|
"An account with this email already exists. "
|
||||||
|
"Please sign in with your password.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. New user — social-only account (no password).
|
||||||
|
new_user = User(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
email=userinfo.email,
|
||||||
|
name=userinfo.name,
|
||||||
|
password_hash=None,
|
||||||
|
avatar_url=userinfo.avatar_url,
|
||||||
|
tier="free",
|
||||||
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
|
)
|
||||||
|
db.add(new_user)
|
||||||
|
await db.flush() # populate new_user.id
|
||||||
|
|
||||||
|
new_oauth = OAuthAccount(
|
||||||
|
user_id=new_user.id,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=userinfo.provider_user_id,
|
||||||
|
provider_email=userinfo.email,
|
||||||
|
)
|
||||||
|
db.add(new_oauth)
|
||||||
|
|
||||||
|
plain_token, tokens = await _issue_refresh_token(new_user, db)
|
||||||
|
await db.commit()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
# ── Onboarding helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
|
||||||
|
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
|
||||||
|
|
||||||
|
# We can't call the FastAPI dependency directly, but we can replicate
|
||||||
|
# the core logic inline. Instead, we just re-query the same way.
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
user_result = await db.execute(
|
||||||
|
select(
|
||||||
|
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||||
|
User.password_hash,
|
||||||
|
).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
onboarding_ms: int | None = None
|
||||||
|
if user_row and user_row.onboarding_completed_at is not None:
|
||||||
|
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||||
|
|
||||||
|
memory_dict: dict[str, str] = {}
|
||||||
|
try:
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
blocks = await mw.list_core_blocks(user_id)
|
||||||
|
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user_id,
|
||||||
|
email=email,
|
||||||
|
name=user_row.name if user_row else None,
|
||||||
|
surname=user_row.surname if user_row else None,
|
||||||
|
avatar_url=user_row.avatar_url if user_row else None,
|
||||||
|
has_password=bool(user_row.password_hash) if user_row else False,
|
||||||
|
tier=tier,
|
||||||
|
onboarding_completed_at=onboarding_ms,
|
||||||
|
memory=memory_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Onboarding routes ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateMemoryRequest(BaseModel):
|
||||||
|
memory: dict[str, str] = Field(default_factory=dict)
|
||||||
|
mark_onboarded: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/memory", response_model=UserProfile)
|
||||||
|
async def update_memory(
|
||||||
|
body: _UpdateMemoryRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update core memory key/value pairs and optionally mark onboarding complete."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
for key, value in body.memory.items():
|
||||||
|
await mw.update_core(current_user.id, key, value)
|
||||||
|
if body.mark_onboarded:
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.onboarding_completed_at = datetime.now(timezone.utc)
|
||||||
|
await db.commit()
|
||||||
|
return await _build_profile(current_user.id, current_user.email, db)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/me/onboarding/reset")
|
||||||
|
async def reset_onboarding(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
):
|
||||||
|
"""Reset onboarding so the wizard runs again on next login."""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.onboarding_completed_at = None
|
||||||
|
await db.commit()
|
||||||
|
return {"status": "reset"}
|
||||||
|
|
||||||
|
|
||||||
|
class _NormalizeRequest(BaseModel):
|
||||||
|
inputs: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class _NormalizeResponse(BaseModel):
|
||||||
|
normalized: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
|
||||||
|
async def normalize_onboarding(
|
||||||
|
body: _NormalizeRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _NormalizeResponse:
|
||||||
|
"""One-shot LLM normalization for free-text onboarding answers."""
|
||||||
|
if not body.inputs:
|
||||||
|
return _NormalizeResponse(normalized={})
|
||||||
|
try:
|
||||||
|
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||||
|
prompt = (
|
||||||
|
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
|
||||||
|
"Return a JSON object with the same keys and normalized values.\n"
|
||||||
|
"Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n"
|
||||||
|
f"Input: {json.dumps(body.inputs)}"
|
||||||
|
)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
normalized = json.loads(response.content)
|
||||||
|
return _NormalizeResponse(normalized=normalized)
|
||||||
|
except Exception:
|
||||||
|
# LLM failure must never block onboarding — return inputs unchanged
|
||||||
|
return _NormalizeResponse(normalized=body.inputs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Password management ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _ChangePasswordRequest(BaseModel):
|
||||||
|
current_password: str = Field(min_length=1)
|
||||||
|
new_password: str = Field(min_length=8)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
||||||
|
async def change_password(
|
||||||
|
body: _ChangePasswordRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Change the authenticated user's password.
|
||||||
|
|
||||||
|
Requires the current password for verification.
|
||||||
|
Returns 400 for social-only users (no password set).
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
if user.password_hash is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
"This account uses social login and has no password to change",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _verify_password(body.current_password, user.password_hash):
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
||||||
|
|
||||||
|
user.password_hash = _hash_password(body.new_password)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth account management ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/oauth-accounts", response_model=list[dict])
|
||||||
|
async def list_oauth_accounts(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""List all OAuth providers linked to the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
accounts = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"provider": a.provider,
|
||||||
|
"provider_email": a.provider_email,
|
||||||
|
"created_at": int(a.created_at.timestamp() * 1000),
|
||||||
|
}
|
||||||
|
for a in accounts
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
||||||
|
async def unlink_oauth_account(
|
||||||
|
provider: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Unlink an OAuth provider from the authenticated user.
|
||||||
|
|
||||||
|
Refuses if the user has no password and this is their only login method.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
oauth_result = await db.execute(
|
||||||
|
select(OAuthAccount).where(
|
||||||
|
OAuthAccount.user_id == current_user.id,
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
account = oauth_result.scalar_one_or_none()
|
||||||
|
if account is None:
|
||||||
|
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
||||||
|
|
||||||
|
# Safety: don't let users lock themselves out.
|
||||||
|
all_oauth = await db.execute(
|
||||||
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
oauth_count = len(all_oauth.scalars().all())
|
||||||
|
|
||||||
|
if user.password_hash is None and oauth_count <= 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
"Cannot unlink the only login method. Set a password first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.delete(account)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Avatar update ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateAvatarRequest(BaseModel):
|
||||||
|
avatar_url: str = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/avatar", response_model=UserProfile)
|
||||||
|
async def update_avatar(
|
||||||
|
body: _UpdateAvatarRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update the authenticated user's avatar URL.
|
||||||
|
|
||||||
|
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
||||||
|
to its own storage and passes the resulting URL here.
|
||||||
|
"""
|
||||||
|
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
||||||
|
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.avatar_url = body.avatar_url
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return await _build_profile(current_user.id, current_user.email, db)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Account deletion ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me", status_code=status.HTTP_200_OK)
|
||||||
|
async def delete_account(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Permanently delete the authenticated user's account.
|
||||||
|
|
||||||
|
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
||||||
|
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
||||||
|
is cancelled if active.
|
||||||
|
"""
|
||||||
|
# Cancel Stripe subscription if present.
|
||||||
|
try:
|
||||||
|
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||||
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
|
except HTTPException:
|
||||||
|
pass # No subscription — that's fine
|
||||||
|
|
||||||
|
# Delete all memory rows (core, associative, episodic, proactive).
|
||||||
|
try:
|
||||||
|
from app.models import ( # noqa: PLC0415
|
||||||
|
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
||||||
|
)
|
||||||
|
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
||||||
|
await db.execute(
|
||||||
|
model.__table__.delete().where(model.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Non-critical — cascade on User will handle most
|
||||||
|
|
||||||
|
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
await db.delete(user)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|||||||
@@ -83,3 +83,16 @@ async def cancel_subscription(
|
|||||||
"""Cancel the active subscription."""
|
"""Cancel the active subscription."""
|
||||||
await stripe_service.cancel_subscription(current_user.id, db)
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/invoices", response_model=list[dict])
|
||||||
|
async def list_invoices(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return billing history (invoices) from Stripe.
|
||||||
|
|
||||||
|
Returns an empty list when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||||
|
return invoices
|
||||||
|
|||||||
1
app/auth/__init__.py
Normal file
1
app/auth/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"OAuth provider abstractions and utilities."
|
||||||
135
app/auth/oauth_providers.py
Normal file
135
app/auth/oauth_providers.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""OAuth 2.0 + PKCE provider abstractions.
|
||||||
|
|
||||||
|
Each provider implements a three-step flow designed for a desktop (public) client:
|
||||||
|
|
||||||
|
1. get_authorization_url(state, code_challenge) → str
|
||||||
|
Build the provider's consent-screen URL. State and code_challenge are
|
||||||
|
generated server-side; the client opens this URL in the system browser.
|
||||||
|
|
||||||
|
2. exchange_code(code, code_verifier, redirect_uri) → dict
|
||||||
|
Exchange the short-lived authorization code for an access token.
|
||||||
|
The code_verifier proves ownership of the PKCE challenge.
|
||||||
|
|
||||||
|
3. get_userinfo(access_token) → OAuthUserInfo
|
||||||
|
Fetch the canonical user identity from the provider.
|
||||||
|
|
||||||
|
Currently supported providers:
|
||||||
|
- GoogleOAuthProvider (scope: openid email profile)
|
||||||
|
|
||||||
|
Adding a new provider:
|
||||||
|
- Implement the three methods above.
|
||||||
|
- Register in _PROVIDERS inside routes/auth.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import urllib.parse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
# ── Data transfer objects ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthUserInfo:
|
||||||
|
"""Normalized user identity returned by any provider."""
|
||||||
|
|
||||||
|
provider_user_id: str
|
||||||
|
email: str
|
||||||
|
email_verified: bool
|
||||||
|
avatar_url: str | None
|
||||||
|
name: str | None
|
||||||
|
|
||||||
|
|
||||||
|
# ── PKCE helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pkce_pair() -> tuple[str, str]:
|
||||||
|
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
|
||||||
|
|
||||||
|
The code_verifier is a random 32-byte URL-safe base64 string.
|
||||||
|
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
|
||||||
|
"""
|
||||||
|
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
|
||||||
|
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||||
|
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||||
|
return code_verifier, code_challenge
|
||||||
|
|
||||||
|
|
||||||
|
# ── Google provider ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleOAuthProvider:
|
||||||
|
"""Google OAuth 2.0 provider (openid email profile scope).
|
||||||
|
|
||||||
|
Uses Google's standard authorization endpoint with PKCE S256.
|
||||||
|
Does NOT use google-auth-oauthlib to keep the flow generic and async.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "google"
|
||||||
|
|
||||||
|
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||||
|
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||||
|
|
||||||
|
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.redirect_uri = redirect_uri
|
||||||
|
|
||||||
|
def get_authorization_url(self, state: str, code_challenge: str) -> str:
|
||||||
|
"""Build the Google consent-screen URL."""
|
||||||
|
params = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": "openid email profile",
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": code_challenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
"access_type": "offline",
|
||||||
|
"prompt": "select_account",
|
||||||
|
}
|
||||||
|
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
|
async def exchange_code(
|
||||||
|
self, code: str, code_verifier: str, redirect_uri: str
|
||||||
|
) -> dict:
|
||||||
|
"""Exchange authorization code for an access token."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self._TOKEN_URL,
|
||||||
|
data={
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"code": code,
|
||||||
|
"code_verifier": code_verifier,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
|
||||||
|
"""Fetch the authenticated user's identity from Google."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
self._USERINFO_URL,
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
return OAuthUserInfo(
|
||||||
|
provider_user_id=data["sub"],
|
||||||
|
email=data["email"],
|
||||||
|
email_verified=data.get("email_verified", False),
|
||||||
|
avatar_url=data.get("picture"),
|
||||||
|
name=data.get("name"),
|
||||||
|
)
|
||||||
@@ -43,8 +43,8 @@ class StripeService:
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tier: str,
|
tier: str,
|
||||||
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
cancel_url: str = "https://app.adiuvai.app/billing/cancel",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a Stripe checkout session and return the URL.
|
"""Create a Stripe checkout session and return the URL.
|
||||||
|
|
||||||
@@ -200,6 +200,45 @@ class StripeService:
|
|||||||
sub.status = "canceled"
|
sub.status = "canceled"
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
async def list_invoices(
|
||||||
|
self, user_id: str, db: AsyncSession, limit: int = 24
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return recent invoices for the user from Stripe.
|
||||||
|
|
||||||
|
Returns an empty list when Stripe is not configured or the user has
|
||||||
|
no ``stripe_customer_id``.
|
||||||
|
"""
|
||||||
|
if not self._configured():
|
||||||
|
return []
|
||||||
|
|
||||||
|
from app.models import User # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(User.stripe_customer_id).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
customer_id = result.scalar_one_or_none()
|
||||||
|
if not customer_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = self._client()
|
||||||
|
invoices = s.Invoice.list(customer=customer_id, limit=limit)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": inv.id,
|
||||||
|
"amount_due": inv.amount_due,
|
||||||
|
"amount_paid": inv.amount_paid,
|
||||||
|
"currency": inv.currency,
|
||||||
|
"status": inv.status,
|
||||||
|
"created": inv.created * 1000, # epoch ms
|
||||||
|
"invoice_url": inv.hosted_invoice_url,
|
||||||
|
"invoice_pdf": inv.invoice_pdf,
|
||||||
|
}
|
||||||
|
for inv in invoices.auto_paging_iter()
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
# ── Private DB helpers ───────────────────────────────────────────────
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
async def _upsert_subscription(
|
async def _upsert_subscription(
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": 1,
|
"providers": 1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": False, # keyword fallback only
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
@@ -33,6 +34,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": True, # pgvector cosine search
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -41,6 +43,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
|
"real_embeddings": True,
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
@@ -49,6 +52,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"sso": True,
|
"sso": True,
|
||||||
|
"real_embeddings": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
|
||||||
JWT_SECRET: str = "change-me-in-production"
|
JWT_SECRET: str = "change-me-in-production"
|
||||||
JWT_ALGORITHM: str = "HS256"
|
JWT_ALGORITHM: str = "HS256"
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
@@ -18,9 +18,16 @@ class Settings(BaseSettings):
|
|||||||
CEREBRAS_API_KEY: str = ""
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
|
||||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
||||||
|
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
|
||||||
|
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
||||||
|
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
|
||||||
|
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||||
|
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||||
|
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||||
|
|
||||||
# GitHub Copilot OAuth token storage directory.
|
# GitHub Copilot OAuth token storage directory.
|
||||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||||
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||||
@@ -34,20 +41,37 @@ class Settings(BaseSettings):
|
|||||||
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||||
MS_TENANT_ID: str = "common"
|
MS_TENANT_ID: str = "common"
|
||||||
|
|
||||||
|
# Google Login OAuth credentials — scope: openid email profile.
|
||||||
|
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
|
||||||
|
GOOGLE_AUTH_CLIENT_ID: str = ""
|
||||||
|
GOOGLE_AUTH_CLIENT_SECRET: str = ""
|
||||||
|
# The redirect URI registered in Google Cloud Console.
|
||||||
|
# Google redirects here after consent; this backend route then bounces to
|
||||||
|
# the adiuvai:// deep link so the Electron app receives the code.
|
||||||
|
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
|
||||||
|
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
||||||
|
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
||||||
|
|
||||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
OAUTH_ENCRYPTION_KEY: str = ""
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = [
|
||||||
|
"app://.",
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:5173",
|
||||||
|
"http://localhost:4173", # Vite preview (web SPA)
|
||||||
|
"https://app.adiuvai.com", # Production web portal
|
||||||
|
]
|
||||||
|
|
||||||
LANGFUSE_SECRET_KEY: str = ""
|
LANGFUSE_SECRET_KEY: str = ""
|
||||||
LANGFUSE_PUBLIC_KEY: str = ""
|
LANGFUSE_PUBLIC_KEY: str = ""
|
||||||
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
|
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
Drives two agent types:
|
Drives two agent types:
|
||||||
|
|
||||||
* **Local directory agent** — two-step execution per file:
|
* **Local directory agent** — V2 unified flow per file:
|
||||||
Step 1 (Classification) uses code to fetch all projects and asks the LLM
|
Phase A (Detect + Preprocess, zero LLM): Python detects the content type
|
||||||
to identify which project the file belongs to and which domains are relevant.
|
and strips markup/noise, producing clean text + metadata.
|
||||||
Step 2 (Processing) fetches existing entities for that project/domains via
|
Phase B (Single LLM call with tools): the LLM identifies the project,
|
||||||
code and runs an LLM with tools — existing data in context enforces
|
checks for duplicates via list_* tools, and creates/updates records.
|
||||||
update-first naturally.
|
``items_created`` is counted from ``create_*`` tool calls.
|
||||||
|
|
||||||
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
||||||
Teams, Outlook) and pushes extracted items to Electron.
|
Teams, Outlook) and pushes extracted items to Electron.
|
||||||
@@ -29,7 +29,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import os
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -42,10 +42,10 @@ from app.agents.note_agent import NOTE_TOOLS
|
|||||||
from app.agents.project_agent import PROJECT_TOOLS
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
from app.agents.task_agent import TASK_TOOLS
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
|
from app.core.preprocessors import detect_content_type, preprocess
|
||||||
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
|
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
@@ -72,92 +72,47 @@ _MAX_PROCESSING_STEPS: int = 12
|
|||||||
_MAX_SCAN_DEPTH: int = 5
|
_MAX_SCAN_DEPTH: int = 5
|
||||||
|
|
||||||
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||||
# NOTE: "projects" is intentionally excluded — project creation/assignment is
|
|
||||||
# handled in code by the runner, never delegated to the Step 2 LLM.
|
|
||||||
|
|
||||||
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||||
"tasks": TASK_TOOLS,
|
"tasks": TASK_TOOLS,
|
||||||
"notes": NOTE_TOOLS,
|
"notes": NOTE_TOOLS,
|
||||||
"timelines": TIMELINE_TOOLS,
|
"timelines": TIMELINE_TOOLS,
|
||||||
|
"timelineEvents": TIMELINE_TOOLS,
|
||||||
|
"projects": PROJECT_TOOLS,
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
# ── V2: Unified processing prompt (hot-swappable via Langfuse "unified_processing") ──
|
||||||
|
|
||||||
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
_UNIFIED_PROCESSING_PROMPT = """\
|
||||||
"tasks": (
|
|
||||||
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
|
||||||
"assigned to someone, or tracked with a due date or status."
|
|
||||||
),
|
|
||||||
"notes": (
|
|
||||||
"Documentation, meeting notes, summaries, reference material — "
|
|
||||||
"written content meant to be read and referenced rather than acted on."
|
|
||||||
),
|
|
||||||
"timelines": (
|
|
||||||
"Project milestones, deadlines, scheduled events — "
|
|
||||||
"specific dates that mark a point in the progress of a project."
|
|
||||||
),
|
|
||||||
"projects": (
|
|
||||||
"High-level project entities — only relevant if the file clearly introduces "
|
|
||||||
"a new project or updates the scope of an existing one."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
_BATCH_FILE_CLASSIFIER_PROMPT = """\
|
|
||||||
You are a file classifier for a freelance project management tool.
|
|
||||||
|
|
||||||
Your job is to match a file to an existing project and identify which data domains to extract.
|
|
||||||
|
|
||||||
## Project matching rules (STRICT — follow in order)
|
|
||||||
|
|
||||||
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
|
||||||
that overlaps with the existing projects listed below.
|
|
||||||
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
|
||||||
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
|
||||||
when the file has zero meaningful connection to any listed project.
|
|
||||||
4. When in doubt, pick the closest match from the list.
|
|
||||||
|
|
||||||
## Response format
|
|
||||||
|
|
||||||
Respond ONLY with a JSON object — no markdown, no explanation:
|
|
||||||
|
|
||||||
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
|
||||||
|
|
||||||
## Domain definitions (only consider domains in the allowed list)
|
|
||||||
|
|
||||||
{domain_definitions}
|
|
||||||
|
|
||||||
## Existing projects
|
|
||||||
|
|
||||||
{projects_list}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
_BATCH_PROCESSING_PROMPT = """\
|
|
||||||
You are a data extraction assistant for a freelance project management tool.
|
You are a data extraction assistant for a freelance project management tool.
|
||||||
|
|
||||||
Your task: extract structured data from the file content and persist it using the available tools.
|
## Your process (follow this exact order)
|
||||||
|
|
||||||
## Mandatory process — follow this order for EVERY item you extract
|
### 1. Identify the project
|
||||||
|
File: {filename}
|
||||||
|
{metadata_section}
|
||||||
|
|
||||||
1. READ the existing records listed below for the relevant domain.
|
Existing projects:
|
||||||
2. SEARCH for a match by title, topic, or semantic similarity.
|
{projects_list}
|
||||||
3. If a match exists → call the update_* tool with the existing record's id.
|
|
||||||
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
|
||||||
|
|
||||||
NEVER call create_* without first checking the existing records.
|
Match this file to an existing project using the filename and content clues.
|
||||||
NEVER duplicate a record that already exists under a different wording.
|
If no project matches, {no_match_behavior}.
|
||||||
|
|
||||||
## Existing records (source of truth)
|
### 2. Check existing records
|
||||||
|
Once you identify the project, use list_tasks / list_notes / list_timelines
|
||||||
|
(filtered by projectId) to see what already exists.
|
||||||
|
NEVER create a record that already exists under the same or similar title.
|
||||||
|
|
||||||
{existing_context}
|
### 3. Extract and create / update
|
||||||
|
{extraction_rules}
|
||||||
|
|
||||||
## Context
|
### Rules
|
||||||
|
- Set isAiSuggested=1 on every new record.
|
||||||
Project: {project_context}
|
- Set projectId on every record (use the id from the project list above).
|
||||||
Domains to extract: {data_types}
|
- Update existing records when a match is found by title or topic.
|
||||||
|
- Do NOT invent data — only extract what is clearly stated in the content.
|
||||||
{custom_prompt_section}
|
- Target entity types: {data_types}.
|
||||||
|
{global_rules}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ── Cloud processing prompt (kept separate for cloud agent) ───────────────
|
# ── Cloud processing prompt (kept separate for cloud agent) ───────────────
|
||||||
@@ -271,12 +226,18 @@ async def _run_agent_with_tools(
|
|||||||
tools: list[Any],
|
tools: list[Any],
|
||||||
max_steps: int,
|
max_steps: int,
|
||||||
user_id: str = "",
|
user_id: str = "",
|
||||||
|
session_id: str = "",
|
||||||
langfuse_prompt: Any = None,
|
langfuse_prompt: Any = None,
|
||||||
agent_name: str = "batch-agent",
|
agent_name: str = "batch-agent",
|
||||||
|
_tool_calls_out: list[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run an LLM agent with tool-calling, returning the final text response."""
|
"""Run an LLM agent with tool-calling, returning the final text response.
|
||||||
|
|
||||||
|
If *_tool_calls_out* is provided, the name of every tool called during the
|
||||||
|
run is appended to it (used by the caller to count ``create_*`` calls).
|
||||||
|
"""
|
||||||
lf = get_langfuse()
|
lf = get_langfuse()
|
||||||
llm = get_llm()
|
llm = get_agent_llm(agent_name)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
messages: list[Any] = [
|
messages: list[Any] = [
|
||||||
SystemMessage(content=system_prompt),
|
SystemMessage(content=system_prompt),
|
||||||
@@ -285,11 +246,14 @@ async def _run_agent_with_tools(
|
|||||||
|
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
||||||
|
_lf_ctx.__enter__()
|
||||||
|
|
||||||
_span_ctx = (
|
_span_ctx = (
|
||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="span",
|
as_type="span",
|
||||||
name=agent_name,
|
name=agent_name,
|
||||||
user_id=user_id or None,
|
metadata={"user_id": user_id} if user_id else None,
|
||||||
input=user_message,
|
input=user_message,
|
||||||
)
|
)
|
||||||
if lf else None
|
if lf else None
|
||||||
@@ -302,7 +266,7 @@ async def _run_agent_with_tools(
|
|||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="generation",
|
as_type="generation",
|
||||||
name=f"{agent_name}-llm",
|
name=f"{agent_name}-llm",
|
||||||
model=settings.LLM_MODEL,
|
model=model_for_agent(agent_name),
|
||||||
prompt=langfuse_prompt,
|
prompt=langfuse_prompt,
|
||||||
input=messages,
|
input=messages,
|
||||||
)
|
)
|
||||||
@@ -311,7 +275,7 @@ async def _run_agent_with_tools(
|
|||||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
if _gen_ctx:
|
if _gen_ctx:
|
||||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
_gen_ctx.__exit__(None, None, None)
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
messages.append(response)
|
messages.append(response)
|
||||||
@@ -332,6 +296,9 @@ async def _run_agent_with_tools(
|
|||||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _tool_calls_out is not None:
|
||||||
|
_tool_calls_out.append(call_name)
|
||||||
|
|
||||||
tool_fn = tool_map.get(call_name)
|
tool_fn = tool_map.get(call_name)
|
||||||
if tool_fn is None:
|
if tool_fn is None:
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
@@ -353,6 +320,7 @@ async def _run_agent_with_tools(
|
|||||||
finally:
|
finally:
|
||||||
if _span_ctx:
|
if _span_ctx:
|
||||||
_span_ctx.__exit__(None, None, None)
|
_span_ctx.__exit__(None, None, None)
|
||||||
|
_lf_ctx.__exit__(None, None, None)
|
||||||
if lf:
|
if lf:
|
||||||
lf.flush()
|
lf.flush()
|
||||||
|
|
||||||
@@ -421,7 +389,8 @@ async def _scan_directories(
|
|||||||
for file_path in all_files:
|
for file_path in all_files:
|
||||||
try:
|
try:
|
||||||
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
||||||
modified_at = meta.get("modifiedAt")
|
# FE sends snake_case keys on the wire (toSnakeCase transform)
|
||||||
|
modified_at = meta.get("modified_at") or meta.get("modifiedAt")
|
||||||
if modified_at is None:
|
if modified_at is None:
|
||||||
filtered.append(file_path)
|
filtered.append(file_path)
|
||||||
continue
|
continue
|
||||||
@@ -523,99 +492,66 @@ def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
|||||||
return f"Existing {domain}:\n" + "\n".join(lines)
|
return f"Existing {domain}:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
# ── V2 helper functions ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _classify_file(
|
def _format_projects(projects: list[dict]) -> str:
|
||||||
file_path: str,
|
"""Format the project list for the unified system prompt."""
|
||||||
file_content: str,
|
if not projects:
|
||||||
projects: list[dict],
|
return " (no projects yet)"
|
||||||
config_data_types: list[str],
|
lines: list[str] = []
|
||||||
) -> tuple[str, list[str], str | None]:
|
for p in projects:
|
||||||
"""Call the LLM to classify a file by project and relevant domains.
|
|
||||||
|
|
||||||
Returns ``(project_id_or_"new", domains, new_project_name_or_None)``.
|
|
||||||
- ``project_id`` is an existing project UUID, or ``"new"`` when no match found.
|
|
||||||
- ``new_project_name`` is only set when ``project_id == "new"``.
|
|
||||||
Falls back to ``("new", config_data_types, None)`` on any error.
|
|
||||||
"""
|
|
||||||
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
|
||||||
|
|
||||||
if not file_content.strip():
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
valid_project_ids = {p["id"] for p in projects}
|
|
||||||
|
|
||||||
def _fmt_project(p: dict) -> str:
|
|
||||||
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
||||||
summary_part = f" — {summary[:100]}" if summary else ""
|
summary_part = f" — {summary[:100]}" if summary else ""
|
||||||
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
lines.append(
|
||||||
|
f" - id={p['id']} | name={p.get('name', '')} | "
|
||||||
|
f"status={p.get('status', '')}{summary_part}"
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
|
||||||
|
|
||||||
domain_definitions = "\n".join(
|
def _format_metadata(metadata: dict) -> str:
|
||||||
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
"""Format preprocessor metadata as a compact context block."""
|
||||||
for d in config_data_types
|
if not metadata:
|
||||||
if d in _DOMAIN_DESCRIPTIONS
|
return ""
|
||||||
|
parts: list[str] = []
|
||||||
|
for key in ("subject", "from", "to", "date"):
|
||||||
|
if metadata.get(key):
|
||||||
|
parts.append(f"{key.capitalize()}: {metadata[key]}")
|
||||||
|
# any remaining keys
|
||||||
|
for key, val in metadata.items():
|
||||||
|
if key not in ("subject", "from", "to", "date") and val:
|
||||||
|
parts.append(f"{key}: {val}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_extraction_rules(agent_config: dict, content_type: str) -> str:
|
||||||
|
"""Return the extraction_prompt for *content_type* from *agent_config*.
|
||||||
|
|
||||||
|
Falls back to a generic instruction when the type is not configured.
|
||||||
|
"""
|
||||||
|
for ct in agent_config.get("content_types", []):
|
||||||
|
if ct.get("id") == content_type:
|
||||||
|
prompt = ct.get("extraction_prompt", "").strip()
|
||||||
|
if prompt:
|
||||||
|
return prompt
|
||||||
|
return (
|
||||||
|
"Extract relevant information as tasks (action items), notes "
|
||||||
|
"(informational content), or timelines (dated events)."
|
||||||
)
|
)
|
||||||
|
|
||||||
step1_template, step1_prompt_obj = get_prompt_or_fallback(
|
|
||||||
"batch_file_classifier", _BATCH_FILE_CLASSIFIER_PROMPT
|
|
||||||
)
|
|
||||||
system = step1_template.format(
|
|
||||||
domain_definitions=domain_definitions,
|
|
||||||
projects_list=projects_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
def _get_no_match_behavior(agent_config: dict) -> str:
|
||||||
llm = get_llm()
|
"""Derive the 'no project match' instruction from global_rules."""
|
||||||
classifier_messages = [
|
rules = agent_config.get("global_rules", [])
|
||||||
SystemMessage(content=system),
|
for rule in rules:
|
||||||
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
lower = rule.lower()
|
||||||
]
|
if "no project" in lower or "no match" in lower or "skip" in lower:
|
||||||
try:
|
return rule
|
||||||
if lf:
|
return "create a new project with a concise name derived from the file content"
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="step1-classifier",
|
|
||||||
model=settings.LLM_ROUTER_MODEL,
|
|
||||||
prompt=step1_prompt_obj,
|
|
||||||
input=classifier_messages,
|
|
||||||
) as gen:
|
|
||||||
response = await llm.ainvoke(classifier_messages)
|
|
||||||
gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
|
||||||
else:
|
|
||||||
response = await llm.ainvoke(classifier_messages)
|
|
||||||
raw = _as_text(response.content).strip()
|
|
||||||
# Strip markdown fences if the model wraps the JSON.
|
|
||||||
if raw.startswith("```"):
|
|
||||||
raw = raw.split("```")[1]
|
|
||||||
if raw.startswith("json"):
|
|
||||||
raw = raw[4:]
|
|
||||||
parsed = json.loads(raw.strip())
|
|
||||||
raw_project_id: str = str(parsed.get("project_id") or "new")
|
|
||||||
# Reject hallucinated UUIDs — only accept ids that exist in the fetched list.
|
|
||||||
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
|
||||||
new_project_name: str | None = (
|
|
||||||
str(parsed["new_project_name"]).strip() or None
|
|
||||||
if project_id == "new" and parsed.get("new_project_name")
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
domains: list[str] = [
|
|
||||||
d for d in parsed.get("domains", [])
|
|
||||||
if d in config_data_types
|
|
||||||
]
|
|
||||||
if not domains:
|
|
||||||
domains = list(config_data_types)
|
|
||||||
return project_id, domains, new_project_name
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
|
||||||
)
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent runner (two-step per file) ────────────────────────────────
|
# ── Local agent runner (V2 — unified per-file flow) ───────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def run_local_agent(
|
async def run_local_agent(
|
||||||
@@ -625,16 +561,17 @@ async def run_local_agent(
|
|||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
run_context: dict | None = None,
|
run_context: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a local directory agent run using a two-step approach per file.
|
"""Execute a local directory agent run — V2 unified flow.
|
||||||
|
|
||||||
Step 1 — Classification (code + 1 LLM call per file, no tools):
|
Phase A — Detect + Preprocess (zero LLM, per file):
|
||||||
Code scans directories and fetches all projects via WS.
|
Python detects the content type from filename + content patterns and
|
||||||
For each file, LLM identifies the project and relevant domains.
|
runs the appropriate handler (e.g. email_html) to produce clean text
|
||||||
|
and structured metadata.
|
||||||
|
|
||||||
Step 2 — Processing (code + 1 LLM call per file, with tools):
|
Phase B — Single LLM call with tools (per file):
|
||||||
Code fetches existing entities for the identified project/domains.
|
One LLM call handles project identification, duplicate checking, and
|
||||||
LLM receives file content + existing entities in context and uses
|
record creation/update. ``create_*`` tool calls are counted to
|
||||||
tools to update existing records or create new ones.
|
produce the accurate ``items_created`` metric.
|
||||||
"""
|
"""
|
||||||
run_id = run_log.id
|
run_id = run_log.id
|
||||||
agent_id = (run_context or {}).get("agent_id") or config.id
|
agent_id = (run_context or {}).get("agent_id") or config.id
|
||||||
@@ -669,16 +606,11 @@ async def run_local_agent(
|
|||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
items_processed = 0
|
items_processed = 0
|
||||||
items_created = 0
|
items_created = 0
|
||||||
|
agent_config: dict = config.agent_config or {}
|
||||||
custom_section = (
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
f"User instructions:\n{config.prompt_template}"
|
|
||||||
if config.prompt_template
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ── Code: scan directories ───────────────────────────────────
|
# ── Code: scan directories ───────────────────────────────────
|
||||||
logger.info("agent_runner: run=%s scanning directories user=%s", run_id, user_id)
|
|
||||||
file_paths = await _scan_directories(
|
file_paths = await _scan_directories(
|
||||||
paths=config.directory_paths,
|
paths=config.directory_paths,
|
||||||
extensions=config.file_extensions or [],
|
extensions=config.file_extensions or [],
|
||||||
@@ -694,114 +626,89 @@ async def run_local_agent(
|
|||||||
|
|
||||||
# ── Code: fetch all projects once ────────────────────────────
|
# ── Code: fetch all projects once ────────────────────────────
|
||||||
projects = await _fetch_projects()
|
projects = await _fetch_projects()
|
||||||
|
projects_block = _format_projects(projects)
|
||||||
|
|
||||||
|
# Prompt template + Langfuse version linking (hot-swappable from UI).
|
||||||
|
unified_template, prompt_obj = get_prompt_or_fallback(
|
||||||
|
"unified_processing", _UNIFIED_PROCESSING_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
try:
|
try:
|
||||||
# Read file content via code.
|
# ── Phase A: read + detect + preprocess ─────────────
|
||||||
file_result = await execute_on_client(
|
file_result = await execute_on_client(
|
||||||
action="read_file_content", data={"path": file_path}
|
action="read_file_content", data={"path": file_path}
|
||||||
)
|
)
|
||||||
file_content: str = file_result.get("content", "")
|
raw_content: str = file_result.get("content", "")
|
||||||
if not file_content:
|
if not raw_content.strip():
|
||||||
logger.debug("agent_runner: run=%s skipping empty file %r", run_id, file_path)
|
logger.debug(
|
||||||
|
"agent_runner: run=%s skipping empty file %r", run_id, file_path
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
items_processed += 1
|
items_processed += 1
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
content_type = detect_content_type(filename, raw_content)
|
||||||
|
preprocessed = preprocess(content_type, raw_content)
|
||||||
|
|
||||||
# Step 1 — classify file.
|
|
||||||
project_id, domains, new_project_name = await _classify_file(
|
|
||||||
file_path=file_path,
|
|
||||||
file_content=file_content,
|
|
||||||
projects=projects,
|
|
||||||
config_data_types=config.data_types,
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s file=%r → project=%s new_name=%r domains=%s",
|
"agent_runner: run=%s file=%r content_type=%s clean_len=%d",
|
||||||
run_id,
|
run_id, file_path, content_type, len(preprocessed.clean_text),
|
||||||
file_path,
|
|
||||||
project_id,
|
|
||||||
new_project_name,
|
|
||||||
domains,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 2 — resolve project_id via CODE, then fetch entities.
|
# ── Phase B: single LLM call ─────────────────────────
|
||||||
# Project creation is NEVER delegated to the Step 2 LLM.
|
extraction_rules = _get_extraction_rules(agent_config, content_type)
|
||||||
if project_id == "new":
|
no_match_behavior = _get_no_match_behavior(agent_config)
|
||||||
proj_name = new_project_name or "Untitled Project"
|
global_rules_lines = "\n".join(
|
||||||
try:
|
f"- {r}" for r in agent_config.get("global_rules", [])
|
||||||
proj_result = await execute_on_client(
|
|
||||||
action="insert",
|
|
||||||
table="projects",
|
|
||||||
data={"name": proj_name, "clientId": None},
|
|
||||||
)
|
)
|
||||||
created = proj_result.get("row", {})
|
metadata_section = _format_metadata(preprocessed.metadata)
|
||||||
effective_project_id = created.get("id", "standalone")
|
|
||||||
# Add to local list so subsequent files can match it.
|
system_prompt = compile_prompt(
|
||||||
if "id" in created:
|
unified_template,
|
||||||
projects.append(created)
|
prompt_obj,
|
||||||
logger.info(
|
filename=filename,
|
||||||
"agent_runner: run=%s created project %r id=%s",
|
metadata_section=metadata_section,
|
||||||
run_id, proj_name, effective_project_id,
|
projects_list=projects_block,
|
||||||
)
|
no_match_behavior=no_match_behavior,
|
||||||
except Exception as exc:
|
extraction_rules=extraction_rules,
|
||||||
logger.warning(
|
global_rules=global_rules_lines,
|
||||||
"agent_runner: run=%s failed to create project %r: %s",
|
data_types=", ".join(config.data_types),
|
||||||
run_id, proj_name, exc,
|
|
||||||
)
|
|
||||||
effective_project_id = "standalone"
|
|
||||||
proj_name = "unknown"
|
|
||||||
project_context = (
|
|
||||||
f"Project: {proj_name} (id: {effective_project_id}). "
|
|
||||||
"Always set projectId to this id on every record you create."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
effective_project_id = project_id
|
|
||||||
proj = next((p for p in projects if p["id"] == project_id), None)
|
|
||||||
proj_name = proj.get("name", project_id) if proj else project_id
|
|
||||||
project_context = (
|
|
||||||
f"Project: {proj_name} (id: {project_id}). "
|
|
||||||
"Always set projectId to this id on every record you create."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# "projects" domain is never passed to Step 2 — handled above in code.
|
|
||||||
domains = [d for d in domains if d != "projects"]
|
|
||||||
|
|
||||||
existing_blocks: list[str] = []
|
|
||||||
for domain in domains:
|
|
||||||
rows = await _fetch_domain_entities(domain, effective_project_id)
|
|
||||||
existing_blocks.append(_format_entities_for_context(domain, rows))
|
|
||||||
|
|
||||||
existing_context = "\n\n".join(existing_blocks)
|
|
||||||
|
|
||||||
step2_template, step2_prompt_obj = get_prompt_or_fallback(
|
|
||||||
"batch_processing", _BATCH_PROCESSING_PROMPT
|
|
||||||
)
|
|
||||||
system_prompt = step2_template.format(
|
|
||||||
existing_context=existing_context,
|
|
||||||
project_context=project_context,
|
|
||||||
data_types=", ".join(domains),
|
|
||||||
custom_prompt_section=custom_section,
|
|
||||||
)
|
|
||||||
|
|
||||||
processing_tools = _build_processing_tools(domains)
|
|
||||||
|
|
||||||
result_text = await _run_agent_with_tools(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
user_message = (
|
user_message = (
|
||||||
f"Process this file and extract relevant information.\n\n"
|
f"Process this file and extract relevant information.\n\n"
|
||||||
f"File: {file_path}\n\nContent:\n{file_content}"
|
f"File: {file_path}\n\n"
|
||||||
),
|
f"Content:\n{preprocessed.clean_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_tool_calls: list[str] = []
|
||||||
|
result_text = await _run_agent_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_message=user_message,
|
||||||
tools=processing_tools,
|
tools=processing_tools,
|
||||||
max_steps=_MAX_PROCESSING_STEPS,
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
langfuse_prompt=step2_prompt_obj,
|
session_id=run_id,
|
||||||
agent_name="step2-processor",
|
langfuse_prompt=prompt_obj,
|
||||||
|
agent_name="unified-processor",
|
||||||
|
_tool_calls_out=file_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
file_created = sum(
|
||||||
|
1 for name in file_tool_calls if name.startswith("create_")
|
||||||
|
)
|
||||||
|
items_created += file_created
|
||||||
|
|
||||||
|
# Refresh project list when a project was created so
|
||||||
|
# subsequent files see it in the prompt context.
|
||||||
|
if "create_project" in file_tool_calls:
|
||||||
|
projects = await _fetch_projects()
|
||||||
|
projects_block = _format_projects(projects)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s file=%r result=%s",
|
"agent_runner: run=%s file=%r created=%d result=%s",
|
||||||
run_id,
|
run_id, file_path, file_created, result_text[:200],
|
||||||
file_path,
|
|
||||||
result_text[:200],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -833,10 +740,11 @@ async def run_local_agent(
|
|||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s done status=%s processed=%d errors=%d",
|
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
run_id,
|
run_id,
|
||||||
final_status,
|
final_status,
|
||||||
items_processed,
|
items_processed,
|
||||||
|
items_created,
|
||||||
len(errors),
|
len(errors),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -997,7 +905,9 @@ async def run_cloud_agent(
|
|||||||
cloud_template, cloud_prompt_obj = get_prompt_or_fallback(
|
cloud_template, cloud_prompt_obj = get_prompt_or_fallback(
|
||||||
"batch_cloud_processing", _BATCH_CLOUD_PROCESSING_PROMPT
|
"batch_cloud_processing", _BATCH_CLOUD_PROCESSING_PROMPT
|
||||||
)
|
)
|
||||||
processing_prompt = cloud_template.format(
|
processing_prompt = compile_prompt(
|
||||||
|
cloud_template,
|
||||||
|
cloud_prompt_obj,
|
||||||
data_types=", ".join(config.data_types),
|
data_types=", ".join(config.data_types),
|
||||||
project_context="Determine the appropriate project from the message context.",
|
project_context="Determine the appropriate project from the message context.",
|
||||||
file_list=f"Message from {config.provider} (id: {msg.id})",
|
file_list=f"Message from {config.provider} (id: {msg.id})",
|
||||||
@@ -1011,6 +921,7 @@ async def run_cloud_agent(
|
|||||||
tools=processing_tools,
|
tools=processing_tools,
|
||||||
max_steps=_MAX_PROCESSING_STEPS,
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
session_id=run_id,
|
||||||
langfuse_prompt=cloud_prompt_obj,
|
langfuse_prompt=cloud_prompt_obj,
|
||||||
agent_name="cloud-processor",
|
agent_name="cloud-processor",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,9 +16,8 @@ from app.agents.note_agent import NOTE_TOOLS
|
|||||||
from app.agents.project_agent import PROJECT_TOOLS
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
from app.agents.task_agent import TASK_TOOLS
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
|
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
from app.core.llm import get_llm
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
from app.config.settings import settings
|
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
@@ -28,6 +27,34 @@ logger = logging.getLogger(__name__)
|
|||||||
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||||
FloatingDomainSection = Literal["task", "timeline", "note"]
|
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||||
|
|
||||||
|
# Mapping of core-memory language values to natural-language names for prompts.
|
||||||
|
_LANGUAGE_NAMES: dict[str, str] = {
|
||||||
|
"en": "English", "it": "Italian", "es": "Spanish",
|
||||||
|
"fr": "French", "de": "German",
|
||||||
|
"english": "English", "italian": "Italian", "italiano": "Italian",
|
||||||
|
"spanish": "Spanish", "español": "Spanish",
|
||||||
|
"french": "French", "français": "French",
|
||||||
|
"german": "German", "deutsch": "German",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _language_instruction(context: dict[str, Any]) -> str:
|
||||||
|
"""Return a system-prompt suffix that tells the LLM to respond in the user's language.
|
||||||
|
|
||||||
|
Returns an empty string when the language is English or unknown — saves tokens.
|
||||||
|
"""
|
||||||
|
core = context.get("core_memory") or {}
|
||||||
|
raw = (core.get("language") or "").strip().lower()
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
lang = _LANGUAGE_NAMES.get(raw, raw.title()) # best-effort capitalisation
|
||||||
|
if lang.lower() == "english":
|
||||||
|
return ""
|
||||||
|
return (
|
||||||
|
f"\n\nIMPORTANT: Always respond in {lang}. "
|
||||||
|
f"All your output text must be written in {lang}."
|
||||||
|
)
|
||||||
|
|
||||||
_HOME_SYSTEM_PROMPT = (
|
_HOME_SYSTEM_PROMPT = (
|
||||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
"Always use tools for factual data retrieval before answering. "
|
"Always use tools for factual data retrieval before answering. "
|
||||||
@@ -149,6 +176,15 @@ def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _session_id_from_context(context: dict[str, Any]) -> str | None:
|
||||||
|
debug = context.get("_debug")
|
||||||
|
if isinstance(debug, dict):
|
||||||
|
session_id = debug.get("session_id")
|
||||||
|
if isinstance(session_id, str) and session_id:
|
||||||
|
return session_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||||
sanitized = dict(context)
|
sanitized = dict(context)
|
||||||
sanitized.pop("_debug", None)
|
sanitized.pop("_debug", None)
|
||||||
@@ -537,7 +573,7 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm = get_llm()
|
llm = get_agent_llm("classifier")
|
||||||
classifier_messages = [
|
classifier_messages = [
|
||||||
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
|
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
@@ -551,16 +587,23 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
|
|||||||
_, classifier_prompt_obj = get_prompt_or_fallback(
|
_, classifier_prompt_obj = get_prompt_or_fallback(
|
||||||
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
|
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract user/session from context for Langfuse attribution
|
||||||
|
_debug = context.get("_debug") if isinstance(context, dict) else None
|
||||||
|
_lf_user = (_debug or {}).get("user_id") if isinstance(_debug, dict) else None
|
||||||
|
_lf_session = (_debug or {}).get("session_id") if isinstance(_debug, dict) else None
|
||||||
|
|
||||||
|
with langfuse_context(user_id=_lf_user, session_id=_lf_session):
|
||||||
if lf:
|
if lf:
|
||||||
with lf.start_as_current_observation(
|
with lf.start_as_current_observation(
|
||||||
as_type="generation",
|
as_type="generation",
|
||||||
name="floating-classifier",
|
name="floating-classifier",
|
||||||
model=settings.LLM_MODEL,
|
model=model_for_agent("classifier"),
|
||||||
prompt=classifier_prompt_obj,
|
prompt=classifier_prompt_obj,
|
||||||
input=classifier_messages,
|
input=classifier_messages,
|
||||||
) as gen:
|
) as gen:
|
||||||
response = await llm.ainvoke(classifier_messages)
|
response = await llm.ainvoke(classifier_messages)
|
||||||
gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
else:
|
else:
|
||||||
response = await llm.ainvoke(classifier_messages)
|
response = await llm.ainvoke(classifier_messages)
|
||||||
parsed = _parse_json_object(_as_text(response.content))
|
parsed = _parse_json_object(_as_text(response.content))
|
||||||
@@ -591,8 +634,9 @@ async def _run_single_agent(
|
|||||||
agent_name: str = "agent",
|
agent_name: str = "agent",
|
||||||
) -> str:
|
) -> str:
|
||||||
trace_id = _trace_id_from_context(context)
|
trace_id = _trace_id_from_context(context)
|
||||||
|
session_id = _session_id_from_context(context)
|
||||||
lf = get_langfuse()
|
lf = get_langfuse()
|
||||||
llm = get_llm()
|
llm = get_agent_llm(agent_name)
|
||||||
tools = _all_tools_for_user(user_id, trace_id)
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
model_context = _context_for_model(context)
|
model_context = _context_for_model(context)
|
||||||
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
@@ -611,12 +655,14 @@ async def _run_single_agent(
|
|||||||
collected: list[dict[str, Any]] = []
|
collected: list[dict[str, Any]] = []
|
||||||
set_tool_result_collector(collected)
|
set_tool_result_collector(collected)
|
||||||
|
|
||||||
|
_lf_ctx = langfuse_context(user_id=user_id, session_id=session_id)
|
||||||
|
_lf_ctx.__enter__()
|
||||||
|
|
||||||
_span_ctx = (
|
_span_ctx = (
|
||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="span",
|
as_type="span",
|
||||||
name=agent_name,
|
name=agent_name,
|
||||||
user_id=user_id,
|
metadata={"user_id": user_id, "session_id": trace_id},
|
||||||
session_id=trace_id,
|
|
||||||
input=message,
|
input=message,
|
||||||
)
|
)
|
||||||
if lf else None
|
if lf else None
|
||||||
@@ -629,7 +675,7 @@ async def _run_single_agent(
|
|||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="generation",
|
as_type="generation",
|
||||||
name=f"{agent_name}-llm",
|
name=f"{agent_name}-llm",
|
||||||
model=settings.LLM_MODEL,
|
model=model_for_agent(agent_name),
|
||||||
prompt=langfuse_prompt,
|
prompt=langfuse_prompt,
|
||||||
input=messages,
|
input=messages,
|
||||||
)
|
)
|
||||||
@@ -638,7 +684,7 @@ async def _run_single_agent(
|
|||||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
if _gen_ctx:
|
if _gen_ctx:
|
||||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
_gen_ctx.__exit__(None, None, None)
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
messages.append(response)
|
messages.append(response)
|
||||||
@@ -700,6 +746,7 @@ async def _run_single_agent(
|
|||||||
clear_tool_result_collector()
|
clear_tool_result_collector()
|
||||||
if _span_ctx:
|
if _span_ctx:
|
||||||
_span_ctx.__exit__(None, None, None)
|
_span_ctx.__exit__(None, None, None)
|
||||||
|
_lf_ctx.__exit__(None, None, None)
|
||||||
if lf:
|
if lf:
|
||||||
lf.flush()
|
lf.flush()
|
||||||
|
|
||||||
@@ -715,8 +762,9 @@ async def _run_single_agent_stream(
|
|||||||
agent_name: str = "agent",
|
agent_name: str = "agent",
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
trace_id = _trace_id_from_context(context)
|
trace_id = _trace_id_from_context(context)
|
||||||
|
session_id = _session_id_from_context(context)
|
||||||
lf = get_langfuse()
|
lf = get_langfuse()
|
||||||
llm = get_llm()
|
llm = get_agent_llm(agent_name)
|
||||||
tools = _all_tools_for_user(user_id, trace_id)
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
model_context = _context_for_model(context)
|
model_context = _context_for_model(context)
|
||||||
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
@@ -736,12 +784,14 @@ async def _run_single_agent_stream(
|
|||||||
collected: list[dict[str, Any]] = []
|
collected: list[dict[str, Any]] = []
|
||||||
set_tool_result_collector(collected)
|
set_tool_result_collector(collected)
|
||||||
|
|
||||||
|
_lf_ctx = langfuse_context(user_id=user_id, session_id=session_id)
|
||||||
|
_lf_ctx.__enter__()
|
||||||
|
|
||||||
_span_ctx = (
|
_span_ctx = (
|
||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="span",
|
as_type="span",
|
||||||
name=f"{agent_name}-stream",
|
name=f"{agent_name}-stream",
|
||||||
user_id=user_id,
|
metadata={"user_id": user_id, "session_id": trace_id},
|
||||||
session_id=trace_id,
|
|
||||||
input=message,
|
input=message,
|
||||||
)
|
)
|
||||||
if lf else None
|
if lf else None
|
||||||
@@ -755,7 +805,7 @@ async def _run_single_agent_stream(
|
|||||||
lf.start_as_current_observation(
|
lf.start_as_current_observation(
|
||||||
as_type="generation",
|
as_type="generation",
|
||||||
name=f"{agent_name}-llm",
|
name=f"{agent_name}-llm",
|
||||||
model=settings.LLM_MODEL,
|
model=model_for_agent(agent_name),
|
||||||
prompt=langfuse_prompt,
|
prompt=langfuse_prompt,
|
||||||
input=messages,
|
input=messages,
|
||||||
)
|
)
|
||||||
@@ -764,7 +814,7 @@ async def _run_single_agent_stream(
|
|||||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
if _gen_ctx:
|
if _gen_ctx:
|
||||||
_gen.update(output=_as_text(response.content), usage=extract_usage(response))
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
_gen_ctx.__exit__(None, None, None)
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
messages.append(response)
|
messages.append(response)
|
||||||
@@ -844,6 +894,7 @@ async def _run_single_agent_stream(
|
|||||||
clear_tool_result_collector()
|
clear_tool_result_collector()
|
||||||
if _span_ctx:
|
if _span_ctx:
|
||||||
_span_ctx.__exit__(None, None, None)
|
_span_ctx.__exit__(None, None, None)
|
||||||
|
_lf_ctx.__exit__(None, None, None)
|
||||||
if lf:
|
if lf:
|
||||||
lf.flush()
|
lf.flush()
|
||||||
|
|
||||||
@@ -853,6 +904,7 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"home_system", _HOME_SYSTEM_PROMPT
|
"home_system", _HOME_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
@@ -870,6 +922,7 @@ async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> t
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
@@ -893,6 +946,7 @@ async def run_home_stream(
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"home_system", _HOME_SYSTEM_PROMPT
|
"home_system", _HOME_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
text_chunks: list[str] = []
|
text_chunks: list[str] = []
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
@@ -925,6 +979,7 @@ async def run_floating_stream(
|
|||||||
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
"floating_system", _FLOATING_SYSTEM_PROMPT
|
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||||
)
|
)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
sanitizer = _FloatingStreamSanitizer()
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
emitted_sanitized = False
|
emitted_sanitized = False
|
||||||
raw_chunks: list[str] = []
|
raw_chunks: list[str] = []
|
||||||
|
|||||||
34
app/core/embeddings.py
Normal file
34
app/core/embeddings.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""OpenAI embedding helper for associative memory tier.
|
||||||
|
|
||||||
|
Single public function: ``embed_text(text) -> list[float] | None``.
|
||||||
|
Returns None on any failure — callers must implement a keyword fallback.
|
||||||
|
Never raises; all exceptions are logged as warnings.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_MAX_INPUT_CHARS = 8000
|
||||||
|
_EMBEDDING_MODEL = "text-embedding-3-small"
|
||||||
|
|
||||||
|
|
||||||
|
async def embed_text(text: str) -> list[float] | None:
|
||||||
|
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
|
||||||
|
try:
|
||||||
|
client = AsyncOpenAI()
|
||||||
|
truncated = text[:_MAX_INPUT_CHARS]
|
||||||
|
response = await client.embeddings.create(
|
||||||
|
input=truncated,
|
||||||
|
model=_EMBEDDING_MODEL,
|
||||||
|
)
|
||||||
|
result: list[float] = response.data[0].embedding
|
||||||
|
logger.debug("embeddings: embed_text dims=%d", len(result))
|
||||||
|
return result
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("embeddings: embed_text failed: %s", exc)
|
||||||
|
return None
|
||||||
@@ -39,8 +39,10 @@ Linking a prompt to a generation::
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
import logging
|
import logging
|
||||||
from typing import Any
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Generator
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -67,9 +69,9 @@ def get_langfuse() -> Any | None:
|
|||||||
_client = Langfuse(
|
_client = Langfuse(
|
||||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||||
host=settings.LANGFUSE_HOST,
|
host=settings.LANGFUSE_BASE_URL,
|
||||||
)
|
)
|
||||||
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_HOST)
|
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_BASE_URL)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("langfuse: failed to initialize: %s", exc)
|
logger.warning("langfuse: failed to initialize: %s", exc)
|
||||||
_client = None
|
_client = None
|
||||||
@@ -80,10 +82,11 @@ def get_langfuse() -> Any | None:
|
|||||||
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
||||||
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
|
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
|
||||||
|
|
||||||
Returns ``(prompt_text, prompt_obj_or_None)``.
|
Returns ``(raw_template, prompt_obj_or_None)``.
|
||||||
|
|
||||||
* ``prompt_text`` — the raw template string (variables not yet substituted).
|
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
|
||||||
Callers perform variable substitution with Python's ``.format()``.
|
on it directly; use :func:`compile_prompt` instead so the correct variable
|
||||||
|
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
|
||||||
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
|
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
|
||||||
unavailable / the fetch failed. Pass this to generation observations so
|
unavailable / the fetch failed. Pass this to generation observations so
|
||||||
Langfuse links the generation to the exact prompt version in the UI.
|
Langfuse links the generation to the exact prompt version in the UI.
|
||||||
@@ -102,6 +105,38 @@ def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
|||||||
return fallback, None
|
return fallback, None
|
||||||
|
|
||||||
|
|
||||||
|
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
|
||||||
|
"""Compile *template* with *variables*, choosing the right syntax.
|
||||||
|
|
||||||
|
* When *prompt_obj* is a real Langfuse prompt object, calls
|
||||||
|
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
|
||||||
|
substitution as defined in the Langfuse UI.
|
||||||
|
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
|
||||||
|
falls back to ``template.format(**variables)`` which handles the
|
||||||
|
``{variable}`` syntax used in the hardcoded fallback strings.
|
||||||
|
|
||||||
|
This keeps callers oblivious to which syntax is in use.
|
||||||
|
"""
|
||||||
|
if prompt_obj is not None:
|
||||||
|
try:
|
||||||
|
compiled = prompt_obj.compile(**variables)
|
||||||
|
# compile() returns a string for text prompts.
|
||||||
|
if isinstance(compiled, str):
|
||||||
|
return compiled
|
||||||
|
# Chat prompts return a list of dicts — join text parts.
|
||||||
|
if isinstance(compiled, list):
|
||||||
|
return "\n".join(
|
||||||
|
m.get("content", "") for m in compiled if isinstance(m, dict)
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
|
||||||
|
getattr(prompt_obj, "name", "?"),
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return template.format(**variables)
|
||||||
|
|
||||||
|
|
||||||
def extract_usage(response: Any) -> dict[str, int]:
|
def extract_usage(response: Any) -> dict[str, int]:
|
||||||
"""Extract token usage from a LangChain AI message into Langfuse format."""
|
"""Extract token usage from a LangChain AI message into Langfuse format."""
|
||||||
meta = getattr(response, "usage_metadata", None)
|
meta = getattr(response, "usage_metadata", None)
|
||||||
@@ -112,3 +147,44 @@ def extract_usage(response: Any) -> dict[str, int]:
|
|||||||
"output": int(meta.get("output_tokens", 0)),
|
"output": int(meta.get("output_tokens", 0)),
|
||||||
"total": int(meta.get("total_tokens", 0)),
|
"total": int(meta.get("total_tokens", 0)),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def hash_user_id(user_id: str) -> str:
|
||||||
|
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
|
||||||
|
|
||||||
|
This avoids sending raw database UUIDs to external observability services
|
||||||
|
while still providing a stable, deterministic identifier for per-user
|
||||||
|
metrics in the Langfuse dashboard.
|
||||||
|
"""
|
||||||
|
return hashlib.sha256(user_id.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def langfuse_context(
|
||||||
|
user_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
|
||||||
|
|
||||||
|
No-op when Langfuse is not configured or parameters are empty.
|
||||||
|
"""
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf is None or (not user_id and not session_id):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import propagate_attributes
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("langfuse: propagate_attributes not available — skipping context")
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
attrs: dict[str, str] = {}
|
||||||
|
if user_id:
|
||||||
|
attrs["user_id"] = hash_user_id(user_id)
|
||||||
|
if session_id:
|
||||||
|
attrs["session_id"] = session_id
|
||||||
|
|
||||||
|
with propagate_attributes(**attrs):
|
||||||
|
yield
|
||||||
|
|||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
|
Every agent and the orchestrator call ``get_llm()``
|
||||||
instead of directly constructing a provider-specific class. The model string
|
instead of directly constructing a provider-specific class. The model string
|
||||||
follows the `LiteLLM model naming convention
|
follows the `LiteLLM model naming convention
|
||||||
<https://docs.litellm.ai/docs/providers>`_:
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
@@ -11,7 +11,7 @@ follows the `LiteLLM model naming convention
|
|||||||
* Ollama: ``ollama/llama3``
|
* Ollama: ``ollama/llama3``
|
||||||
* Bedrock: ``bedrock/anthropic.claude-v2``
|
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||||
|
|
||||||
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
Switch providers by changing **LLM_MODEL** in ``.env``
|
||||||
— no code changes required.
|
— no code changes required.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -95,12 +96,33 @@ def get_llm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_router_llm(
|
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||||
|
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
||||||
|
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
||||||
|
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
||||||
|
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||||
|
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||||
|
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def model_for_agent(agent_name: str) -> str:
|
||||||
|
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
|
||||||
|
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_llm(
|
||||||
|
agent_name: str,
|
||||||
*,
|
*,
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
) -> ChatOpenAI | ChatLiteLLM:
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
"""Return the lighter model used for intent classification / routing."""
|
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
|
||||||
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
|
||||||
|
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
|
||||||
|
per-agent override is left empty in ``.env``.
|
||||||
|
"""
|
||||||
|
model = model_for_agent(agent_name)
|
||||||
|
return get_llm(model=model, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
async def embed(text: str) -> list[float]:
|
||||||
|
|||||||
@@ -69,17 +69,19 @@ class MemoryMiddleware:
|
|||||||
if fernet is None:
|
if fernet is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier: str = user_dbg.get("tier") or "free"
|
||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
||||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d",
|
||||||
trace_id or "-",
|
trace_id or "-",
|
||||||
user_id,
|
user_id,
|
||||||
user_dbg.get("tier") or "-",
|
user_tier,
|
||||||
len(core),
|
len(core),
|
||||||
len(associative),
|
len(associative),
|
||||||
len(episodic),
|
len(episodic),
|
||||||
@@ -255,6 +257,50 @@ class MemoryMiddleware:
|
|||||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
async def store_associative(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
content: str,
|
||||||
|
entity_type: str | None = None,
|
||||||
|
entity_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Store associative memory; embed if user tier has real_embeddings."""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||||
|
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, content)
|
||||||
|
|
||||||
|
user_dbg = await self._get_user_debug(user_id)
|
||||||
|
user_tier = user_dbg.get("tier") or "free"
|
||||||
|
|
||||||
|
embedding: list[float] | None = None
|
||||||
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||||
|
embedding = await embed_text(content)
|
||||||
|
|
||||||
|
row = MemoryAssociative(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
content_encrypted=encrypted,
|
||||||
|
embedding=embedding,
|
||||||
|
entity_type=entity_type,
|
||||||
|
entity_id=entity_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
logger.info(
|
||||||
|
"memory: store_associative user=%s embedded=%s",
|
||||||
|
user_id,
|
||||||
|
embedding is not None,
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
||||||
"""Insert a long-term archival memory entry."""
|
"""Insert a long-term archival memory entry."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
@@ -364,14 +410,49 @@ class MemoryMiddleware:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_associative(
|
async def _load_associative(
|
||||||
self, user_id: str, message: str, fernet: Fernet
|
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Load top-k associative memories.
|
"""Load top-k associative memories.
|
||||||
|
|
||||||
Production: uses pgvector cosine similarity on the message embedding.
|
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
||||||
Current implementation: keyword-based fallback (no external embedding call)
|
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
||||||
so tests pass without a live OpenAI key.
|
|
||||||
"""
|
"""
|
||||||
|
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
||||||
|
from app.core.embeddings import embed_text # noqa: PLC0415
|
||||||
|
|
||||||
|
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
||||||
|
vec = await embed_text(message)
|
||||||
|
if vec is not None:
|
||||||
|
try:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(
|
||||||
|
MemoryAssociative.user_id == user_id,
|
||||||
|
MemoryAssociative.embedding.isnot(None),
|
||||||
|
)
|
||||||
|
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
||||||
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
logger.info(
|
||||||
|
"memory: _load_associative user=%s mode=vector hits=%d",
|
||||||
|
user_id,
|
||||||
|
len(out),
|
||||||
|
)
|
||||||
|
return out
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"memory: vector search failed user=%s, falling back to keyword: %s",
|
||||||
|
user_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Keyword fallback: most recent rows
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryAssociative)
|
select(MemoryAssociative)
|
||||||
.where(MemoryAssociative.user_id == user_id)
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
@@ -379,7 +460,7 @@ class MemoryMiddleware:
|
|||||||
.limit(_ASSOCIATIVE_TOP_K)
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
)
|
)
|
||||||
rows = result.scalars().all()
|
rows = result.scalars().all()
|
||||||
out: list[str] = []
|
out = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
if plaintext is not None:
|
if plaintext is not None:
|
||||||
|
|||||||
104
app/core/preprocessors/__init__.py
Normal file
104
app/core/preprocessors/__init__.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Preprocessor registry: detect content type and dispatch to handlers.
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
detect_content_type(filename, raw_content) -> str
|
||||||
|
Heuristic detection based on file extension and content patterns.
|
||||||
|
|
||||||
|
preprocess(content_type, raw_content) -> PreprocessResult
|
||||||
|
Dispatch to the appropriate handler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from app.core.preprocessors.base import PreprocessResult
|
||||||
|
|
||||||
|
# ── Heuristics ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Patterns that strongly suggest an email HTML file
|
||||||
|
_EMAIL_SIGNALS = re.compile(
|
||||||
|
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patterns that suggest a generic HTML page (not an email)
|
||||||
|
_GENERIC_HTML_SIGNALS = re.compile(
|
||||||
|
r"<(nav|main|header|footer|article|section)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_content_type(filename: str, raw_content: str) -> str:
|
||||||
|
"""Return a content-type string for the given file.
|
||||||
|
|
||||||
|
Supported types: ``"email_html"``, ``"generic_html"``,
|
||||||
|
``"plain_text"``, ``"unknown"``.
|
||||||
|
"""
|
||||||
|
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
||||||
|
|
||||||
|
if ext == "txt":
|
||||||
|
return "plain_text"
|
||||||
|
|
||||||
|
if ext in ("html", "htm", "eml", "mhtml", "mht"):
|
||||||
|
# Prefer email detection over generic HTML
|
||||||
|
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||||
|
return "email_html"
|
||||||
|
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
|
||||||
|
return "generic_html"
|
||||||
|
# .html without clear signals — check for any email header
|
||||||
|
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
|
||||||
|
return "email_html"
|
||||||
|
return "generic_html"
|
||||||
|
|
||||||
|
# Plain text files with email headers
|
||||||
|
if ext in ("", "txt") or not ext:
|
||||||
|
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||||
|
return "email_html"
|
||||||
|
|
||||||
|
# Detect binary content
|
||||||
|
try:
|
||||||
|
raw_content.encode("utf-8")
|
||||||
|
except (UnicodeEncodeError, AttributeError):
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
# Non-text bytes heuristic: high ratio of non-printable chars
|
||||||
|
sample = raw_content[:512]
|
||||||
|
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
|
||||||
|
if len(sample) > 0 and non_printable / len(sample) > 0.1:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Generic fallback handler ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
|
||||||
|
"""Strip HTML tags if present, return text as-is."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
|
||||||
|
except ImportError:
|
||||||
|
# No BeautifulSoup — strip tags with a simple regex
|
||||||
|
text = re.sub(r"<[^>]+>", "", raw_content)
|
||||||
|
|
||||||
|
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
||||||
|
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Dispatch ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
|
||||||
|
"""Dispatch *raw_content* to the handler registered for *content_type*.
|
||||||
|
|
||||||
|
Falls back to the generic handler for unknown types.
|
||||||
|
"""
|
||||||
|
if content_type == "email_html":
|
||||||
|
from app.core.preprocessors.email_html import preprocess_email_html
|
||||||
|
return preprocess_email_html(raw_content)
|
||||||
|
|
||||||
|
return _preprocess_generic(raw_content, content_type)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]
|
||||||
25
app/core/preprocessors/base.py
Normal file
25
app/core/preprocessors/base.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Base types for the preprocessor system."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreprocessResult:
|
||||||
|
"""Output of a preprocessor handler.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
content_type:
|
||||||
|
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
|
||||||
|
clean_text:
|
||||||
|
Human-readable text stripped of markup/binary noise.
|
||||||
|
metadata:
|
||||||
|
Dict of extracted metadata (keys vary by handler).
|
||||||
|
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content_type: str
|
||||||
|
clean_text: str
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
111
app/core/preprocessors/email_html.py
Normal file
111
app/core/preprocessors/email_html.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Preprocessor for email HTML files.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- HTML stripping via BeautifulSoup
|
||||||
|
- Metadata extraction (Subject, From, To, Date)
|
||||||
|
- Thread splitting — isolates the latest reply
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from app.core.preprocessors.base import PreprocessResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ── Thread split markers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Matches patterns like:
|
||||||
|
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
|
||||||
|
# "-----Original Message-----"
|
||||||
|
# "> " (plain-text quote prefix)
|
||||||
|
_THREAD_PATTERNS = [
|
||||||
|
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
|
||||||
|
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
|
||||||
|
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
|
||||||
|
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
|
||||||
|
|
||||||
|
_META_PATTERNS: dict[str, list[re.Pattern]] = {
|
||||||
|
"subject": [
|
||||||
|
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
|
||||||
|
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
"from": [
|
||||||
|
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||||
|
re.compile(r"From:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
"to": [
|
||||||
|
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||||
|
re.compile(r"To:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
"date": [
|
||||||
|
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||||
|
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
|
||||||
|
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_metadata(raw_html: str, text: str) -> dict:
|
||||||
|
"""Extract Subject/From/To/Date from raw HTML or plain text."""
|
||||||
|
metadata: dict[str, str] = {}
|
||||||
|
for field, patterns in _META_PATTERNS.items():
|
||||||
|
for pat in patterns:
|
||||||
|
m = pat.search(raw_html) or pat.search(text)
|
||||||
|
if m:
|
||||||
|
metadata[field] = m.group(1).strip()
|
||||||
|
break
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _split_thread(text: str) -> str:
|
||||||
|
"""Return only the latest message in a threaded email."""
|
||||||
|
earliest_pos: int | None = None
|
||||||
|
for pat in _THREAD_PATTERNS:
|
||||||
|
m = pat.search(text)
|
||||||
|
if m and (earliest_pos is None or m.start() < earliest_pos):
|
||||||
|
earliest_pos = m.start()
|
||||||
|
|
||||||
|
if earliest_pos is not None and earliest_pos > 0:
|
||||||
|
return text[:earliest_pos].strip()
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_email_html(raw_content: str) -> PreprocessResult:
|
||||||
|
"""Strip HTML, extract metadata, split thread from an email HTML file."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup # lazy import — optional dep
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"beautifulsoup4 is required for email_html preprocessing. "
|
||||||
|
"Install it with: pip install beautifulsoup4"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Parse with lxml if available, fall back to html.parser
|
||||||
|
try:
|
||||||
|
soup = BeautifulSoup(raw_content, "lxml")
|
||||||
|
except Exception:
|
||||||
|
soup = BeautifulSoup(raw_content, "html.parser")
|
||||||
|
|
||||||
|
# Remove noise tags
|
||||||
|
for tag in soup(["style", "script", "head", "noscript"]):
|
||||||
|
tag.decompose()
|
||||||
|
|
||||||
|
clean_text = soup.get_text(separator="\n")
|
||||||
|
# Collapse excessive blank lines
|
||||||
|
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
|
||||||
|
|
||||||
|
metadata = _extract_metadata(raw_content, clean_text)
|
||||||
|
latest_message = _split_thread(clean_text)
|
||||||
|
|
||||||
|
return PreprocessResult(
|
||||||
|
content_type="email_html",
|
||||||
|
clean_text=latest_message,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
@@ -25,7 +25,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Adiuva Cloud API",
|
title="AdiuvAI Cloud API",
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
redoc_url=None,
|
redoc_url=None,
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ from __future__ import annotations
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
from pgvector.sqlalchemy import Vector
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
Boolean,
|
Boolean,
|
||||||
DateTime,
|
DateTime,
|
||||||
@@ -69,7 +70,8 @@ class User(Base):
|
|||||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||||
@@ -78,6 +80,9 @@ class User(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
|
onboarding_completed_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True, default=None
|
||||||
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
)
|
)
|
||||||
@@ -88,6 +93,9 @@ class User(Base):
|
|||||||
subscription: Mapped[Subscription | None] = relationship(
|
subscription: Mapped[Subscription | None] = relationship(
|
||||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(Base):
|
class RefreshToken(Base):
|
||||||
@@ -108,6 +116,25 @@ class RefreshToken(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccount(Base):
|
||||||
|
__tablename__ = "oauth_accounts"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
provider_user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
provider_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="oauth_accounts")
|
||||||
|
|
||||||
|
|
||||||
class Subscription(Base):
|
class Subscription(Base):
|
||||||
__tablename__ = "subscriptions"
|
__tablename__ = "subscriptions"
|
||||||
|
|
||||||
@@ -143,6 +170,7 @@ class LocalAgentConfig(Base):
|
|||||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
@@ -272,8 +300,8 @@ class MemoryAssociative(Base):
|
|||||||
nullable=False, index=True,
|
nullable=False, index=True,
|
||||||
)
|
)
|
||||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
# vector(1536) via pgvector; SQLite tests use NULL embeddings so no dialect issue.
|
||||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
embedding: Mapped[list | None] = mapped_column(Vector(1536), nullable=True)
|
||||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
@@ -321,3 +349,25 @@ class MemoryProactive(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
"""Plugin marketplace catalog entry."""
|
||||||
|
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
||||||
|
status: Mapped[str] = mapped_column(String(50), nullable=False, default="pending")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|||||||
@@ -30,6 +30,16 @@ class UserProfile(BaseModel):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
surname: str | None = None
|
surname: str | None = None
|
||||||
tier: BillingTier
|
tier: BillingTier
|
||||||
|
avatar_url: str | None = None
|
||||||
|
has_password: bool = True
|
||||||
|
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
|
||||||
|
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountInfo(BaseModel):
|
||||||
|
provider: str
|
||||||
|
provider_email: str | None = None
|
||||||
|
created_at: int # epoch ms
|
||||||
|
|
||||||
|
|
||||||
# ── Chat ─────────────────────────────────────────────────────────────
|
# ── Chat ─────────────────────────────────────────────────────────────
|
||||||
@@ -191,6 +201,27 @@ class WsFloatingDomain(BaseModel):
|
|||||||
domain: WsDomain
|
domain: WsDomain
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent Config V2 ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class ContentTypeConfig(BaseModel):
|
||||||
|
"""Per-type extraction config produced by the journey chatbot."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
label: str = ""
|
||||||
|
detection_hint: str = ""
|
||||||
|
preprocessing: str = "generic" # handler name: "email_html", "plain_text", ...
|
||||||
|
extraction_prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Structured agent configuration (replaces freeform prompt_template)."""
|
||||||
|
|
||||||
|
content_types: list[ContentTypeConfig] = []
|
||||||
|
global_rules: list[str] = []
|
||||||
|
data_types: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
class AgentCatalogItem(BaseModel):
|
class AgentCatalogItem(BaseModel):
|
||||||
@@ -215,10 +246,11 @@ class AgentTriggerRequest(BaseModel):
|
|||||||
device_id: str = Field(default="")
|
device_id: str = Field(default="")
|
||||||
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||||
what_to_extract: list[str] = Field(min_length=1)
|
what_to_extract: list[str] = Field(min_length=1)
|
||||||
actions_by_type: dict[str, list[str]] | None = None
|
|
||||||
batch_interval: str = Field(min_length=1)
|
batch_interval: str = Field(min_length=1)
|
||||||
custom_agent_prompt: str = Field(min_length=1)
|
custom_agent_prompt: str | None = None
|
||||||
|
agent_config: dict | None = None
|
||||||
active_agents: int = Field(ge=0, default=0)
|
active_agents: int = Field(ge=0, default=0)
|
||||||
|
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ services:
|
|||||||
- path: .env
|
- path: .env
|
||||||
required: false
|
required: false
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
|
||||||
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
||||||
volumes:
|
volumes:
|
||||||
- copilot_tokens:/root/.config/litellm/github_copilot
|
- copilot_tokens:/root/.config/litellm/github_copilot
|
||||||
@@ -21,7 +21,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: postgres
|
||||||
POSTGRES_PASSWORD: postgres
|
POSTGRES_PASSWORD: postgres
|
||||||
POSTGRES_DB: adiuva
|
POSTGRES_DB: adiuvai
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
|
|||||||
@@ -1,941 +0,0 @@
|
|||||||
# Adiuva — Architettura Microservizi (MVP)
|
|
||||||
|
|
||||||
## Panoramica
|
|
||||||
|
|
||||||
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
|
|
||||||
|
|
||||||
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐
|
|
||||||
│ Cloudflare │
|
|
||||||
│ (DNS + CDN) │
|
|
||||||
└──────┬───────┘
|
|
||||||
│ HTTPS / WSS
|
|
||||||
┌──────▼───────┐
|
|
||||||
│ Traefik │
|
|
||||||
│ API Gateway │
|
|
||||||
│ (routing, │
|
|
||||||
│ TLS, rate │
|
|
||||||
│ limiting) │
|
|
||||||
└──────┬───────┘
|
|
||||||
│
|
|
||||||
┌──────────┬───────────┼───────────┐
|
|
||||||
│ │ │ │
|
|
||||||
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
|
|
||||||
│ Auth │ │ Chat │ │ Agent │ │Billing │
|
|
||||||
│ Service │ │Service│ │ Service │ │Service │
|
|
||||||
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
|
|
||||||
│ │ │ │
|
|
||||||
┌─────▼──────────▼──────────▼───────────▼────┐
|
|
||||||
│ Infrastruttura │
|
|
||||||
│ PostgreSQL │ Redis │ Qdrant │
|
|
||||||
└─────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. Suddivisione dei Servizi
|
|
||||||
|
|
||||||
### 1.1 Auth Service (`auth-service`)
|
|
||||||
|
|
||||||
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
|
|
||||||
|
|
||||||
| Endpoint originale | Metodo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/auth/register` | POST |
|
|
||||||
| `/api/v1/auth/login` | POST |
|
|
||||||
| `/api/v1/auth/refresh` | POST |
|
|
||||||
| `/api/v1/auth/me` | GET / PUT |
|
|
||||||
|
|
||||||
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
|
|
||||||
|
|
||||||
**Modifica chiave — JWT con RS256**:
|
|
||||||
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
|
|
||||||
- L'Auth Service firma i JWT con la **chiave privata**.
|
|
||||||
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
|
|
||||||
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# auth-service/app/auth/jwt.py
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PRIVATE_KEY = ... # Da env/secret
|
|
||||||
PUBLIC_KEY = ... # Derivata o da env
|
|
||||||
|
|
||||||
def create_access_token(user_id: str, tier: str) -> str:
|
|
||||||
return jwt.encode(
|
|
||||||
{"sub": user_id, "tier": tier, "exp": ...},
|
|
||||||
PRIVATE_KEY,
|
|
||||||
algorithm="RS256",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/auth.py (usato da tutti gli altri servizi)
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
|
|
||||||
|
|
||||||
def verify_token(token: str) -> dict:
|
|
||||||
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
|
||||||
```
|
|
||||||
|
|
||||||
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
|
|
||||||
|
|
||||||
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
|
|
||||||
|
|
||||||
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
|
|
||||||
|
|
||||||
| Endpoint | Tipo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
|
|
||||||
| `/api/v1/chat` | POST (REST fallback) |
|
|
||||||
|
|
||||||
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
|
|
||||||
|
|
||||||
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
|
|
||||||
|
|
||||||
**Scaling**: 2–N repliche. Sticky cookies per le WS + Redis per cross-instance.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.3 Agent Service (`agent-service`) ⭐ Batch
|
|
||||||
|
|
||||||
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
|
|
||||||
|
|
||||||
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
|
|
||||||
|
|
||||||
| Endpoint | Tipo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/agents/catalog` | GET |
|
|
||||||
| `/api/v1/agents/can-create` | POST |
|
|
||||||
| `/api/v1/agents/trigger` | POST |
|
|
||||||
| `/api/v1/agents/journey/start` | POST (o WS relay) |
|
|
||||||
| `/api/v1/agents/journey/message` | POST (o WS relay) |
|
|
||||||
|
|
||||||
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
|
|
||||||
|
|
||||||
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐ ┌──────────────┐ ┌──────────┐
|
|
||||||
│ Agent Service│ │ Redis │ │ Chat │
|
|
||||||
│ (batch run) │ │ │ │ Service │
|
|
||||||
│ │ │ │ │ (ha WS) │
|
|
||||||
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
|
|
||||||
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
|
|
||||||
│ from │ │ │ │ al │
|
|
||||||
│ device │ │ │ │ device│
|
|
||||||
│ │ │ │ │ via WS│
|
|
||||||
│ │ SUBSCRIBE │ │ PUBLISH │ │
|
|
||||||
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
|
|
||||||
│ risultato │ │ │ │ reply │
|
|
||||||
└──────────────┘ └──────────────┘ └──────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
**Scaling**: 1–N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
|
|
||||||
|
|
||||||
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.4 Billing Service (`billing-service`)
|
|
||||||
|
|
||||||
**Responsabilità**: Stripe checkout, webhook, subscription management.
|
|
||||||
|
|
||||||
| Endpoint originale | Metodo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/billing/checkout` | POST |
|
|
||||||
| `/api/v1/billing/webhook` | POST |
|
|
||||||
| `/api/v1/billing/subscription` | GET / DELETE |
|
|
||||||
|
|
||||||
**Database**: Tabelle `subscriptions` (schema `billing`).
|
|
||||||
|
|
||||||
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
|
|
||||||
|
|
||||||
**Scaling**: 1 replica sufficiente. Basso traffico.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.5 Servizi esclusi dall'MVP
|
|
||||||
|
|
||||||
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
|
|
||||||
|
|
||||||
| Servizio | Responsabilità | Note |
|
|
||||||
|---|---|---|
|
|
||||||
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
|
|
||||||
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. Tier Check — Dove e Come
|
|
||||||
|
|
||||||
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
|
|
||||||
|
|
||||||
### Strategia: Tier nel JWT
|
|
||||||
|
|
||||||
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"sub": "user_123",
|
|
||||||
"tier": "pro",
|
|
||||||
"exp": 1742515200,
|
|
||||||
"iat": 1742511600
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Ogni servizio:
|
|
||||||
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
|
|
||||||
2. Legge `payload["tier"]` — **zero chiamate extra**
|
|
||||||
3. Applica le sue regole di enforcement localmente
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/auth.py — dependency FastAPI condivisa
|
|
||||||
from fastapi import Depends, HTTPException, Request
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PUBLIC_KEY = ...
|
|
||||||
|
|
||||||
class CurrentUser:
|
|
||||||
def __init__(self, user_id: str, tier: str):
|
|
||||||
self.user_id = user_id
|
|
||||||
self.tier = tier
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> CurrentUser:
|
|
||||||
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
|
|
||||||
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
|
||||||
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
|
|
||||||
|
|
||||||
def require_tier(*allowed_tiers: str):
|
|
||||||
"""Dependency che blocca se il tier non è tra quelli ammessi."""
|
|
||||||
async def check(user: CurrentUser = Depends(get_current_user)):
|
|
||||||
if user.tier not in allowed_tiers:
|
|
||||||
raise HTTPException(403, "Tier insufficient")
|
|
||||||
return user
|
|
||||||
return check
|
|
||||||
```
|
|
||||||
|
|
||||||
### Cosa succede quando il tier cambia (upgrade/downgrade)?
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
|
|
||||||
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
|
|
||||||
│ │ │ Service │ (Redis pub/sub) │ Service │
|
|
||||||
└──────────┘ └──────────┘ └────┬─────┘
|
|
||||||
│
|
|
||||||
UPDATE users
|
|
||||||
SET tier = 'power'
|
|
||||||
│
|
|
||||||
Al prossimo /refresh
|
|
||||||
il JWT conterrà tier='power'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 15–30 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
|
|
||||||
|
|
||||||
### Dove si applica in ciascun servizio
|
|
||||||
|
|
||||||
| Servizio | Enforcement |
|
|
||||||
|---|---|
|
|
||||||
| **Auth Service** | Nessuno (è lui che scrive il tier) |
|
|
||||||
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
|
|
||||||
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
|
|
||||||
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
|
|
||||||
|
|
||||||
### Rate-limit distribuito via Redis
|
|
||||||
|
|
||||||
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/middleware/rate_limit.py
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
|
|
||||||
class DistributedRateLimiter:
|
|
||||||
def __init__(self, redis: aioredis.Redis):
|
|
||||||
self._redis = redis
|
|
||||||
|
|
||||||
async def check(self, user_id: str, tier: str, service: str) -> bool:
|
|
||||||
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
|
|
||||||
max_req = limits.get(tier, 20)
|
|
||||||
key = f"rate:{service}:{user_id}"
|
|
||||||
|
|
||||||
pipe = self._redis.pipeline()
|
|
||||||
pipe.incr(key)
|
|
||||||
pipe.expire(key, 60)
|
|
||||||
count, _ = await pipe.execute()
|
|
||||||
|
|
||||||
return count <= max_req
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
|
|
||||||
|
|
||||||
`DeviceConnectionManager` è un **singleton in-memory**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class DeviceConnectionManager:
|
|
||||||
def __init__(self):
|
|
||||||
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
|
|
||||||
```
|
|
||||||
|
|
||||||
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
|
|
||||||
|
|
||||||
### La soluzione: Redis Pub/Sub + Registry
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────────────────────────────────────────────────────┐
|
|
||||||
│ Redis │
|
|
||||||
│ │
|
|
||||||
│ Hash: ws:connections │
|
|
||||||
│ user_123 → instance_A │
|
|
||||||
│ user_456 → instance_B │
|
|
||||||
│ │
|
|
||||||
│ Pub/Sub channels: │
|
|
||||||
│ tool_call:{user_id} → tool call payloads │
|
|
||||||
│ tool_result:{call_id} → tool result payloads │
|
|
||||||
│ stream:{user_id} → text_chunk streaming │
|
|
||||||
└──────────────────────────────────────────────────────────────┘
|
|
||||||
|
|
||||||
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
|
|
||||||
┌───────────────────────┐ ┌───────────────────────┐
|
|
||||||
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
|
|
||||||
│ tool_call:user_123│ │ → user_123 è su A │
|
|
||||||
│ │ │ │
|
|
||||||
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
|
|
||||||
│ da Redis channel │ │ tool_call:user_123 │
|
|
||||||
│ │ │ {id, action, ...} │
|
|
||||||
│ 3. Invia al device │ │ │
|
|
||||||
│ via WS │ │ 4. SUBSCRIBE │
|
|
||||||
│ │ │ tool_result:{id} │
|
|
||||||
│ 4. Device risponde │ │ │
|
|
||||||
│ tool_result │──────────│► 5. Riceve risultato │
|
|
||||||
│ │ │ │
|
|
||||||
│ 5. PUBLISH │ │ │
|
|
||||||
│ tool_result:{id} │ │ │
|
|
||||||
└───────────────────────┘ └───────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### Implementazione: `RedisDeviceManager`
|
|
||||||
|
|
||||||
```python
|
|
||||||
# chat-service/app/core/device_manager.py
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from fastapi import WebSocket
|
|
||||||
|
|
||||||
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LocalConnection:
|
|
||||||
ws: WebSocket
|
|
||||||
device_id: str
|
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisDeviceManager:
|
|
||||||
"""Device manager backed by Redis for cross-instance communication."""
|
|
||||||
|
|
||||||
def __init__(self, redis_url: str = "redis://redis:6379"):
|
|
||||||
self._redis = aioredis.from_url(redis_url)
|
|
||||||
self._pubsub = self._redis.pubsub()
|
|
||||||
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
|
|
||||||
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""Avvia il listener Redis per tool_call in arrivo."""
|
|
||||||
asyncio.create_task(self._listen_tool_calls())
|
|
||||||
|
|
||||||
# ── Registrazione ──
|
|
||||||
|
|
||||||
async def register(self, user_id: str, device_id: str, ws: WebSocket):
|
|
||||||
# Registra localmente
|
|
||||||
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
|
|
||||||
# Registra in Redis quale istanza ha la connessione
|
|
||||||
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
|
|
||||||
# Sottoscrivi ai tool_call per questo utente
|
|
||||||
await self._pubsub.subscribe(f"tool_call:{user_id}")
|
|
||||||
|
|
||||||
async def unregister(self, user_id: str):
|
|
||||||
conn = self._local.pop(user_id, None)
|
|
||||||
if conn:
|
|
||||||
for fut in conn.pending_calls.values():
|
|
||||||
if not fut.done():
|
|
||||||
fut.cancel()
|
|
||||||
await self._redis.hdel("ws:connections", user_id)
|
|
||||||
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
|
|
||||||
|
|
||||||
# ── Presenza ──
|
|
||||||
|
|
||||||
async def is_online(self, user_id: str) -> bool:
|
|
||||||
return await self._redis.hexists("ws:connections", user_id)
|
|
||||||
|
|
||||||
# ── Tool-call round-trip (cross-instance) ──
|
|
||||||
|
|
||||||
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Invia un tool_call al device dell'utente.
|
|
||||||
Funziona sia che la WS sia locale che su un'altra istanza.
|
|
||||||
"""
|
|
||||||
call_id = payload["id"]
|
|
||||||
|
|
||||||
# Caso 1: connessione locale → invio diretto
|
|
||||||
if user_id in self._local:
|
|
||||||
conn = self._local[user_id]
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
fut: asyncio.Future[dict] = loop.create_future()
|
|
||||||
conn.pending_calls[call_id] = fut
|
|
||||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
|
|
||||||
# Caso 2: connessione remota → Redis pub/sub
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
fut = loop.create_future()
|
|
||||||
self._remote_futures[call_id] = fut
|
|
||||||
|
|
||||||
# Sottoscrivi al canale di risposta
|
|
||||||
result_channel = f"tool_result:{call_id}"
|
|
||||||
await self._pubsub.subscribe(result_channel)
|
|
||||||
|
|
||||||
# Pubblica il tool_call
|
|
||||||
await self._redis.publish(
|
|
||||||
f"tool_call:{user_id}",
|
|
||||||
json.dumps(payload),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
finally:
|
|
||||||
self._remote_futures.pop(call_id, None)
|
|
||||||
await self._pubsub.unsubscribe(result_channel)
|
|
||||||
|
|
||||||
# ── Risoluzione tool_result (da WS locale) ──
|
|
||||||
|
|
||||||
def resolve_local(self, user_id: str, call_id: str, result: dict):
|
|
||||||
conn = self._local.get(user_id)
|
|
||||||
if conn:
|
|
||||||
fut = conn.pending_calls.pop(call_id, None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(result)
|
|
||||||
|
|
||||||
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
|
|
||||||
"""Chiamato quando il device locale invia un tool_result."""
|
|
||||||
self.resolve_local(user_id, call_id, result)
|
|
||||||
# Pubblica anche su Redis per l'istanza remota che aspetta
|
|
||||||
await self._redis.publish(
|
|
||||||
f"tool_result:{call_id}",
|
|
||||||
json.dumps(result),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Listener Redis ──
|
|
||||||
|
|
||||||
async def _listen_tool_calls(self):
|
|
||||||
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
|
|
||||||
async for message in self._pubsub.listen():
|
|
||||||
if message["type"] != "message":
|
|
||||||
continue
|
|
||||||
channel = message["channel"]
|
|
||||||
if isinstance(channel, bytes):
|
|
||||||
channel = channel.decode()
|
|
||||||
|
|
||||||
data = json.loads(message["data"])
|
|
||||||
|
|
||||||
if channel.startswith("tool_call:"):
|
|
||||||
# Un'altra istanza vuole che inviamo un tool_call al nostro device
|
|
||||||
user_id = channel.split(":", 1)[1]
|
|
||||||
conn = self._local.get(user_id)
|
|
||||||
if conn:
|
|
||||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
|
|
||||||
|
|
||||||
elif channel.startswith("tool_result:"):
|
|
||||||
# Risposta a un tool_call che abbiamo inviato tramite Redis
|
|
||||||
call_id = channel.split(":", 1)[1]
|
|
||||||
fut = self._remote_futures.pop(call_id, None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(data)
|
|
||||||
|
|
||||||
# ── Stream cross-instance ──
|
|
||||||
|
|
||||||
async def publish_stream_chunk(self, user_id: str, chunk: dict):
|
|
||||||
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
|
|
||||||
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. Struttura Directory Proposta (MVP)
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── docker-compose.yml # Orchestrazione completa
|
|
||||||
├── docker-compose.dev.yml # Override per sviluppo locale
|
|
||||||
├── shared/ # Codice condiviso (montato come volume)
|
|
||||||
│ ├── auth.py # JWT verification (chiave pubblica)
|
|
||||||
│ ├── schemas.py # Pydantic schemas condivisi
|
|
||||||
│ ├── middleware/
|
|
||||||
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
|
|
||||||
│ │ └── sanitizer.py
|
|
||||||
│ └── models/
|
|
||||||
│ └── base.py # SQLAlchemy base condivisa
|
|
||||||
│
|
|
||||||
├── auth-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # users, refresh_tokens
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ └── auth.py
|
|
||||||
│ └── services/
|
|
||||||
│ ├── jwt_service.py # RS256 signing
|
|
||||||
│ └── user_service.py
|
|
||||||
│
|
|
||||||
├── chat-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # memory_*
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ ├── device_ws.py # WS connection owner
|
|
||||||
│ │ └── chat.py # REST fallback
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── device_manager.py # RedisDeviceManager
|
|
||||||
│ │ ├── deep_agent.py # Home + floating chat
|
|
||||||
│ │ ├── memory_middleware.py
|
|
||||||
│ │ ├── ws_context.py
|
|
||||||
│ │ ├── output_formatter.py
|
|
||||||
│ │ └── llm.py
|
|
||||||
│ └── agents/ # Tool definitions (used by deep_agent)
|
|
||||||
│ ├── task_agent.py
|
|
||||||
│ ├── project_agent.py
|
|
||||||
│ ├── note_agent.py
|
|
||||||
│ └── timeline_agent.py
|
|
||||||
│
|
|
||||||
├── agent-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ ├── agents.py # catalog, can-create, trigger
|
|
||||||
│ │ └── agent_setup.py # journey start/message
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── agent_runner.py # Batch classify → process
|
|
||||||
│ │ ├── agent_registry.py
|
|
||||||
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
|
|
||||||
│ │ └── llm.py
|
|
||||||
│ └── agents/
|
|
||||||
│ ├── task_agent.py # Tool definitions (batch context)
|
|
||||||
│ ├── project_agent.py
|
|
||||||
│ ├── note_agent.py
|
|
||||||
│ ├── timeline_agent.py
|
|
||||||
│ └── filesystem_agent.py
|
|
||||||
│
|
|
||||||
├── billing-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # subscriptions
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ └── billing.py
|
|
||||||
│ └── services/
|
|
||||||
│ ├── stripe_service.py
|
|
||||||
│ └── tier_manager.py
|
|
||||||
│
|
|
||||||
└── infra/
|
|
||||||
├── traefik/
|
|
||||||
│ └── traefik.yml
|
|
||||||
├── keys/
|
|
||||||
│ ├── jwt_private.pem # Solo auth-service
|
|
||||||
│ └── jwt_public.pem # Tutti i servizi
|
|
||||||
└── alembic/ # Migrazioni condivise o per-servizio
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. Docker Compose — Configurazione MVP
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# docker-compose.yml
|
|
||||||
|
|
||||||
services:
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# API Gateway
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
traefik:
|
|
||||||
image: traefik:v3.2
|
|
||||||
command:
|
|
||||||
- "--api.insecure=true"
|
|
||||||
- "--providers.docker=true"
|
|
||||||
- "--providers.docker.exposedbydefault=false"
|
|
||||||
- "--entrypoints.web.address=:80"
|
|
||||||
- "--entrypoints.websecure.address=:443"
|
|
||||||
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
|
||||||
ports:
|
|
||||||
- "80:80"
|
|
||||||
- "443:443"
|
|
||||||
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
|
|
||||||
volumes:
|
|
||||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
|
||||||
- ./infra/certs:/certs:ro
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Auth Service (2 repliche)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
auth-service:
|
|
||||||
build: ./auth-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
|
|
||||||
SERVICE_NAME: auth
|
|
||||||
secrets:
|
|
||||||
- jwt_private_key
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
|
|
||||||
- "traefik.http.services.auth.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Chat Service — Real-time WS + Chat (scalabile)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
chat-service:
|
|
||||||
build: ./chat-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: chat
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
# REST chat endpoint
|
|
||||||
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
|
|
||||||
- "traefik.http.services.chat.loadbalancer.server.port=8000"
|
|
||||||
# WebSocket route con sticky session
|
|
||||||
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
|
|
||||||
- "traefik.http.routers.ws.service=chat-ws"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Agent Service — Batch processing (scalabile indipendentemente)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
agent-service:
|
|
||||||
build: ./agent-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: agent
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
|
|
||||||
- "traefik.http.services.agents.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Billing Service (1 replica)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
billing-service:
|
|
||||||
build: ./billing-service
|
|
||||||
deploy:
|
|
||||||
replicas: 1
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: billing
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
|
|
||||||
- "traefik.http.services.billing.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Infrastruttura
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
db:
|
|
||||||
image: pgvector/pgvector:pg16
|
|
||||||
environment:
|
|
||||||
POSTGRES_USER: postgres
|
|
||||||
POSTGRES_PASSWORD: postgres
|
|
||||||
POSTGRES_DB: adiuva
|
|
||||||
volumes:
|
|
||||||
- postgres_data:/var/lib/postgresql/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
redis:
|
|
||||||
image: redis:7-alpine
|
|
||||||
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
|
||||||
volumes:
|
|
||||||
- redis_data:/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 3s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
qdrant:
|
|
||||||
image: qdrant/qdrant:latest
|
|
||||||
volumes:
|
|
||||||
- qdrant_data:/qdrant/storage
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
secrets:
|
|
||||||
jwt_private_key:
|
|
||||||
file: ./infra/keys/jwt_private.pem
|
|
||||||
jwt_public_key:
|
|
||||||
file: ./infra/keys/jwt_public.pem
|
|
||||||
|
|
||||||
volumes:
|
|
||||||
postgres_data:
|
|
||||||
redis_data:
|
|
||||||
qdrant_data:
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. Configurazione Cloudflare + VPS
|
|
||||||
|
|
||||||
### 6.1 DNS
|
|
||||||
|
|
||||||
```
|
|
||||||
api.tuodominio.com → A record → IP del VPS
|
|
||||||
→ Proxy: ON (orange cloud)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.2 Cloudflare Settings
|
|
||||||
|
|
||||||
| Setting | Valore | Motivo |
|
|
||||||
|---------|--------|--------|
|
|
||||||
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
|
|
||||||
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
|
|
||||||
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
|
|
||||||
| Under Attack Mode | Off (attivare se necessario) | |
|
|
||||||
|
|
||||||
### 6.3 TLS sul VPS
|
|
||||||
|
|
||||||
Due opzioni:
|
|
||||||
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
|
|
||||||
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# traefik.yml — con Cloudflare Origin Certificate
|
|
||||||
entryPoints:
|
|
||||||
websecure:
|
|
||||||
address: ":443"
|
|
||||||
|
|
||||||
tls:
|
|
||||||
certificates:
|
|
||||||
- certFile: /certs/origin.pem
|
|
||||||
keyFile: /certs/origin-key.pem
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.4 Rete VPS
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
|
|
||||||
# https://www.cloudflare.com/ips/
|
|
||||||
ufw default deny incoming
|
|
||||||
ufw allow from 173.245.48.0/20 to any port 443
|
|
||||||
ufw allow from 103.21.244.0/22 to any port 443
|
|
||||||
# ... (tutti gli IP range di Cloudflare)
|
|
||||||
ufw allow ssh
|
|
||||||
ufw enable
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. Comunicazione Inter-Servizio
|
|
||||||
|
|
||||||
### 7.1 Redis Pub/Sub — Event Bus
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ tier_changed:user_123 ┌──────────┐
|
|
||||||
│ Billing │ ────────────────────────► │ Auth │
|
|
||||||
│ Service │ │ Service │
|
|
||||||
└──────────┘ └──────────┘
|
|
||||||
|
|
||||||
┌──────────┐ tool_call:user_123 ┌──────────┐
|
|
||||||
│ Agent │ ────────────────────────► │ Chat │
|
|
||||||
│ Service │ │ Service │
|
|
||||||
│ (batch) │ ◄────────────────────────│ (ha WS) │
|
|
||||||
└──────────┘ tool_result:{call_id} └──────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 7.2 Health Checks e Service Discovery
|
|
||||||
|
|
||||||
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
|
|
||||||
- **Redis pub/sub** (tool-call cross-instance, tier events)
|
|
||||||
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
|
|
||||||
- **PostgreSQL** (dati persistenti condivisi)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. Piano di Migrazione Incrementale (MVP)
|
|
||||||
|
|
||||||
### Fase 1 — Preparazione (nel monolite attuale)
|
|
||||||
1. Aggiungere Redis al `docker-compose.yml` attuale
|
|
||||||
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
|
|
||||||
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
|
|
||||||
4. Estrarre `shared/` con auth verification, schemas, middleware
|
|
||||||
|
|
||||||
### Fase 2 — Auth Service (primo split)
|
|
||||||
1. Estrarre `auth.py` routes + models in `auth-service/`
|
|
||||||
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
|
|
||||||
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
|
|
||||||
4. Il monolite continua a servire tutto il resto
|
|
||||||
|
|
||||||
### Fase 3 — Billing Service
|
|
||||||
1. Estrarre billing routes, Stripe service, tier manager
|
|
||||||
2. Configurare Redis pub/sub per `tier_changed` events
|
|
||||||
3. Routare via Traefik
|
|
||||||
|
|
||||||
### Fase 4 — Split Chat + Agent (il più delicato)
|
|
||||||
1. Il monolite residuo contiene WS + chat + agents
|
|
||||||
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
|
|
||||||
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
|
|
||||||
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
|
|
||||||
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
|
|
||||||
|
|
||||||
### Fase 5 — Scaling test
|
|
||||||
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
|
|
||||||
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
|
|
||||||
3. Monitoring (Prometheus + Grafana) per ogni servizio
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 9. Monitoraggio e Logging
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Aggiungere al docker-compose.yml
|
|
||||||
|
|
||||||
prometheus:
|
|
||||||
image: prom/prometheus:latest
|
|
||||||
volumes:
|
|
||||||
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
grafana:
|
|
||||||
image: grafana/grafana:latest
|
|
||||||
ports:
|
|
||||||
- "3000:3000"
|
|
||||||
volumes:
|
|
||||||
- grafana_data:/var/lib/grafana
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
loki:
|
|
||||||
image: grafana/loki:latest
|
|
||||||
restart: unless-stopped
|
|
||||||
```
|
|
||||||
|
|
||||||
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 10. Sizing VPS Minimo Consigliato (MVP)
|
|
||||||
|
|
||||||
| Componente | CPU | RAM | Note |
|
|
||||||
|---|---|---|---|
|
|
||||||
| Traefik | 0.25 | 128MB | |
|
|
||||||
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
|
|
||||||
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
|
|
||||||
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
|
|
||||||
| Billing Service | 0.25 | 128MB | |
|
|
||||||
| PostgreSQL | 1.0 | 1GB | |
|
|
||||||
| Redis | 0.25 | 256MB | |
|
|
||||||
| Qdrant | 0.5 | 512MB | |
|
|
||||||
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
|
|
||||||
|
|
||||||
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Riepilogo Architettura MVP
|
|
||||||
|
|
||||||
| Servizio | Repliche | Proprietario di |
|
|
||||||
|---|---|---|
|
|
||||||
| **Traefik** | 1 | Routing, TLS, sticky sessions |
|
|
||||||
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
|
|
||||||
| **Chat Service** | 2–N | WebSocket, home/floating chat, streaming |
|
|
||||||
| **Agent Service** | 2–N | Batch processing, directory scan, agent setup |
|
|
||||||
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
|
|
||||||
|
|
||||||
| Decisione | Scelta | Motivazione |
|
|
||||||
|---|---|---|
|
|
||||||
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
|
|
||||||
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
|
|
||||||
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
|
|
||||||
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
|
|
||||||
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
|
|
||||||
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
|
|
||||||
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
|
|
||||||
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
|
|
||||||
| TLS | Cloudflare Origin Certificate | Zero maintenance |
|
|
||||||
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
|
|
||||||
| Storage / Plugin | Post-MVP | Non critici per il lancio |
|
|
||||||
@@ -32,5 +32,9 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
|
pgvector>=0.2.5
|
||||||
langfuse>=2.0.0
|
langfuse>=2.0.0
|
||||||
|
beautifulsoup4>=4.12.0
|
||||||
|
lxml>=5.0.0
|
||||||
|
PyYAML>=6.0.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
@@ -6,26 +6,21 @@ a per-test session, and a FastAPI ``TestClient`` wired to use it.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
from unittest.mock import patch
|
|
||||||
|
|
||||||
import boto3
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from moto import mock_aws
|
|
||||||
from sqlalchemy import StaticPool, event
|
from sqlalchemy import StaticPool, event
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.db import Base, get_session
|
from app.db import Base, get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.models import Plugin, Subscription, User
|
from app.models import Subscription, User
|
||||||
|
|
||||||
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
||||||
|
|
||||||
@@ -109,79 +104,6 @@ def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # n
|
|||||||
app.dependency_overrides.pop(get_session, None)
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
# ── Seed data helpers ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_SEED_PLUGINS = [
|
|
||||||
Plugin(
|
|
||||||
id="plugin-github-sync",
|
|
||||||
name="GitHub Sync",
|
|
||||||
description="Sync tasks with GitHub Issues and pull requests.",
|
|
||||||
version="1.0.0",
|
|
||||||
author_name="Adiuva",
|
|
||||||
category="productivity",
|
|
||||||
price_cents=0,
|
|
||||||
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
|
||||||
status="approved",
|
|
||||||
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
|
|
||||||
install_count=0,
|
|
||||||
avg_rating=0.0,
|
|
||||||
),
|
|
||||||
Plugin(
|
|
||||||
id="plugin-slack-notify",
|
|
||||||
name="Slack Notifier",
|
|
||||||
description="Post task and timeline updates to Slack channels.",
|
|
||||||
version="1.2.0",
|
|
||||||
author_name="Adiuva",
|
|
||||||
category="communication",
|
|
||||||
price_cents=499,
|
|
||||||
permissions=json.dumps(["read:tasks", "read:timelines"]),
|
|
||||||
status="approved",
|
|
||||||
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
|
||||||
install_count=0,
|
|
||||||
avg_rating=0.0,
|
|
||||||
),
|
|
||||||
Plugin(
|
|
||||||
id="plugin-time-tracker",
|
|
||||||
name="Time Tracker",
|
|
||||||
description="Track time spent on tasks with automatic reporting.",
|
|
||||||
version="0.9.1",
|
|
||||||
author_name="Third Party",
|
|
||||||
category="productivity",
|
|
||||||
price_cents=999,
|
|
||||||
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
|
||||||
status="approved",
|
|
||||||
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
|
|
||||||
install_count=0,
|
|
||||||
avg_rating=0.0,
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
|
|
||||||
"""Insert the 3 default approved plugins and return them."""
|
|
||||||
plugins = []
|
|
||||||
for template in _SEED_PLUGINS:
|
|
||||||
p = Plugin(
|
|
||||||
id=template.id,
|
|
||||||
name=template.name,
|
|
||||||
description=template.description,
|
|
||||||
version=template.version,
|
|
||||||
author_name=template.author_name,
|
|
||||||
category=template.category,
|
|
||||||
price_cents=template.price_cents,
|
|
||||||
permissions=template.permissions,
|
|
||||||
status=template.status,
|
|
||||||
s3_package_key=template.s3_package_key,
|
|
||||||
install_count=template.install_count,
|
|
||||||
avg_rating=template.avg_rating,
|
|
||||||
)
|
|
||||||
db_session.add(p)
|
|
||||||
plugins.append(p)
|
|
||||||
await db_session.commit()
|
|
||||||
return plugins
|
|
||||||
|
|
||||||
|
|
||||||
# ── JWT helpers ──────────────────────────────────────────────────────
|
# ── JWT helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -212,24 +134,21 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
|
|||||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||||
|
|
||||||
|
|
||||||
# ── S3 mock fixture ──────────────────────────────────────────────────
|
# ── CLI options ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
S3_TEST_BUCKET = "test-bucket"
|
def pytest_addoption(parser):
|
||||||
S3_TEST_REGION = "us-east-1"
|
parser.addoption(
|
||||||
|
"--preprocess-dir",
|
||||||
|
default=None,
|
||||||
@pytest.fixture
|
help="Override fixture folder for preprocessor tests (must contain cases.yaml + data/)",
|
||||||
def s3_bucket():
|
)
|
||||||
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
|
parser.addoption(
|
||||||
with mock_aws():
|
"--runner-dir",
|
||||||
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
default=None,
|
||||||
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
help="Override fixture folder for agent_runner_v2 eval tests (must contain cases.yaml + data/)",
|
||||||
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
|
)
|
||||||
client = boto3.client("s3", region_name=S3_TEST_REGION)
|
parser.addoption(
|
||||||
client.create_bucket(Bucket=S3_TEST_BUCKET)
|
"--journey-dir",
|
||||||
with patch("app.storage.blob_store.settings") as mock_settings:
|
default=None,
|
||||||
mock_settings.S3_BUCKET = S3_TEST_BUCKET
|
help="Override fixture folder for journey_v2 eval tests (must contain cases.yaml + data/)",
|
||||||
mock_settings.S3_REGION = S3_TEST_REGION
|
)
|
||||||
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
|
||||||
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
|
||||||
yield S3_TEST_BUCKET
|
|
||||||
|
|||||||
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
Normal file
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
Normal file
@@ -0,0 +1,86 @@
|
|||||||
|
# Agent Runner V2 — eval test cases (Step 2, requires real LLM)
|
||||||
|
#
|
||||||
|
# Each case drives one parametrized `test_eval_runner` invocation.
|
||||||
|
#
|
||||||
|
# Keys
|
||||||
|
# ----
|
||||||
|
# id: str unique identifier shown in pytest output
|
||||||
|
# description: str human-readable label
|
||||||
|
# file: str filename inside data/
|
||||||
|
# file_path: str path reported to the executor (affects project-matching via filename)
|
||||||
|
# projects: [alpha|beta] symbolic project names resolved by the test helper
|
||||||
|
#
|
||||||
|
# Optional pre-existing records (dedup tests)
|
||||||
|
# existing_tasks: list of {id, title, status, priority}
|
||||||
|
# existing_notes: list of {id, title, content}
|
||||||
|
# existing_timelines: list of {id, title, date}
|
||||||
|
#
|
||||||
|
# Assertions (one or more)
|
||||||
|
# expect_insert: <table> at least 1 insert row in this table (tasks|notes|timelines)
|
||||||
|
# expect_no_insert: true zero inserts in any table
|
||||||
|
# expect_project_id: <id> any insert must carry this projectId
|
||||||
|
# expect_dedup: true task inserts == 0 OR task updates >= 1 (dedup check)
|
||||||
|
#
|
||||||
|
# Langfuse
|
||||||
|
# score_name: str observation score name
|
||||||
|
|
||||||
|
- id: "2.1"
|
||||||
|
description: "Action email → create_task"
|
||||||
|
file: email_action.html
|
||||||
|
file_path: /emails/ProjectAlpha_action.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_insert: tasks
|
||||||
|
score_name: runner.email_to_task
|
||||||
|
|
||||||
|
- id: "2.2"
|
||||||
|
description: "Informational email → create_note"
|
||||||
|
file: email_info.html
|
||||||
|
file_path: /emails/ProjectAlpha_info.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_insert: notes
|
||||||
|
score_name: runner.email_to_note
|
||||||
|
|
||||||
|
- id: "2.3"
|
||||||
|
description: "Email with meeting date → create_timeline"
|
||||||
|
file: email_date.html
|
||||||
|
file_path: /emails/ProjectAlpha_kickoff.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_insert: timelines
|
||||||
|
score_name: runner.email_to_timeline
|
||||||
|
|
||||||
|
- id: "2.4"
|
||||||
|
description: "Filename contains project name → correct project assigned"
|
||||||
|
file: email_action.html
|
||||||
|
file_path: /emails/ProjectAlpha_report.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_project_id: proj-alpha
|
||||||
|
score_name: runner.project_filename
|
||||||
|
|
||||||
|
- id: "2.5"
|
||||||
|
description: "Email body mentions project → correct project assigned"
|
||||||
|
file: email_action.html
|
||||||
|
file_path: /emails/email_001.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_project_id: proj-alpha
|
||||||
|
score_name: runner.project_content
|
||||||
|
|
||||||
|
- id: "2.6"
|
||||||
|
description: "Newsletter + global rule no-project → no creates"
|
||||||
|
file: email_no_project.html
|
||||||
|
file_path: /emails/newsletter.html
|
||||||
|
projects: [alpha, beta]
|
||||||
|
expect_no_insert: true
|
||||||
|
score_name: runner.no_project
|
||||||
|
|
||||||
|
- id: "2.7"
|
||||||
|
description: "Existing task with same title → dedup (update not create)"
|
||||||
|
file: email_action.html
|
||||||
|
file_path: /emails/ProjectAlpha_followup.html
|
||||||
|
projects: [alpha]
|
||||||
|
existing_tasks:
|
||||||
|
- id: task-existing
|
||||||
|
title: Fix the login bug
|
||||||
|
status: todo
|
||||||
|
priority: medium
|
||||||
|
expect_dedup: true
|
||||||
|
score_name: runner.dedup
|
||||||
7
tests/fixtures/agent_runner_v2/data/email_action.html
vendored
Normal file
7
tests/fixtures/agent_runner_v2/data/email_action.html
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<html><head></head><body>
|
||||||
|
<p><b>From:</b> boss@company.com</p>
|
||||||
|
<p><b>To:</b> dev@company.com</p>
|
||||||
|
<p><b>Subject:</b> Fix the login bug</p>
|
||||||
|
<p><b>Date:</b> 2026-04-07</p>
|
||||||
|
<p>Hi,<br>Please fix the login bug in Project Alpha by Friday. High priority!</p>
|
||||||
|
</body></html>
|
||||||
5
tests/fixtures/agent_runner_v2/data/email_date.html
vendored
Normal file
5
tests/fixtures/agent_runner_v2/data/email_date.html
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<html><head></head><body>
|
||||||
|
<p><b>From:</b> pm@company.com</p>
|
||||||
|
<p><b>Subject:</b> Project Alpha kick-off meeting</p>
|
||||||
|
<p>The kick-off meeting for Project Alpha is scheduled for 2026-04-15 at 10:00.</p>
|
||||||
|
</body></html>
|
||||||
7
tests/fixtures/agent_runner_v2/data/email_info.html
vendored
Normal file
7
tests/fixtures/agent_runner_v2/data/email_info.html
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
<html><head></head><body>
|
||||||
|
<p><b>From:</b> pm@company.com</p>
|
||||||
|
<p><b>To:</b> team@company.com</p>
|
||||||
|
<p><b>Subject:</b> FYI: New policy for Project Alpha</p>
|
||||||
|
<p>Just a heads-up that starting next week all code reviews must be done
|
||||||
|
within 24 hours for Project Alpha. No action needed from you now.</p>
|
||||||
|
</body></html>
|
||||||
5
tests/fixtures/agent_runner_v2/data/email_no_project.html
vendored
Normal file
5
tests/fixtures/agent_runner_v2/data/email_no_project.html
vendored
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
<html><head></head><body>
|
||||||
|
<p><b>From:</b> newsletter@ads.com</p>
|
||||||
|
<p><b>Subject:</b> Weekly newsletter</p>
|
||||||
|
<p>Check out our latest deals on electronics!</p>
|
||||||
|
</body></html>
|
||||||
19
tests/fixtures/journey_v2/cases.yaml
vendored
Normal file
19
tests/fixtures/journey_v2/cases.yaml
vendored
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
# Journey V2 eval test cases — Step 4
|
||||||
|
#
|
||||||
|
# Only case 4.1 is kept as an automated eval. Cases 4.2–4.5 (multi-turn
|
||||||
|
# conversations that expect the LLM to produce a complete AgentConfig)
|
||||||
|
# are non-deterministic and tested manually — results tracked in Langfuse.
|
||||||
|
#
|
||||||
|
# Assertion keys:
|
||||||
|
# expect_question: true → first reply must contain "?"
|
||||||
|
|
||||||
|
- id: "4.1"
|
||||||
|
description: "Journey start explores directory, first reply contains a question"
|
||||||
|
directory: "/test/emails"
|
||||||
|
data_types: ["tasks", "notes", "timelines"]
|
||||||
|
directory_files:
|
||||||
|
- path: "/test/emails/outlook_export_2024.html"
|
||||||
|
content_file: "email_action.html"
|
||||||
|
user_messages: []
|
||||||
|
score_name: "journey.start"
|
||||||
|
expect_question: true
|
||||||
23
tests/fixtures/journey_v2/data/email_action.html
vendored
Normal file
23
tests/fixtures/journey_v2/data/email_action.html
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Email: Fix the login bug</title>
|
||||||
|
<style>body { font-family: Arial; } .header { color: #666; }</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="header">
|
||||||
|
<p><strong>From:</strong> boss@company.com</p>
|
||||||
|
<p><strong>To:</strong> dev@company.com</p>
|
||||||
|
<p><strong>Subject:</strong> Fix the login bug</p>
|
||||||
|
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:15:00 +0000</p>
|
||||||
|
</div>
|
||||||
|
<div class="body">
|
||||||
|
<p>Hi,</p>
|
||||||
|
<p>Please fix the login bug in Project Alpha as soon as possible.
|
||||||
|
Users are reporting that they can't log in with their Google accounts.
|
||||||
|
This is blocking the whole team. Please resolve it by Friday.</p>
|
||||||
|
<p>Thanks,<br>Boss</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
23
tests/fixtures/journey_v2/data/email_info.html
vendored
Normal file
23
tests/fixtures/journey_v2/data/email_info.html
vendored
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>Email: New policy update</title>
|
||||||
|
<style>body { font-family: Arial; }</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="header">
|
||||||
|
<p><strong>From:</strong> hr@company.com</p>
|
||||||
|
<p><strong>To:</strong> all@company.com</p>
|
||||||
|
<p><strong>Subject:</strong> FYI: New remote work policy effective May 1</p>
|
||||||
|
<p><strong>Date:</strong> Tue, 8 Apr 2026 10:00:00 +0000</p>
|
||||||
|
</div>
|
||||||
|
<div class="body">
|
||||||
|
<p>Hi everyone,</p>
|
||||||
|
<p>Just a heads-up that starting May 1, 2026 the company will be moving to
|
||||||
|
a hybrid work model. You will be expected to come into the office at least
|
||||||
|
two days per week. More details will follow in the employee handbook.</p>
|
||||||
|
<p>Best,<br>HR Team</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
68
tests/fixtures/preprocessors/cases.yaml
vendored
Normal file
68
tests/fixtures/preprocessors/cases.yaml
vendored
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# Preprocessor test cases
|
||||||
|
#
|
||||||
|
# detect: <expected_type> → chiama detect_content_type(filename, content)
|
||||||
|
# process: <content_type> → chiama preprocess(content_type, content)
|
||||||
|
#
|
||||||
|
# Sorgente: file: <nome in data/> oppure generate: binary_noise
|
||||||
|
#
|
||||||
|
# Assertions piatte (solo per process):
|
||||||
|
# no_html: true clean_text senza tag HTML
|
||||||
|
# min_chars: N len(clean_text) >= N
|
||||||
|
# ratio_lt: F len(clean) / len(raw) < F
|
||||||
|
# has_meta: [k, ...] chiavi presenti in metadata
|
||||||
|
# contains: str | [str] substring(s) presenti in clean_text
|
||||||
|
# excludes: str | [str] substring(s) assenti da clean_text
|
||||||
|
# content_type: str result.content_type == questo valore
|
||||||
|
|
||||||
|
- id: "1.1"
|
||||||
|
file: email_action.html
|
||||||
|
detect: email_html
|
||||||
|
|
||||||
|
- id: "1.2"
|
||||||
|
file: generic_page.html
|
||||||
|
detect: generic_html
|
||||||
|
|
||||||
|
- id: "1.3"
|
||||||
|
file: notes.txt
|
||||||
|
detect: plain_text
|
||||||
|
|
||||||
|
- id: "1.4"
|
||||||
|
file: archive.xyz
|
||||||
|
generate: binary_noise
|
||||||
|
detect: unknown
|
||||||
|
|
||||||
|
- id: "1.5"
|
||||||
|
file: email_action.html
|
||||||
|
process: email_html
|
||||||
|
no_html: true
|
||||||
|
min_chars: 50
|
||||||
|
ratio_lt: 0.8
|
||||||
|
|
||||||
|
- id: "1.6"
|
||||||
|
file: email_action.html
|
||||||
|
process: email_html
|
||||||
|
has_meta: [subject, from]
|
||||||
|
|
||||||
|
- id: "1.7"
|
||||||
|
file: email_thread.html
|
||||||
|
process: email_html
|
||||||
|
contains: "Sure, I'll handle the deploy"
|
||||||
|
excludes: "Let's plan the deploy"
|
||||||
|
|
||||||
|
- id: "1.8"
|
||||||
|
file: email_single.html
|
||||||
|
process: email_html
|
||||||
|
contains: "deploy is done"
|
||||||
|
|
||||||
|
- id: "1.9"
|
||||||
|
file: email_heavy.html
|
||||||
|
process: email_html
|
||||||
|
no_html: true
|
||||||
|
min_chars: 30
|
||||||
|
excludes: [border-collapse, font-size]
|
||||||
|
|
||||||
|
- id: "1.10"
|
||||||
|
file: fallback.txt
|
||||||
|
process: unknown
|
||||||
|
min_chars: 1
|
||||||
|
content_type: unknown
|
||||||
25
tests/fixtures/preprocessors/data/email_action.html
vendored
Normal file
25
tests/fixtures/preprocessors/data/email_action.html
vendored
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<title>Fix the login bug</title>
|
||||||
|
<style>
|
||||||
|
body { font-family: Arial, sans-serif; color: #333; margin: 0; padding: 20px; }
|
||||||
|
.header { background: #f5f5f5; padding: 10px; border-bottom: 1px solid #ddd; }
|
||||||
|
.body { padding: 20px; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<div class="header">
|
||||||
|
<p><strong>From:</strong> boss@company.com</p>
|
||||||
|
<p><strong>To:</strong> dev@company.com</p>
|
||||||
|
<p><strong>Subject:</strong> Fix the login bug</p>
|
||||||
|
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:00:00 +0200</p>
|
||||||
|
</div>
|
||||||
|
<div class="body">
|
||||||
|
<p>Hi,</p>
|
||||||
|
<p>Please fix the login bug by Friday. It is blocking the release.</p>
|
||||||
|
<p>Priority: high. Let me know if you need anything.</p>
|
||||||
|
<p>Thanks,<br>Boss</p>
|
||||||
|
</div>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
49
tests/fixtures/preprocessors/data/email_heavy.html
vendored
Normal file
49
tests/fixtures/preprocessors/data/email_heavy.html
vendored
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html>
|
||||||
|
<head>
|
||||||
|
<style>
|
||||||
|
table { border-collapse: collapse; width: 100%; max-width: 600px; margin: 0 auto; }
|
||||||
|
td { padding: 8px 12px; border: 1px solid #dddddd; font-size: 12px; color: #444444; }
|
||||||
|
.header-row { background-color: #003366; color: #ffffff; font-weight: bold; }
|
||||||
|
.label-col { background-color: #f0f0f0; width: 80px; font-weight: bold; }
|
||||||
|
.footer-row { font-size: 10px; color: #999999; text-align: center; }
|
||||||
|
</style>
|
||||||
|
</head>
|
||||||
|
<body bgcolor="#eeeeee">
|
||||||
|
<center>
|
||||||
|
<table cellpadding="0" cellspacing="0">
|
||||||
|
<tr class="header-row">
|
||||||
|
<td colspan="2">Company Internal Update</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="label-col">From:</td>
|
||||||
|
<td>newsletter@corp.com</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="label-col">Subject:</td>
|
||||||
|
<td>Q1 Results Update</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td class="label-col">Date:</td>
|
||||||
|
<td>Apr 7, 2026</td>
|
||||||
|
</tr>
|
||||||
|
<tr>
|
||||||
|
<td colspan="2">
|
||||||
|
<table width="100%" cellpadding="10">
|
||||||
|
<tr>
|
||||||
|
<td>
|
||||||
|
<p style="font-size:14px; font-weight:bold;">Dear Team,</p>
|
||||||
|
<p>Q1 results are in. Revenue up 15% year-over-year.</p>
|
||||||
|
<p>Please review the attached report and share any feedback by EOW.</p>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</td>
|
||||||
|
</tr>
|
||||||
|
<tr class="footer-row">
|
||||||
|
<td colspan="2">Confidential — do not forward outside the company.</td>
|
||||||
|
</tr>
|
||||||
|
</table>
|
||||||
|
</center>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
8
tests/fixtures/preprocessors/data/email_single.html
vendored
Normal file
8
tests/fixtures/preprocessors/data/email_single.html
vendored
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html><body>
|
||||||
|
<p><strong>From:</strong> alice@co.com</p>
|
||||||
|
<p><strong>To:</strong> team@co.com</p>
|
||||||
|
<p><strong>Subject:</strong> Quick update</p>
|
||||||
|
<p><strong>Date:</strong> Tue, 7 Apr 2026 10:30:00 +0200</p>
|
||||||
|
<p>The deploy is done. Everything looks good. No issues so far.</p>
|
||||||
|
</body></html>
|
||||||
24
tests/fixtures/preprocessors/data/email_thread.html
vendored
Normal file
24
tests/fixtures/preprocessors/data/email_thread.html
vendored
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html><body>
|
||||||
|
<div class="message-latest">
|
||||||
|
<p><strong>From:</strong> alice@co.com</p>
|
||||||
|
<p><strong>Subject:</strong> Re: Re: Deploy plan</p>
|
||||||
|
<p>Sure, I'll handle the deploy.</p>
|
||||||
|
</div>
|
||||||
|
|
||||||
|
<p>On Mon, Apr 6, 2026 at 3:00 PM, Bob <bob@co.com> wrote:</p>
|
||||||
|
<blockquote>
|
||||||
|
<p>From: bob@co.com</p>
|
||||||
|
<p>Can you handle the deploy?</p>
|
||||||
|
<p>On Sun, Apr 5, 2026 at 1:00 PM, Alice <alice@co.com> wrote:</p>
|
||||||
|
<blockquote>
|
||||||
|
<p>From: alice@co.com</p>
|
||||||
|
<p>Let's plan the deploy for Monday.</p>
|
||||||
|
<p>On Sat, Apr 4, 2026 at 11:00 AM, Charlie <charlie@co.com> wrote:</p>
|
||||||
|
<blockquote>
|
||||||
|
<p>From: charlie@co.com</p>
|
||||||
|
<p>We need to schedule the deploy. What day works?</p>
|
||||||
|
</blockquote>
|
||||||
|
</blockquote>
|
||||||
|
</blockquote>
|
||||||
|
</body></html>
|
||||||
3
tests/fixtures/preprocessors/data/fallback.txt
vendored
Normal file
3
tests/fixtures/preprocessors/data/fallback.txt
vendored
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
random text content without any structure
|
||||||
|
line two with some words
|
||||||
|
line three and more content here
|
||||||
35
tests/fixtures/preprocessors/data/generic_page.html
vendored
Normal file
35
tests/fixtures/preprocessors/data/generic_page.html
vendored
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
<!DOCTYPE html>
|
||||||
|
<html lang="en">
|
||||||
|
<head>
|
||||||
|
<meta charset="UTF-8">
|
||||||
|
<title>My Web App</title>
|
||||||
|
<link rel="stylesheet" href="styles.css">
|
||||||
|
</head>
|
||||||
|
<body>
|
||||||
|
<nav>
|
||||||
|
<a href="/">Home</a>
|
||||||
|
<a href="/about">About</a>
|
||||||
|
<a href="/contact">Contact</a>
|
||||||
|
</nav>
|
||||||
|
<main>
|
||||||
|
<header>
|
||||||
|
<h1>Welcome to My App</h1>
|
||||||
|
</header>
|
||||||
|
<article>
|
||||||
|
<p>This is a generic web page with no email headers.</p>
|
||||||
|
<p>It has navigation, main content, and a footer.</p>
|
||||||
|
</article>
|
||||||
|
<section>
|
||||||
|
<h2>Features</h2>
|
||||||
|
<ul>
|
||||||
|
<li>Fast</li>
|
||||||
|
<li>Reliable</li>
|
||||||
|
<li>Secure</li>
|
||||||
|
</ul>
|
||||||
|
</section>
|
||||||
|
</main>
|
||||||
|
<footer>
|
||||||
|
<p>© 2026 My App</p>
|
||||||
|
</footer>
|
||||||
|
</body>
|
||||||
|
</html>
|
||||||
15
tests/fixtures/preprocessors/data/notes.txt
vendored
Normal file
15
tests/fixtures/preprocessors/data/notes.txt
vendored
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
Meeting notes - April 7, 2026
|
||||||
|
|
||||||
|
Attendees: Alice, Bob, Charlie
|
||||||
|
|
||||||
|
Discussion points:
|
||||||
|
- Deploy scheduled for Friday
|
||||||
|
- Bug fix for login must be completed by Thursday
|
||||||
|
- Review Q1 numbers before EOW
|
||||||
|
|
||||||
|
Action items:
|
||||||
|
- Alice: fix login bug
|
||||||
|
- Bob: prepare deploy checklist
|
||||||
|
- Charlie: send Q1 report
|
||||||
|
|
||||||
|
Next meeting: April 14, 2026
|
||||||
@@ -28,7 +28,6 @@ from datetime import datetime, timezone
|
|||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from app.core.agent_runner import (
|
from app.core.agent_runner import (
|
||||||
_extract_items_from_content,
|
_extract_items_from_content,
|
||||||
@@ -597,7 +596,7 @@ async def test_run_cloud_agent_provider_fetch_error():
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_run_cloud_agent_refreshed_token_persisted():
|
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||||
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||||
from app.integrations import EmailMessage, encrypt_token
|
from app.integrations import encrypt_token
|
||||||
from cryptography.fernet import Fernet as _Fernet
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
|
||||||
fernet_key = _Fernet.generate_key().decode()
|
fernet_key = _Fernet.generate_key().decode()
|
||||||
@@ -791,7 +790,6 @@ async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
|||||||
json={
|
json={
|
||||||
"directory": "/home/user/docs",
|
"directory": "/home/user/docs",
|
||||||
"what_to_extract": ["task", "note"],
|
"what_to_extract": ["task", "note"],
|
||||||
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
|
|
||||||
"batch_interval": "0 */6 * * *",
|
"batch_interval": "0 */6 * * *",
|
||||||
"custom_agent_prompt": "Extract tasks and notes.",
|
"custom_agent_prompt": "Extract tasks and notes.",
|
||||||
"active_agents": 0,
|
"active_agents": 0,
|
||||||
|
|||||||
431
tests/test_agent_runner_v2.py
Normal file
431
tests/test_agent_runner_v2.py
Normal file
@@ -0,0 +1,431 @@
|
|||||||
|
"""Tests for Local Agent V2 runner (Step 2).
|
||||||
|
|
||||||
|
Covers the unified per-file flow:
|
||||||
|
Phase A — detect + preprocess (Python, zero LLM)
|
||||||
|
Phase B — single LLM call with tools (classify + extract + create)
|
||||||
|
|
||||||
|
Fixture-based eval tests (2.1–2.7)
|
||||||
|
-----------------------------------
|
||||||
|
Cases are defined in tests/fixtures/agent_runner_v2/cases.yaml.
|
||||||
|
Email HTML files live in tests/fixtures/agent_runner_v2/data/.
|
||||||
|
Use --runner-dir to point at a custom folder (same structure required).
|
||||||
|
|
||||||
|
Unit tests (no LLM)
|
||||||
|
--------------------
|
||||||
|
2.8 items_created count → items_created == N create_* calls
|
||||||
|
2.9 Device offline → status=error
|
||||||
|
2.10 Empty file → items_processed=0, status=success
|
||||||
|
|
||||||
|
Run:
|
||||||
|
pytest tests/test_agent_runner_v2.py -v
|
||||||
|
pytest tests/test_agent_runner_v2.py -v -k "2_9 or 2_10 or 2_8" # unit only
|
||||||
|
pytest tests/test_agent_runner_v2.py -v -k "eval" # LLM evals only
|
||||||
|
pytest tests/test_agent_runner_v2.py -v --runner-dir /path/to/dir # custom fixtures
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.core.agent_runner import (
|
||||||
|
_format_metadata,
|
||||||
|
_format_projects,
|
||||||
|
_get_extraction_rules,
|
||||||
|
_get_no_match_behavior,
|
||||||
|
run_local_agent,
|
||||||
|
)
|
||||||
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
|
from app.core.langfuse_client import get_langfuse
|
||||||
|
from app.models import AgentRunLog, LocalAgentConfig
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
# ── Constants ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "agent_runner_v2"
|
||||||
|
|
||||||
|
_AGENT_CONFIG = {
|
||||||
|
"content_types": [
|
||||||
|
{
|
||||||
|
"id": "email_html",
|
||||||
|
"label": "Email HTML",
|
||||||
|
"detection_hint": "HTML file with From/To/Subject headers",
|
||||||
|
"preprocessing": "email_html",
|
||||||
|
"extraction_prompt": (
|
||||||
|
"If the email contains a direct action request or task assignment → create a task. "
|
||||||
|
"If the email contains informational content, updates, or FYI → create a note. "
|
||||||
|
"If the email mentions a specific date for a meeting or deadline → create a timeline entry."
|
||||||
|
),
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"global_rules": [
|
||||||
|
"Se il file non è riconducibile a nessun progetto, non creare alcuna entità."
|
||||||
|
],
|
||||||
|
"data_types": ["tasks", "notes", "timelines"],
|
||||||
|
}
|
||||||
|
|
||||||
|
# Canonical project definitions, referenced symbolically in cases.yaml.
|
||||||
|
_PROJECTS: dict[str, dict] = {
|
||||||
|
"alpha": {"id": "proj-alpha", "name": "Project Alpha", "status": "active"},
|
||||||
|
"beta": {"id": "proj-beta", "name": "Project Beta", "status": "active"},
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fixture loading ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _fixtures_dir(config) -> Path:
|
||||||
|
override = config.getoption("--runner-dir")
|
||||||
|
return Path(override) if override else _DEFAULT_FIXTURE_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cases(config) -> list[dict]:
|
||||||
|
return yaml.safe_load(
|
||||||
|
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_case_file(case: dict, data_dir: Path) -> str:
|
||||||
|
return (data_dir / case["file"]).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_projects(entries: list[str | dict]) -> list[dict]:
|
||||||
|
"""Resolve project list from YAML: symbolic names and/or inline dicts."""
|
||||||
|
result = []
|
||||||
|
for entry in entries:
|
||||||
|
if isinstance(entry, str):
|
||||||
|
if entry in _PROJECTS:
|
||||||
|
result.append(_PROJECTS[entry])
|
||||||
|
elif isinstance(entry, dict):
|
||||||
|
result.append(entry)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# ── pytest_generate_tests — parametrize eval tests from YAML ─────────────
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "runner_case" not in metafunc.fixturenames:
|
||||||
|
return
|
||||||
|
cases = _load_cases(metafunc.config)
|
||||||
|
metafunc.parametrize("runner_case", cases, ids=[c["id"] for c in cases])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Test helpers ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_config(
|
||||||
|
agent_config: dict | None = None,
|
||||||
|
directory: str = "/emails",
|
||||||
|
device_id: str = "dev-001",
|
||||||
|
) -> LocalAgentConfig:
|
||||||
|
return LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=_USER_ID,
|
||||||
|
device_id=device_id,
|
||||||
|
name="Test V2 Agent",
|
||||||
|
directory_paths=[directory],
|
||||||
|
data_types=["tasks", "notes", "timelines"],
|
||||||
|
prompt_template="",
|
||||||
|
agent_config=agent_config or _AGENT_CONFIG,
|
||||||
|
file_extensions=[".html", ".eml"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_run_log(agent_id: str) -> AgentRunLog:
|
||||||
|
return AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=_USER_ID,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_manager(online: bool = True) -> DeviceConnectionManager:
|
||||||
|
mgr = DeviceConnectionManager()
|
||||||
|
if online:
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
mgr.register(_USER_ID, "dev-001", ws)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
|
||||||
|
def _make_executor(
|
||||||
|
file_path: str,
|
||||||
|
file_content: str,
|
||||||
|
projects: list[dict] | None = None,
|
||||||
|
existing_tasks: list[dict] | None = None,
|
||||||
|
existing_notes: list[dict] | None = None,
|
||||||
|
existing_timelines: list[dict] | None = None,
|
||||||
|
) -> tuple[Any, list[dict]]:
|
||||||
|
"""Return (async_executor, captured_calls).
|
||||||
|
|
||||||
|
The executor handles all ``execute_on_client`` payloads:
|
||||||
|
directory listing, file reading, project/entity fetching, and CRUD.
|
||||||
|
"""
|
||||||
|
calls: list[dict] = []
|
||||||
|
_projects = projects if projects is not None else list(_PROJECTS.values())
|
||||||
|
|
||||||
|
async def _executor(payload: dict) -> dict:
|
||||||
|
action = payload.get("action", "")
|
||||||
|
table = payload.get("table", "")
|
||||||
|
data = payload.get("data") or {}
|
||||||
|
calls.append({"action": action, "table": table, "data": data})
|
||||||
|
|
||||||
|
if action == "list_directory":
|
||||||
|
return {"entries": [{"type": "file", "path": file_path}]}
|
||||||
|
|
||||||
|
if action == "get_file_metadata":
|
||||||
|
return {"modifiedAt": None}
|
||||||
|
|
||||||
|
if action == "read_file_content":
|
||||||
|
return {"content": file_content}
|
||||||
|
|
||||||
|
if action == "select":
|
||||||
|
if table == "projects":
|
||||||
|
return {"rows": _projects}
|
||||||
|
if table == "tasks":
|
||||||
|
return {"rows": existing_tasks or []}
|
||||||
|
if table == "notes":
|
||||||
|
return {"rows": existing_notes or []}
|
||||||
|
if table == "timelines":
|
||||||
|
return {"rows": existing_timelines or []}
|
||||||
|
return {"rows": []}
|
||||||
|
|
||||||
|
if action == "insert":
|
||||||
|
return {"row": {"id": str(uuid.uuid4()), **data}}
|
||||||
|
|
||||||
|
if action == "update":
|
||||||
|
return {"success": True}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return _executor, calls
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: helper functions ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_projects_empty():
|
||||||
|
assert "(no projects" in _format_projects([])
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_projects_with_data():
|
||||||
|
result = _format_projects([_PROJECTS["alpha"]])
|
||||||
|
assert "proj-alpha" in result
|
||||||
|
assert "Project Alpha" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_metadata_empty():
|
||||||
|
assert _format_metadata({}) == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_format_metadata_email():
|
||||||
|
meta = {"subject": "Fix bug", "from": "boss@co.com", "date": "2026-04-07"}
|
||||||
|
result = _format_metadata(meta)
|
||||||
|
assert "Fix bug" in result
|
||||||
|
assert "boss@co.com" in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_extraction_rules_match():
|
||||||
|
rules = _get_extraction_rules(_AGENT_CONFIG, "email_html")
|
||||||
|
assert "task" in rules.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_extraction_rules_fallback():
|
||||||
|
rules = _get_extraction_rules(_AGENT_CONFIG, "plain_text")
|
||||||
|
assert "extract" in rules.lower()
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_no_match_behavior_from_global_rules():
|
||||||
|
behavior = _get_no_match_behavior(_AGENT_CONFIG)
|
||||||
|
assert behavior # non-empty
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_no_match_behavior_default():
|
||||||
|
behavior = _get_no_match_behavior({})
|
||||||
|
assert "project" in behavior.lower()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: 2.9 — device offline ───────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_2_9_device_offline():
|
||||||
|
"""2.9 No device online → status=error, no executor created."""
|
||||||
|
config = _make_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager(online=False)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_fin.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("not connected" in e for e in kwargs.get("errors", []))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: 2.10 — empty file ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_2_10_empty_file():
|
||||||
|
"""2.10 File with empty content → skipped, items_processed=0, success."""
|
||||||
|
config = _make_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
executor, calls = _make_executor(
|
||||||
|
file_path="/emails/empty.html",
|
||||||
|
file_content="",
|
||||||
|
projects=[_PROJECTS["alpha"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_fin.call_args
|
||||||
|
assert kwargs["items_processed"] == 0
|
||||||
|
assert kwargs["status"] == "success"
|
||||||
|
assert kwargs["items_created"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: 2.8 — items_created count ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_2_8_items_created_count():
|
||||||
|
"""2.8 items_created == number of create_* tool calls per run."""
|
||||||
|
config = _make_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
executor, _calls = _make_executor(
|
||||||
|
file_path="/emails/action.html",
|
||||||
|
file_content="<html><body><p>Fix the login bug in Project Alpha.</p></body></html>",
|
||||||
|
projects=[_PROJECTS["alpha"]],
|
||||||
|
)
|
||||||
|
|
||||||
|
async def mock_run_agent(*, _tool_calls_out=None, **kw) -> str:
|
||||||
|
if _tool_calls_out is not None:
|
||||||
|
_tool_calls_out.extend(["create_task", "create_note", "update_task"])
|
||||||
|
return "Done."
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
||||||
|
patch("app.core.agent_runner._run_agent_with_tools", side_effect=mock_run_agent), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_fin.call_args
|
||||||
|
# Only create_task + create_note count (not update_task).
|
||||||
|
assert kwargs["items_created"] == 2
|
||||||
|
assert kwargs["items_processed"] == 1
|
||||||
|
|
||||||
|
|
||||||
|
# ── Eval: 2.1–2.7 — fixture-driven, real LLM + Langfuse scoring ──────────
|
||||||
|
#
|
||||||
|
# Cases loaded from tests/fixtures/agent_runner_v2/cases.yaml.
|
||||||
|
# Supported assertions (from YAML):
|
||||||
|
# expect_insert: <table> → at least 1 insert in that table
|
||||||
|
# expect_no_insert: true → zero inserts in any table
|
||||||
|
# expect_project_id: <id> → any insert carries this projectId
|
||||||
|
# expect_dedup: true → task inserts == 0 OR task updates >= 1
|
||||||
|
# ─────────────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.eval
|
||||||
|
async def test_eval_runner(runner_case, pytestconfig):
|
||||||
|
"""Parametrized eval test — one invocation per YAML case."""
|
||||||
|
case: dict = runner_case
|
||||||
|
data_dir = _fixtures_dir(pytestconfig) / "data"
|
||||||
|
file_content = _read_case_file(case, data_dir)
|
||||||
|
projects = _resolve_projects(case.get("projects", []))
|
||||||
|
|
||||||
|
config = _make_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
executor, calls = _make_executor(
|
||||||
|
file_path=case["file_path"],
|
||||||
|
file_content=file_content,
|
||||||
|
projects=projects,
|
||||||
|
existing_tasks=case.get("existing_tasks"),
|
||||||
|
existing_notes=case.get("existing_notes"),
|
||||||
|
existing_timelines=case.get("existing_timelines"),
|
||||||
|
)
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
obs_ctx = lf.start_as_current_observation(
|
||||||
|
name=f"eval-runner-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
|
||||||
|
metadata={"step": "2", "case_id": case["id"]},
|
||||||
|
) if lf else nullcontext()
|
||||||
|
|
||||||
|
with obs_ctx as obs:
|
||||||
|
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
||||||
|
await run_local_agent(_USER_ID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_fin.call_args
|
||||||
|
inserts = [c for c in calls if c["action"] == "insert"]
|
||||||
|
score, comment = _evaluate_case(case, calls, kwargs)
|
||||||
|
|
||||||
|
if obs is not None:
|
||||||
|
obs.score(
|
||||||
|
name=case.get("score_name", f"runner.case_{case['id']}"),
|
||||||
|
value=score,
|
||||||
|
comment=comment,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lf:
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
assert score == 1.0, f"[{case['id']}] {case.get('description', '')} — {comment}"
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate_case(case: dict, calls: list[dict], finalize_kwargs: dict) -> tuple[float, str]:
|
||||||
|
"""Return (score, comment) for a YAML case given the captured executor calls."""
|
||||||
|
inserts = [c for c in calls if c["action"] == "insert"]
|
||||||
|
|
||||||
|
if case.get("expect_no_insert"):
|
||||||
|
score = 1.0 if len(inserts) == 0 else 0.0
|
||||||
|
return score, f"inserts={len(inserts)} (expected 0)"
|
||||||
|
|
||||||
|
if "expect_insert" in case:
|
||||||
|
tables = case["expect_insert"]
|
||||||
|
if isinstance(tables, str):
|
||||||
|
tables = [tables]
|
||||||
|
missing = [t for t in tables if not any(c["table"] == t for c in inserts)]
|
||||||
|
score = 1.0 if not missing else 0.0
|
||||||
|
counts = {t: sum(1 for c in inserts if c["table"] == t) for t in tables}
|
||||||
|
return score, f"inserts={counts}" + (f" missing={missing}" if missing else "")
|
||||||
|
|
||||||
|
if "expect_project_id" in case:
|
||||||
|
expected_pid = case["expect_project_id"]
|
||||||
|
correct = any(c.get("data", {}).get("projectId") == expected_pid for c in inserts)
|
||||||
|
score = 1.0 if correct else 0.0
|
||||||
|
all_pids = [c.get("data", {}).get("projectId") for c in inserts]
|
||||||
|
return score, f"projectIds={all_pids} (expected {expected_pid!r})"
|
||||||
|
|
||||||
|
if case.get("expect_dedup"):
|
||||||
|
task_creates = [c for c in inserts if c["table"] == "tasks"]
|
||||||
|
task_updates = [c for c in calls if c["action"] == "update" and c["table"] == "tasks"]
|
||||||
|
score = 1.0 if len(task_creates) == 0 or len(task_updates) >= 1 else 0.0
|
||||||
|
return score, f"task_creates={len(task_creates)} task_updates={len(task_updates)}"
|
||||||
|
|
||||||
|
return 0.0, "no assertion defined in case"
|
||||||
@@ -21,7 +21,6 @@ import time
|
|||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import AsyncMock, patch
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
import pytest
|
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Tests for auth routes: register, login, refresh, me.
|
"""Tests for auth routes: register, login, refresh, me, OAuth social login.
|
||||||
|
|
||||||
Exercises the full auth lifecycle through the FastAPI TestClient against the
|
Exercises the full auth lifecycle through the FastAPI TestClient against the
|
||||||
in-memory SQLite test database seeded by ``conftest.py``.
|
in-memory SQLite test database seeded by ``conftest.py``.
|
||||||
@@ -7,9 +7,11 @@ in-memory SQLite test database seeded by ``conftest.py``.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import time
|
import time
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
|
|
||||||
|
from app.auth.oauth_providers import GoogleOAuthProvider, OAuthUserInfo
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from tests.conftest import auth_header, TEST_USER_IDS
|
from tests.conftest import auth_header, TEST_USER_IDS
|
||||||
|
|
||||||
@@ -204,3 +206,153 @@ class TestMe:
|
|||||||
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
|
||||||
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
|
||||||
assert resp.status_code == 401
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── TestOAuth ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class TestOAuth:
|
||||||
|
"""GET /auth/oauth/google/authorize and POST /auth/oauth/google/callback."""
|
||||||
|
|
||||||
|
FAKE_PROVIDER_USER_ID = "google-sub-12345"
|
||||||
|
FAKE_EMAIL = "oauth@example.com"
|
||||||
|
FAKE_AVATAR = "https://lh3.googleusercontent.com/photo.jpg"
|
||||||
|
|
||||||
|
def _patch_google(self, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "fake-client-id")
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "fake-client-secret")
|
||||||
|
|
||||||
|
def _userinfo(
|
||||||
|
self,
|
||||||
|
email: str | None = None,
|
||||||
|
email_verified: bool = True,
|
||||||
|
) -> OAuthUserInfo:
|
||||||
|
return OAuthUserInfo(
|
||||||
|
provider_user_id=self.FAKE_PROVIDER_USER_ID,
|
||||||
|
email=email or self.FAKE_EMAIL,
|
||||||
|
email_verified=email_verified,
|
||||||
|
avatar_url=self.FAKE_AVATAR,
|
||||||
|
name="OAuth User",
|
||||||
|
)
|
||||||
|
|
||||||
|
def _authorize(self, client) -> str:
|
||||||
|
"""Call /authorize and return the fresh state token."""
|
||||||
|
resp = client.get("/api/v1/auth/oauth/google/authorize")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
return resp.json()["state"]
|
||||||
|
|
||||||
|
def _callback(self, client, state: str, userinfo: OAuthUserInfo):
|
||||||
|
"""POST /callback with mocked provider exchange_code + get_userinfo."""
|
||||||
|
with (
|
||||||
|
patch.object(
|
||||||
|
GoogleOAuthProvider,
|
||||||
|
"exchange_code",
|
||||||
|
new=AsyncMock(return_value={"access_token": "google-access-tok"}),
|
||||||
|
),
|
||||||
|
patch.object(
|
||||||
|
GoogleOAuthProvider,
|
||||||
|
"get_userinfo",
|
||||||
|
new=AsyncMock(return_value=userinfo),
|
||||||
|
),
|
||||||
|
):
|
||||||
|
return client.post(
|
||||||
|
"/api/v1/auth/oauth/google/callback",
|
||||||
|
json={"code": "auth-code", "state": state},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _decode_sub(self, access_token: str) -> str:
|
||||||
|
return jwt.decode(
|
||||||
|
access_token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)["sub"]
|
||||||
|
|
||||||
|
# -- authorize --
|
||||||
|
|
||||||
|
def test_authorize_returns_url_and_state(self, client, monkeypatch) -> None:
|
||||||
|
self._patch_google(monkeypatch)
|
||||||
|
resp = client.get("/api/v1/auth/oauth/google/authorize")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "url" in data and "state" in data
|
||||||
|
assert "accounts.google.com" in data["url"]
|
||||||
|
assert len(data["state"]) > 0
|
||||||
|
|
||||||
|
def test_authorize_unconfigured_returns_503(self, client, monkeypatch) -> None:
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "")
|
||||||
|
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "")
|
||||||
|
resp = client.get("/api/v1/auth/oauth/google/authorize")
|
||||||
|
assert resp.status_code == 503
|
||||||
|
|
||||||
|
# -- callback --
|
||||||
|
|
||||||
|
def test_callback_state_mismatch_returns_401(self, client, monkeypatch) -> None:
|
||||||
|
self._patch_google(monkeypatch)
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/auth/oauth/google/callback",
|
||||||
|
json={"code": "code", "state": "not-a-real-state"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
def test_callback_creates_new_user(self, client, monkeypatch) -> None:
|
||||||
|
"""First-time Google login creates a new user and returns valid tokens."""
|
||||||
|
self._patch_google(monkeypatch)
|
||||||
|
state = self._authorize(client)
|
||||||
|
resp = self._callback(client, state, self._userinfo())
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
data = resp.json()
|
||||||
|
assert "access_token" in data and "refresh_token" in data
|
||||||
|
payload = jwt.decode(
|
||||||
|
data["access_token"], settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
assert payload["email"] == self.FAKE_EMAIL
|
||||||
|
|
||||||
|
def test_callback_existing_oauth_link_logs_in(self, client, monkeypatch) -> None:
|
||||||
|
"""Second Google login with the same account re-uses the existing user."""
|
||||||
|
self._patch_google(monkeypatch)
|
||||||
|
userinfo = self._userinfo()
|
||||||
|
|
||||||
|
# First login — creates user + oauth_accounts row
|
||||||
|
resp1 = self._callback(client, self._authorize(client), userinfo)
|
||||||
|
assert resp1.status_code == 200
|
||||||
|
sub1 = self._decode_sub(resp1.json()["access_token"])
|
||||||
|
|
||||||
|
# Second login — finds existing oauth_accounts row → same user
|
||||||
|
resp2 = self._callback(client, self._authorize(client), userinfo)
|
||||||
|
assert resp2.status_code == 200
|
||||||
|
sub2 = self._decode_sub(resp2.json()["access_token"])
|
||||||
|
|
||||||
|
assert sub1 == sub2
|
||||||
|
|
||||||
|
def test_callback_email_match_links_account(self, client, monkeypatch) -> None:
|
||||||
|
"""Verified Google email matching an existing password user links the accounts."""
|
||||||
|
email = "link-target@example.com"
|
||||||
|
reg_resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": email, "password": "TestPass123!"},
|
||||||
|
)
|
||||||
|
assert reg_resp.status_code == 201
|
||||||
|
orig_sub = self._decode_sub(reg_resp.json()["access_token"])
|
||||||
|
|
||||||
|
self._patch_google(monkeypatch)
|
||||||
|
state = self._authorize(client)
|
||||||
|
resp = self._callback(client, state, self._userinfo(email=email, email_verified=True))
|
||||||
|
|
||||||
|
assert resp.status_code == 200
|
||||||
|
oauth_sub = self._decode_sub(resp.json()["access_token"])
|
||||||
|
# OAuth login must resolve to the same user as the original registration
|
||||||
|
assert orig_sub == oauth_sub
|
||||||
|
|
||||||
|
def test_callback_unverified_email_conflict_returns_409(self, client, monkeypatch) -> None:
|
||||||
|
"""Unverified Google email matching an existing account returns 409, not 500."""
|
||||||
|
email = "conflict@example.com"
|
||||||
|
reg_resp = client.post(
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
json={"email": email, "password": "TestPass123!"},
|
||||||
|
)
|
||||||
|
assert reg_resp.status_code == 201
|
||||||
|
|
||||||
|
self._patch_google(monkeypatch)
|
||||||
|
state = self._authorize(client)
|
||||||
|
resp = self._callback(client, state, self._userinfo(email=email, email_verified=False))
|
||||||
|
|
||||||
|
assert resp.status_code == 409
|
||||||
|
|||||||
@@ -18,13 +18,12 @@ from datetime import datetime, timezone
|
|||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
|
|
||||||
from app.core.device_manager import DeviceConnection, DeviceConnectionManager
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
from tests.conftest import TEST_USER_IDS, auth_header, make_jwt
|
from tests.conftest import TEST_USER_IDS, make_jwt
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# Helpers
|
# Helpers
|
||||||
|
|||||||
@@ -40,11 +40,9 @@ Coverage:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch
|
from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
|||||||
299
tests/test_journey_v2.py
Normal file
299
tests/test_journey_v2.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
"""Tests for Local Agent V2 journey setup (Step 4).
|
||||||
|
|
||||||
|
Covers the chatbot journey that produces a structured AgentConfig JSON
|
||||||
|
instead of a freeform prompt_template string.
|
||||||
|
|
||||||
|
Unit tests (no LLM)
|
||||||
|
--------------------
|
||||||
|
4.6a _extract_agent_config: valid JSON → returns serialised config
|
||||||
|
4.6b _extract_agent_config: invalid JSON → returns None
|
||||||
|
4.6c _extract_agent_config: markers absent → returns None
|
||||||
|
4.6d _extract_agent_config: only START marker → returns None
|
||||||
|
4.6e Session not found → done=True, agent_config=None
|
||||||
|
4.6f Nudge uses AGENT_CONFIG_START/END markers (not old PROMPT_TEMPLATE)
|
||||||
|
|
||||||
|
Eval test (real LLM + Langfuse scoring)
|
||||||
|
----------------------------------------
|
||||||
|
4.1 Journey start explores directory → first reply contains a question
|
||||||
|
|
||||||
|
Cases 4.2–4.5 (multi-turn conversations producing a full AgentConfig) are
|
||||||
|
non-deterministic and tested manually — results tracked in Langfuse.
|
||||||
|
|
||||||
|
Run:
|
||||||
|
pytest tests/test_journey_v2.py -v
|
||||||
|
pytest tests/test_journey_v2.py -v -k "4_6" # unit only
|
||||||
|
pytest tests/test_journey_v2.py -v -k "eval" # single LLM eval
|
||||||
|
pytest tests/test_journey_v2.py -v --journey-dir /p # custom fixtures
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from contextlib import nullcontext
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.api.routes.agent_setup import (
|
||||||
|
_CONFIG_END,
|
||||||
|
_CONFIG_START,
|
||||||
|
_MAX_TURNS,
|
||||||
|
_extract_agent_config,
|
||||||
|
_sessions,
|
||||||
|
handle_journey_message,
|
||||||
|
handle_journey_start,
|
||||||
|
)
|
||||||
|
from app.core.langfuse_client import get_langfuse
|
||||||
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
|
from app.schemas import AgentConfig
|
||||||
|
from tests.conftest import TEST_USER_IDS
|
||||||
|
|
||||||
|
# ── Constants ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_USER_ID = TEST_USER_IDS["power"]
|
||||||
|
|
||||||
|
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "journey_v2"
|
||||||
|
|
||||||
|
# ── Fixture loading ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _fixtures_dir(config) -> Path:
|
||||||
|
override = config.getoption("--journey-dir")
|
||||||
|
return Path(override) if override else _DEFAULT_FIXTURE_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cases(config) -> list[dict]:
|
||||||
|
return yaml.safe_load(
|
||||||
|
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _read_data_file(filename: str, fixtures_dir: Path) -> str:
|
||||||
|
return (fixtures_dir / "data" / filename).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
# ── pytest_generate_tests ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "journey_case" not in metafunc.fixturenames:
|
||||||
|
return
|
||||||
|
cases = _load_cases(metafunc.config)
|
||||||
|
metafunc.parametrize("journey_case", cases, ids=[c["id"] for c in cases])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Executor builder ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _make_fs_executor(directory_files: list[dict], fixtures_dir: Path):
|
||||||
|
"""Return an async callback that simulates filesystem tool responses.
|
||||||
|
|
||||||
|
Matches the signature expected by ``set_client_executor`` / ``execute_on_client``:
|
||||||
|
receives the full ``payload`` dict and returns a result dict.
|
||||||
|
|
||||||
|
``directory_files`` is a list of ``{path, content_file}`` dicts;
|
||||||
|
``content_file`` is relative to ``fixtures_dir/data/``.
|
||||||
|
"""
|
||||||
|
file_map: dict[str, str] = {
|
||||||
|
entry["path"]: _read_data_file(entry["content_file"], fixtures_dir)
|
||||||
|
for entry in directory_files
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _executor(payload: dict) -> dict:
|
||||||
|
action = payload.get("action", "")
|
||||||
|
data = payload.get("data") or {}
|
||||||
|
|
||||||
|
if action == "list_directory":
|
||||||
|
return {"entries": [
|
||||||
|
{"type": "file", "name": p.split("/")[-1], "path": p}
|
||||||
|
for p in file_map
|
||||||
|
]}
|
||||||
|
|
||||||
|
if action == "read_file_content":
|
||||||
|
path = data.get("path", "")
|
||||||
|
return {"content": file_map.get(path, "")}
|
||||||
|
|
||||||
|
if action == "get_file_metadata":
|
||||||
|
path = data.get("path", "")
|
||||||
|
name = path.split("/")[-1]
|
||||||
|
ext = "." + name.rsplit(".", 1)[-1] if "." in name else ""
|
||||||
|
return {"name": name, "extension": ext, "size": 1024,
|
||||||
|
"createdAt": None, "modifiedAt": None}
|
||||||
|
|
||||||
|
return {}
|
||||||
|
|
||||||
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
|
# ── Journey runner helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _run_journey(user_id: str, case: dict, executor) -> dict[str, Any]:
|
||||||
|
"""Drive start + all user_messages for a case. Returns the final reply dict.
|
||||||
|
|
||||||
|
Mirrors ``device_ws._handle_journey_start/message``: sets the client
|
||||||
|
executor (so filesystem tools work) before each handler call.
|
||||||
|
"""
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
try:
|
||||||
|
set_client_executor(executor)
|
||||||
|
reply = await handle_journey_start(user_id, {
|
||||||
|
"agent_type": "local",
|
||||||
|
"directory": case["directory"],
|
||||||
|
"data_types": case["data_types"],
|
||||||
|
"session_id": session_id,
|
||||||
|
})
|
||||||
|
|
||||||
|
for msg in case.get("user_messages", []):
|
||||||
|
if reply.get("done"):
|
||||||
|
break
|
||||||
|
set_client_executor(executor)
|
||||||
|
reply = await handle_journey_message(user_id, {
|
||||||
|
"session_id": reply["session_id"],
|
||||||
|
"message": msg,
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
|
return reply
|
||||||
|
|
||||||
|
|
||||||
|
# ── Assertion helper ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _evaluate_case(case: dict, reply: dict) -> tuple[float, str]:
|
||||||
|
"""Return (score, comment) for a journey case given the final reply dict."""
|
||||||
|
if case.get("expect_question"):
|
||||||
|
has_q = "?" in reply.get("message", "")
|
||||||
|
return (1.0 if has_q else 0.0), f"first_reply_has_question={has_q}"
|
||||||
|
|
||||||
|
return 1.0, "no specific assertion"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit tests ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_4_6a_extract_valid_json():
|
||||||
|
"""_extract_agent_config: valid JSON between markers → returns serialised config."""
|
||||||
|
config = AgentConfig(
|
||||||
|
content_types=[],
|
||||||
|
global_rules=["No project = no entity"],
|
||||||
|
data_types=["tasks"],
|
||||||
|
)
|
||||||
|
text = f"Some preamble\n{_CONFIG_START}\n{config.model_dump_json()}\n{_CONFIG_END}\nTrailing"
|
||||||
|
result = _extract_agent_config(text)
|
||||||
|
assert result is not None
|
||||||
|
parsed = AgentConfig.model_validate_json(result)
|
||||||
|
assert parsed.global_rules == ["No project = no entity"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_4_6b_extract_invalid_json():
|
||||||
|
"""_extract_agent_config: malformed JSON between markers → returns None."""
|
||||||
|
text = f"{_CONFIG_START}\n{{not: valid json\n{_CONFIG_END}"
|
||||||
|
assert _extract_agent_config(text) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_4_6c_extract_markers_absent():
|
||||||
|
"""_extract_agent_config: no markers at all → returns None."""
|
||||||
|
assert _extract_agent_config("No markers here at all") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_4_6d_extract_only_start_marker():
|
||||||
|
"""_extract_agent_config: START without END → returns None."""
|
||||||
|
assert _extract_agent_config(f"text {_CONFIG_START} no end marker") is None
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_4_6e_session_not_found():
|
||||||
|
"""4.6e Session not found → done=True, agent_config=None, informative message."""
|
||||||
|
reply = await handle_journey_message(_USER_ID, {
|
||||||
|
"session_id": "nonexistent-session-id",
|
||||||
|
"message": "Hello",
|
||||||
|
})
|
||||||
|
assert reply["done"] is True
|
||||||
|
assert reply["agent_config"] is None
|
||||||
|
assert "not found" in reply["message"].lower() or "expired" in reply["message"].lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_4_6f_nudge_uses_new_markers():
|
||||||
|
"""4.6f Nudge injected after max turns uses AGENT_CONFIG markers, not PROMPT_TEMPLATE."""
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
captured_histories: list[list[dict]] = []
|
||||||
|
|
||||||
|
async def _mock_llm(system_prompt, history, tools, **kwargs) -> str:
|
||||||
|
captured_histories.append(list(history))
|
||||||
|
# Return plain text — no markers — to trigger the nudge path.
|
||||||
|
return "I still need more information from you."
|
||||||
|
|
||||||
|
from app.api.routes.agent_setup import JourneySession
|
||||||
|
|
||||||
|
fake_session = JourneySession(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=_USER_ID,
|
||||||
|
agent_type="local",
|
||||||
|
directory="/test",
|
||||||
|
data_types=["tasks"],
|
||||||
|
system_prompt="system",
|
||||||
|
langfuse_prompt=None,
|
||||||
|
)
|
||||||
|
# Fill history to the turn limit so the next message triggers the nudge.
|
||||||
|
for i in range(_MAX_TURNS):
|
||||||
|
fake_session.history.append({"role": "user", "content": f"msg {i}"})
|
||||||
|
fake_session.history.append({"role": "assistant", "content": "ok"})
|
||||||
|
_sessions[session_id] = fake_session
|
||||||
|
|
||||||
|
try:
|
||||||
|
with patch("app.api.routes.agent_setup._call_llm_with_tools", side_effect=_mock_llm):
|
||||||
|
await handle_journey_message(_USER_ID, {
|
||||||
|
"session_id": session_id,
|
||||||
|
"message": "one more message to trigger nudge",
|
||||||
|
})
|
||||||
|
finally:
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
|
||||||
|
# Second LLM call receives the nudge appended to history.
|
||||||
|
assert len(captured_histories) >= 2, "Expected ≥ 2 LLM calls (main reply + nudge)"
|
||||||
|
nudge_history = captured_histories[1]
|
||||||
|
user_msgs = " ".join(t["content"] for t in nudge_history if t["role"] == "user")
|
||||||
|
assert _CONFIG_START in user_msgs, f"Nudge must reference {_CONFIG_START}"
|
||||||
|
assert _CONFIG_END in user_msgs, f"Nudge must reference {_CONFIG_END}"
|
||||||
|
assert "PROMPT_TEMPLATE" not in user_msgs, "Old PROMPT_TEMPLATE markers must not appear in nudge"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Eval tests (real LLM + Langfuse) ─────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.eval
|
||||||
|
async def test_eval_journey(journey_case, pytestconfig):
|
||||||
|
"""Parametrized eval test — one invocation per YAML case."""
|
||||||
|
case: dict = journey_case
|
||||||
|
fixtures_dir = _fixtures_dir(pytestconfig)
|
||||||
|
executor = _make_fs_executor(case.get("directory_files", []), fixtures_dir)
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
obs_ctx = lf.start_as_current_observation(
|
||||||
|
name=f"eval-journey-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
|
||||||
|
metadata={"step": "4", "case_id": case["id"]},
|
||||||
|
) if lf else nullcontext()
|
||||||
|
|
||||||
|
with obs_ctx as obs:
|
||||||
|
reply = await _run_journey(_USER_ID, case, executor)
|
||||||
|
score, comment = _evaluate_case(case, reply)
|
||||||
|
|
||||||
|
if obs is not None:
|
||||||
|
obs.score(
|
||||||
|
name=case.get("score_name", f"journey.case_{case['id']}"),
|
||||||
|
value=score,
|
||||||
|
comment=comment,
|
||||||
|
)
|
||||||
|
|
||||||
|
if lf:
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
assert score == 1.0, f"[{case['id']}] {case.get('description', '')} — {comment}"
|
||||||
@@ -12,14 +12,15 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from unittest.mock import patch
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD
|
from app.core.embeddings import embed_text
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.models import (
|
from app.models import (
|
||||||
@@ -341,3 +342,33 @@ def test_home_request_calls_memory_middleware(client):
|
|||||||
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
|
||||||
assert stored_session_id == session_id
|
assert stored_session_id == session_id
|
||||||
assert stored_message == "Show tasks"
|
assert stored_message == "Show tasks"
|
||||||
|
|
||||||
|
|
||||||
|
# ── embed_text ─────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_text_returns_1536_floats():
|
||||||
|
"""embed_text returns a 1536-dim float list when OpenAI responds successfully."""
|
||||||
|
fake_embedding = [0.1] * 1536
|
||||||
|
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.data = [MagicMock(embedding=fake_embedding)]
|
||||||
|
|
||||||
|
mock_client = MagicMock()
|
||||||
|
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.embeddings.AsyncOpenAI", return_value=mock_client):
|
||||||
|
result = await embed_text("test text")
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) == 1536
|
||||||
|
assert all(isinstance(x, float) for x in result)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embed_text_returns_none_on_failure():
|
||||||
|
"""embed_text returns None when OpenAI raises; must not propagate the exception."""
|
||||||
|
with patch("app.core.embeddings.AsyncOpenAI", side_effect=Exception("no key")):
|
||||||
|
result = await embed_text("test text")
|
||||||
|
|
||||||
|
assert result is None
|
||||||
|
|||||||
@@ -7,10 +7,9 @@ column is stored as JSON in tests (SQLite-compatible).
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
|||||||
97
tests/test_preprocessors.py
Normal file
97
tests/test_preprocessors.py
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
"""Tests for the preprocessor system (Step 1 — Local Agent V2).
|
||||||
|
|
||||||
|
Run:
|
||||||
|
pytest tests/test_preprocessors.py -v
|
||||||
|
pytest tests/test_preprocessors.py -v --preprocess-dir /path/to/folder
|
||||||
|
|
||||||
|
The folder must contain cases.yaml + data/.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import yaml
|
||||||
|
|
||||||
|
from app.core.preprocessors import detect_content_type, preprocess
|
||||||
|
|
||||||
|
_DEFAULT_DIR = Path(__file__).parent / "fixtures" / "preprocessors"
|
||||||
|
|
||||||
|
_GENERATORS = {
|
||||||
|
"binary_noise": "some\x00\x01\x02\x03\x04\x05content" * 20,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _fixtures_dir(config) -> Path:
|
||||||
|
override = config.getoption("--preprocess-dir")
|
||||||
|
return Path(override) if override else _DEFAULT_DIR
|
||||||
|
|
||||||
|
|
||||||
|
def _load_cases(config) -> list[dict]:
|
||||||
|
return yaml.safe_load((_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8"))
|
||||||
|
|
||||||
|
|
||||||
|
def _content(case: dict, data_dir: Path) -> str:
|
||||||
|
if "generate" in case:
|
||||||
|
return _GENERATORS[case["generate"]]
|
||||||
|
return (data_dir / case["file"]).read_text(encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
# ── parametrize at collection time via pytest hook ────────────────────
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "preprocess_case" not in metafunc.fixturenames:
|
||||||
|
return
|
||||||
|
cases = _load_cases(metafunc.config)
|
||||||
|
test_name = metafunc.function.__name__
|
||||||
|
if test_name == "test_detect":
|
||||||
|
subset = [c for c in cases if "detect" in c]
|
||||||
|
else:
|
||||||
|
subset = [c for c in cases if "process" in c]
|
||||||
|
metafunc.parametrize("preprocess_case", subset, ids=[c["id"] for c in subset])
|
||||||
|
|
||||||
|
|
||||||
|
# ── detect ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_detect(preprocess_case, pytestconfig) -> None:
|
||||||
|
case = preprocess_case
|
||||||
|
data_dir = _fixtures_dir(pytestconfig) / "data"
|
||||||
|
raw = _content(case, data_dir)
|
||||||
|
filename = case.get("file", "")
|
||||||
|
ct = detect_content_type(filename, raw)
|
||||||
|
expected = case["detect"]
|
||||||
|
assert ct == expected, f"[{case['id']}] expected {expected!r}, got {ct!r}"
|
||||||
|
|
||||||
|
|
||||||
|
# ── preprocess ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def test_preprocess(preprocess_case, pytestconfig) -> None:
|
||||||
|
case = preprocess_case
|
||||||
|
data_dir = _fixtures_dir(pytestconfig) / "data"
|
||||||
|
raw = _content(case, data_dir)
|
||||||
|
result = preprocess(case["process"], raw)
|
||||||
|
|
||||||
|
if case.get("no_html"):
|
||||||
|
assert not re.search(r"<[^>]+>", result.clean_text), "clean_text contains HTML tags"
|
||||||
|
|
||||||
|
if "min_chars" in case:
|
||||||
|
assert len(result.clean_text) >= case["min_chars"], \
|
||||||
|
f"clean_text too short: {len(result.clean_text)} < {case['min_chars']}"
|
||||||
|
|
||||||
|
if "ratio_lt" in case:
|
||||||
|
ratio = len(result.clean_text) / len(raw)
|
||||||
|
assert ratio < case["ratio_lt"], f"compression ratio {ratio:.2f} >= {case['ratio_lt']}"
|
||||||
|
|
||||||
|
for key in case.get("has_meta", []):
|
||||||
|
assert result.metadata.get(key), f"metadata missing {key!r} (got {result.metadata})"
|
||||||
|
|
||||||
|
for item in ([case["contains"]] if isinstance(case.get("contains"), str) else case.get("contains", [])):
|
||||||
|
assert item in result.clean_text, f"clean_text missing {item!r}"
|
||||||
|
|
||||||
|
for item in ([case["excludes"]] if isinstance(case.get("excludes"), str) else case.get("excludes", [])):
|
||||||
|
assert item not in result.clean_text, f"clean_text contains forbidden {item!r}"
|
||||||
|
|
||||||
|
if "content_type" in case:
|
||||||
|
assert result.content_type == case["content_type"], \
|
||||||
|
f"expected content_type {case['content_type']!r}, got {result.content_type!r}"
|
||||||
Reference in New Issue
Block a user