Compare commits
10 Commits
feat/proje
...
feature/de
| Author | SHA1 | Date | |
|---|---|---|---|
| 47bf1881e5 | |||
| 24a9c1b752 | |||
| 706bf88883 | |||
| 4ff0b27084 | |||
| 61d2a18234 | |||
| b3687719b6 | |||
| f80bdfa8f7 | |||
| 617a17db40 | |||
| 92716cb89a | |||
| cfc9d7a942 |
85
.env.example
85
.env.example
@@ -2,7 +2,7 @@
|
|||||||
ENV=dev
|
ENV=dev
|
||||||
|
|
||||||
# ── Database ──────────────────────────────────────────────────────────────────
|
# ── Database ──────────────────────────────────────────────────────────────────
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
||||||
|
|
||||||
# ── Auth ──────────────────────────────────────────────────────────────────────
|
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||||
JWT_SECRET=replace-with-a-long-random-secret
|
JWT_SECRET=replace-with-a-long-random-secret
|
||||||
@@ -13,82 +13,31 @@ JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
|||||||
# ── LLM ───────────────────────────────────────────────────────────────────────
|
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||||
# LiteLLM model identifiers — change to swap providers without code changes.
|
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||||
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||||
#
|
|
||||||
# API keys — only the key(s) matching your chosen provider(s) are required.
|
|
||||||
# The correct key is picked automatically from the model prefix (e.g.
|
|
||||||
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
|
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
CEREBRAS_API_KEY=
|
LLM_MODEL=gpt-4o
|
||||||
GROQ_API_KEY=
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
DEEPSEEK_API_KEY=
|
|
||||||
|
|
||||||
# Default model used by any agent that does not have a specific override below.
|
|
||||||
LLM_MODEL=gpt-5-mini
|
|
||||||
LLM_EMBED_MODEL=text-embedding-3-small
|
|
||||||
|
|
||||||
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
|
|
||||||
# In Docker, point this to a named-volume path so tokens survive restarts.
|
|
||||||
# GITHUB_COPILOT_TOKEN_DIR=
|
|
||||||
|
|
||||||
# ── Per-agent model overrides ─────────────────────────────────────────────────
|
|
||||||
# Leave a value empty to fall back to LLM_MODEL.
|
|
||||||
# Each agent resolves its API key from the model prefix automatically.
|
|
||||||
#
|
|
||||||
# Intent classifier — routes user messages to the right domain agent.
|
|
||||||
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
|
|
||||||
LLM_MODEL_CLASSIFIER=
|
|
||||||
|
|
||||||
# Home-agent — handles chat from the home screen (all tools available).
|
|
||||||
LLM_MODEL_HOME_AGENT=
|
|
||||||
|
|
||||||
# Floating-agent — handles contextual chat triggered from a task/project/note.
|
|
||||||
LLM_MODEL_FLOATING_AGENT=
|
|
||||||
|
|
||||||
# Unified-processor — processes local directory files (local agent runner).
|
|
||||||
LLM_MODEL_UNIFIED_PROCESSOR=
|
|
||||||
|
|
||||||
# Cloud-processor — fetches and processes data from cloud connectors.
|
|
||||||
LLM_MODEL_CLOUD_PROCESSOR=
|
|
||||||
|
|
||||||
# Brief-agent — produces home and project text briefs.
|
|
||||||
# A small model (e.g. gpt-4o-mini) is sufficient.
|
|
||||||
# LLM_MODEL_BRIEF_AGENT=
|
|
||||||
|
|
||||||
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
|
|
||||||
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
|
|
||||||
# LLM_MODEL_TASK_BRIEF_AGENT=
|
|
||||||
|
|
||||||
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
|
||||||
LLM_MODEL_SETUP_AGENT=
|
|
||||||
|
|
||||||
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
|
|
||||||
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
|
|
||||||
LLM_MODEL_MEMORY_EXTRACTOR=
|
|
||||||
|
|
||||||
# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only).
|
|
||||||
# Defaults to gpt-4o-mini when empty.
|
|
||||||
LLM_MODEL_MEMORY_MINER=
|
|
||||||
|
|
||||||
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
|
|
||||||
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
|
|
||||||
LLM_MODEL_MEMORY_AUDITOR=
|
|
||||||
|
|
||||||
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
|
|
||||||
SCHEDULER_ENABLED=true
|
|
||||||
|
|
||||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
|
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
||||||
|
S3_BUCKET=adiuva
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
S3_ENDPOINT_URL=
|
||||||
|
AWS_ACCESS_KEY_ID=
|
||||||
|
AWS_SECRET_ACCESS_KEY=
|
||||||
|
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
|
||||||
# ── Langfuse (leave empty to disable observability) ───────────────────────────
|
# ── Vector Store ──────────────────────────────────────────────────────────────
|
||||||
LANGFUSE_SECRET_KEY=
|
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
|
||||||
LANGFUSE_PUBLIC_KEY=
|
PINECONE_API_KEY=
|
||||||
# LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default)
|
PINECONE_INDEX=adiuva
|
||||||
# LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US
|
QDRANT_URL=
|
||||||
# LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted
|
QDRANT_API_KEY=
|
||||||
|
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
||||||
|
|
||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
|
|||||||
@@ -48,23 +48,23 @@ jobs:
|
|||||||
key: ${{ secrets.SSH_KEY }}
|
key: ${{ secrets.SSH_KEY }}
|
||||||
script: |
|
script: |
|
||||||
set -e
|
set -e
|
||||||
DEPLOY_DIR="/opt/adiuvai-api"
|
DEPLOY_DIR="/opt/adiuva-api"
|
||||||
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||||
TAG="${{ gitea.ref_name }}"
|
TAG="${{ gitea.ref_name }}"
|
||||||
|
|
||||||
# ── Pull latest code ──
|
# ── Pull latest code ──
|
||||||
cd /tmp && rm -rf adiuvai-api-deploy
|
cd /tmp && rm -rf adiuva-api-deploy
|
||||||
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
|
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
|
||||||
|
|
||||||
# ── Sync source (preserve .env) ──
|
# ── Sync source (preserve .env) ──
|
||||||
cp -rf /tmp/adiuvai-api-deploy/app/ \
|
cp -rf /tmp/adiuva-api-deploy/app/ \
|
||||||
/tmp/adiuvai-api-deploy/alembic/ \
|
/tmp/adiuva-api-deploy/alembic/ \
|
||||||
/tmp/adiuvai-api-deploy/alembic.ini \
|
/tmp/adiuva-api-deploy/alembic.ini \
|
||||||
/tmp/adiuvai-api-deploy/Dockerfile \
|
/tmp/adiuva-api-deploy/Dockerfile \
|
||||||
/tmp/adiuvai-api-deploy/docker-compose.yml \
|
/tmp/adiuva-api-deploy/docker-compose.yml \
|
||||||
/tmp/adiuvai-api-deploy/requirements.txt \
|
/tmp/adiuva-api-deploy/requirements.txt \
|
||||||
"$DEPLOY_DIR/"
|
"$DEPLOY_DIR/"
|
||||||
rm -rf /tmp/adiuvai-api-deploy
|
rm -rf /tmp/adiuva-api-deploy
|
||||||
|
|
||||||
# ── Verify .env ──
|
# ── Verify .env ──
|
||||||
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||||
|
|||||||
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Build image
|
- name: Build image
|
||||||
run: docker build -t adiuvai-api:ci .
|
run: docker build -t adiuva-api:ci .
|
||||||
|
|
||||||
- name: Verify gunicorn installed
|
- name: Verify gunicorn installed
|
||||||
run: docker run --rm adiuvai-api:ci gunicorn --version
|
run: docker run --rm adiuva-api:ci gunicorn --version
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -21,16 +21,12 @@ env/
|
|||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.coverage
|
.coverage
|
||||||
tests/fixtures/private*/
|
|
||||||
|
|
||||||
# Docker
|
# Docker
|
||||||
*.log
|
*.log
|
||||||
|
|
||||||
# OS
|
# OS
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
|
||||||
# Smoke scripts (dev-only, not for CI)
|
|
||||||
scripts/smoke_*.py
|
|
||||||
Thumbs.db
|
Thumbs.db
|
||||||
|
|
||||||
# Claude Code
|
# Claude Code
|
||||||
|
|||||||
794
README.md
794
README.md
@@ -1,5 +1,793 @@
|
|||||||
## DEV
|
# Adiuva Cloud API
|
||||||
Run in DEV with command:
|
|
||||||
|
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
|
||||||
|
|
||||||
|
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Table of Contents
|
||||||
|
|
||||||
|
- [Overview](#overview)
|
||||||
|
- [Architecture](#architecture)
|
||||||
|
- [Key Features](#key-features)
|
||||||
|
- [Tech Stack](#tech-stack)
|
||||||
|
- [Getting Started](#getting-started)
|
||||||
|
- [Docker Deployment](#docker-deployment)
|
||||||
|
- [Environment Variables](#environment-variables)
|
||||||
|
- [API Reference](#api-reference)
|
||||||
|
- [Data Model](#data-model)
|
||||||
|
- [AI Agent System](#ai-agent-system)
|
||||||
|
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
||||||
|
- [Middleware](#middleware)
|
||||||
|
- [Storage Layer](#storage-layer)
|
||||||
|
- [Billing & Tiers](#billing--tiers)
|
||||||
|
- [Plugin Marketplace](#plugin-marketplace)
|
||||||
|
- [Testing](#testing)
|
||||||
|
- [Project Structure](#project-structure)
|
||||||
|
- [License](#license)
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
|
||||||
|
|
||||||
|
### Design Principles
|
||||||
|
|
||||||
|
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
|
||||||
|
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
||||||
|
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
|
||||||
|
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
||||||
|
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Architecture
|
||||||
|
|
||||||
```
|
```
|
||||||
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-config logging.conf
|
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
|
||||||
|
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
|
||||||
|
│ Desktop App │────▶│ │
|
||||||
|
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
|
||||||
|
└──────────────┘ │ │
|
||||||
|
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
||||||
|
│ │ Auth Routes │ │ Chat Routes │ │
|
||||||
|
│ │ Billing Routes │ │ ↓ │ │
|
||||||
|
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
||||||
|
│ │ Backup Routes │ │ ↓ classify intent │ │
|
||||||
|
│ │ Plugin Routes │ │ Agent Registry │ │
|
||||||
|
│ │ Vector Routes │ │ ↓ │ │
|
||||||
|
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
|
||||||
|
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
|
||||||
|
│ │ (GPT-4o + LangChain) │ │
|
||||||
|
│ └────────────────────────────┘ │
|
||||||
|
└────────────────────────────────────────────────────────┘
|
||||||
|
│ │ │
|
||||||
|
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
|
||||||
|
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
|
||||||
|
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
|
||||||
|
│ Billing, │ │ backups) │ │ (Vectors) │
|
||||||
|
│ Metadata) │ └───────────────┘ └────────────────┘
|
||||||
|
└────────────┘
|
||||||
|
│
|
||||||
|
┌────────▼───┐
|
||||||
|
│ Stripe │
|
||||||
|
│ (Billing, │
|
||||||
|
│ Connect) │
|
||||||
|
└────────────┘
|
||||||
```
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
||||||
|
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
||||||
|
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
||||||
|
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
||||||
|
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
||||||
|
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
|
||||||
|
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
|
||||||
|
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
||||||
|
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
||||||
|
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
||||||
|
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
||||||
|
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
|
||||||
|
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
||||||
|
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
|
||||||
|
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Tech Stack
|
||||||
|
|
||||||
|
| Package | Version | Purpose |
|
||||||
|
|---|---|---|
|
||||||
|
| `fastapi` | ≥ 0.115.0 | Web framework |
|
||||||
|
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
|
||||||
|
| `gunicorn` | ≥ 22.0.0 | Production process manager |
|
||||||
|
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
|
||||||
|
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
|
||||||
|
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
|
||||||
|
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
|
||||||
|
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
||||||
|
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
||||||
|
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
||||||
|
| `boto3` | ≥ 1.35.0 | AWS S3 client |
|
||||||
|
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
||||||
|
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
||||||
|
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
||||||
|
| `alembic` | ≥ 1.14.0 | Database migration management |
|
||||||
|
| `bcrypt` | ≥ 4.2.0 | Password hashing |
|
||||||
|
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
|
||||||
|
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
||||||
|
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
||||||
|
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
||||||
|
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
|
||||||
|
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
|
||||||
|
| `pytest` | ≥ 8.0.0 | Test framework |
|
||||||
|
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
||||||
|
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
||||||
|
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
|
||||||
|
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- Python 3.12+
|
||||||
|
- PostgreSQL 16+
|
||||||
|
- An OpenAI API key (for LLM features)
|
||||||
|
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
||||||
|
- AWS credentials (optional — needed for S3 storage in production)
|
||||||
|
|
||||||
|
### Installation
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Clone the repository
|
||||||
|
git clone <repo-url> && cd adiuva-api
|
||||||
|
|
||||||
|
# Create a virtual environment
|
||||||
|
python -m venv .venv && source .venv/bin/activate
|
||||||
|
|
||||||
|
# Install dependencies
|
||||||
|
pip install -r requirements.txt
|
||||||
|
|
||||||
|
# Configure environment
|
||||||
|
cp .env.example .env
|
||||||
|
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Database Setup
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Start PostgreSQL (or use the Docker Compose database)
|
||||||
|
docker compose up db -d
|
||||||
|
|
||||||
|
# Run migrations
|
||||||
|
alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### Run the Development Server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
||||||
|
```
|
||||||
|
|
||||||
|
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Docker Deployment
|
||||||
|
|
||||||
|
### Quick Start
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up --build
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts two services:
|
||||||
|
|
||||||
|
- **app** — FastAPI server on port `8000`
|
||||||
|
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
||||||
|
|
||||||
|
The compose file also includes optional services for fully local deployments:
|
||||||
|
|
||||||
|
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
|
||||||
|
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
|
||||||
|
|
||||||
|
### Dockerfile Details
|
||||||
|
|
||||||
|
The Dockerfile uses a multi-stage build:
|
||||||
|
|
||||||
|
1. **Builder stage** — Installs Python dependencies into a virtual environment.
|
||||||
|
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
|
||||||
|
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Production command (run by the container)
|
||||||
|
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Homelab / Self-Hosted Deployment
|
||||||
|
|
||||||
|
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
|
||||||
|
|
||||||
|
### 1. Start all services
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose up -d
|
||||||
|
```
|
||||||
|
|
||||||
|
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
|
||||||
|
|
||||||
|
### 2. Create the MinIO bucket
|
||||||
|
|
||||||
|
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
|
||||||
|
docker compose exec minio mc mb local/adiuva
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Configure your `.env`
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Database (uses the compose PostgreSQL)
|
||||||
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
|
||||||
|
# S3 → MinIO
|
||||||
|
S3_BUCKET=adiuva
|
||||||
|
S3_REGION=us-east-1
|
||||||
|
S3_ENDPOINT_URL=http://minio:9000
|
||||||
|
AWS_ACCESS_KEY_ID=minioadmin
|
||||||
|
AWS_SECRET_ACCESS_KEY=minioadmin
|
||||||
|
|
||||||
|
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
|
||||||
|
QDRANT_URL=http://qdrant:6333
|
||||||
|
QDRANT_API_KEY=
|
||||||
|
PINECONE_API_KEY=
|
||||||
|
|
||||||
|
# Billing — leave empty to stub (no Stripe needed)
|
||||||
|
STRIPE_SECRET_KEY=
|
||||||
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
|
# LLM — the only external service
|
||||||
|
OPENAI_API_KEY=sk-...
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# Auth
|
||||||
|
JWT_SECRET=your-secret-here
|
||||||
|
ENV=dev
|
||||||
|
```
|
||||||
|
|
||||||
|
### 4. Run migrations
|
||||||
|
|
||||||
|
```bash
|
||||||
|
docker compose exec app alembic upgrade head
|
||||||
|
```
|
||||||
|
|
||||||
|
### What runs where
|
||||||
|
|
||||||
|
| Service | Runs on | Port | Notes |
|
||||||
|
|---|---|---|---|
|
||||||
|
| FastAPI app | Docker | 8000 | API server |
|
||||||
|
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
|
||||||
|
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
|
||||||
|
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
|
||||||
|
| Stripe | — | — | Stubbed when keys are empty |
|
||||||
|
| OpenAI / LLM | Cloud | — | Only external dependency |
|
||||||
|
|
||||||
|
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Environment Variables
|
||||||
|
|
||||||
|
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
|
||||||
|
|
||||||
|
| Variable | Type | Default | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
|
||||||
|
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
|
||||||
|
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
|
||||||
|
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
||||||
|
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
||||||
|
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
||||||
|
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
|
||||||
|
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
|
||||||
|
| `S3_REGION` | `str` | `us-east-1` | AWS region |
|
||||||
|
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
|
||||||
|
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
|
||||||
|
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
|
||||||
|
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
|
||||||
|
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
|
||||||
|
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
|
||||||
|
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
|
||||||
|
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
|
||||||
|
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
||||||
|
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
||||||
|
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
||||||
|
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## API Reference
|
||||||
|
|
||||||
|
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
|
||||||
|
|
||||||
|
### Health
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
|
||||||
|
|
||||||
|
### Auth
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
|
||||||
|
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
|
||||||
|
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
|
||||||
|
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
|
||||||
|
|
||||||
|
### Chat
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
||||||
|
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
||||||
|
|
||||||
|
### Plans
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
||||||
|
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
||||||
|
|
||||||
|
### Storage (Cloud Records)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
|
||||||
|
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
|
||||||
|
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
|
||||||
|
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
|
||||||
|
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
|
||||||
|
|
||||||
|
### Vectors (Cloud Vector Store)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
|
||||||
|
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
|
||||||
|
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
|
||||||
|
|
||||||
|
### Backup
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
|
||||||
|
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
|
||||||
|
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
|
||||||
|
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
|
||||||
|
|
||||||
|
### Plugins (Marketplace)
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
|
||||||
|
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
|
||||||
|
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
|
||||||
|
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
|
||||||
|
|
||||||
|
### Billing
|
||||||
|
|
||||||
|
| Method | Path | Auth | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
|
||||||
|
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
|
||||||
|
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
|
||||||
|
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Data Model
|
||||||
|
|
||||||
|
9 tables managed by Alembic migrations. Source: `app/models.py`
|
||||||
|
|
||||||
|
### Tables
|
||||||
|
|
||||||
|
| Table | Primary Key | Key Columns | Purpose |
|
||||||
|
|---|---|---|---|
|
||||||
|
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
||||||
|
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
||||||
|
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
||||||
|
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
|
||||||
|
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
|
||||||
|
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
|
||||||
|
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
|
||||||
|
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
|
||||||
|
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
|
||||||
|
|
||||||
|
### Enum Types
|
||||||
|
|
||||||
|
| Enum | Values |
|
||||||
|
|---|---|
|
||||||
|
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
||||||
|
| `plugin_status` | `pending_review`, `approved`, `rejected` |
|
||||||
|
| `review_decision` | `approved`, `rejected` |
|
||||||
|
|
||||||
|
### Migrations
|
||||||
|
|
||||||
|
| Version | Description |
|
||||||
|
|---|---|
|
||||||
|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
|
||||||
|
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## AI Agent System
|
||||||
|
|
||||||
|
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
|
||||||
|
|
||||||
|
### Architecture
|
||||||
|
|
||||||
|
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
|
||||||
|
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
||||||
|
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
||||||
|
|
||||||
|
### Registered Agents
|
||||||
|
|
||||||
|
| Agent | Registry Name | Tools | Description |
|
||||||
|
|---|---|---|---|
|
||||||
|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
||||||
|
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
||||||
|
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
|
||||||
|
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
||||||
|
|
||||||
|
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
||||||
|
|
||||||
|
### Switching LLM Providers
|
||||||
|
|
||||||
|
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# OpenAI (default)
|
||||||
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
|
# Anthropic
|
||||||
|
LLM_MODEL=anthropic/claude-3.5-sonnet
|
||||||
|
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
|
||||||
|
|
||||||
|
# Google Gemini
|
||||||
|
LLM_MODEL=gemini/gemini-pro
|
||||||
|
LLM_ROUTER_MODEL=gemini/gemini-flash
|
||||||
|
|
||||||
|
# Local Ollama
|
||||||
|
LLM_MODEL=ollama/llama3
|
||||||
|
LLM_ROUTER_MODEL=ollama/llama3
|
||||||
|
|
||||||
|
# AWS Bedrock
|
||||||
|
LLM_MODEL=bedrock/anthropic.claude-v2
|
||||||
|
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
|
||||||
|
```
|
||||||
|
|
||||||
|
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Orchestration & Execution Plans
|
||||||
|
|
||||||
|
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
|
||||||
|
|
||||||
|
### Orchestrator
|
||||||
|
|
||||||
|
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
|
||||||
|
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
|
||||||
|
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
|
||||||
|
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
|
||||||
|
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
|
||||||
|
|
||||||
|
### Execution Plans
|
||||||
|
|
||||||
|
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
|
||||||
|
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
|
||||||
|
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
|
||||||
|
|
||||||
|
### Built-in Templates (6)
|
||||||
|
|
||||||
|
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
||||||
|
|
||||||
|
### Built-in Playbooks (2)
|
||||||
|
|
||||||
|
| Playbook | Description |
|
||||||
|
|---|---|
|
||||||
|
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
|
||||||
|
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Middleware
|
||||||
|
|
||||||
|
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
|
||||||
|
|
||||||
|
### JWT Authentication
|
||||||
|
|
||||||
|
Source: `app/api/middleware/auth.py`
|
||||||
|
|
||||||
|
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
|
||||||
|
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
|
||||||
|
- Falls back to `free` when no subscription row exists.
|
||||||
|
- Raises `401 Unauthorized` on invalid or expired tokens.
|
||||||
|
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
||||||
|
|
||||||
|
### Tier-Based Rate Limiter
|
||||||
|
|
||||||
|
Source: `app/api/middleware/rate_limit.py`
|
||||||
|
|
||||||
|
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
|
||||||
|
- Per-user 60-second window sized by subscription tier:
|
||||||
|
|
||||||
|
| Tier | Requests / Minute |
|
||||||
|
|---|---|
|
||||||
|
| Free | 20 |
|
||||||
|
| Pro | 60 |
|
||||||
|
| Power | 120 |
|
||||||
|
| Team | 200 |
|
||||||
|
|
||||||
|
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
|
||||||
|
- **Exempt paths:** register, login, webhook, health
|
||||||
|
|
||||||
|
### Response Sanitizer
|
||||||
|
|
||||||
|
Source: `app/api/middleware/sanitizer.py`
|
||||||
|
|
||||||
|
- Runs only on `/api/v1/chat` endpoints.
|
||||||
|
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
||||||
|
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
||||||
|
- Logs sanitization events as `WARNING`.
|
||||||
|
- Binary responses (storage, backup) are never touched.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Storage Layer
|
||||||
|
|
||||||
|
### Blob Store
|
||||||
|
|
||||||
|
Source: `app/storage/blob_store.py`
|
||||||
|
|
||||||
|
- S3-backed storage for E2E encrypted blobs.
|
||||||
|
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
|
||||||
|
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
|
||||||
|
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
|
||||||
|
- The backend **never inspects or decrypts blob content**.
|
||||||
|
|
||||||
|
### Vector Store
|
||||||
|
|
||||||
|
Source: `app/storage/vector_store.py`
|
||||||
|
|
||||||
|
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
|
||||||
|
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
|
||||||
|
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
|
||||||
|
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
|
||||||
|
- Methods: `upsert()`, `search()`, `delete()`
|
||||||
|
|
||||||
|
### Encryption Utilities
|
||||||
|
|
||||||
|
Source: `app/storage/encryption.py`
|
||||||
|
|
||||||
|
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
|
||||||
|
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
|
||||||
|
- **No decryption key ever reaches the backend.**
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Billing & Tiers
|
||||||
|
|
||||||
|
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
||||||
|
|
||||||
|
### Feature Matrix
|
||||||
|
|
||||||
|
| Feature | Free | Pro | Power | Team |
|
||||||
|
|---|---|---|---|---|
|
||||||
|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
||||||
|
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
||||||
|
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
|
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
||||||
|
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
||||||
|
| Batch Builder | — | — | ✓ | ✓ |
|
||||||
|
| Plugin Marketplace | — | — | ✓ | ✓ |
|
||||||
|
| SSO | — | — | — | ✓ |
|
||||||
|
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
||||||
|
|
||||||
|
### Stripe Integration
|
||||||
|
|
||||||
|
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
|
||||||
|
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
|
||||||
|
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
|
||||||
|
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
|
||||||
|
|
||||||
|
### Tier Manager
|
||||||
|
|
||||||
|
- `get_tier(user_id)` — Returns the user's current billing tier.
|
||||||
|
- `check_feature(tier, feature)` — Boolean feature gate check.
|
||||||
|
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
||||||
|
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Plugin Marketplace
|
||||||
|
|
||||||
|
Source: `app/marketplace/`
|
||||||
|
|
||||||
|
### Plugin Registry
|
||||||
|
|
||||||
|
- PostgreSQL-backed catalog of submitted and approved plugins.
|
||||||
|
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
|
||||||
|
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
|
||||||
|
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
|
||||||
|
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
|
||||||
|
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
|
||||||
|
|
||||||
|
### Review Queue
|
||||||
|
|
||||||
|
- Automated security checklist before human review:
|
||||||
|
- Plugin ID must match `^[a-z0-9-]+$`
|
||||||
|
- Permissions must be from the allowed set only
|
||||||
|
- No binary blobs in the manifest
|
||||||
|
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
|
||||||
|
- `get_pending(db)` — Lists plugins awaiting review.
|
||||||
|
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
||||||
|
|
||||||
|
### Revenue Sharing
|
||||||
|
|
||||||
|
- **70% developer / 30% platform** split on all paid plugin sales.
|
||||||
|
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
|
||||||
|
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
|
||||||
|
- Gracefully stubs transfers when Stripe is not configured.
|
||||||
|
|
||||||
|
### Seed Plugins
|
||||||
|
|
||||||
|
| Plugin | Category | Price |
|
||||||
|
|---|---|---|
|
||||||
|
| GitHub Sync | Productivity | Free |
|
||||||
|
| Slack Notifier | Communication | €4.99 |
|
||||||
|
| Time Tracker | Productivity | €9.99 |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Testing
|
||||||
|
|
||||||
|
### Running Tests
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# Run all tests
|
||||||
|
pytest
|
||||||
|
|
||||||
|
# Run a specific test file
|
||||||
|
pytest tests/test_auth.py
|
||||||
|
|
||||||
|
# Run with verbose output
|
||||||
|
pytest -v
|
||||||
|
```
|
||||||
|
|
||||||
|
### Test Infrastructure
|
||||||
|
|
||||||
|
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
||||||
|
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
|
||||||
|
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
||||||
|
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
||||||
|
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
|
||||||
|
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
||||||
|
- **No external dependencies** — all tests run fully offline.
|
||||||
|
|
||||||
|
### Test Coverage
|
||||||
|
|
||||||
|
| File | Coverage |
|
||||||
|
|---|---|
|
||||||
|
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
||||||
|
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
|
||||||
|
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
|
||||||
|
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
|
||||||
|
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
|
||||||
|
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
|
||||||
|
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
|
||||||
|
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
|
||||||
|
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## Project Structure
|
||||||
|
|
||||||
|
```
|
||||||
|
adiuva-api/
|
||||||
|
├── alembic.ini # Alembic configuration
|
||||||
|
├── BACKEND_PLAN.md # Architecture & design decisions
|
||||||
|
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
||||||
|
├── Dockerfile # Multi-stage production build
|
||||||
|
├── requirements.txt # Python dependencies
|
||||||
|
│
|
||||||
|
├── alembic/ # Database migrations
|
||||||
|
│ ├── env.py # Alembic environment config
|
||||||
|
│ ├── script.py.mako # Migration template
|
||||||
|
│ └── versions/
|
||||||
|
│ ├── 001_initial_schema.py # Tables, indexes, FKs
|
||||||
|
│ └── 002_seed_plugins.py # Seed marketplace plugins
|
||||||
|
│
|
||||||
|
├── app/ # Application source
|
||||||
|
│ ├── main.py # FastAPI app factory, middleware, routes
|
||||||
|
│ ├── db.py # Async SQLAlchemy engine & session
|
||||||
|
│ ├── models.py # SQLAlchemy ORM models (9 tables)
|
||||||
|
│ ├── schemas.py # Pydantic request/response schemas
|
||||||
|
│ │
|
||||||
|
│ ├── config/
|
||||||
|
│ │ └── settings.py # Pydantic Settings (env vars)
|
||||||
|
│ │
|
||||||
|
│ ├── agents/ # LLM-powered domain agents
|
||||||
|
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
||||||
|
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
||||||
|
│ │ ├── timeline_agent.py # Milestones (4 tools)
|
||||||
|
│ │ └── note_agent.py # Markdown notes (5 tools)
|
||||||
|
│ │
|
||||||
|
│ ├── core/ # Orchestration engine
|
||||||
|
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
||||||
|
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
||||||
|
│ │ ├── orchestrator.py # Intent classification & routing
|
||||||
|
│ │ └── execution_plan.py # Plan builder, templates, cache
|
||||||
|
│ │
|
||||||
|
│ ├── api/ # HTTP layer
|
||||||
|
│ │ ├── deps.py # Shared FastAPI dependencies
|
||||||
|
│ │ ├── middleware/
|
||||||
|
│ │ │ ├── auth.py # JWT validation, live tier lookup
|
||||||
|
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
||||||
|
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
||||||
|
│ │ └── routes/
|
||||||
|
│ │ ├── auth.py # Register, login, refresh, me
|
||||||
|
│ │ ├── chat.py # Chat + WebSocket streaming
|
||||||
|
│ │ ├── plans.py # Execution plan playbooks
|
||||||
|
│ │ ├── storage.py # E2E encrypted record CRUD
|
||||||
|
│ │ ├── vectors.py # Vector upsert, search, delete
|
||||||
|
│ │ ├── backup.py # Encrypted backup management
|
||||||
|
│ │ ├── plugins.py # Marketplace browse & install
|
||||||
|
│ │ └── billing.py # Stripe checkout & webhooks
|
||||||
|
│ │
|
||||||
|
│ ├── storage/ # Storage backends
|
||||||
|
│ │ ├── blob_store.py # S3 blob storage
|
||||||
|
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
|
||||||
|
│ │ └── encryption.py # Checksum verification utilities
|
||||||
|
│ │
|
||||||
|
│ ├── billing/ # Subscription management
|
||||||
|
│ │ ├── stripe_service.py # Stripe API integration
|
||||||
|
│ │ └── tier_manager.py # Feature matrix & quota enforcement
|
||||||
|
│ │
|
||||||
|
│ └── marketplace/ # Plugin ecosystem
|
||||||
|
│ ├── plugin_registry.py # Catalog CRUD & search
|
||||||
|
│ ├── plugin_review.py # Security checklist & review queue
|
||||||
|
│ └── revenue_share.py # 70/30 split & Stripe Connect
|
||||||
|
│
|
||||||
|
└── tests/ # Test suite
|
||||||
|
├── conftest.py # Fixtures: DB, S3, auth, seeds
|
||||||
|
├── test_auth.py
|
||||||
|
├── test_orchestrator.py
|
||||||
|
├── test_agents.py
|
||||||
|
├── test_storage.py
|
||||||
|
├── test_backup.py
|
||||||
|
├── test_plugins.py
|
||||||
|
├── test_agent_registry.py
|
||||||
|
├── test_execution_plan.py
|
||||||
|
└── test_middleware.py
|
||||||
|
```
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## License
|
||||||
|
|
||||||
|
*To be determined.*
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import re
|
|||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sqlalchemy import pool
|
from sqlalchemy import engine_from_config, pool
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
# Alembic Config object (gives access to alembic.ini values).
|
# Alembic Config object (gives access to alembic.ini values).
|
||||||
|
|||||||
@@ -1,4 +1,5 @@
|
|||||||
"""Initial schema: users, refresh_tokens, subscriptions.
|
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
||||||
|
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
||||||
|
|
||||||
Revision ID: 001
|
Revision ID: 001
|
||||||
Revises:
|
Revises:
|
||||||
@@ -27,6 +28,18 @@ def upgrade() -> None:
|
|||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
END $$;
|
END $$;
|
||||||
""")
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
# ── users ─────────────────────────────────────────────────────────────
|
# ── users ─────────────────────────────────────────────────────────────
|
||||||
op.create_table(
|
op.create_table(
|
||||||
@@ -75,10 +88,122 @@ def upgrade() -> None:
|
|||||||
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||||
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||||
|
|
||||||
|
# ── storage_records ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"storage_records",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("table_name", sa.String(100), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
||||||
|
|
||||||
|
# ── backup_metadata ───────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"backup_metadata",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("s3_key", sa.String(500), nullable=False),
|
||||||
|
sa.Column("version", sa.Integer, nullable=False),
|
||||||
|
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
||||||
|
sa.Column("checksum", sa.String(64), nullable=False),
|
||||||
|
sa.Column("size_bytes", sa.Integer, nullable=False),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugins ───────────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugins",
|
||||||
|
sa.Column("id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
||||||
|
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
||||||
|
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
||||||
|
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
||||||
|
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
||||||
|
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
||||||
|
sa.Column("rejection_reason", sa.Text, nullable=True),
|
||||||
|
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── plugin_installations ──────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_installations",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
||||||
|
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
||||||
|
|
||||||
|
# ── plugin_reviews ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"plugin_reviews",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
||||||
|
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
||||||
|
sa.Column("notes", sa.Text, nullable=True),
|
||||||
|
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
||||||
|
|
||||||
|
# ── revenue_events ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"revenue_events",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("plugin_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
||||||
|
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
||||||
|
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
||||||
|
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
|
op.drop_table("revenue_events")
|
||||||
|
op.drop_table("plugin_reviews")
|
||||||
|
op.drop_table("plugin_installations")
|
||||||
|
op.drop_table("plugins")
|
||||||
|
op.drop_table("backup_metadata")
|
||||||
|
op.drop_table("storage_records")
|
||||||
op.drop_table("subscriptions")
|
op.drop_table("subscriptions")
|
||||||
op.drop_table("refresh_tokens")
|
op.drop_table("refresh_tokens")
|
||||||
op.drop_table("users")
|
op.drop_table("users")
|
||||||
|
|
||||||
|
op.execute("DROP TYPE IF EXISTS review_decision")
|
||||||
|
op.execute("DROP TYPE IF EXISTS plugin_status")
|
||||||
op.execute("DROP TYPE IF EXISTS billing_tier")
|
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||||
|
|||||||
92
alembic/versions/002_seed_plugins.py
Normal file
92
alembic/versions/002_seed_plugins.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
|
||||||
|
|
||||||
|
Revision ID: 002
|
||||||
|
Revises: 001
|
||||||
|
Create Date: 2026-03-03
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "002"
|
||||||
|
down_revision: Union[str, None] = "001"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
{
|
||||||
|
"id": "plugin-github-sync",
|
||||||
|
"name": "GitHub Sync",
|
||||||
|
"description": "Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
"version": "1.0.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 0,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-slack-notify",
|
||||||
|
"name": "Slack Notifier",
|
||||||
|
"description": "Post task and timeline updates to Slack channels.",
|
||||||
|
"version": "1.2.0",
|
||||||
|
"author_name": "Adiuva",
|
||||||
|
"category": "communication",
|
||||||
|
"price_cents": 499,
|
||||||
|
"permissions": json.dumps(["read:tasks", "read:timelines"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"id": "plugin-time-tracker",
|
||||||
|
"name": "Time Tracker",
|
||||||
|
"description": "Track time spent on tasks with automatic reporting.",
|
||||||
|
"version": "0.9.1",
|
||||||
|
"author_name": "Third Party",
|
||||||
|
"category": "productivity",
|
||||||
|
"price_cents": 999,
|
||||||
|
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
"status": "approved",
|
||||||
|
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
"install_count": 0,
|
||||||
|
"avg_rating": 0.0,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
plugins = sa.table(
|
||||||
|
"plugins",
|
||||||
|
sa.column("id", sa.String),
|
||||||
|
sa.column("name", sa.String),
|
||||||
|
sa.column("description", sa.Text),
|
||||||
|
sa.column("version", sa.String),
|
||||||
|
sa.column("author_name", sa.String),
|
||||||
|
sa.column("category", sa.String),
|
||||||
|
sa.column("price_cents", sa.Integer),
|
||||||
|
sa.column("permissions", sa.Text),
|
||||||
|
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
|
||||||
|
sa.column("s3_package_key", sa.String),
|
||||||
|
sa.column("install_count", sa.Integer),
|
||||||
|
sa.column("avg_rating", sa.Float),
|
||||||
|
)
|
||||||
|
op.bulk_insert(plugins, _SEED_PLUGINS)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.execute(
|
||||||
|
"DELETE FROM plugins WHERE id IN ("
|
||||||
|
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
|
||||||
|
")"
|
||||||
|
)
|
||||||
@@ -14,7 +14,7 @@ from alembic import op
|
|||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
revision: str = "003"
|
revision: str = "003"
|
||||||
down_revision: Union[str, None] = "001"
|
down_revision: Union[str, None] = "002"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|||||||
@@ -1,54 +0,0 @@
|
|||||||
"""Phase 1 — confirm pgvector activation on memory_associative.
|
|
||||||
|
|
||||||
Migration 004 created the embedding column as vector(1536) and added the
|
|
||||||
IVFFlat index. This migration is the Phase-1 checkpoint:
|
|
||||||
1. Ensures the pgvector extension is enabled (idempotent).
|
|
||||||
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
|
|
||||||
memory_associative_embedding_idx (creates it only if absent).
|
|
||||||
|
|
||||||
Revision ID: 005
|
|
||||||
Revises: 9a1f2d0b6c7e
|
|
||||||
Create Date: 2026-04-15
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
revision: str = "005"
|
|
||||||
down_revision: Union[str, None] = "e04100e88ace"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Ensure pgvector extension is enabled (also done in 004, idempotent).
|
|
||||||
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
|
|
||||||
|
|
||||||
# Ensure the canonical Phase-1 IVFFlat index exists.
|
|
||||||
# 004 may have created ix_memory_associative_embedding; this adds the
|
|
||||||
# Phase-1 name memory_associative_embedding_idx if it is missing.
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
DO $$
|
|
||||||
BEGIN
|
|
||||||
IF NOT EXISTS (
|
|
||||||
SELECT 1
|
|
||||||
FROM pg_indexes
|
|
||||||
WHERE tablename = 'memory_associative'
|
|
||||||
AND indexname = 'memory_associative_embedding_idx'
|
|
||||||
) THEN
|
|
||||||
CREATE INDEX memory_associative_embedding_idx
|
|
||||||
ON memory_associative
|
|
||||||
USING ivfflat (embedding vector_cosine_ops)
|
|
||||||
WITH (lists = 100);
|
|
||||||
END IF;
|
|
||||||
END $$;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")
|
|
||||||
@@ -1,74 +0,0 @@
|
|||||||
"""Add memory_relations table (Phase 3 — relational tier).
|
|
||||||
|
|
||||||
Revision ID: 006
|
|
||||||
Revises: 1f5975a4f3f4
|
|
||||||
Create Date: 2026-04-16
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
revision: str = "006"
|
|
||||||
down_revision: Union[str, None] = "1f5975a4f3f4"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"memory_relations",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
|
|
||||||
sa.Column(
|
|
||||||
"user_id",
|
|
||||||
postgresql.UUID(as_uuid=False),
|
|
||||||
sa.ForeignKey("users.id", ondelete="CASCADE"),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("subject_label", sa.String(128), nullable=False),
|
|
||||||
sa.Column("subject_type", sa.String(32), nullable=False),
|
|
||||||
sa.Column("predicate", sa.String(64), nullable=False),
|
|
||||||
sa.Column("object_label", sa.String(128), nullable=False),
|
|
||||||
sa.Column("object_type", sa.String(32), nullable=False),
|
|
||||||
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
|
|
||||||
sa.Column(
|
|
||||||
"source_episode_id",
|
|
||||||
postgresql.UUID(as_uuid=False),
|
|
||||||
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
|
||||||
nullable=True,
|
|
||||||
),
|
|
||||||
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
|
|
||||||
sa.Column(
|
|
||||||
"created_at",
|
|
||||||
sa.DateTime(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
server_default=sa.func.now(),
|
|
||||||
),
|
|
||||||
sa.Column(
|
|
||||||
"updated_at",
|
|
||||||
sa.DateTime(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
server_default=sa.func.now(),
|
|
||||||
),
|
|
||||||
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"memory_relations_user_subject_idx",
|
|
||||||
"memory_relations",
|
|
||||||
["user_id", "subject_label"],
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"memory_relations_user_predicate_idx",
|
|
||||||
"memory_relations",
|
|
||||||
["user_id", "predicate"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
|
|
||||||
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
|
|
||||||
op.drop_table("memory_relations")
|
|
||||||
@@ -1,38 +0,0 @@
|
|||||||
"""add extraction_queue
|
|
||||||
|
|
||||||
Revision ID: 1f5975a4f3f4
|
|
||||||
Revises: 005
|
|
||||||
Create Date: 2026-04-16 17:26:25.790870
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = '1f5975a4f3f4'
|
|
||||||
down_revision: Union[str, None] = '005'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
'extraction_queue',
|
|
||||||
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
|
|
||||||
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
|
|
||||||
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
|
|
||||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
|
||||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
|
|
||||||
sa.PrimaryKeyConstraint('id'),
|
|
||||||
)
|
|
||||||
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
|
|
||||||
op.drop_table('extraction_queue')
|
|
||||||
@@ -1,92 +0,0 @@
|
|||||||
"""Deprecate backend agent config tables.
|
|
||||||
|
|
||||||
The Electron client is now the source of truth for agent configuration
|
|
||||||
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
|
||||||
billing checks and trigger/run logs only.
|
|
||||||
|
|
||||||
Revision ID: 9a1f2d0b6c7e
|
|
||||||
Revises: 818478c251dc
|
|
||||||
Create Date: 2026-03-16
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
revision: str = "9a1f2d0b6c7e"
|
|
||||||
down_revision: Union[str, None] = "818478c251dc"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
bind = op.get_bind()
|
|
||||||
inspector = sa.inspect(bind)
|
|
||||||
existing = set(inspector.get_table_names())
|
|
||||||
|
|
||||||
if "cloud_agent_configs" in existing:
|
|
||||||
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
|
||||||
op.drop_table("cloud_agent_configs")
|
|
||||||
|
|
||||||
if "local_agent_configs" in existing:
|
|
||||||
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
|
||||||
op.drop_table("local_agent_configs")
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"local_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("device_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
|
||||||
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
op.create_table(
|
|
||||||
"cloud_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"provider",
|
|
||||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
|
||||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
|
||||||
@@ -1,107 +0,0 @@
|
|||||||
"""Restore agent config tables and add agent_config column.
|
|
||||||
|
|
||||||
9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both
|
|
||||||
ORM models are still active. This migration recreates them with agent_config
|
|
||||||
added to local_agent_configs.
|
|
||||||
|
|
||||||
Revision ID: a3b9c0d1e2f3
|
|
||||||
Revises: 9a1f2d0b6c7e
|
|
||||||
Create Date: 2026-04-07 00:00:00.000000
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "a3b9c0d1e2f3"
|
|
||||||
down_revision: Union[str, None] = "9a1f2d0b6c7e"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# Recreate enum types (idempotent — they may already exist from migration 003)
|
|
||||||
op.execute("""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
op.execute("""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
op.execute("""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
|
|
||||||
bind = op.get_bind()
|
|
||||||
inspector = sa.inspect(bind)
|
|
||||||
existing = set(inspector.get_table_names())
|
|
||||||
|
|
||||||
# ── local_agent_configs (with agent_config column) ────────────────────
|
|
||||||
if "local_agent_configs" not in existing:
|
|
||||||
op.create_table(
|
|
||||||
"local_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("device_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("agent_config", sa.JSON, nullable=True),
|
|
||||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
|
||||||
|
|
||||||
# ── cloud_agent_configs ───────────────────────────────────────────────
|
|
||||||
if "cloud_agent_configs" not in existing:
|
|
||||||
op.create_table(
|
|
||||||
"cloud_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"provider",
|
|
||||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
|
||||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
|
||||||
op.drop_table("cloud_agent_configs")
|
|
||||||
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
|
||||||
op.drop_table("local_agent_configs")
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
|
|
||||||
|
|
||||||
Revision ID: b4c0d1e2f3a4
|
|
||||||
Revises: a3b9c0d1e2f3
|
|
||||||
Create Date: 2026-04-10 00:00:00.000000
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "b4c0d1e2f3a4"
|
|
||||||
down_revision: Union[str, None] = "a3b9c0d1e2f3"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
# ── users: make password_hash nullable (social users have no password) ──
|
|
||||||
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
|
|
||||||
|
|
||||||
# ── users: add avatar_url ─────────────────────────────────────────────
|
|
||||||
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
|
|
||||||
|
|
||||||
# ── oauth_accounts ────────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"oauth_accounts",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("provider", sa.String(50), nullable=False),
|
|
||||||
sa.Column("provider_user_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("provider_email", sa.String(255), nullable=True),
|
|
||||||
sa.Column(
|
|
||||||
"created_at",
|
|
||||||
sa.DateTime(timezone=True),
|
|
||||||
nullable=False,
|
|
||||||
server_default=sa.text("now()"),
|
|
||||||
),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
|
|
||||||
op.drop_table("oauth_accounts")
|
|
||||||
op.drop_column("users", "avatar_url")
|
|
||||||
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)
|
|
||||||
@@ -1,31 +0,0 @@
|
|||||||
"""Add onboarding_completed_at column to users table.
|
|
||||||
|
|
||||||
Revision ID: c5d1e2f3a4b5
|
|
||||||
Revises: b4c0d1e2f3a4
|
|
||||||
Create Date: 2026-04-11 00:00:00.000000
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "c5d1e2f3a4b5"
|
|
||||||
down_revision: Union[str, None] = "b4c0d1e2f3a4"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.add_column(
|
|
||||||
"users",
|
|
||||||
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_column("users", "onboarding_completed_at")
|
|
||||||
@@ -1,46 +0,0 @@
|
|||||||
"""Add token tracking columns for folder integration.
|
|
||||||
|
|
||||||
Revision ID: d6e3f4a5b6c7
|
|
||||||
Revises: 006
|
|
||||||
Create Date: 2026-05-11 00:00:00.000000
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects.postgresql import UUID
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = "d6e3f4a5b6c7"
|
|
||||||
down_revision: Union[str, None] = "006"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.add_column(
|
|
||||||
"agent_run_logs",
|
|
||||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
|
||||||
)
|
|
||||||
op.create_table(
|
|
||||||
"monthly_token_usage",
|
|
||||||
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
|
|
||||||
sa.Column("year_month", sa.String(7), nullable=False),
|
|
||||||
sa.Column("feature", sa.String(64), nullable=False),
|
|
||||||
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
|
|
||||||
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
|
|
||||||
)
|
|
||||||
op.create_index(
|
|
||||||
"ix_monthly_token_usage_user_month",
|
|
||||||
"monthly_token_usage",
|
|
||||||
["user_id", "year_month"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
|
|
||||||
op.drop_table("monthly_token_usage")
|
|
||||||
op.drop_column("agent_run_logs", "tokens_used")
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
"""avatar_url_varchar_to_text
|
|
||||||
|
|
||||||
Revision ID: e04100e88ace
|
|
||||||
Revises: c5d1e2f3a4b5
|
|
||||||
Create Date: 2026-04-13 09:13:06.733674
|
|
||||||
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
from alembic import op
|
|
||||||
import sqlalchemy as sa
|
|
||||||
|
|
||||||
|
|
||||||
# revision identifiers, used by Alembic.
|
|
||||||
revision: str = 'e04100e88ace'
|
|
||||||
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
op.alter_column('users', 'avatar_url',
|
|
||||||
existing_type=sa.VARCHAR(length=2048),
|
|
||||||
type_=sa.Text(),
|
|
||||||
existing_nullable=True)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.alter_column('users', 'avatar_url',
|
|
||||||
existing_type=sa.Text(),
|
|
||||||
type_=sa.VARCHAR(length=2048),
|
|
||||||
existing_nullable=True)
|
|
||||||
@@ -1,5 +1,5 @@
|
|||||||
"""Expose tool modules used by deep orchestrator-worker graphs."""
|
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
|
||||||
|
|
||||||
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
|
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]
|
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
|
|||||||
@@ -1,52 +0,0 @@
|
|||||||
"""Client agent — read-only tools for the clients table."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_clients(search: str = "", limit: int = 20) -> str:
|
|
||||||
"""List clients, optionally filtered by a name/email substring search.
|
|
||||||
|
|
||||||
search: optional substring to match against client name or email.
|
|
||||||
limit: max rows to return (default 20).
|
|
||||||
"""
|
|
||||||
filters: dict[str, Any] = {"limit": limit}
|
|
||||||
if search:
|
|
||||||
filters["search"] = search
|
|
||||||
|
|
||||||
result = await execute_on_client(action="select", table="clients", filters=filters)
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No clients found."
|
|
||||||
lines = [
|
|
||||||
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
|
|
||||||
f"company: {r.get('company', '')})"
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_client(id: str) -> str:
|
|
||||||
"""Get full details for one client by UUID.
|
|
||||||
|
|
||||||
id: the client's UUID.
|
|
||||||
"""
|
|
||||||
if not id:
|
|
||||||
return "Client id is required."
|
|
||||||
|
|
||||||
result = await execute_on_client(action="get", table="clients", data={"id": id})
|
|
||||||
row = result.get("row") or result.get("rows", [None])[0] if result else None
|
|
||||||
if not row:
|
|
||||||
return f"Client '{id}' not found."
|
|
||||||
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
|
|
||||||
|
|
||||||
|
|
||||||
CLIENT_TOOLS: list[Any] = [list_clients, get_client]
|
|
||||||
@@ -1,194 +0,0 @@
|
|||||||
"""Filesystem agent — tools for reading local directories and files on Electron.
|
|
||||||
|
|
||||||
These tools delegate to the Electron client via ``execute_on_client()`` using
|
|
||||||
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
|
|
||||||
handles actual disk I/O and responds with ``tool_result`` frames.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import re
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
|
|
||||||
# Max characters returned by read_file_content in journey (exploration) tools.
|
|
||||||
# The journey only needs to understand file structure, not full content.
|
|
||||||
_JOURNEY_READ_MAX_CHARS: int = 4000
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_path(path: str, base: str) -> str:
|
|
||||||
"""Resolve *path* against *base* when *path* is relative.
|
|
||||||
|
|
||||||
The LLM often passes ``"."`` meaning "the configured directory".
|
|
||||||
Without this, Electron resolves ``"."`` relative to its own CWD instead
|
|
||||||
of the user's chosen directory.
|
|
||||||
"""
|
|
||||||
if os.path.isabs(path):
|
|
||||||
return path
|
|
||||||
return str(Path(base) / path)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_directory(path: str) -> str:
|
|
||||||
"""List files and folders in a local directory on the user's device.
|
|
||||||
|
|
||||||
Returns a formatted listing of entries with name, type (file/directory),
|
|
||||||
and full path.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="list_directory",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
|
||||||
if not entries:
|
|
||||||
return f"Directory '{path}' is empty or does not exist."
|
|
||||||
lines: list[str] = []
|
|
||||||
for entry in entries:
|
|
||||||
entry_type = entry.get("type", "unknown")
|
|
||||||
entry_name = entry.get("name", "")
|
|
||||||
entry_path = entry.get("path", "")
|
|
||||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
|
||||||
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def read_file_content(path: str) -> str:
|
|
||||||
"""Read the text content of a local file on the user's device.
|
|
||||||
|
|
||||||
Returns the file content as a string. Large files may be truncated
|
|
||||||
by the Electron client.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="read_file_content",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
content: str = result.get("content", "")
|
|
||||||
if not content:
|
|
||||||
return f"File '{path}' is empty or could not be read."
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_file_metadata(path: str) -> str:
|
|
||||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
|
||||||
|
|
||||||
Returns a formatted summary of the file's metadata.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="get_file_metadata",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
size = result.get("size", "unknown")
|
|
||||||
created = result.get("createdAt", "unknown")
|
|
||||||
modified = result.get("modifiedAt", "unknown")
|
|
||||||
extension = result.get("extension", "unknown")
|
|
||||||
name = result.get("name", path)
|
|
||||||
return (
|
|
||||||
f"File: {name}\n"
|
|
||||||
f" Extension: {extension}\n"
|
|
||||||
f" Size: {size} bytes\n"
|
|
||||||
f" Created: {created}\n"
|
|
||||||
f" Modified: {modified}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
FILESYSTEM_TOOLS: list[Any] = [
|
|
||||||
list_directory,
|
|
||||||
read_file_content,
|
|
||||||
get_file_metadata,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def make_directory_tools(base_directory: str) -> list[Any]:
|
|
||||||
"""Return filesystem tools that resolve relative paths against *base_directory*.
|
|
||||||
|
|
||||||
Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target
|
|
||||||
directory upfront (e.g., journey setup sessions). Relative paths like ``"."``
|
|
||||||
from the LLM are resolved to the correct absolute path before being sent to
|
|
||||||
the Electron client, preventing it from falling back to its own CWD.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _compact_for_journey(raw: str) -> str:
|
|
||||||
"""Strip HTML noise and truncate for journey exploration.
|
|
||||||
|
|
||||||
The journey LLM only needs to understand file structure (headers,
|
|
||||||
first paragraphs). Full CSS/style blocks are pure noise that eat
|
|
||||||
up context window budget.
|
|
||||||
"""
|
|
||||||
text = re.sub(r"<style[^>]*>.*?</style>", "", raw, flags=re.DOTALL | re.IGNORECASE)
|
|
||||||
text = re.sub(r"<script[^>]*>.*?</script>", "", text, flags=re.DOTALL | re.IGNORECASE)
|
|
||||||
text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)
|
|
||||||
if len(text) > _JOURNEY_READ_MAX_CHARS:
|
|
||||||
text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]"
|
|
||||||
return text
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_directory(path: str) -> str: # noqa: F811
|
|
||||||
"""List files and folders in a local directory on the user's device.
|
|
||||||
|
|
||||||
Returns a formatted listing of entries with name, type (file/directory),
|
|
||||||
and full path.
|
|
||||||
"""
|
|
||||||
resolved = _resolve_path(path, base_directory)
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="list_directory",
|
|
||||||
data={"path": resolved},
|
|
||||||
)
|
|
||||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
|
||||||
if not entries:
|
|
||||||
return f"Directory '{resolved}' is empty or does not exist."
|
|
||||||
lines: list[str] = []
|
|
||||||
for entry in entries:
|
|
||||||
entry_type = entry.get("type", "unknown")
|
|
||||||
entry_name = entry.get("name", "")
|
|
||||||
entry_path = entry.get("path", "")
|
|
||||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
|
||||||
return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def read_file_content(path: str) -> str: # noqa: F811
|
|
||||||
"""Read the text content of a local file on the user's device.
|
|
||||||
|
|
||||||
Returns the file content as a string. Large files may be truncated
|
|
||||||
by the Electron client.
|
|
||||||
"""
|
|
||||||
resolved = _resolve_path(path, base_directory)
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="read_file_content",
|
|
||||||
data={"path": resolved},
|
|
||||||
)
|
|
||||||
content: str = result.get("content", "")
|
|
||||||
if not content:
|
|
||||||
return f"File '{resolved}' is empty or could not be read."
|
|
||||||
return _compact_for_journey(content)
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_file_metadata(path: str) -> str: # noqa: F811
|
|
||||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
|
||||||
|
|
||||||
Returns a formatted summary of the file's metadata.
|
|
||||||
"""
|
|
||||||
resolved = _resolve_path(path, base_directory)
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="get_file_metadata",
|
|
||||||
data={"path": resolved},
|
|
||||||
)
|
|
||||||
size = result.get("size", "unknown")
|
|
||||||
created = result.get("createdAt", "unknown")
|
|
||||||
modified = result.get("modifiedAt", "unknown")
|
|
||||||
extension = result.get("extension", "unknown")
|
|
||||||
name = result.get("name", resolved)
|
|
||||||
return (
|
|
||||||
f"File: {name}\n"
|
|
||||||
f" Extension: {extension}\n"
|
|
||||||
f" Size: {size} bytes\n"
|
|
||||||
f" Created: {created}\n"
|
|
||||||
f" Modified: {modified}"
|
|
||||||
)
|
|
||||||
|
|
||||||
return [list_directory, read_file_content, get_file_metadata]
|
|
||||||
@@ -1,168 +0,0 @@
|
|||||||
"""Scoped file-read and search tools for the project folder feature."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
|
|
||||||
from app.core.ws_context import execute_on_client
|
|
||||||
|
|
||||||
# Cap returned slice size to keep tool output under control.
|
|
||||||
_MAX_RETURN_CHARS = 50_000
|
|
||||||
_MAX_SEARCH_MATCHES = 20
|
|
||||||
|
|
||||||
|
|
||||||
def _is_unsafe_path(rel: str) -> bool:
|
|
||||||
if not rel:
|
|
||||||
return True
|
|
||||||
norm = rel.replace("\\", "/")
|
|
||||||
if norm.startswith("/"):
|
|
||||||
return True
|
|
||||||
# Windows drive letter
|
|
||||||
if len(rel) >= 2 and rel[1] == ":":
|
|
||||||
return True
|
|
||||||
parts = norm.split("/")
|
|
||||||
return ".." in parts
|
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
|
|
||||||
"""Return the raw Electron tool_result dict for a file read."""
|
|
||||||
return await execute_on_client(
|
|
||||||
action="read_project_folder_file",
|
|
||||||
data={
|
|
||||||
"projectId": project_id,
|
|
||||||
"relativePath": relative_path,
|
|
||||||
"offset": offset,
|
|
||||||
"length": length,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _decode(result: dict) -> tuple[str, str, int]:
|
|
||||||
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
|
|
||||||
extracts text from base64. For images, returns a placeholder string.
|
|
||||||
For text, content is already a sliced utf-8 string.
|
|
||||||
"""
|
|
||||||
kind = result.get("kind", "text")
|
|
||||||
content = result.get("content", "") or ""
|
|
||||||
total = int(result.get("totalSize", 0) or 0)
|
|
||||||
if kind == "image":
|
|
||||||
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
|
|
||||||
if kind == "pdf":
|
|
||||||
return (_extract_pdf_text(content), kind, total)
|
|
||||||
if kind == "docx":
|
|
||||||
return (_extract_docx_text(content), kind, total)
|
|
||||||
return (content, kind, total)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def read_project_folder_file(
|
|
||||||
project_id: str,
|
|
||||||
relative_path: str,
|
|
||||||
offset: int = 0,
|
|
||||||
length: int = _MAX_RETURN_CHARS,
|
|
||||||
) -> str:
|
|
||||||
"""Read a slice of a file inside the project's linked folder.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
project_id: project ID.
|
|
||||||
relative_path: path relative to the linked folder root.
|
|
||||||
offset: char offset to start reading from (0 = beginning).
|
|
||||||
length: max chars to return. Default 50000. Use smaller values to save tokens.
|
|
||||||
|
|
||||||
Returns text content slice with a header showing position. Header tells you
|
|
||||||
when more content is available; call again with the suggested next offset.
|
|
||||||
|
|
||||||
For PDF / DOCX files the backend extracts text first, then applies offset/length
|
|
||||||
on the extracted text. For images returns a placeholder; navigate with the
|
|
||||||
manifest summary instead.
|
|
||||||
"""
|
|
||||||
if _is_unsafe_path(relative_path):
|
|
||||||
return "Access denied"
|
|
||||||
|
|
||||||
result = await _fetch_file(project_id, relative_path, offset, length)
|
|
||||||
text, kind, total_size = _decode(result)
|
|
||||||
|
|
||||||
if not text and kind in ("missing", "error"):
|
|
||||||
return f"File not found or unreadable: {relative_path}"
|
|
||||||
|
|
||||||
if kind in ("pdf", "docx"):
|
|
||||||
# Backend extracted full text — apply offset/length on chars.
|
|
||||||
sliced = text[offset:offset + length]
|
|
||||||
slice_end = min(offset + length, len(text))
|
|
||||||
header = (
|
|
||||||
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
|
|
||||||
f"totalChars={len(text)}]"
|
|
||||||
)
|
|
||||||
if slice_end < len(text):
|
|
||||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
|
||||||
return header + "\n" + sliced
|
|
||||||
|
|
||||||
if kind == "text":
|
|
||||||
slice_end = offset + len(text)
|
|
||||||
header = (
|
|
||||||
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
|
|
||||||
f"totalBytes={total_size}]"
|
|
||||||
)
|
|
||||||
if slice_end < total_size:
|
|
||||||
header += f"\n[More content available — call again with offset={slice_end}.]"
|
|
||||||
return header + "\n" + text
|
|
||||||
|
|
||||||
# image or unknown
|
|
||||||
return text
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def search_project_folder_file(
|
|
||||||
project_id: str,
|
|
||||||
relative_path: str,
|
|
||||||
query: str,
|
|
||||||
context_lines: int = 3,
|
|
||||||
) -> str:
|
|
||||||
"""Search a project folder file for a query string (case-insensitive substring).
|
|
||||||
|
|
||||||
Args:
|
|
||||||
project_id: project ID.
|
|
||||||
relative_path: path relative to the linked folder root.
|
|
||||||
query: text to search for.
|
|
||||||
context_lines: number of lines of context around each match (default 3).
|
|
||||||
|
|
||||||
Returns matching line ranges with surrounding context and 1-based line numbers.
|
|
||||||
Capped at 20 matches; if more exist the header shows the total.
|
|
||||||
|
|
||||||
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
|
|
||||||
Images and binary files are not searchable.
|
|
||||||
"""
|
|
||||||
if _is_unsafe_path(relative_path):
|
|
||||||
return "Access denied"
|
|
||||||
if not query:
|
|
||||||
return "Empty query."
|
|
||||||
|
|
||||||
# For text we still need full file; pass length=very large.
|
|
||||||
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
|
|
||||||
text, kind, _ = _decode(result)
|
|
||||||
|
|
||||||
if not text and kind in ("missing", "error"):
|
|
||||||
return f"File not found or unreadable: {relative_path}"
|
|
||||||
if kind == "image":
|
|
||||||
return "Cannot search inside images."
|
|
||||||
|
|
||||||
lines = text.splitlines()
|
|
||||||
q = query.lower()
|
|
||||||
matches = [i for i, line in enumerate(lines) if q in line.lower()]
|
|
||||||
if not matches:
|
|
||||||
return f"No matches for '{query}' in {relative_path}."
|
|
||||||
|
|
||||||
shown = matches[:_MAX_SEARCH_MATCHES]
|
|
||||||
snippets: list[str] = []
|
|
||||||
for i in shown:
|
|
||||||
start = max(0, i - context_lines)
|
|
||||||
end = min(len(lines), i + context_lines + 1)
|
|
||||||
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
|
|
||||||
snippets.append(block)
|
|
||||||
|
|
||||||
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
|
|
||||||
body = "\n---\n".join(snippets)
|
|
||||||
return header + "\n" + body
|
|
||||||
|
|
||||||
|
|
||||||
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]
|
|
||||||
@@ -1,50 +1,27 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, propose edit)."""
|
"""Note agent — tool definitions for Markdown note CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.note_summarizer import generate_note_summary
|
from app.core.llm import embed
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
|
|
||||||
def _fmt_summary(row: dict) -> str:
|
|
||||||
summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
|
|
||||||
if summary:
|
|
||||||
return f" — {summary}"
|
|
||||||
snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
|
|
||||||
return f" — {snippet}" if snippet else ""
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes with AI summaries, optionally scoped to a project by project_id.
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
|
|
||||||
Returns id, title, and ai_summary for each note so you can decide which
|
|
||||||
note to read in full with get_note before creating or updating.
|
|
||||||
"""
|
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": normalized_project_id or None},
|
filters={"projectId": project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
return "No notes found."
|
return "No notes found."
|
||||||
lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
|
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
||||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@@ -79,10 +56,14 @@ async def create_note(
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
note_id: str = row["id"]
|
# Index the note content in the vector store.
|
||||||
# Generate summary asynchronously — fire-and-forget.
|
vector = await embed(content)
|
||||||
asyncio.create_task(_refresh_summary(note_id, title, content))
|
await execute_on_client(
|
||||||
return f"Note created: '{row['title']}' (id: {note_id})."
|
action="vector_upsert",
|
||||||
|
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
|
return f"Note created: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -91,8 +72,7 @@ async def update_note(
|
|||||||
title: str = "",
|
title: str = "",
|
||||||
content: str = "",
|
content: str = "",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update an existing note directly (no approval required).
|
"""Update an existing note. Only pass fields that should change.
|
||||||
Use propose_note_edit instead when human review is needed.
|
|
||||||
note_id: UUID of the note (required)
|
note_id: UUID of the note (required)
|
||||||
If you need to preserve existing content, call get_note first.
|
If you need to preserve existing content, call get_note first.
|
||||||
"""
|
"""
|
||||||
@@ -107,63 +87,17 @@ async def update_note(
|
|||||||
data={"id": note_id, "updates": updates},
|
data={"id": note_id, "updates": updates},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
|
# Re-index if content changed.
|
||||||
if content:
|
if content:
|
||||||
new_title = title or row.get("title", "")
|
vector = await embed(content)
|
||||||
asyncio.create_task(_refresh_summary(note_id, new_title, content))
|
await execute_on_client(
|
||||||
|
action="vector_upsert",
|
||||||
|
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
||||||
|
vector=vector,
|
||||||
|
)
|
||||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def propose_note_edit(
|
|
||||||
note_id: str,
|
|
||||||
edit_type: str,
|
|
||||||
proposed_content: str,
|
|
||||||
reasoning: str = "",
|
|
||||||
anchor_before: str = "",
|
|
||||||
anchor_text: str = "",
|
|
||||||
agent_id: str = "",
|
|
||||||
run_id: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""Propose an AI edit to an existing note, pending human approval.
|
|
||||||
|
|
||||||
Use this instead of update_note when review_required is true.
|
|
||||||
The user will see the proposal highlighted before it is merged.
|
|
||||||
|
|
||||||
note_id: UUID of the target note (required)
|
|
||||||
edit_type: 'append' | 'insert' | 'replace'
|
|
||||||
- append: adds proposed_content at the end of the note
|
|
||||||
- insert: inserts proposed_content immediately after anchor_before text
|
|
||||||
- replace: replaces the first occurrence of anchor_text with proposed_content
|
|
||||||
proposed_content: the new Markdown text to add or substitute (required)
|
|
||||||
reasoning: brief explanation shown to the user (recommended)
|
|
||||||
anchor_before: for 'insert' — the text snippet that precedes the insertion point
|
|
||||||
anchor_text: for 'replace' — the exact text to be replaced
|
|
||||||
agent_id: agent identifier (for traceability)
|
|
||||||
run_id: run identifier (for traceability)
|
|
||||||
"""
|
|
||||||
if edit_type not in ("append", "insert", "replace"):
|
|
||||||
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
|
|
||||||
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="propose_note_edit",
|
|
||||||
data={
|
|
||||||
"noteId": note_id,
|
|
||||||
"type": edit_type,
|
|
||||||
"proposedContent": proposed_content,
|
|
||||||
"reasoning": reasoning or None,
|
|
||||||
"anchorBefore": anchor_before or None,
|
|
||||||
"anchorText": anchor_text or None,
|
|
||||||
"agentId": agent_id or None,
|
|
||||||
"runId": run_id or None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
edit_id = result.get("id", "?")
|
|
||||||
return (
|
|
||||||
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
|
|
||||||
f"Status: pending user approval."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_note(note_id: str) -> str:
|
async def delete_note(note_id: str) -> str:
|
||||||
"""Delete a note permanently by its UUID."""
|
"""Delete a note permanently by its UUID."""
|
||||||
@@ -171,36 +105,4 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
|
|
||||||
"""Generate and persist the AI summary for a note. Fire-and-forget."""
|
|
||||||
try:
|
|
||||||
summary = await generate_note_summary(title, content)
|
|
||||||
if summary:
|
|
||||||
await execute_on_client(
|
|
||||||
action="update",
|
|
||||||
table="notes",
|
|
||||||
data={
|
|
||||||
"id": note_id,
|
|
||||||
"updates": {
|
|
||||||
"aiSummary": summary,
|
|
||||||
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # fire-and-forget; errors logged by generate_note_summary
|
|
||||||
|
|
||||||
|
|
||||||
NOTE_TOOLS: list[Any] = [
|
|
||||||
list_notes,
|
|
||||||
get_note,
|
|
||||||
create_note,
|
|
||||||
update_note,
|
|
||||||
propose_note_edit,
|
|
||||||
delete_note,
|
|
||||||
]
|
|
||||||
|
|
||||||
NOTE_READ_TOOLS: list[Any] = [
|
|
||||||
list_notes,
|
|
||||||
get_note,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete)."""
|
"""Project agent — tool definitions for project lifecycle CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -117,17 +117,4 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
PROJECT_TOOLS: list[Any] = [
|
|
||||||
list_projects,
|
|
||||||
list_all_projects,
|
|
||||||
get_project,
|
|
||||||
create_project,
|
|
||||||
update_project,
|
|
||||||
delete_project,
|
|
||||||
]
|
|
||||||
|
|
||||||
PROJECT_READ_TOOLS: list[Any] = [
|
|
||||||
list_projects,
|
|
||||||
list_all_projects,
|
|
||||||
get_project,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,63 +0,0 @@
|
|||||||
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
|
||||||
from app.db import async_session
|
|
||||||
|
|
||||||
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
|
|
||||||
# Each tool closure captures the user_id bound at factory time.
|
|
||||||
|
|
||||||
|
|
||||||
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
|
|
||||||
"""Return a query_relations tool bound to *user_id*."""
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def query_relations(
|
|
||||||
subject_label: str = "",
|
|
||||||
predicate: str = "",
|
|
||||||
object_label: str = "",
|
|
||||||
limit: int = 10,
|
|
||||||
) -> str:
|
|
||||||
"""Query the relational memory graph for entity relationships.
|
|
||||||
|
|
||||||
Returns rows where subject ↔ predicate ↔ object match the given filters.
|
|
||||||
All parameters are optional — omit to retrieve all relations up to limit.
|
|
||||||
|
|
||||||
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
|
|
||||||
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
|
|
||||||
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
|
|
||||||
limit: max rows to return (default 10).
|
|
||||||
"""
|
|
||||||
import logging
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
logger.info(
|
|
||||||
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
|
|
||||||
trace_id or "-", user_id, subject_label, predicate, object_label,
|
|
||||||
)
|
|
||||||
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
rows = await memory.query_relations(
|
|
||||||
user_id=user_id,
|
|
||||||
subject=subject_label or None,
|
|
||||||
predicate=predicate or None,
|
|
||||||
object_=object_label or None,
|
|
||||||
limit=limit,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not rows:
|
|
||||||
return "No relational memory entries found for the given filters."
|
|
||||||
|
|
||||||
lines = [
|
|
||||||
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
|
|
||||||
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
return query_relations
|
|
||||||
@@ -1,23 +1,14 @@
|
|||||||
"""Task agent — full CRUD for tasks and task comments."""
|
"""Task agent — tool definitions for task and task comment CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -26,141 +17,31 @@ def _is_uuid(value: str) -> bool:
|
|||||||
async def list_tasks(
|
async def list_tasks(
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
status: str = "",
|
status: str = "",
|
||||||
priority: str = "",
|
|
||||||
assignee: str = "",
|
|
||||||
search: str = "",
|
search: str = "",
|
||||||
order_by: str = "",
|
order_by: str = "",
|
||||||
order_dir: str = "",
|
|
||||||
due_date_from: int = -1,
|
|
||||||
due_date_to: int = -1,
|
|
||||||
created_at_from: int = -1,
|
|
||||||
created_at_to: int = -1,
|
|
||||||
completed_at_from: int = -1,
|
|
||||||
completed_at_to: int = -1,
|
|
||||||
is_ai_suggested: int = -1,
|
|
||||||
limit: int = 50,
|
|
||||||
offset: int = 0,
|
|
||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks with optional filters. Returns up to `limit` results (default 50).
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
project_id: UUID of the project to scope results to.
|
result = await execute_on_client(
|
||||||
status: filter by status — todo | in_progress | done.
|
action="select",
|
||||||
priority: filter by priority — high | medium | low.
|
table="tasks",
|
||||||
assignee: substring to match against assignee names. OMIT unless the user explicitly
|
filters={
|
||||||
names a person or refers to themselves ("my tasks", "assigned to me", "mine").
|
"projectId": project_id or None,
|
||||||
Do NOT default to the current user.
|
"status": status or None,
|
||||||
search: substring search across title and description.
|
"search": search or None,
|
||||||
order_by: sort field — dueDate | priority | createdAt | completedAt.
|
"orderBy": order_by or None,
|
||||||
order_dir: asc (default) | desc.
|
},
|
||||||
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
)
|
||||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
|
||||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
|
||||||
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
|
||||||
limit: max rows to return (default 50). Use with offset to paginate.
|
|
||||||
offset: skip first N rows (default 0).
|
|
||||||
|
|
||||||
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
|
||||||
Tip — prefer count_tasks for "how many" questions to avoid listing rows.
|
|
||||||
Tip — for natural-language windows ("today", "tomorrow", "this week", "last month", etc.)
|
|
||||||
take due_date_from / due_date_to verbatim from the DATE CONTEXT block in the system prompt;
|
|
||||||
do not compute boundaries from the current UTC instant.
|
|
||||||
"""
|
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
filters: dict[str, Any] = {
|
|
||||||
"projectId": normalized_project_id or None,
|
|
||||||
"status": status or None,
|
|
||||||
"priority": priority or None,
|
|
||||||
"search": search or None,
|
|
||||||
"orderBy": order_by or None,
|
|
||||||
"orderDir": order_dir or None,
|
|
||||||
"limit": limit,
|
|
||||||
"offset": offset,
|
|
||||||
}
|
|
||||||
if assignee:
|
|
||||||
filters["assignee"] = assignee
|
|
||||||
if due_date_from != -1:
|
|
||||||
filters["dueDateFrom"] = due_date_from
|
|
||||||
if due_date_to != -1:
|
|
||||||
filters["dueDateTo"] = due_date_to
|
|
||||||
if created_at_from != -1:
|
|
||||||
filters["createdAtFrom"] = created_at_from
|
|
||||||
if created_at_to != -1:
|
|
||||||
filters["createdAtTo"] = created_at_to
|
|
||||||
if completed_at_from != -1:
|
|
||||||
filters["completedAtFrom"] = completed_at_from
|
|
||||||
if completed_at_to != -1:
|
|
||||||
filters["completedAtTo"] = completed_at_to
|
|
||||||
if is_ai_suggested != -1:
|
|
||||||
filters["isAiSuggested"] = is_ai_suggested
|
|
||||||
|
|
||||||
result = await execute_on_client(action="select", table="tasks", filters=filters)
|
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
return "No tasks found matching the given filters."
|
return "No tasks found matching the given filters."
|
||||||
lines = [
|
lines = [
|
||||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, "
|
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
||||||
f"dueDate: {r.get('dueDate')}, completedAt: {r.get('completedAt')}, "
|
|
||||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
|
||||||
for r in rows
|
for r in rows
|
||||||
]
|
]
|
||||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def count_tasks(
|
|
||||||
project_id: str = "",
|
|
||||||
status: str = "",
|
|
||||||
priority: str = "",
|
|
||||||
assignee: str = "",
|
|
||||||
search: str = "",
|
|
||||||
due_date_from: int = -1,
|
|
||||||
due_date_to: int = -1,
|
|
||||||
created_at_from: int = -1,
|
|
||||||
created_at_to: int = -1,
|
|
||||||
completed_at_from: int = -1,
|
|
||||||
completed_at_to: int = -1,
|
|
||||||
is_ai_suggested: int = -1,
|
|
||||||
) -> str:
|
|
||||||
"""Count tasks matching the given filters without returning rows.
|
|
||||||
|
|
||||||
Use this instead of list_tasks for "how many" questions — it is much cheaper.
|
|
||||||
Same filter parameters as list_tasks (no limit/offset/order_by needed).
|
|
||||||
assignee: OMIT unless the user explicitly names a person or refers to themselves
|
|
||||||
("my tasks"). Do NOT default to the current user.
|
|
||||||
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
|
|
||||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
|
||||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
|
||||||
Tip — for natural-language windows take due_date_from / due_date_to from the DATE CONTEXT block;
|
|
||||||
do not compute boundaries from the current UTC instant.
|
|
||||||
"""
|
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
filters: dict[str, Any] = {
|
|
||||||
"projectId": normalized_project_id or None,
|
|
||||||
"status": status or None,
|
|
||||||
"priority": priority or None,
|
|
||||||
"search": search or None,
|
|
||||||
}
|
|
||||||
if assignee:
|
|
||||||
filters["assignee"] = assignee
|
|
||||||
if due_date_from != -1:
|
|
||||||
filters["dueDateFrom"] = due_date_from
|
|
||||||
if due_date_to != -1:
|
|
||||||
filters["dueDateTo"] = due_date_to
|
|
||||||
if created_at_from != -1:
|
|
||||||
filters["createdAtFrom"] = created_at_from
|
|
||||||
if created_at_to != -1:
|
|
||||||
filters["createdAtTo"] = created_at_to
|
|
||||||
if completed_at_from != -1:
|
|
||||||
filters["completedAtFrom"] = completed_at_from
|
|
||||||
if completed_at_to != -1:
|
|
||||||
filters["completedAtTo"] = completed_at_to
|
|
||||||
if is_ai_suggested != -1:
|
|
||||||
filters["isAiSuggested"] = is_ai_suggested
|
|
||||||
|
|
||||||
result = await execute_on_client(action="count", table="tasks", filters=filters)
|
|
||||||
return f"Task count: {result.get('count', 0)}"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def create_task(
|
async def create_task(
|
||||||
title: str,
|
title: str,
|
||||||
@@ -171,6 +52,7 @@ async def create_task(
|
|||||||
due_date: int = 0,
|
due_date: int = 0,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new task.
|
"""Create a new task.
|
||||||
title: task title (required)
|
title: task title (required)
|
||||||
@@ -181,8 +63,7 @@ async def create_task(
|
|||||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
project_id: optional UUID of the parent project
|
project_id: optional UUID of the parent project
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms; 1 when confirmed
|
||||||
completedAt is set automatically when status is 'done'.
|
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -196,12 +77,13 @@ async def create_task(
|
|||||||
"dueDate": due_date or None,
|
"dueDate": due_date or None,
|
||||||
"projectId": project_id or None,
|
"projectId": project_id or None,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
return (
|
return (
|
||||||
f"Task created: '{row['title']}' "
|
f"Task created: '{row['title']}' "
|
||||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']}, projectId: {row.get('projectId')})"
|
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -215,14 +97,12 @@ async def update_task(
|
|||||||
assignees: str = "",
|
assignees: str = "",
|
||||||
due_date: int = -1,
|
due_date: int = -1,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
|
is_approved: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update fields on an existing task. Only pass fields you want to change.
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
task_id: the task's UUID (required)
|
task_id: the task's UUID (required)
|
||||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the value
|
||||||
completedAt is managed automatically:
|
|
||||||
- setting status to 'done' records the current timestamp
|
|
||||||
- changing status away from 'done' clears completedAt
|
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
@@ -239,13 +119,15 @@ async def update_task(
|
|||||||
updates["dueDate"] = due_date or None
|
updates["dueDate"] = due_date or None
|
||||||
if project_id:
|
if project_id:
|
||||||
updates["projectId"] = project_id
|
updates["projectId"] = project_id
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
data={"id": task_id, "updates": updates},
|
data={"id": task_id, "updates": updates},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']}, projectId: {row.get('projectId')})"
|
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -256,36 +138,21 @@ async def delete_task(task_id: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_tasks_due_today(user_timezone: str = "UTC", include_done: bool = False) -> str:
|
async def list_tasks_due_today() -> str:
|
||||||
"""List all tasks whose due date falls on today's date.
|
"""List all tasks whose due date falls on today's date."""
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
||||||
Always pass the user's timezone so 'today' is computed in their local time.
|
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
||||||
include_done: set True to also include already-completed tasks due today (default False).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
tz = ZoneInfo(user_timezone or "UTC")
|
|
||||||
except Exception:
|
|
||||||
tz = timezone.utc
|
|
||||||
now_local = datetime.now(tz=tz)
|
|
||||||
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
|
||||||
start_ms = int(start_dt.timestamp() * 1000)
|
|
||||||
end_ms = start_ms + 86_400_000 - 1
|
|
||||||
filters: dict[str, Any] = {"dueDateFrom": start_ms, "dueDateTo": end_ms}
|
|
||||||
if not include_done:
|
|
||||||
filters["status"] = "todo"
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters=filters,
|
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
return "No tasks are due today."
|
return "No tasks are due today."
|
||||||
lines = [
|
lines = [
|
||||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, "
|
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
||||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
|
||||||
for r in rows
|
for r in rows
|
||||||
]
|
]
|
||||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
||||||
@@ -321,11 +188,8 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result.get("row", {})
|
row = result["row"]
|
||||||
row_author = row.get("author", author)
|
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
||||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
|
||||||
row_comment_id = row.get("id", "unknown")
|
|
||||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -335,24 +199,4 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
return f"Comment {comment_id} deleted."
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
# ── Agent ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
TASK_TOOLS: list[Any] = [
|
|
||||||
list_tasks,
|
|
||||||
count_tasks,
|
|
||||||
create_task,
|
|
||||||
update_task,
|
|
||||||
delete_task,
|
|
||||||
list_tasks_due_today,
|
|
||||||
list_task_comments,
|
|
||||||
add_task_comment,
|
|
||||||
delete_task_comment,
|
|
||||||
]
|
|
||||||
|
|
||||||
TASK_READ_TOOLS: list[Any] = [
|
|
||||||
list_tasks,
|
|
||||||
count_tasks,
|
|
||||||
list_tasks_due_today,
|
|
||||||
list_task_comments,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -1,147 +1,27 @@
|
|||||||
"""Timeline agent — project milestone management (list, create, update, delete)."""
|
"""Timeline agent — tool definitions for project milestone CRUD."""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
project_id: str = "",
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
type: str = "",
|
result = await execute_on_client(
|
||||||
is_completed: int = -1,
|
action="select",
|
||||||
is_ai_suggested: int = -1,
|
table="timelines",
|
||||||
order_by: str = "",
|
filters={"projectId": project_id or None},
|
||||||
order_dir: str = "",
|
)
|
||||||
date_from: int = -1,
|
|
||||||
date_to: int = -1,
|
|
||||||
created_at_from: int = -1,
|
|
||||||
created_at_to: int = -1,
|
|
||||||
completed_at_from: int = -1,
|
|
||||||
completed_at_to: int = -1,
|
|
||||||
limit: int = 50,
|
|
||||||
offset: int = 0,
|
|
||||||
) -> str:
|
|
||||||
"""List timeline events (milestones, checkpoints, activities) with optional filters.
|
|
||||||
|
|
||||||
project_id: UUID to scope results to a specific project.
|
|
||||||
type: filter by event type — milestone | checkpoint | activity.
|
|
||||||
is_completed: 0 = incomplete only, 1 = completed only, -1 = any (default).
|
|
||||||
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
|
|
||||||
order_by: sort field — date (default) | createdAt | completedAt.
|
|
||||||
order_dir: asc (default) | desc.
|
|
||||||
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
|
||||||
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
|
|
||||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
|
||||||
limit: max rows to return (default 50). Use with offset to paginate.
|
|
||||||
offset: skip first N rows (default 0).
|
|
||||||
|
|
||||||
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
|
|
||||||
Tip — prefer count_timelines for "how many" questions to avoid listing rows.
|
|
||||||
Tip — for natural-language windows ("today", "this week", "last month", etc.)
|
|
||||||
take date_from / date_to verbatim from the DATE CONTEXT block in the system prompt;
|
|
||||||
do not compute boundaries from the current UTC instant.
|
|
||||||
"""
|
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
filters: dict[str, Any] = {
|
|
||||||
"projectId": normalized_project_id or None,
|
|
||||||
"orderBy": order_by or None,
|
|
||||||
"orderDir": order_dir or None,
|
|
||||||
"limit": limit,
|
|
||||||
"offset": offset,
|
|
||||||
}
|
|
||||||
if type:
|
|
||||||
filters["type"] = type
|
|
||||||
if is_completed != -1:
|
|
||||||
filters["isCompleted"] = is_completed
|
|
||||||
if is_ai_suggested != -1:
|
|
||||||
filters["isAiSuggested"] = is_ai_suggested
|
|
||||||
if date_from != -1:
|
|
||||||
filters["dateFrom"] = date_from
|
|
||||||
if date_to != -1:
|
|
||||||
filters["dateTo"] = date_to
|
|
||||||
if created_at_from != -1:
|
|
||||||
filters["createdAtFrom"] = created_at_from
|
|
||||||
if created_at_to != -1:
|
|
||||||
filters["createdAtTo"] = created_at_to
|
|
||||||
if completed_at_from != -1:
|
|
||||||
filters["completedAtFrom"] = completed_at_from
|
|
||||||
if completed_at_to != -1:
|
|
||||||
filters["completedAtTo"] = completed_at_to
|
|
||||||
|
|
||||||
result = await execute_on_client(action="select", table="timelines", filters=filters)
|
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
return "No timeline events found."
|
return "No timelines found."
|
||||||
lines = [
|
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
||||||
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
||||||
f"completed: {bool(r.get('isCompleted'))}, completedAt: {r.get('completedAt')}, "
|
|
||||||
f"projectId: {r.get('projectId')}, id: {r['id']})"
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
return f"Found {len(rows)} timeline event(s):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def count_timelines(
|
|
||||||
project_id: str = "",
|
|
||||||
type: str = "",
|
|
||||||
is_completed: int = -1,
|
|
||||||
is_ai_suggested: int = -1,
|
|
||||||
date_from: int = -1,
|
|
||||||
date_to: int = -1,
|
|
||||||
created_at_from: int = -1,
|
|
||||||
created_at_to: int = -1,
|
|
||||||
completed_at_from: int = -1,
|
|
||||||
completed_at_to: int = -1,
|
|
||||||
) -> str:
|
|
||||||
"""Count timeline events matching the given filters without returning rows.
|
|
||||||
|
|
||||||
Use this instead of list_timelines for "how many" questions — it is much cheaper.
|
|
||||||
Same filter parameters as list_timelines (no limit/offset/order_by needed).
|
|
||||||
|
|
||||||
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
|
|
||||||
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
|
|
||||||
Tip — for natural-language windows take date_from / date_to from the DATE CONTEXT block;
|
|
||||||
do not compute boundaries from the current UTC instant.
|
|
||||||
"""
|
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
filters: dict[str, Any] = {"projectId": normalized_project_id or None}
|
|
||||||
if type:
|
|
||||||
filters["type"] = type
|
|
||||||
if is_completed != -1:
|
|
||||||
filters["isCompleted"] = is_completed
|
|
||||||
if is_ai_suggested != -1:
|
|
||||||
filters["isAiSuggested"] = is_ai_suggested
|
|
||||||
if date_from != -1:
|
|
||||||
filters["dateFrom"] = date_from
|
|
||||||
if date_to != -1:
|
|
||||||
filters["dateTo"] = date_to
|
|
||||||
if created_at_from != -1:
|
|
||||||
filters["createdAtFrom"] = created_at_from
|
|
||||||
if created_at_to != -1:
|
|
||||||
filters["createdAtTo"] = created_at_to
|
|
||||||
if completed_at_from != -1:
|
|
||||||
filters["completedAtFrom"] = completed_at_from
|
|
||||||
if completed_at_to != -1:
|
|
||||||
filters["completedAtTo"] = completed_at_to
|
|
||||||
|
|
||||||
result = await execute_on_client(action="count", table="timelines", filters=filters)
|
|
||||||
return f"Timeline event count: {result.get('count', 0)}"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -149,19 +29,15 @@ async def create_timeline(
|
|||||||
project_id: str,
|
project_id: str,
|
||||||
title: str,
|
title: str,
|
||||||
date: int,
|
date: int,
|
||||||
type: str = "milestone",
|
|
||||||
is_completed: int = 0,
|
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a project timeline event.
|
"""Create a project timeline (milestone).
|
||||||
project_id: REQUIRED UUID of the parent project
|
project_id: REQUIRED UUID of the parent project
|
||||||
title: descriptive name for the event
|
title: descriptive name for the milestone
|
||||||
date: Unix timestamp in milliseconds for the event date
|
date: Unix timestamp in milliseconds
|
||||||
type: milestone (default) | checkpoint | activity
|
|
||||||
is_completed: 1 if already completed, 0 if not (default 0)
|
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms
|
||||||
completedAt is set automatically when is_completed is 1.
|
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -170,13 +46,12 @@ async def create_timeline(
|
|||||||
"projectId": project_id,
|
"projectId": project_id,
|
||||||
"title": title,
|
"title": title,
|
||||||
"date": date,
|
"date": date,
|
||||||
"type": type,
|
|
||||||
"isCompleted": is_completed,
|
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
return f"Timeline event created: '{row['title']}' (id: {row['id']}, date: {row['date']}, type: {row.get('type')})"
|
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -184,87 +59,34 @@ async def update_timeline(
|
|||||||
timeline_id: str,
|
timeline_id: str,
|
||||||
title: str = "",
|
title: str = "",
|
||||||
date: int = -1,
|
date: int = -1,
|
||||||
is_completed: int = -1,
|
is_approved: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update a timeline event. Only pass fields that should change.
|
"""Update a timeline. Only pass fields that should change.
|
||||||
timeline_id: UUID of the event (required)
|
timeline_id: UUID of the timeline (required)
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
is_completed: 0 = mark incomplete, 1 = mark complete, -1 = unchanged
|
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
||||||
|
|
||||||
completedAt is managed automatically:
|
|
||||||
- setting is_completed to 1 records the current timestamp
|
|
||||||
- setting is_completed to 0 clears completedAt
|
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
updates["title"] = title
|
updates["title"] = title
|
||||||
if date != -1:
|
if date != -1:
|
||||||
updates["date"] = date
|
updates["date"] = date
|
||||||
if is_completed != -1:
|
if is_approved != -1:
|
||||||
updates["isCompleted"] = is_completed
|
updates["isApproved"] = is_approved
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
data={"id": timeline_id, "updates": updates},
|
data={"id": timeline_id, "updates": updates},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
return f"Timeline event updated: '{row['title']}' (id: {row['id']})"
|
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def delete_timeline(timeline_id: str) -> str:
|
async def delete_timeline(timeline_id: str) -> str:
|
||||||
"""Delete a timeline event permanently by its UUID."""
|
"""Delete a timeline permanently by its UUID."""
|
||||||
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
||||||
return f"Timeline event {timeline_id} deleted."
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_timelines_today(user_timezone: str = "UTC", include_completed: bool = True) -> str:
|
|
||||||
"""List all timeline events whose date falls on today.
|
|
||||||
|
|
||||||
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
|
|
||||||
Always pass the user's timezone so 'today' is computed in their local time.
|
|
||||||
include_completed: set False to exclude already-completed events (default True).
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
from zoneinfo import ZoneInfo
|
|
||||||
tz = ZoneInfo(user_timezone or "UTC")
|
|
||||||
except Exception:
|
|
||||||
tz = timezone.utc
|
|
||||||
now_local = datetime.now(tz=tz)
|
|
||||||
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
|
|
||||||
start_ms = int(start_dt.timestamp() * 1000)
|
|
||||||
end_ms = start_ms + 86_400_000 - 1
|
|
||||||
filters: dict[str, Any] = {"dateFrom": start_ms, "dateTo": end_ms}
|
|
||||||
if not include_completed:
|
|
||||||
filters["isCompleted"] = 0
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="select",
|
|
||||||
table="timelines",
|
|
||||||
filters=filters,
|
|
||||||
)
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No timeline events today."
|
|
||||||
lines = [
|
|
||||||
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
|
|
||||||
f"completed: {bool(r.get('isCompleted'))}, projectId: {r.get('projectId')}, id: {r['id']})"
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
TIMELINE_TOOLS: list[Any] = [
|
|
||||||
list_timelines,
|
|
||||||
count_timelines,
|
|
||||||
list_timelines_today,
|
|
||||||
create_timeline,
|
|
||||||
update_timeline,
|
|
||||||
delete_timeline,
|
|
||||||
]
|
|
||||||
|
|
||||||
TIMELINE_READ_TOOLS: list[Any] = [
|
|
||||||
list_timelines,
|
|
||||||
count_timelines,
|
|
||||||
list_timelines_today,
|
|
||||||
]
|
|
||||||
|
|||||||
@@ -55,49 +55,23 @@ async def get_current_user(
|
|||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup — subscription row is the authoritative source.
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
# In dev, fall back to 'power' (unlimited) so quota limits don't
|
|
||||||
# block local development when no Stripe subscription exists.
|
|
||||||
from app.models import Subscription, User # noqa: PLC0415
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
tier: str = result.scalar_one_or_none() or "free"
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
|
# Fetch name/surname from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
select(
|
select(User.name, User.surname).where(User.id == user_id)
|
||||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
|
||||||
User.password_hash,
|
|
||||||
).where(User.id == user_id)
|
|
||||||
)
|
)
|
||||||
user_row = user_result.one_or_none()
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
# Convert onboarding_completed_at to epoch ms (int) or None.
|
|
||||||
onboarding_ms: int | None = None
|
|
||||||
if user_row and user_row.onboarding_completed_at is not None:
|
|
||||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
|
||||||
|
|
||||||
# Load decrypted core memory.
|
|
||||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
|
||||||
|
|
||||||
memory_dict: dict[str, str] = {}
|
|
||||||
try:
|
|
||||||
mw = MemoryMiddleware(db)
|
|
||||||
blocks = await mw.list_core_blocks(user_id)
|
|
||||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
|
||||||
except Exception:
|
|
||||||
pass # Non-critical — return empty memory on failure
|
|
||||||
|
|
||||||
return UserProfile(
|
return UserProfile(
|
||||||
id=user_id,
|
id=user_id,
|
||||||
email=email,
|
email=email,
|
||||||
name=user_row.name if user_row else None,
|
name=user_row.name if user_row else None,
|
||||||
surname=user_row.surname if user_row else None,
|
surname=user_row.surname if user_row else None,
|
||||||
avatar_url=user_row.avatar_url if user_row else None,
|
|
||||||
has_password=bool(user_row.password_hash) if user_row else False,
|
|
||||||
tier=tier,
|
tier=tier,
|
||||||
onboarding_completed_at=onboarding_ms,
|
|
||||||
memory=memory_dict,
|
|
||||||
) # type: ignore[arg-type]
|
) # type: ignore[arg-type]
|
||||||
|
|||||||
@@ -8,7 +8,8 @@ that could reveal server-side prompt IP:
|
|||||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||||
- Exact-match known prompt fingerprints
|
- Exact-match known prompt fingerprints
|
||||||
|
|
||||||
The middleware only activates for paths under /api/v1/chat.
|
Binary responses (storage blobs, backup data) are never touched — the
|
||||||
|
middleware only activates for paths under /api/v1/chat.
|
||||||
|
|
||||||
Any sanitisation event is logged as a WARNING with the request path and the
|
Any sanitisation event is logged as a WARNING with the request path and the
|
||||||
names of the fields that were modified.
|
names of the fields that were modified.
|
||||||
|
|||||||
@@ -1,71 +1,74 @@
|
|||||||
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
|
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
Endpoints:
|
||||||
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
POST /agents/journey/start — start a new journey session
|
||||||
frames to the functions exported here.
|
POST /agents/journey/message — continue the conversation
|
||||||
|
|
||||||
|
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
||||||
|
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. FE sends ``journey_start`` frame with basic agent info (directory,
|
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
||||||
data_types, schedule).
|
2. Server creates a session, calls the LLM with a contextual system prompt,
|
||||||
2. Server creates an in-memory session, sets up a WS executor so the
|
and returns the first question.
|
||||||
setup LLM can use file-system tools, does a first directory scrape,
|
3. Client sends follow-up messages to ``/message``.
|
||||||
and sends back a ``journey_reply`` with the first question.
|
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
||||||
3. FE sends ``journey_message`` frames for each user reply.
|
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||||
4. Server appends the user message, calls the LLM (which may read files
|
5. Server parses the block, sets ``done=True``, and returns the template.
|
||||||
via tools), and sends back a ``journey_reply``.
|
|
||||||
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
|
The ``prompt_template`` from the final response is meant to be stored in
|
||||||
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
||||||
6. Server parses and validates the JSON with Pydantic, sends
|
by the Electron client (via the agent CRUD endpoints).
|
||||||
``journey_reply`` with ``done=True`` and the serialised config.
|
|
||||||
FE stores it locally.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.agents.filesystem_agent import make_directory_tools
|
from app.api.deps import get_current_user
|
||||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
from app.core.llm import get_llm
|
||||||
from app.core.llm import get_agent_llm, model_for_agent
|
from app.db import get_session
|
||||||
from app.schemas import AgentConfig
|
from app.models import CloudAgentConfig, LocalAgentConfig
|
||||||
|
from app.schemas import (
|
||||||
|
JourneyMessageRequest,
|
||||||
|
JourneyResponse,
|
||||||
|
JourneyStartRequest,
|
||||||
|
UserProfile,
|
||||||
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
||||||
|
|
||||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
|
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||||
_CONFIG_START = "AGENT_CONFIG_START"
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
_CONFIG_END = "AGENT_CONFIG_END"
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
# Minimum turns before we consider nudging the LLM to wrap up.
|
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
||||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
_MAX_TURNS: int = 5
|
||||||
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
|
|
||||||
_MAX_TURNS: int = 15
|
|
||||||
# Max tool-calling steps per LLM invocation.
|
|
||||||
_MAX_TOOL_STEPS: int = 6
|
|
||||||
|
|
||||||
# ── In-memory session store ───────────────────────────────────────────────
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class JourneySession:
|
class _JourneySession:
|
||||||
session_id: str
|
session_id: str
|
||||||
user_id: str
|
user_id: str
|
||||||
agent_type: str # "local" | "cloud"
|
agent_type: str # "local" | "cloud"
|
||||||
directory: str
|
|
||||||
data_types: list[str]
|
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
system_prompt: str = ""
|
|
||||||
langfuse_prompt: Any = None
|
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
@@ -73,182 +76,103 @@ class JourneySession:
|
|||||||
|
|
||||||
|
|
||||||
# session_id → session
|
# session_id → session
|
||||||
_sessions: dict[str, JourneySession] = {}
|
_sessions: dict[str, _JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
||||||
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
||||||
s = _sessions.get(session_id)
|
s = _sessions.get(session_id)
|
||||||
if s is None or s.is_expired():
|
if s is None or s.is_expired():
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
return None
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
if s.user_id != user_id:
|
if s.user_id != user_id:
|
||||||
return None
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt ─────────────────────────────────────────────────────────
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
_JOURNEY_SYSTEM_PROMPT = """\
|
_LOCAL_PREAMBLE = """\
|
||||||
|
What kind of files are in the directories you want to monitor? \
|
||||||
|
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
||||||
|
|
||||||
|
_CLOUD_PREAMBLE = """\
|
||||||
|
What kind of emails or messages should I look for? \
|
||||||
|
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand what files the user has in their directory and produce a
|
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
||||||
structured AgentConfig JSON that the extraction agent will use as its instruction set.
|
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
||||||
|
|
||||||
You have access to file-system tools to explore the user's directory:
|
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
- list_directory: see folder structure and file names
|
1. The type and format of the source content.
|
||||||
- read_file_content: peek at a file's content
|
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
||||||
- get_file_metadata: check file size, extension, dates
|
3. How fields should be mapped (e.g. email subject → task title).
|
||||||
|
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
5. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
The user's configured directory is: {directory}
|
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
||||||
Target data types: {data_types}
|
these exact markers on their own lines:
|
||||||
|
|
||||||
## Your process
|
{template_start}
|
||||||
|
<the complete extraction prompt here>
|
||||||
|
{template_end}
|
||||||
|
|
||||||
### Step 1 — Explore the directory
|
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
||||||
Use list_directory and read_file_content to understand what types of files are present
|
and must return a JSON array of records in this shape:
|
||||||
(HTML emails, plain-text documents, CSVs, etc.).
|
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
||||||
|
|
||||||
### Step 2 — Identify content types
|
|
||||||
For each distinct file type found, decide:
|
|
||||||
- A short id (e.g. "email_html", "plain_text", "csv")
|
|
||||||
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
|
|
||||||
- A human-readable label and optional detection_hint
|
|
||||||
|
|
||||||
### Step 3 — Ask focused questions (one at a time)
|
|
||||||
Cover these topics based on what you discovered:
|
|
||||||
1. How to map content to entity types (task / note / timeline entry)
|
|
||||||
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
|
|
||||||
3. Priority or status rules (e.g. "urgent" in subject → high priority)
|
|
||||||
4. Date extraction (e.g. "by Friday" → dueDate)
|
|
||||||
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
|
||||||
|
|
||||||
### Step 4 — Produce the AgentConfig JSON
|
|
||||||
Once you are ≥ 90% confident, output the final config between these exact markers
|
|
||||||
(each on its own line):
|
|
||||||
|
|
||||||
{config_start}
|
|
||||||
{{
|
|
||||||
"content_types": [
|
|
||||||
{{
|
|
||||||
"id": "email_html",
|
|
||||||
"label": "Email HTML",
|
|
||||||
"detection_hint": "HTML file with From/To/Subject headers",
|
|
||||||
"preprocessing": "email_html",
|
|
||||||
"extraction_prompt": "Detailed extraction instructions for this content type..."
|
|
||||||
}}
|
|
||||||
],
|
|
||||||
"global_rules": [
|
|
||||||
"If the file cannot be matched to any project, do not create any entity."
|
|
||||||
],
|
|
||||||
"data_types": {data_types_json}
|
|
||||||
}}
|
|
||||||
{config_end}
|
|
||||||
|
|
||||||
## Rules for the extraction_prompt field
|
|
||||||
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
|
|
||||||
- Include field mapping rules based on what you found in the directory
|
|
||||||
- Include priority/status/date rules if applicable
|
|
||||||
- Do NOT include projectId logic — the runner handles project assignment automatically
|
|
||||||
- Do NOT mention isAiSuggested — the runner always sets it to 1
|
|
||||||
|
|
||||||
## Constraints
|
|
||||||
- Never ask about projects, projectId, or how to link records to projects
|
|
||||||
- Never include projectId or project creation logic in the generated config
|
|
||||||
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
|
|
||||||
|
|
||||||
|
Rules for the generated template:
|
||||||
|
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
||||||
|
- Include concrete examples of mappings.
|
||||||
|
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
||||||
|
- Set isAiSuggested: true and isApproved: false on every record.
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Begin by exploring the directory, then ask your first question.\
|
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
def _build_system_prompt(
|
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
||||||
directory: str,
|
source_description = (
|
||||||
data_types: list[str],
|
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
||||||
existing_config: str | None = None,
|
)
|
||||||
) -> tuple[str, Any]:
|
|
||||||
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
|
||||||
existing_section = (
|
existing_section = (
|
||||||
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
f"```json\n{existing_config}\n```\n"
|
f"---\n{existing_template}\n---\n"
|
||||||
if existing_config
|
if existing_template
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
template, prompt_obj = get_prompt_or_fallback(
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
"journey_system", _JOURNEY_SYSTEM_PROMPT
|
source_description=source_description,
|
||||||
)
|
template_start=_TEMPLATE_START,
|
||||||
compiled = compile_prompt(
|
template_end=_TEMPLATE_END,
|
||||||
template,
|
|
||||||
prompt_obj,
|
|
||||||
directory=directory,
|
|
||||||
data_types=", ".join(data_types),
|
|
||||||
data_types_json=json.dumps(data_types),
|
|
||||||
config_start=_CONFIG_START,
|
|
||||||
config_end=_CONFIG_END,
|
|
||||||
existing_section=existing_section,
|
existing_section=existing_section,
|
||||||
|
max_turns=_MAX_TURNS,
|
||||||
)
|
)
|
||||||
return compiled, prompt_obj
|
|
||||||
|
|
||||||
|
|
||||||
# ── AgentConfig extraction ────────────────────────────────────────────────
|
def _first_question(agent_type: str) -> str:
|
||||||
|
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
||||||
|
|
||||||
|
|
||||||
def _extract_agent_config(text: str) -> str | None:
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
"""Return validated AgentConfig JSON string from between markers, or None.
|
|
||||||
|
|
||||||
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
|
||||||
returning. Returns None if markers are absent or JSON is invalid.
|
def _extract_template(text: str) -> str | None:
|
||||||
"""
|
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||||
if _CONFIG_START not in text or _CONFIG_END not in text:
|
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||||
return None
|
|
||||||
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
|
|
||||||
end_idx = text.index(_CONFIG_END)
|
|
||||||
raw = text[start_idx:end_idx].strip()
|
|
||||||
if not raw:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
parsed = AgentConfig.model_validate_json(raw)
|
|
||||||
return parsed.model_dump_json()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
|
|
||||||
return None
|
return None
|
||||||
|
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||||
|
end_idx = text.index(_TEMPLATE_END)
|
||||||
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call with tool support ───────────────────────────────────────────
|
# ── LLM call ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _as_text(content: Any) -> str:
|
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
||||||
if content is None:
|
"""Build LangChain messages from history and invoke the LLM."""
|
||||||
return ""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts: list[str] = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, str):
|
|
||||||
parts.append(item)
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
text = item.get("text")
|
|
||||||
if isinstance(text, str):
|
|
||||||
parts.append(text)
|
|
||||||
return "".join(parts)
|
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
async def _call_llm_with_tools(
|
|
||||||
system_prompt: str,
|
|
||||||
history: list[dict[str, Any]],
|
|
||||||
tools: list[Any],
|
|
||||||
*,
|
|
||||||
user_id: str = "",
|
|
||||||
session_id: str = "",
|
|
||||||
langfuse_prompt: Any = None,
|
|
||||||
) -> str:
|
|
||||||
"""Build LangChain messages from history and invoke the LLM with tools.
|
|
||||||
|
|
||||||
Handles tool-calling loops: if the LLM calls tools, execute them and
|
|
||||||
continue until a final text response is produced.
|
|
||||||
"""
|
|
||||||
lf = get_langfuse()
|
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
for turn in history:
|
for turn in history:
|
||||||
if turn["role"] == "user":
|
if turn["role"] == "user":
|
||||||
@@ -256,258 +180,138 @@ async def _call_llm_with_tools(
|
|||||||
else:
|
else:
|
||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_agent_llm("setup", temperature=0.4)
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
response = await llm.ainvoke(messages)
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
return response.content # type: ignore[return-value]
|
||||||
|
|
||||||
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
|
||||||
_lf_ctx.__enter__()
|
|
||||||
|
|
||||||
_span_ctx = (
|
|
||||||
lf.start_as_current_observation(
|
|
||||||
as_type="span",
|
|
||||||
name="journey-setup",
|
|
||||||
input=history[-1]["content"] if history else "",
|
|
||||||
)
|
|
||||||
if lf else None
|
|
||||||
)
|
|
||||||
_span = _span_ctx.__enter__() if _span_ctx else None
|
|
||||||
|
|
||||||
try:
|
|
||||||
for step in range(_MAX_TOOL_STEPS):
|
|
||||||
_gen_ctx = (
|
|
||||||
lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="journey-setup-llm",
|
|
||||||
model=model_for_agent("setup"),
|
|
||||||
prompt=langfuse_prompt,
|
|
||||||
input=messages,
|
|
||||||
)
|
|
||||||
if lf else None
|
|
||||||
)
|
|
||||||
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
if _gen_ctx:
|
|
||||||
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
|
||||||
_gen_ctx.__exit__(None, None, None)
|
|
||||||
|
|
||||||
resp_text = _as_text(response.content)
|
|
||||||
|
|
||||||
# Guard against empty responses (e.g. model returned finish_reason
|
|
||||||
# 'error' which LiteLLM maps to 'stop' with empty content).
|
|
||||||
if not response.tool_calls and not resp_text.strip():
|
|
||||||
logger.warning(
|
|
||||||
"agent_setup: journey LLM returned empty response at step %d — retrying",
|
|
||||||
step,
|
|
||||||
)
|
|
||||||
# Drop the empty AIMessage so we don't pollute history, and retry.
|
|
||||||
continue
|
|
||||||
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
if _span:
|
|
||||||
_span.update(output=resp_text)
|
|
||||||
return resp_text
|
|
||||||
|
|
||||||
for call in response.tool_calls:
|
|
||||||
call_name = str(call.get("name", ""))
|
|
||||||
call_args = call.get("args", {})
|
|
||||||
logger.info(
|
|
||||||
"agent_setup: journey tool_call name=%s args=%s",
|
|
||||||
call_name,
|
|
||||||
json.dumps(call_args, ensure_ascii=True)[:500],
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_fn = tool_map.get(call_name)
|
|
||||||
if tool_fn is None:
|
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
|
||||||
else:
|
|
||||||
tool_output = await tool_fn.ainvoke(call_args)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"agent_setup: journey tool_result name=%s output=%s",
|
|
||||||
call_name,
|
|
||||||
str(tool_output)[:800],
|
|
||||||
)
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
||||||
|
|
||||||
# Fallback: exceeded max steps.
|
|
||||||
final = await llm.ainvoke(messages)
|
|
||||||
final_text = _as_text(final.content)
|
|
||||||
if _span:
|
|
||||||
_span.update(output=final_text)
|
|
||||||
return final_text or (
|
|
||||||
"Sorry, I had trouble processing the files. "
|
|
||||||
"Could you try again? If the issue persists, the files might be too large for me to analyse."
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
if _span_ctx:
|
|
||||||
_span_ctx.__exit__(None, None, None)
|
|
||||||
_lf_ctx.__exit__(None, None, None)
|
|
||||||
if lf:
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
# ── Existing-config loader ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def handle_journey_start(
|
async def _load_existing_template(
|
||||||
|
agent_id: str,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
frame: dict[str, Any],
|
db: AsyncSession,
|
||||||
) -> dict[str, Any]:
|
) -> str | None:
|
||||||
"""Handle a ``journey_start`` WS frame.
|
"""Return the prompt_template of an existing agent config, or None."""
|
||||||
|
# Try local first, then cloud.
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local = local_result.scalar_one_or_none()
|
||||||
|
if local is not None:
|
||||||
|
return local.prompt_template
|
||||||
|
|
||||||
Creates a session, runs the setup LLM with directory exploration,
|
cloud_result = await db.execute(
|
||||||
and returns the ``journey_reply`` payload.
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud = cloud_result.scalar_one_or_none()
|
||||||
|
return cloud.prompt_template if cloud is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
|
async def start_journey(
|
||||||
|
body: JourneyStartRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Start a new Chatbot Journey session.
|
||||||
|
|
||||||
|
If ``agent_id`` is provided the session is pre-seeded with the existing
|
||||||
|
agent's ``prompt_template`` so the user can refine it.
|
||||||
"""
|
"""
|
||||||
agent_type = frame.get("agent_type", "local")
|
# Load existing template (may be None).
|
||||||
directory = frame.get("directory", "")
|
existing_template: str | None = None
|
||||||
data_types = frame.get("data_types", [])
|
if body.agent_id:
|
||||||
existing_config = frame.get("existing_config")
|
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
||||||
|
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
||||||
|
# the user may be starting a fresh journey for a not-yet-persisted config).
|
||||||
|
|
||||||
# Use the session_id provided by the FE so the reply matches the
|
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
||||||
# listener key; fall back to a generated one if absent.
|
first_question = _first_question(body.agent_type)
|
||||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
|
||||||
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
|
|
||||||
|
|
||||||
session = JourneySession(
|
session_id = str(uuid.uuid4())
|
||||||
|
session = _JourneySession(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
user_id=user_id,
|
user_id=current_user.id,
|
||||||
agent_type=agent_type,
|
agent_type=body.agent_type,
|
||||||
directory=directory,
|
# Seed history with the AI's first question so it stays consistent.
|
||||||
data_types=data_types,
|
history=[{"role": "assistant", "content": first_question}],
|
||||||
system_prompt=system_prompt,
|
|
||||||
langfuse_prompt=langfuse_prompt,
|
|
||||||
)
|
)
|
||||||
|
# Store the system prompt inside the session for reuse in /message.
|
||||||
# Seed with an initial user message — some providers require at least one
|
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
||||||
# user/input message to be present.
|
|
||||||
seed_history: list[dict[str, Any]] = [
|
|
||||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
|
||||||
]
|
|
||||||
ai_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
history=seed_history,
|
|
||||||
tools=make_directory_tools(directory),
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
langfuse_prompt=langfuse_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
session.history.extend(seed_history)
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
_sessions[session_id] = session
|
_sessions[session_id] = session
|
||||||
|
|
||||||
logger.info(
|
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
||||||
"agent_setup: journey session %s started for user %s (directory=%s)",
|
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
||||||
session_id,
|
|
||||||
user_id,
|
|
||||||
directory,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Check if the LLM produced the config on the first turn (unlikely but possible).
|
|
||||||
agent_config = _extract_agent_config(ai_reply)
|
|
||||||
done = agent_config is not None
|
|
||||||
|
|
||||||
display_message = ai_reply
|
|
||||||
if done:
|
|
||||||
display_message = (
|
|
||||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
|
||||||
)
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": display_message,
|
|
||||||
"done": done,
|
|
||||||
"agent_config": agent_config,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_journey_message(
|
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
user_id: str,
|
async def send_journey_message(
|
||||||
frame: dict[str, Any],
|
body: JourneyMessageRequest,
|
||||||
) -> dict[str, Any]:
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
"""Handle a ``journey_message`` WS frame.
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Send a message in an existing Chatbot Journey session.
|
||||||
|
|
||||||
Appends the user message, calls the LLM, and returns the
|
The server appends the user's message to the conversation history,
|
||||||
``journey_reply`` payload.
|
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
||||||
|
``prompt_template`` block the response includes ``done=True`` and the
|
||||||
|
extracted template.
|
||||||
"""
|
"""
|
||||||
session_id = frame.get("session_id", "")
|
session = _get_session(body.session_id, current_user.id)
|
||||||
message = frame.get("message", "")
|
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
||||||
|
|
||||||
session = get_journey_session(session_id, user_id)
|
# Append user turn to history.
|
||||||
if session is None:
|
session.history.append({"role": "user", "content": body.message})
|
||||||
return {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": "Journey session not found or expired. Please start a new setup.",
|
|
||||||
"done": True,
|
|
||||||
"agent_config": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Append user turn.
|
# Call the LLM with the full conversation so far.
|
||||||
session.history.append({"role": "user", "content": message})
|
ai_reply = await _call_llm(system_prompt, session.history)
|
||||||
|
|
||||||
# Call the LLM with tools.
|
|
||||||
session_tools = make_directory_tools(session.directory)
|
|
||||||
ai_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=session.system_prompt,
|
|
||||||
history=session.history,
|
|
||||||
tools=session_tools,
|
|
||||||
user_id=session.user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
langfuse_prompt=session.langfuse_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
# Append AI turn.
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
# Check if the LLM produced the final config.
|
# Check if the LLM produced the final template.
|
||||||
agent_config = _extract_agent_config(ai_reply)
|
prompt_template = _extract_template(ai_reply)
|
||||||
done = agent_config is not None
|
done = prompt_template is not None
|
||||||
|
|
||||||
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
|
|
||||||
if not done:
|
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
|
||||||
if turns >= _MAX_TURNS:
|
|
||||||
nudge_content = (
|
|
||||||
"[System: You have enough information. Please generate the final "
|
|
||||||
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
|
||||||
)
|
|
||||||
session.history.append({"role": "user", "content": nudge_content})
|
|
||||||
|
|
||||||
nudge_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=session.system_prompt,
|
|
||||||
history=session.history,
|
|
||||||
tools=session_tools,
|
|
||||||
user_id=session.user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
langfuse_prompt=session.langfuse_prompt,
|
|
||||||
)
|
|
||||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
|
||||||
|
|
||||||
agent_config = _extract_agent_config(nudge_reply)
|
|
||||||
if agent_config is not None:
|
|
||||||
done = True
|
|
||||||
ai_reply = nudge_reply
|
|
||||||
|
|
||||||
|
# Strip the sentinel markers from the message shown to the user.
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
if _CONFIG_START in ai_reply
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
else "Here is your agent configuration. You can save it or continue refining."
|
|
||||||
)
|
)
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
|
|
||||||
|
|
||||||
return {
|
if done:
|
||||||
"type": "journey_reply",
|
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
||||||
"session_id": session_id,
|
# Clean up the session immediately on completion.
|
||||||
"message": display_message,
|
_sessions.pop(body.session_id, None)
|
||||||
"done": done,
|
else:
|
||||||
"agent_config": agent_config,
|
# Nudge the LLM to wrap up after max turns.
|
||||||
}
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
|
if turns >= _MAX_TURNS:
|
||||||
|
# Add a system-level nudge as a hidden user message.
|
||||||
|
session.history.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"[System: You have enough information. Please generate the final "
|
||||||
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
return JourneyResponse(
|
||||||
|
session_id=body.session_id,
|
||||||
|
message=display_message,
|
||||||
|
done=done,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,45 +1,48 @@
|
|||||||
"""Agent routes.
|
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
||||||
|
|
||||||
Backend responsibilities are intentionally minimal:
|
Endpoints:
|
||||||
GET /agents/catalog — static catalog for UI display
|
GET /agents/catalog — hardcoded agent type catalog
|
||||||
POST /agents/can-create — billing eligibility check
|
GET /agents/local — list user's local agent configs
|
||||||
POST /agents/trigger — trigger a local agent run
|
POST /agents/local — create local agent (tier-gated)
|
||||||
|
PUT /agents/local/{agent_id} — partial update (ownership check)
|
||||||
Agent configuration is owned by the Electron app and is not persisted
|
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
||||||
in backend agent-config tables.
|
GET /agents/cloud — list user's cloud agent configs
|
||||||
|
POST /agents/cloud — create cloud agent (tier-gated)
|
||||||
|
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
||||||
|
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
||||||
|
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
||||||
|
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
from datetime import datetime
|
||||||
import uuid
|
from typing import Any
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, or_, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
from app.core.agent_runner import is_agent_running, run_local_agent
|
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.note_summarizer import generate_note_summary
|
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, LocalAgentConfig
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
AgentCatalogItem,
|
AgentCatalogItem,
|
||||||
AgentCreationCheckRequest,
|
|
||||||
AgentCreationCheckResponse,
|
|
||||||
AgentRunLogResponse,
|
AgentRunLogResponse,
|
||||||
AgentTriggerRequest,
|
CloudAgentConfigCreate,
|
||||||
|
CloudAgentConfigResponse,
|
||||||
|
CloudAgentConfigUpdate,
|
||||||
|
LocalAgentConfigCreate,
|
||||||
|
LocalAgentConfigResponse,
|
||||||
|
LocalAgentConfigUpdate,
|
||||||
UserProfile,
|
UserProfile,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||||
|
|
||||||
|
|
||||||
@@ -53,21 +56,39 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|||||||
return int(dt.timestamp() * 1000) if dt else None
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
def _to_data_types(values: list[str]) -> list[str]:
|
# ── Model → schema converters ─────────────────────────────────────────
|
||||||
normalize = {
|
|
||||||
"task": "tasks", "tasks": "tasks",
|
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
||||||
"note": "notes", "notes": "notes",
|
return LocalAgentConfigResponse(
|
||||||
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
id=a.id,
|
||||||
"project": "projects", "projects": "projects",
|
name=a.name,
|
||||||
}
|
device_id=a.device_id,
|
||||||
seen: set[str] = set()
|
directory_paths=a.directory_paths,
|
||||||
result: list[str] = []
|
data_types=a.data_types,
|
||||||
for v in values:
|
prompt_template=a.prompt_template,
|
||||||
mapped = normalize.get(v)
|
file_extensions=a.file_extensions,
|
||||||
if mapped and mapped not in seen:
|
schedule_cron=a.schedule_cron,
|
||||||
seen.add(mapped)
|
enabled=a.enabled,
|
||||||
result.append(mapped)
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
return result
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
||||||
|
return CloudAgentConfigResponse(
|
||||||
|
id=a.id,
|
||||||
|
provider=a.provider, # type: ignore[arg-type]
|
||||||
|
name=a.name,
|
||||||
|
data_types=a.data_types,
|
||||||
|
prompt_template=a.prompt_template,
|
||||||
|
schedule_cron=a.schedule_cron,
|
||||||
|
filter_config=a.filter_config,
|
||||||
|
enabled=a.enabled,
|
||||||
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
@@ -84,42 +105,77 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
# ── Ownership-checked lookups ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_local_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> LocalAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_cloud_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> CloudAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier limit helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return combined enabled local + cloud agent count for the user."""
|
||||||
|
local_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(LocalAgentConfig.id)).where(
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
LocalAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
cloud_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(CloudAgentConfig.id)).where(
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
CloudAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
return local_count + cloud_count
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
if limit != -1 and current_count >= limit:
|
if limit != -1 and current_count >= limit:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
)
|
)
|
||||||
return limit
|
|
||||||
|
|
||||||
|
|
||||||
async def _enforce_run_frequency(
|
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
||||||
tier: str,
|
|
||||||
user_id: str,
|
|
||||||
db: AsyncSession,
|
|
||||||
) -> None:
|
|
||||||
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
|
||||||
if limit == -1:
|
|
||||||
return # unlimited
|
|
||||||
|
|
||||||
today_start = datetime.now(timezone.utc).replace(
|
class _RunsPage(BaseModel):
|
||||||
hour=0, minute=0, second=0, microsecond=0
|
total: int
|
||||||
)
|
page: int
|
||||||
result = await db.execute(
|
limit: int
|
||||||
select(func.count(AgentRunLog.id)).where(
|
items: list[AgentRunLogResponse]
|
||||||
AgentRunLog.user_id == user_id,
|
|
||||||
AgentRunLog.started_at >= today_start,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
runs_today: int = result.scalar_one()
|
|
||||||
|
|
||||||
if runs_today >= limit:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
@@ -153,68 +209,229 @@ async def get_agent_catalog(
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
# ── Local agent CRUD ──────────────────────────────────────────────────
|
||||||
async def can_create_agent(
|
|
||||||
body: AgentCreationCheckRequest,
|
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
||||||
|
async def list_local_agents(
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> AgentCreationCheckResponse:
|
db: AsyncSession = Depends(get_session),
|
||||||
"""Check if the user can create one more agent based on billing tier.
|
) -> list[LocalAgentConfigResponse]:
|
||||||
|
"""List all local directory agent configs owned by the authenticated user."""
|
||||||
Since configuration is client-owned, the Electron app sends its current
|
result = await db.execute(
|
||||||
active agent count and the backend applies tier limits.
|
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
||||||
"""
|
|
||||||
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
|
||||||
allowed = limit == -1 or body.active_agents < limit
|
|
||||||
return AgentCreationCheckResponse(
|
|
||||||
allowed=allowed,
|
|
||||||
tier=current_user.tier,
|
|
||||||
active_agents=body.active_agents,
|
|
||||||
limit=limit,
|
|
||||||
)
|
)
|
||||||
|
return [_to_local_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_local_agent(
|
||||||
|
body: LocalAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Create a new local directory agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = LocalAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
name=body.name,
|
||||||
|
device_id=body.device_id,
|
||||||
|
directory_paths=body.directory_paths,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
file_extensions=body.file_extensions,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
||||||
|
async def update_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: LocalAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Partially update a local agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/local/{agent_id}", response_model=dict)
|
||||||
|
async def delete_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
||||||
|
async def list_cloud_agents(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[CloudAgentConfigResponse]:
|
||||||
|
"""List all cloud connector agent configs owned by the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
return [_to_cloud_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_cloud_agent(
|
||||||
|
body: CloudAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Create a new cloud connector agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = CloudAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
provider=body.provider,
|
||||||
|
name=body.name,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
oauth_token_encrypted=body.oauth_token_encrypted,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
filter_config=body.filter_config,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
||||||
|
async def update_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: CloudAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Partially update a cloud agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/cloud/{agent_id}", response_model=dict)
|
||||||
|
async def delete_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Run logs ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/runs", response_model=_RunsPage)
|
||||||
|
async def list_run_logs(
|
||||||
|
agent_id: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=20, ge=1, le=100),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _RunsPage:
|
||||||
|
"""Return paginated run logs for the authenticated user.
|
||||||
|
|
||||||
|
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
||||||
|
"""
|
||||||
|
base_filter = [AgentRunLog.user_id == current_user.id]
|
||||||
|
if agent_id:
|
||||||
|
base_filter.append(AgentRunLog.agent_id == agent_id)
|
||||||
|
|
||||||
|
total = (
|
||||||
|
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
||||||
|
).scalar_one()
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentRunLog)
|
||||||
|
.where(*base_filter)
|
||||||
|
.order_by(AgentRunLog.started_at.desc())
|
||||||
|
.offset((page - 1) * limit)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
||||||
|
|
||||||
|
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Manual trigger stub ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
async def trigger_agent_run(
|
async def trigger_agent_run(
|
||||||
body: AgentTriggerRequest,
|
agent_id: str,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
db: AsyncSession = Depends(get_session),
|
db: AsyncSession = Depends(get_session),
|
||||||
) -> AgentRunLogResponse:
|
) -> AgentRunLogResponse:
|
||||||
"""Trigger a local agent run using client-provided configuration."""
|
"""Manually trigger an agent run.
|
||||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
|
||||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
|
||||||
|
|
||||||
last_run_dt = (
|
Looks up the agent config (local or cloud) by ID with ownership check,
|
||||||
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
|
creates a run log entry with ``status="running"``, and returns it.
|
||||||
if body.last_run_at
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
config = LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=current_user.id,
|
|
||||||
device_id=body.device_id,
|
|
||||||
name="Local Directory Monitor",
|
|
||||||
directory_paths=[body.directory],
|
|
||||||
data_types=_to_data_types(body.what_to_extract),
|
|
||||||
prompt_template=body.custom_agent_prompt or "",
|
|
||||||
agent_config=body.agent_config,
|
|
||||||
file_extensions=[],
|
|
||||||
schedule_cron=body.batch_interval,
|
|
||||||
enabled=True,
|
|
||||||
last_run_at=last_run_dt,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
Actual dispatch to the agent runner is wired in Step 3.4 once
|
||||||
stable_agent_id = body.agent_id or config.id
|
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
||||||
|
"""
|
||||||
|
# Determine agent type by trying local first, then cloud.
|
||||||
|
# Keep the full config object so we can pass it to the agent runner.
|
||||||
|
local_config: LocalAgentConfig | None = None
|
||||||
|
cloud_config: CloudAgentConfig | None = None
|
||||||
|
|
||||||
if is_agent_running(stable_agent_id):
|
local_result = await db.execute(
|
||||||
raise HTTPException(
|
select(LocalAgentConfig).where(
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
LocalAgentConfig.id == agent_id,
|
||||||
detail="Agent is already running. Only one run per agent is allowed at a time.",
|
LocalAgentConfig.user_id == current_user.id,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
local_config = local_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if local_config is not None:
|
||||||
|
agent_type = "local"
|
||||||
|
else:
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud_config = cloud_result.scalar_one_or_none()
|
||||||
|
if cloud_config is not None:
|
||||||
|
agent_type = "cloud"
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
|
||||||
run_log = AgentRunLog(
|
run_log = AgentRunLog(
|
||||||
agent_id=stable_agent_id,
|
agent_id=agent_id,
|
||||||
agent_type="local",
|
agent_type=agent_type,
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
status="running",
|
status="running",
|
||||||
)
|
)
|
||||||
@@ -222,36 +439,14 @@ async def trigger_agent_run(
|
|||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
run_context = {
|
# Dispatch the run as a background task — returns 202 immediately.
|
||||||
"type": "agent_batch",
|
if agent_type == "local" and local_config is not None:
|
||||||
"run_id": run_log.id,
|
asyncio.create_task(
|
||||||
"agent_id": stable_agent_id,
|
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
||||||
}
|
)
|
||||||
|
elif agent_type == "cloud" and cloud_config is not None:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
|
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
||||||
)
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|
||||||
|
|
||||||
# ── Note summary endpoint ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class NoteSummarizeRequest(BaseModel):
|
|
||||||
title: str
|
|
||||||
content: str
|
|
||||||
|
|
||||||
|
|
||||||
class NoteSummarizeResponse(BaseModel):
|
|
||||||
summary: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
|
|
||||||
async def summarize_note(
|
|
||||||
body: NoteSummarizeRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> NoteSummarizeResponse:
|
|
||||||
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
|
|
||||||
summary = await generate_note_summary(body.title, body.content)
|
|
||||||
return NoteSummarizeResponse(summary=summary)
|
|
||||||
|
|||||||
@@ -1,68 +1,34 @@
|
|||||||
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||||
SHA-256 hashes so plaintext never reaches the DB.
|
SHA-256 hashes so plaintext never reaches the DB.
|
||||||
|
|
||||||
OAuth (Google):
|
|
||||||
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
|
|
||||||
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import json
|
|
||||||
import time
|
import time
|
||||||
import urllib.parse
|
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from fastapi.responses import RedirectResponse
|
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.llm import get_llm
|
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import OAuthAccount, RefreshToken, User
|
from app.models import RefreshToken, User
|
||||||
from app.schemas import AuthTokens, UserProfile
|
from app.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
# ── OAuth provider registry ───────────────────────────────────────────
|
|
||||||
|
|
||||||
def _get_google_provider() -> GoogleOAuthProvider:
|
|
||||||
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
|
||||||
"Google login is not configured on this server",
|
|
||||||
)
|
|
||||||
return GoogleOAuthProvider(
|
|
||||||
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
|
|
||||||
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
|
|
||||||
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
_PROVIDERS = {"google": _get_google_provider}
|
|
||||||
|
|
||||||
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
|
|
||||||
# Production note: replace with Redis for multi-process deployments.
|
|
||||||
_pending_states: dict[str, tuple[str, float]] = {}
|
|
||||||
_STATE_TTL_SECONDS = 600 # 10 minutes
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ─────────────────────────────────────────────────
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -265,531 +231,5 @@ async def update_profile(
|
|||||||
email=user.email,
|
email=user.email,
|
||||||
name=user.name,
|
name=user.name,
|
||||||
surname=user.surname,
|
surname=user.surname,
|
||||||
avatar_url=user.avatar_url,
|
|
||||||
tier=current_user.tier,
|
tier=current_user.tier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── OAuth helpers ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
|
|
||||||
"""Create a refresh token row and return (plain_token, AuthTokens)."""
|
|
||||||
plain_token = str(uuid.uuid4())
|
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
||||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
||||||
)
|
|
||||||
rt = RefreshToken(
|
|
||||||
user_id=user.id,
|
|
||||||
token_hash=_hash_token(plain_token),
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
db.add(rt)
|
|
||||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
||||||
return plain_token, AuthTokens(
|
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=plain_token,
|
|
||||||
expires_at=expires_at_ms,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── OAuth request/response schemas ───────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _OAuthAuthorizeResponse(BaseModel):
|
|
||||||
url: str
|
|
||||||
state: str
|
|
||||||
|
|
||||||
|
|
||||||
class _OAuthCallbackRequest(BaseModel):
|
|
||||||
code: str
|
|
||||||
state: str
|
|
||||||
|
|
||||||
|
|
||||||
# ── OAuth routes ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/oauth/{provider}/web-callback",
|
|
||||||
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
|
|
||||||
include_in_schema=False,
|
|
||||||
)
|
|
||||||
async def oauth_web_callback(
|
|
||||||
provider: Literal["google"],
|
|
||||||
code: str,
|
|
||||||
state: str,
|
|
||||||
) -> RedirectResponse:
|
|
||||||
"""Google redirects here after user consent.
|
|
||||||
|
|
||||||
This endpoint immediately redirects to the Electron deep-link URI so the
|
|
||||||
desktop app receives the authorization code. It is intentionally simple —
|
|
||||||
no state validation here (the Electron app + backend callback do that).
|
|
||||||
|
|
||||||
Registered in Google Cloud Console as:
|
|
||||||
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
|
|
||||||
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
|
|
||||||
"""
|
|
||||||
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
|
|
||||||
deep_link = f"adiuvai://oauth/callback?{params}"
|
|
||||||
return RedirectResponse(url=deep_link, status_code=302)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get(
|
|
||||||
"/oauth/{provider}/authorize",
|
|
||||||
response_model=_OAuthAuthorizeResponse,
|
|
||||||
summary="Start OAuth flow — returns the provider consent-screen URL",
|
|
||||||
)
|
|
||||||
async def oauth_authorize(
|
|
||||||
provider: Literal["google"],
|
|
||||||
) -> _OAuthAuthorizeResponse:
|
|
||||||
"""Generate a PKCE state + code_challenge and return the authorization URL.
|
|
||||||
|
|
||||||
The client opens this URL in the system browser. After the user grants
|
|
||||||
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
|
|
||||||
with ``code`` and ``state`` query params. The client then calls
|
|
||||||
``POST /auth/oauth/{provider}/callback`` with those values.
|
|
||||||
"""
|
|
||||||
provider_factory = _PROVIDERS.get(provider)
|
|
||||||
if provider_factory is None:
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
|
||||||
|
|
||||||
oauth_provider = provider_factory()
|
|
||||||
state = str(uuid.uuid4())
|
|
||||||
code_verifier, code_challenge = generate_pkce_pair()
|
|
||||||
|
|
||||||
# Purge expired states to prevent unbounded growth.
|
|
||||||
now = time.time()
|
|
||||||
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
|
|
||||||
for s in expired:
|
|
||||||
del _pending_states[s]
|
|
||||||
|
|
||||||
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
|
|
||||||
|
|
||||||
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
|
|
||||||
return _OAuthAuthorizeResponse(url=url, state=state)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post(
|
|
||||||
"/oauth/{provider}/callback",
|
|
||||||
response_model=AuthTokens,
|
|
||||||
summary="Complete OAuth flow — exchange code and issue JWT tokens",
|
|
||||||
)
|
|
||||||
async def oauth_callback(
|
|
||||||
provider: Literal["google"],
|
|
||||||
body: _OAuthCallbackRequest,
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> AuthTokens:
|
|
||||||
"""Validate state, exchange the authorization code, and sign in (or register) the user.
|
|
||||||
|
|
||||||
Resolution order:
|
|
||||||
1. ``oauth_accounts`` row match → existing user, log in.
|
|
||||||
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
|
|
||||||
3. No match → create new user (password_hash=None, avatar from provider).
|
|
||||||
"""
|
|
||||||
provider_factory = _PROVIDERS.get(provider)
|
|
||||||
if provider_factory is None:
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
|
||||||
|
|
||||||
# Validate state (CSRF protection).
|
|
||||||
now = time.time()
|
|
||||||
entry = _pending_states.pop(body.state, None)
|
|
||||||
if entry is None or entry[1] < now:
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
|
||||||
|
|
||||||
code_verifier, _ = entry
|
|
||||||
|
|
||||||
oauth_provider = provider_factory()
|
|
||||||
|
|
||||||
# Exchange code for tokens.
|
|
||||||
try:
|
|
||||||
token_data = await oauth_provider.exchange_code(
|
|
||||||
code=body.code,
|
|
||||||
code_verifier=code_verifier,
|
|
||||||
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
|
|
||||||
)
|
|
||||||
|
|
||||||
access_token_google = token_data.get("access_token")
|
|
||||||
if not access_token_google:
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
|
|
||||||
|
|
||||||
# Fetch user identity.
|
|
||||||
try:
|
|
||||||
userinfo = await oauth_provider.get_userinfo(access_token_google)
|
|
||||||
except Exception:
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
|
|
||||||
|
|
||||||
# ── Resolution order ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
# 1. Existing OAuth link?
|
|
||||||
oauth_result = await db.execute(
|
|
||||||
select(OAuthAccount).where(
|
|
||||||
OAuthAccount.provider == provider,
|
|
||||||
OAuthAccount.provider_user_id == userinfo.provider_user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
oauth_account = oauth_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if oauth_account is not None:
|
|
||||||
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
|
|
||||||
user = user_result.scalar_one()
|
|
||||||
# Backfill avatar if the user doesn't have one yet.
|
|
||||||
if user.avatar_url is None and userinfo.avatar_url:
|
|
||||||
user.avatar_url = userinfo.avatar_url
|
|
||||||
await db.commit()
|
|
||||||
plain_token, tokens = await _issue_refresh_token(user, db)
|
|
||||||
await db.commit()
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
# 2. Email match with a verified Google email → link accounts.
|
|
||||||
if userinfo.email_verified:
|
|
||||||
email_result = await db.execute(select(User).where(User.email == userinfo.email))
|
|
||||||
existing_user = email_result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if existing_user is not None:
|
|
||||||
new_link = OAuthAccount(
|
|
||||||
user_id=existing_user.id,
|
|
||||||
provider=provider,
|
|
||||||
provider_user_id=userinfo.provider_user_id,
|
|
||||||
provider_email=userinfo.email,
|
|
||||||
)
|
|
||||||
db.add(new_link)
|
|
||||||
if existing_user.avatar_url is None and userinfo.avatar_url:
|
|
||||||
existing_user.avatar_url = userinfo.avatar_url
|
|
||||||
plain_token, tokens = await _issue_refresh_token(existing_user, db)
|
|
||||||
await db.commit()
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
# Guard: if the email is already taken but we couldn't auto-link (e.g.
|
|
||||||
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
|
|
||||||
if not userinfo.email_verified:
|
|
||||||
conflict = await db.execute(select(User).where(User.email == userinfo.email))
|
|
||||||
if conflict.scalar_one_or_none() is not None:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_409_CONFLICT,
|
|
||||||
"An account with this email already exists. "
|
|
||||||
"Please sign in with your password.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. New user — social-only account (no password).
|
|
||||||
new_user = User(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
email=userinfo.email,
|
|
||||||
name=userinfo.name,
|
|
||||||
password_hash=None,
|
|
||||||
avatar_url=userinfo.avatar_url,
|
|
||||||
tier="free",
|
|
||||||
encryption_key=Fernet.generate_key().decode(),
|
|
||||||
)
|
|
||||||
db.add(new_user)
|
|
||||||
await db.flush() # populate new_user.id
|
|
||||||
|
|
||||||
new_oauth = OAuthAccount(
|
|
||||||
user_id=new_user.id,
|
|
||||||
provider=provider,
|
|
||||||
provider_user_id=userinfo.provider_user_id,
|
|
||||||
provider_email=userinfo.email,
|
|
||||||
)
|
|
||||||
db.add(new_oauth)
|
|
||||||
|
|
||||||
plain_token, tokens = await _issue_refresh_token(new_user, db)
|
|
||||||
await db.commit()
|
|
||||||
return tokens
|
|
||||||
|
|
||||||
|
|
||||||
# ── Onboarding helpers ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
|
|
||||||
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
|
|
||||||
|
|
||||||
# We can't call the FastAPI dependency directly, but we can replicate
|
|
||||||
# the core logic inline. Instead, we just re-query the same way.
|
|
||||||
from app.models import Subscription # noqa: PLC0415
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
||||||
)
|
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
user_result = await db.execute(
|
|
||||||
select(
|
|
||||||
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
|
||||||
User.password_hash,
|
|
||||||
).where(User.id == user_id)
|
|
||||||
)
|
|
||||||
user_row = user_result.one_or_none()
|
|
||||||
|
|
||||||
onboarding_ms: int | None = None
|
|
||||||
if user_row and user_row.onboarding_completed_at is not None:
|
|
||||||
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
|
||||||
|
|
||||||
memory_dict: dict[str, str] = {}
|
|
||||||
try:
|
|
||||||
mw = MemoryMiddleware(db)
|
|
||||||
blocks = await mw.list_core_blocks(user_id)
|
|
||||||
memory_dict = {b["label"]: b["value"] for b in blocks}
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return UserProfile(
|
|
||||||
id=user_id,
|
|
||||||
email=email,
|
|
||||||
name=user_row.name if user_row else None,
|
|
||||||
surname=user_row.surname if user_row else None,
|
|
||||||
avatar_url=user_row.avatar_url if user_row else None,
|
|
||||||
has_password=bool(user_row.password_hash) if user_row else False,
|
|
||||||
tier=tier,
|
|
||||||
onboarding_completed_at=onboarding_ms,
|
|
||||||
memory=memory_dict,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Onboarding routes ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _UpdateMemoryRequest(BaseModel):
|
|
||||||
memory: dict[str, str] = Field(default_factory=dict)
|
|
||||||
mark_onboarded: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/me/memory", response_model=UserProfile)
|
|
||||||
async def update_memory(
|
|
||||||
body: _UpdateMemoryRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> UserProfile:
|
|
||||||
"""Update core memory key/value pairs and optionally mark onboarding complete."""
|
|
||||||
mw = MemoryMiddleware(db)
|
|
||||||
for key, value in body.memory.items():
|
|
||||||
await mw.update_core(current_user.id, key, value)
|
|
||||||
if body.mark_onboarded:
|
|
||||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
||||||
user = result.scalar_one()
|
|
||||||
user.onboarding_completed_at = datetime.now(timezone.utc)
|
|
||||||
await db.commit()
|
|
||||||
return await _build_profile(current_user.id, current_user.email, db)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/me/onboarding/reset")
|
|
||||||
async def reset_onboarding(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
):
|
|
||||||
"""Reset onboarding so the wizard runs again on next login."""
|
|
||||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
||||||
user = result.scalar_one()
|
|
||||||
user.onboarding_completed_at = None
|
|
||||||
await db.commit()
|
|
||||||
return {"status": "reset"}
|
|
||||||
|
|
||||||
|
|
||||||
class _NormalizeRequest(BaseModel):
|
|
||||||
inputs: dict[str, str]
|
|
||||||
|
|
||||||
|
|
||||||
class _NormalizeResponse(BaseModel):
|
|
||||||
normalized: dict[str, str]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
|
|
||||||
async def normalize_onboarding(
|
|
||||||
body: _NormalizeRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> _NormalizeResponse:
|
|
||||||
"""One-shot LLM normalization for free-text onboarding answers."""
|
|
||||||
if not body.inputs:
|
|
||||||
return _NormalizeResponse(normalized={})
|
|
||||||
try:
|
|
||||||
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
|
||||||
prompt = (
|
|
||||||
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
|
|
||||||
"Return a JSON object with the same keys and normalized values.\n"
|
|
||||||
"Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n"
|
|
||||||
f"Input: {json.dumps(body.inputs)}"
|
|
||||||
)
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[
|
|
||||||
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
|
|
||||||
{"role": "user", "content": prompt},
|
|
||||||
],
|
|
||||||
)
|
|
||||||
normalized = json.loads(response.content)
|
|
||||||
return _NormalizeResponse(normalized=normalized)
|
|
||||||
except Exception:
|
|
||||||
# LLM failure must never block onboarding — return inputs unchanged
|
|
||||||
return _NormalizeResponse(normalized=body.inputs)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Password management ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _ChangePasswordRequest(BaseModel):
|
|
||||||
current_password: str = Field(min_length=1)
|
|
||||||
new_password: str = Field(min_length=8)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
|
||||||
async def change_password(
|
|
||||||
body: _ChangePasswordRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Change the authenticated user's password.
|
|
||||||
|
|
||||||
Requires the current password for verification.
|
|
||||||
Returns 400 for social-only users (no password set).
|
|
||||||
"""
|
|
||||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
||||||
user = result.scalar_one()
|
|
||||||
|
|
||||||
if user.password_hash is None:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_400_BAD_REQUEST,
|
|
||||||
"This account uses social login and has no password to change",
|
|
||||||
)
|
|
||||||
|
|
||||||
if not _verify_password(body.current_password, user.password_hash):
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
|
||||||
|
|
||||||
user.password_hash = _hash_password(body.new_password)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── OAuth account management ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me/oauth-accounts", response_model=list[dict])
|
|
||||||
async def list_oauth_accounts(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[dict]:
|
|
||||||
"""List all OAuth providers linked to the authenticated user."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
accounts = result.scalars().all()
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"provider": a.provider,
|
|
||||||
"provider_email": a.provider_email,
|
|
||||||
"created_at": int(a.created_at.timestamp() * 1000),
|
|
||||||
}
|
|
||||||
for a in accounts
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
|
||||||
async def unlink_oauth_account(
|
|
||||||
provider: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Unlink an OAuth provider from the authenticated user.
|
|
||||||
|
|
||||||
Refuses if the user has no password and this is their only login method.
|
|
||||||
"""
|
|
||||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
||||||
user = result.scalar_one()
|
|
||||||
|
|
||||||
oauth_result = await db.execute(
|
|
||||||
select(OAuthAccount).where(
|
|
||||||
OAuthAccount.user_id == current_user.id,
|
|
||||||
OAuthAccount.provider == provider,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
account = oauth_result.scalar_one_or_none()
|
|
||||||
if account is None:
|
|
||||||
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
|
||||||
|
|
||||||
# Safety: don't let users lock themselves out.
|
|
||||||
all_oauth = await db.execute(
|
|
||||||
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
oauth_count = len(all_oauth.scalars().all())
|
|
||||||
|
|
||||||
if user.password_hash is None and oauth_count <= 1:
|
|
||||||
raise HTTPException(
|
|
||||||
status.HTTP_400_BAD_REQUEST,
|
|
||||||
"Cannot unlink the only login method. Set a password first.",
|
|
||||||
)
|
|
||||||
|
|
||||||
await db.delete(account)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Avatar update ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _UpdateAvatarRequest(BaseModel):
|
|
||||||
avatar_url: str = Field(min_length=1)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/me/avatar", response_model=UserProfile)
|
|
||||||
async def update_avatar(
|
|
||||||
body: _UpdateAvatarRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> UserProfile:
|
|
||||||
"""Update the authenticated user's avatar URL.
|
|
||||||
|
|
||||||
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
|
||||||
to its own storage and passes the resulting URL here.
|
|
||||||
"""
|
|
||||||
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
|
||||||
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
|
||||||
|
|
||||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
||||||
user = result.scalar_one()
|
|
||||||
user.avatar_url = body.avatar_url
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return await _build_profile(current_user.id, current_user.email, db)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Account deletion ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/me", status_code=status.HTTP_200_OK)
|
|
||||||
async def delete_account(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Permanently delete the authenticated user's account.
|
|
||||||
|
|
||||||
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
|
||||||
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
|
||||||
is cancelled if active.
|
|
||||||
"""
|
|
||||||
# Cancel Stripe subscription if present.
|
|
||||||
try:
|
|
||||||
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
|
||||||
await stripe_service.cancel_subscription(current_user.id, db)
|
|
||||||
except HTTPException:
|
|
||||||
pass # No subscription — that's fine
|
|
||||||
|
|
||||||
# Delete all memory rows (core, associative, episodic, proactive).
|
|
||||||
try:
|
|
||||||
from app.models import ( # noqa: PLC0415
|
|
||||||
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
|
||||||
)
|
|
||||||
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
|
||||||
await db.execute(
|
|
||||||
model.__table__.delete().where(model.user_id == current_user.id)
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # Non-critical — cascade on User will handle most
|
|
||||||
|
|
||||||
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
|
||||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
||||||
user = result.scalar_one()
|
|
||||||
await db.delete(user)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|||||||
171
app/api/routes/backup.py
Normal file
171
app/api/routes/backup.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
|
||||||
|
PostgreSQL ``backup_metadata`` table.
|
||||||
|
|
||||||
|
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
||||||
|
treating "history" as a ``{backup_id}`` path parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from email.utils import parsedate_to_datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import BackupMetadata as BackupMetadataModel
|
||||||
|
from app.schemas import BackupMetadata, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/backup", tags=["backup"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
|
||||||
|
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return total backup bytes stored by *user_id*."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
|
||||||
|
BackupMetadataModel.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_backup_quota(
|
||||||
|
user: UserProfile, size_bytes: int, db: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||||
|
current = await _current_backup_bytes(user.id, db)
|
||||||
|
tier_manager.enforce_backup_quota(
|
||||||
|
user.tier, current_bytes=current, additional_bytes=size_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("")
|
||||||
|
async def upload_backup(
|
||||||
|
request: Request,
|
||||||
|
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
||||||
|
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
||||||
|
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Upload an E2E-encrypted backup blob.
|
||||||
|
|
||||||
|
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
||||||
|
"""
|
||||||
|
blob = await request.body()
|
||||||
|
reject_if_tampered(blob, x_backup_checksum)
|
||||||
|
await _check_backup_quota(current_user, len(blob), db)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
row = BackupMetadataModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=current_user.id,
|
||||||
|
s3_key=s3_key,
|
||||||
|
version=x_backup_version,
|
||||||
|
timestamp=x_backup_timestamp,
|
||||||
|
checksum=x_backup_checksum,
|
||||||
|
size_bytes=len(blob),
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history", response_model=list[BackupMetadata])
|
||||||
|
async def backup_history(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[BackupMetadata]:
|
||||||
|
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
return [
|
||||||
|
BackupMetadata(
|
||||||
|
version=r.version,
|
||||||
|
timestamp=r.timestamp,
|
||||||
|
checksum=r.checksum,
|
||||||
|
chunk_count=1,
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def download_backup(
|
||||||
|
request: Request,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> Response:
|
||||||
|
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
latest = result.scalar_one_or_none()
|
||||||
|
if latest is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
||||||
|
|
||||||
|
ims_header = request.headers.get("If-Modified-Since")
|
||||||
|
if ims_header:
|
||||||
|
try:
|
||||||
|
ims_dt = parsedate_to_datetime(ims_header)
|
||||||
|
ims_ms = int(ims_dt.timestamp() * 1000)
|
||||||
|
if latest.timestamp <= ims_ms:
|
||||||
|
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
||||||
|
except Exception:
|
||||||
|
pass # malformed header — ignore and serve the blob
|
||||||
|
|
||||||
|
blob = await _blob_store.download(current_user.id, latest.s3_key)
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={
|
||||||
|
"X-Backup-Version": str(latest.version),
|
||||||
|
"X-Backup-Timestamp": str(latest.timestamp),
|
||||||
|
"X-Checksum": latest.checksum,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{backup_id}", response_model=dict)
|
||||||
|
async def delete_backup(
|
||||||
|
backup_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a specific backup by ID."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel).where(
|
||||||
|
BackupMetadataModel.id == backup_id,
|
||||||
|
BackupMetadataModel.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
target = result.scalar_one_or_none()
|
||||||
|
if target is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
||||||
|
|
||||||
|
await _blob_store.delete(current_user.id, target.s3_key)
|
||||||
|
await db.delete(target)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
@@ -9,7 +9,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
|
from fastapi import APIRouter, Depends, Header, Request, status
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
@@ -83,50 +83,3 @@ async def cancel_subscription(
|
|||||||
"""Cancel the active subscription."""
|
"""Cancel the active subscription."""
|
||||||
await stripe_service.cancel_subscription(current_user.id, db)
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
@router.get("/invoices", response_model=list[dict])
|
|
||||||
async def list_invoices(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Return billing history (invoices) from Stripe.
|
|
||||||
|
|
||||||
Returns an empty list when Stripe is not configured.
|
|
||||||
"""
|
|
||||||
invoices = await stripe_service.list_invoices(current_user.id, db)
|
|
||||||
return invoices
|
|
||||||
|
|
||||||
|
|
||||||
# ── Quota check ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaCheckRequest(BaseModel):
|
|
||||||
feature: str
|
|
||||||
estimated_files: int
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/quota/check")
|
|
||||||
async def quota_check(
|
|
||||||
payload: QuotaCheckRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict:
|
|
||||||
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
|
|
||||||
if payload.feature != "folder_index":
|
|
||||||
raise HTTPException(status_code=400, detail="Unknown feature")
|
|
||||||
try:
|
|
||||||
await check_folder_quota(
|
|
||||||
user_id=current_user.id,
|
|
||||||
tier=current_user.tier,
|
|
||||||
estimated_files=payload.estimated_files,
|
|
||||||
db=db,
|
|
||||||
)
|
|
||||||
except QuotaExceeded as exc:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=402,
|
|
||||||
detail={"reason": exc.reason, "message": str(exc)},
|
|
||||||
)
|
|
||||||
return {"ok": True}
|
|
||||||
|
|||||||
@@ -1,116 +1,42 @@
|
|||||||
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
|
"""Chat routes: POST /chat (REST fallback).
|
||||||
|
|
||||||
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import uuid
|
from fastapi import APIRouter, Depends
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException
|
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
|
||||||
from app.core.deep_agent import run_home
|
from app.core.deep_agent import run_home
|
||||||
from app.core.llm import embed
|
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.schemas import ChatRequest, ChatResponse, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
# ── Embed helpers ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _EmbedRequest(BaseModel):
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class _EmbedResponse(BaseModel):
|
|
||||||
vector: list[float]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Endpoints ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def chat(
|
async def chat(
|
||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
) -> JSONResponse:
|
) -> JSONResponse:
|
||||||
"""REST fallback for home chat when websocket streaming is unavailable."""
|
"""Route a chat message through the Home deep agent (non-streaming)."""
|
||||||
response = await run_home(
|
|
||||||
user_id=current_user.id,
|
|
||||||
message=body.message,
|
|
||||||
context=body.context.model_dump(),
|
|
||||||
)
|
|
||||||
return JSONResponse(content={"response": response})
|
|
||||||
|
|
||||||
|
|
||||||
class _BriefRequest(BaseModel):
|
|
||||||
mode: Literal["home", "project"]
|
|
||||||
project_id: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class _BriefResponse(BaseModel):
|
|
||||||
response: str
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/brief", response_model=_BriefResponse)
|
|
||||||
async def brief(
|
|
||||||
body: _BriefRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> _BriefResponse:
|
|
||||||
"""REST fallback for brief when the device WebSocket is not ready."""
|
|
||||||
if body.mode == "project":
|
|
||||||
if not body.project_id:
|
|
||||||
raise HTTPException(status_code=422, detail="project_id required for project mode")
|
|
||||||
try:
|
|
||||||
uuid.UUID(body.project_id)
|
|
||||||
except ValueError:
|
|
||||||
raise HTTPException(status_code=422, detail="project_id must be a valid UUID")
|
|
||||||
|
|
||||||
request_id = str(uuid.uuid4())
|
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(
|
memory_context = await memory.enrich_context(current_user.id, body.message)
|
||||||
current_user.id,
|
|
||||||
"",
|
|
||||||
trace_id=request_id,
|
|
||||||
session_id=request_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
context = {
|
||||||
"_debug": {"request_id": request_id, "user_id": current_user.id},
|
**body.context.model_dump(),
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
chunks: list[str] = []
|
response_text = await run_home(
|
||||||
if body.mode == "project":
|
user_id=current_user.id,
|
||||||
stream = run_project_brief(current_user.id, body.project_id, context) # type: ignore[arg-type]
|
message=body.message,
|
||||||
else:
|
context=context,
|
||||||
stream = run_home_brief(current_user.id, context)
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
async for event_type, data in stream:
|
result = ChatResponse(response=response_text)
|
||||||
if event_type == "token" and data:
|
return JSONResponse(content=result.model_dump())
|
||||||
chunks.append(str(data))
|
|
||||||
|
|
||||||
return _BriefResponse(response="".join(chunks))
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/embed", response_model=_EmbedResponse)
|
|
||||||
async def embed_text(
|
|
||||||
body: _EmbedRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> _EmbedResponse:
|
|
||||||
"""Generate a 1536-dim embedding vector for the given text.
|
|
||||||
|
|
||||||
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
|
||||||
Used by Electron (vectordb.ts) for local note search.
|
|
||||||
"""
|
|
||||||
vector = await embed(body.text)
|
|
||||||
return _EmbedResponse(vector=vector)
|
|
||||||
|
|||||||
@@ -14,11 +14,11 @@ Protocol:
|
|||||||
4. Session enters message dispatch loop + heartbeat.
|
4. Session enters message dispatch loop + heartbeat.
|
||||||
|
|
||||||
Incoming frame dispatch:
|
Incoming frame dispatch:
|
||||||
- ``tool_result`` → resolves a pending tool-call Future.
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
- ``journey_start`` → starts a guided setup journey session.
|
- ``agent_data`` → enqueued in the per-run agent data queue.
|
||||||
- ``journey_message`` → continues a journey conversation.
|
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
||||||
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
- unknown types → logged, ignored.
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
||||||
|
|
||||||
@@ -39,28 +39,21 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|||||||
from jose import JWTError, jwt
|
from jose import JWTError, jwt
|
||||||
from sqlalchemy import update
|
from sqlalchemy import update
|
||||||
|
|
||||||
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.core.agent_runner import trigger_pending_runs
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
|
||||||
from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream
|
|
||||||
from app.core.output_formatter import extract_canvas_block
|
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.output_formatter import StreamFormatter
|
from app.core.deep_agent import run_home_stream, run_floating_stream
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
from app.core.ws_context import clear_client_executor, set_client_executor
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
from app.schemas import WsFrameType, WsStreamEnd
|
from app.schemas import WsFrameType
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||||
|
|
||||||
# ── v7 folder index session state ─────────────────────────────────────
|
|
||||||
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
|
|
||||||
_index_sessions: dict[str, dict] = {}
|
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||||
|
|
||||||
@@ -154,6 +147,37 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
"device_ws: tool_result missing id from user=%s", user_id
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_data:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
await queue.put(frame)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data for unknown run user=%s run=%s",
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_complete:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
# Sentinel: signals the agent data stream is finished.
|
||||||
|
await queue.put(None)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.home_request:
|
elif frame_type == WsFrameType.home_request:
|
||||||
asyncio.create_task(
|
asyncio.create_task(
|
||||||
_handle_home_request(websocket, user_id, frame)
|
_handle_home_request(websocket, user_id, frame)
|
||||||
@@ -164,39 +188,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_floating_request(websocket, user_id, frame)
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
elif frame_type == WsFrameType.brief_request:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_brief_request(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.task_brief_request:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_task_brief_request(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.journey_start:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_journey_start(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.journey_message:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_journey_message(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.index_session_start:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_index_session_start(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.index_file_batch:
|
|
||||||
asyncio.create_task(
|
|
||||||
_handle_index_file_batch(websocket, user_id, frame)
|
|
||||||
)
|
|
||||||
|
|
||||||
elif frame_type == WsFrameType.index_session_cancel:
|
|
||||||
await _handle_index_session_cancel(websocket, frame)
|
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
@@ -209,13 +200,35 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
|
|
||||||
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_WS_TOOL_CALL_TIMEOUT = 30 # seconds to wait for Electron tool_result
|
||||||
|
|
||||||
|
|
||||||
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
async def _executor(payload: dict) -> dict:
|
async def _executor(payload: dict) -> dict:
|
||||||
payload["type"] = WsFrameType.tool_call
|
payload["type"] = WsFrameType.tool_call
|
||||||
|
call_id = payload["id"]
|
||||||
|
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
|
||||||
await websocket.send_text(json.dumps(payload))
|
await websocket.send_text(json.dumps(payload))
|
||||||
future = device_manager.create_pending_call(user_id, payload["id"])
|
future = device_manager.create_pending_call(user_id, call_id)
|
||||||
return await future
|
try:
|
||||||
|
result = await asyncio.wait_for(future, timeout=_WS_TOOL_CALL_TIMEOUT)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
logger.error(
|
||||||
|
"ws_executor: timeout waiting for tool_result id=%s action=%s user=%s",
|
||||||
|
call_id, payload.get("action"), user_id,
|
||||||
|
)
|
||||||
|
# Clean up the pending future so it doesn't leak
|
||||||
|
conn = device_manager._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.pending_calls.pop(call_id, None)
|
||||||
|
return {"error": f"Tool call timed out after {_WS_TOOL_CALL_TIMEOUT}s", "rows": []}
|
||||||
|
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
|
||||||
|
call_id, type(result).__name__,
|
||||||
|
list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
|
if result is None:
|
||||||
|
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
|
||||||
|
return result
|
||||||
return _executor
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
@@ -228,30 +241,14 @@ async def _handle_home_request(
|
|||||||
request_id = frame.get("request_id") or str(uuid4())
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
|
||||||
logger.info(
|
|
||||||
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
project_id,
|
|
||||||
message[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
user_id,
|
|
||||||
message,
|
|
||||||
trace_id=request_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
"format_prefs": frame.get("format_prefs"),
|
|
||||||
**memory_context,
|
**memory_context,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -259,11 +256,12 @@ async def _handle_home_request(
|
|||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
|
event_stream = run_home_stream(
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
user_id, message, context, db_session_factory=async_session
|
||||||
|
)
|
||||||
|
formatter = HomeFormatter(request_id=request_id)
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
# Collect text chunks to build the full response for episode storage
|
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -278,15 +276,8 @@ async def _handle_home_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
len("".join(response_chunks)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_floating_request(
|
async def _handle_floating_request(
|
||||||
@@ -299,39 +290,23 @@ async def _handle_floating_request(
|
|||||||
message: str = frame.get("message", "")
|
message: str = frame.get("message", "")
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
scope: dict = frame.get("scope", {})
|
scope: dict = frame.get("scope", {})
|
||||||
logger.info(
|
|
||||||
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
json.dumps(scope, ensure_ascii=True)[:200],
|
|
||||||
message[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Memory: enrich context before LLM call ────────────────────────
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
memory_context = await memory.enrich_context(
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
user_id,
|
|
||||||
message,
|
|
||||||
trace_id=request_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
context: dict = {"scope": scope, **memory_context}
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
|
||||||
"scope": scope,
|
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
"format_prefs": frame.get("format_prefs"),
|
|
||||||
**memory_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
set_client_executor(executor)
|
set_client_executor(executor)
|
||||||
response_chunks: list[str] = []
|
response_chunks: list[str] = []
|
||||||
try:
|
try:
|
||||||
event_stream = run_floating_stream(user_id, message, context)
|
event_stream = run_floating_stream(
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
user_id, message, context, scope=scope,
|
||||||
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
|
formatter = FloatingFormatter(request_id=request_id)
|
||||||
async for ws_frame in formatter.format(event_stream):
|
async for ws_frame in formatter.format(event_stream):
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
@@ -348,413 +323,8 @@ async def _handle_floating_request(
|
|||||||
async with async_session() as db:
|
async with async_session() as db:
|
||||||
memory = MemoryMiddleware(db)
|
memory = MemoryMiddleware(db)
|
||||||
await memory.store_episode(
|
await memory.store_episode(
|
||||||
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
)
|
)
|
||||||
logger.info(
|
|
||||||
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
|
|
||||||
user_id,
|
|
||||||
request_id,
|
|
||||||
session_id,
|
|
||||||
len("".join(response_chunks)),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_brief_request(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Handle a brief_request frame — streams plain-text brief back on the socket.
|
|
||||||
|
|
||||||
No episode storage — briefs are not conversations.
|
|
||||||
"""
|
|
||||||
import uuid as _uuid
|
|
||||||
|
|
||||||
request_id = frame.get("request_id") or str(uuid4())
|
|
||||||
session_id = frame.get("session_id") or str(uuid4())
|
|
||||||
mode: str = frame.get("mode", "home")
|
|
||||||
project_id: str | None = frame.get("project_id")
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
|
|
||||||
user_id, request_id, mode, project_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate project_id for project mode before touching LLM.
|
|
||||||
if mode == "project":
|
|
||||||
try:
|
|
||||||
if not project_id:
|
|
||||||
raise ValueError("project_id required for project mode")
|
|
||||||
_uuid.UUID(project_id)
|
|
||||||
except (ValueError, AttributeError) as exc:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
|
|
||||||
user_id, request_id, exc,
|
|
||||||
)
|
|
||||||
await websocket.send_text(
|
|
||||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Enrich context with memory (no user message — use empty string as probe).
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
memory_context = await memory.enrich_context(
|
|
||||||
user_id,
|
|
||||||
"",
|
|
||||||
trace_id=request_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
"format_prefs": frame.get("format_prefs"),
|
|
||||||
**memory_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
|
||||||
set_client_executor(executor)
|
|
||||||
try:
|
|
||||||
if mode == "project":
|
|
||||||
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
|
|
||||||
else:
|
|
||||||
event_stream = run_home_brief(user_id, context)
|
|
||||||
|
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(event_stream):
|
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(
|
|
||||||
"device_ws: brief_request failed user=%s req=%s: %s",
|
|
||||||
user_id, request_id, exc,
|
|
||||||
)
|
|
||||||
await websocket.send_text(
|
|
||||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"device_ws: brief_request_end user=%s req=%s mode=%s",
|
|
||||||
user_id, request_id, mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── v6 Task Brief Handler ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_task_brief_request(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
|
|
||||||
|
|
||||||
Streams the briefing markdown back to the client.
|
|
||||||
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
|
|
||||||
"""
|
|
||||||
request_id = frame.get("request_id") or str(uuid4())
|
|
||||||
session_id = frame.get("session_id") or str(uuid4())
|
|
||||||
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
|
|
||||||
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
|
|
||||||
user_id, request_id, task_id, project_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not task_id:
|
|
||||||
await websocket.send_text(
|
|
||||||
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
memory_context = await memory.enrich_context(
|
|
||||||
user_id,
|
|
||||||
f"task brief: {task_id}",
|
|
||||||
trace_id=request_id,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
"format_prefs": frame.get("format_prefs"),
|
|
||||||
**memory_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
|
||||||
set_client_executor(executor)
|
|
||||||
response_chunks: list[str] = []
|
|
||||||
|
|
||||||
try:
|
|
||||||
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
|
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(event_stream):
|
|
||||||
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
|
||||||
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
|
||||||
elif ws_frame.type == "stream_start":
|
|
||||||
await websocket.send_text(ws_frame.model_dump_json())
|
|
||||||
# stream_end is emitted below with mutations — skip formatter's version
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(
|
|
||||||
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
|
|
||||||
user_id, request_id, task_id, exc,
|
|
||||||
)
|
|
||||||
await websocket.send_text(
|
|
||||||
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
|
|
||||||
)
|
|
||||||
return
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
# Extract canvas block then emit stream_end with optional mutations.
|
|
||||||
full_response = "".join(response_chunks)
|
|
||||||
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
|
|
||||||
|
|
||||||
mutations: list[dict] = []
|
|
||||||
if canvas_content:
|
|
||||||
mutations.append({
|
|
||||||
"type": "canvas_draft",
|
|
||||||
"content": canvas_content,
|
|
||||||
"kind": canvas_kind,
|
|
||||||
})
|
|
||||||
|
|
||||||
await websocket.send_text(
|
|
||||||
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
|
|
||||||
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── v4 Journey Handlers ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_journey_start(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Handle a journey_start frame — explores directory and sends first question."""
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
|
||||||
set_client_executor(executor)
|
|
||||||
try:
|
|
||||||
reply = await handle_journey_start(user_id, frame)
|
|
||||||
await websocket.send_text(json.dumps(reply))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error(
|
|
||||||
"device_ws: journey_start failed user=%s: %s", user_id, exc
|
|
||||||
)
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": frame.get("session_id", ""),
|
|
||||||
"message": f"Failed to start journey: {exc}",
|
|
||||||
"done": True,
|
|
||||||
"prompt_template": None,
|
|
||||||
}))
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_journey_message(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Handle a journey_message frame — continues the journey conversation."""
|
|
||||||
executor = await _make_ws_executor(websocket, user_id)
|
|
||||||
set_client_executor(executor)
|
|
||||||
try:
|
|
||||||
reply = await handle_journey_message(user_id, frame)
|
|
||||||
await websocket.send_text(json.dumps(reply))
|
|
||||||
except Exception as exc:
|
|
||||||
session_id = frame.get("session_id", "")
|
|
||||||
logger.error(
|
|
||||||
"device_ws: journey_message failed user=%s session=%s: %s",
|
|
||||||
user_id, session_id, exc,
|
|
||||||
)
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": f"Journey error: {exc}",
|
|
||||||
"done": True,
|
|
||||||
"prompt_template": None,
|
|
||||||
}))
|
|
||||||
finally:
|
|
||||||
clear_client_executor()
|
|
||||||
|
|
||||||
|
|
||||||
# ── v7 Folder Index Handlers ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_index_session_start(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Register a new folder index session. No response sent — client is declaring intent."""
|
|
||||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
|
||||||
project_id: str | None = frame.get("projectId") or frame.get("project_id")
|
|
||||||
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
|
|
||||||
|
|
||||||
if not session_id:
|
|
||||||
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
_index_sessions[session_id] = {
|
|
||||||
"user_id": user_id,
|
|
||||||
"project_id": project_id,
|
|
||||||
"processed": 0,
|
|
||||||
"total": total,
|
|
||||||
"cancelled": False,
|
|
||||||
}
|
|
||||||
logger.info(
|
|
||||||
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
|
|
||||||
user_id, session_id, project_id, total,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_index_session_cancel(
|
|
||||||
websocket: WebSocket,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
|
|
||||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
|
||||||
session = _index_sessions.get(session_id)
|
|
||||||
if session:
|
|
||||||
session["cancelled"] = True
|
|
||||||
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": WsFrameType.index_session_done,
|
|
||||||
"sessionId": session_id,
|
|
||||||
"status": "cancelled",
|
|
||||||
}))
|
|
||||||
_index_sessions.pop(session_id, None)
|
|
||||||
logger.info("device_ws: index_session_cancel session=%s", session_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_index_file_batch(
|
|
||||||
websocket: WebSocket,
|
|
||||||
user_id: str,
|
|
||||||
frame: dict,
|
|
||||||
) -> None:
|
|
||||||
"""Process a batch of files for an index session, streaming results back."""
|
|
||||||
# Lazy imports to avoid heavy load at module startup.
|
|
||||||
from app.core.folder_indexer import ( # noqa: PLC0415
|
|
||||||
summarize_image,
|
|
||||||
summarize_pdf,
|
|
||||||
summarize_docx,
|
|
||||||
summarize_text,
|
|
||||||
)
|
|
||||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
||||||
from app.billing.quota import add_token_usage # noqa: PLC0415
|
|
||||||
|
|
||||||
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
|
|
||||||
files: list[dict] = frame.get("files", [])
|
|
||||||
|
|
||||||
session = _index_sessions.get(session_id)
|
|
||||||
if not session or session.get("cancelled"):
|
|
||||||
return
|
|
||||||
|
|
||||||
async with async_session() as db:
|
|
||||||
tier = await tier_manager.get_tier(user_id, db)
|
|
||||||
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
|
||||||
cap: int | None = None if raw_cap == -1 else raw_cap
|
|
||||||
|
|
||||||
for file_info in files:
|
|
||||||
if session.get("cancelled"):
|
|
||||||
return
|
|
||||||
|
|
||||||
# Electron's toSnakeCase converts payload keys, so accept both forms.
|
|
||||||
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
|
|
||||||
kind: str = file_info.get("kind") or "text"
|
|
||||||
content: str = file_info.get("content") or ""
|
|
||||||
ext: str = file_info.get("ext") or ""
|
|
||||||
mime: str = file_info.get("mime") or "application/octet-stream"
|
|
||||||
name: str = rel_path.split("/")[-1] or rel_path
|
|
||||||
|
|
||||||
try:
|
|
||||||
if kind == "image":
|
|
||||||
res = await summarize_image(image_b64=content, mime=mime)
|
|
||||||
elif kind == "pdf":
|
|
||||||
res = await summarize_pdf(pdf_b64=content, name=name)
|
|
||||||
elif kind == "docx":
|
|
||||||
res = await summarize_docx(docx_b64=content, name=name)
|
|
||||||
else:
|
|
||||||
res = await summarize_text(content=content, ext=ext, name=name)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
|
|
||||||
session_id, rel_path, exc,
|
|
||||||
)
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": WsFrameType.index_file_result,
|
|
||||||
"sessionId": session_id,
|
|
||||||
"relPath": rel_path,
|
|
||||||
"summary": None,
|
|
||||||
"tokensUsed": 0,
|
|
||||||
"error": str(exc),
|
|
||||||
}))
|
|
||||||
session["processed"] += 1
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Account for token usage and check cap.
|
|
||||||
usage = await add_token_usage(
|
|
||||||
user_id=user_id,
|
|
||||||
feature="folder_index",
|
|
||||||
tokens=res.tokens_used,
|
|
||||||
db=db,
|
|
||||||
cap=cap,
|
|
||||||
)
|
|
||||||
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": WsFrameType.index_file_result,
|
|
||||||
"sessionId": session_id,
|
|
||||||
"relPath": rel_path,
|
|
||||||
"summary": res.summary,
|
|
||||||
"tokensUsed": res.tokens_used,
|
|
||||||
}))
|
|
||||||
session["processed"] += 1
|
|
||||||
|
|
||||||
if usage.exhausted:
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": WsFrameType.index_session_done,
|
|
||||||
"sessionId": session_id,
|
|
||||||
"status": "quota_exceeded",
|
|
||||||
}))
|
|
||||||
_index_sessions.pop(session_id, None)
|
|
||||||
logger.info(
|
|
||||||
"device_ws: index_session quota_exceeded user=%s session=%s",
|
|
||||||
user_id, session_id,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# After processing the batch, emit progress.
|
|
||||||
processed = session["processed"]
|
|
||||||
total = session["total"]
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": WsFrameType.index_session_progress,
|
|
||||||
"sessionId": session_id,
|
|
||||||
"processed": processed,
|
|
||||||
"total": total,
|
|
||||||
}))
|
|
||||||
|
|
||||||
if processed >= total:
|
|
||||||
await websocket.send_text(json.dumps({
|
|
||||||
"type": WsFrameType.index_session_done,
|
|
||||||
"sessionId": session_id,
|
|
||||||
"status": "completed",
|
|
||||||
}))
|
|
||||||
_index_sessions.pop(session_id, None)
|
|
||||||
logger.info(
|
|
||||||
"device_ws: index_session_done completed user=%s session=%s processed=%d",
|
|
||||||
user_id, session_id, processed,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ─────────────────────────────────────────────────────────
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
@@ -790,3 +360,6 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
|||||||
user_id,
|
user_id,
|
||||||
exc,
|
exc,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -1,225 +0,0 @@
|
|||||||
"""Memory management routes — view/edit/delete user memory tiers.
|
|
||||||
|
|
||||||
All routes require authentication. Data is always user-scoped.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Annotated
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, status
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy import delete, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
|
||||||
from app.db import get_session
|
|
||||||
from app.models import (
|
|
||||||
ExtractionQueue,
|
|
||||||
MemoryAssociative,
|
|
||||||
MemoryCore,
|
|
||||||
MemoryEpisodic,
|
|
||||||
MemoryProactive,
|
|
||||||
MemoryRelation,
|
|
||||||
)
|
|
||||||
from app.schemas import UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/memory", tags=["memory"])
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_ALLOWED_PREDICATES = {
|
|
||||||
"works_at",
|
|
||||||
"reports_to",
|
|
||||||
"stakeholder_of",
|
|
||||||
"last_contacted_on",
|
|
||||||
"owes_followup",
|
|
||||||
"manages",
|
|
||||||
"collaborates_with",
|
|
||||||
"owns",
|
|
||||||
"member_of",
|
|
||||||
"custom",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Response schemas ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class RelationOut(BaseModel):
|
|
||||||
id: str
|
|
||||||
subject_label: str
|
|
||||||
subject_type: str
|
|
||||||
predicate: str
|
|
||||||
object_label: str
|
|
||||||
object_type: str
|
|
||||||
confidence: float
|
|
||||||
last_confirmed_at: int | None = None # epoch ms
|
|
||||||
|
|
||||||
|
|
||||||
class RelationPatch(BaseModel):
|
|
||||||
subject_label: str | None = None
|
|
||||||
object_label: str | None = None
|
|
||||||
predicate: str | None = None
|
|
||||||
confidence: float | None = Field(None, ge=0.0, le=1.0)
|
|
||||||
|
|
||||||
|
|
||||||
class CoreAddBody(BaseModel):
|
|
||||||
key: str = Field(..., min_length=1, max_length=255)
|
|
||||||
value: str = Field(..., min_length=1)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _relation_to_out(row: MemoryRelation) -> RelationOut:
|
|
||||||
last_ms: int | None = None
|
|
||||||
if row.last_confirmed_at is not None:
|
|
||||||
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
|
|
||||||
return RelationOut(
|
|
||||||
id=row.id,
|
|
||||||
subject_label=row.subject_label,
|
|
||||||
subject_type=row.subject_type,
|
|
||||||
predicate=row.predicate,
|
|
||||||
object_label=row.object_label,
|
|
||||||
object_type=row.object_type,
|
|
||||||
confidence=row.confidence,
|
|
||||||
last_confirmed_at=last_ms,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ───────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/core", response_model=dict[str, str])
|
|
||||||
async def get_core_memory(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Return all core memory k/v pairs (plaintext) for the current user."""
|
|
||||||
mw = MemoryMiddleware(db)
|
|
||||||
blocks = await mw.list_core_blocks(current_user.id)
|
|
||||||
return {b["label"]: b["value"] for b in blocks}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
|
|
||||||
async def delete_core_key(
|
|
||||||
key: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> None:
|
|
||||||
"""Delete a single core memory key (GDPR Art. 17)."""
|
|
||||||
mw = MemoryMiddleware(db)
|
|
||||||
deleted = await mw.delete_core(current_user.id, key)
|
|
||||||
if not deleted:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
|
|
||||||
async def add_core_key(
|
|
||||||
body: CoreAddBody,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Add or overwrite a core memory key/value pair."""
|
|
||||||
mw = MemoryMiddleware(db)
|
|
||||||
await mw.update_core(current_user.id, body.key, body.value)
|
|
||||||
return {body.key: body.value}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/relational", response_model=list[RelationOut])
|
|
||||||
async def get_relational_memory(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[RelationOut]:
|
|
||||||
"""Return all relational memory rows for the current user."""
|
|
||||||
mw = MemoryMiddleware(db)
|
|
||||||
rows = await mw.query_relations(current_user.id, limit=200)
|
|
||||||
return [_relation_to_out(r) for r in rows]
|
|
||||||
|
|
||||||
|
|
||||||
@router.patch("/relational/{relation_id}", response_model=RelationOut)
|
|
||||||
async def patch_relation(
|
|
||||||
relation_id: str,
|
|
||||||
body: RelationPatch,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> RelationOut:
|
|
||||||
"""Edit a relation row's labels, predicate, or confidence."""
|
|
||||||
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
|
||||||
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
|
|
||||||
)
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(MemoryRelation).where(
|
|
||||||
MemoryRelation.id == relation_id,
|
|
||||||
MemoryRelation.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
|
||||||
|
|
||||||
if body.subject_label is not None:
|
|
||||||
row.subject_label = body.subject_label
|
|
||||||
if body.object_label is not None:
|
|
||||||
row.object_label = body.object_label
|
|
||||||
if body.predicate is not None:
|
|
||||||
row.predicate = body.predicate
|
|
||||||
if body.confidence is not None:
|
|
||||||
row.confidence = body.confidence
|
|
||||||
row.last_confirmed_at = datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(row)
|
|
||||||
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
|
|
||||||
return _relation_to_out(row)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
|
|
||||||
async def delete_relation(
|
|
||||||
relation_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> None:
|
|
||||||
"""Hard-delete a relation row (GDPR Art. 17)."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(MemoryRelation).where(
|
|
||||||
MemoryRelation.id == relation_id,
|
|
||||||
MemoryRelation.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
|
|
||||||
await db.delete(row)
|
|
||||||
await db.commit()
|
|
||||||
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
|
|
||||||
async def forget_all(
|
|
||||||
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> None:
|
|
||||||
"""Wipe all memory tiers for the current user (GDPR Art. 17).
|
|
||||||
|
|
||||||
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
|
|
||||||
"""
|
|
||||||
if x_confirm != "true":
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
|
||||||
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
|
|
||||||
)
|
|
||||||
|
|
||||||
uid = current_user.id
|
|
||||||
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
|
|
||||||
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
|
|
||||||
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
|
|
||||||
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
|
|
||||||
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
|
|
||||||
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
|
|
||||||
await db.commit()
|
|
||||||
logger.warning("memory: forget_all GDPR wipe user=%s", uid)
|
|
||||||
148
app/api/routes/plugins.py
Normal file
148
app/api/routes/plugins.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""Plugins routes: browse and install plugins from the marketplace.
|
||||||
|
|
||||||
|
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
|
||||||
|
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.db import get_session
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
from app.models import PluginInstallation, PluginReview as PluginReviewModel
|
||||||
|
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier gate ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _require_plugin_tier(user: UserProfile) -> None:
|
||||||
|
"""Raise HTTP 403 for users below Power tier."""
|
||||||
|
if user.tier not in ("power", "team"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Plugin marketplace requires Power tier or above",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local detail schema ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _PluginDetail(BaseModel):
|
||||||
|
plugin: PluginManifest
|
||||||
|
install_count: int
|
||||||
|
ratings: list[Any]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("", response_model=PluginListResponse)
|
||||||
|
async def list_plugins(
|
||||||
|
category: str | None = Query(default=None),
|
||||||
|
q: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||||
|
async def get_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _PluginDetail:
|
||||||
|
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Fetch review ratings for this plugin
|
||||||
|
review_result = await db.execute(
|
||||||
|
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
|
||||||
|
)
|
||||||
|
reviews = review_result.scalars().all()
|
||||||
|
ratings = [
|
||||||
|
{
|
||||||
|
"reviewer_id": r.reviewer_id,
|
||||||
|
"decision": r.decision,
|
||||||
|
"notes": r.notes,
|
||||||
|
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
|
||||||
|
}
|
||||||
|
for r in reviews
|
||||||
|
]
|
||||||
|
|
||||||
|
return _PluginDetail(
|
||||||
|
plugin=entry["manifest"],
|
||||||
|
install_count=entry["install_count"],
|
||||||
|
ratings=ratings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def install_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
||||||
|
|
||||||
|
Requires Power tier or above.
|
||||||
|
"""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Record the installation in plugin_installations
|
||||||
|
installation = PluginInstallation(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
)
|
||||||
|
db.add(installation)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
await revenue_share.record_install(
|
||||||
|
db,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
amount_cents=entry["manifest"].price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
||||||
|
return {"ok": True, "download_url": download_url}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def uninstall_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Unregister a plugin installation."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(PluginInstallation).where(
|
||||||
|
PluginInstallation.plugin_id == plugin_id,
|
||||||
|
PluginInstallation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
installation = result.scalar_one_or_none()
|
||||||
|
if installation is not None:
|
||||||
|
await db.delete(installation)
|
||||||
|
await db.commit()
|
||||||
|
await registry.record_uninstall(db, plugin_id)
|
||||||
|
return {"ok": True}
|
||||||
195
app/api/routes/storage.py
Normal file
195
app/api/routes/storage.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
|
||||||
|
PostgreSQL ``storage_records`` table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import StorageRecord
|
||||||
|
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["storage"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local response schemas ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CreateResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordMeta(BaseModel):
|
||||||
|
id: str
|
||||||
|
table: str
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return total bytes stored by *user_id*."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
|
||||||
|
StorageRecord.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
|
||||||
|
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
|
||||||
|
current = await _current_usage_bytes(user.id, db)
|
||||||
|
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_record_for_user(
|
||||||
|
record_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> StorageRecord:
|
||||||
|
"""Look up a record and verify ownership. Returns 404 on mismatch
|
||||||
|
to prevent user enumeration attacks."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(StorageRecord).where(
|
||||||
|
StorageRecord.id == record_id, StorageRecord.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_record(
|
||||||
|
body: StorageRecordCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _CreateResponse:
|
||||||
|
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
await _check_quota(current_user, len(body.blob), db)
|
||||||
|
|
||||||
|
record_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, body.table, record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
record = StorageRecord(
|
||||||
|
id=record_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
table_name=body.table,
|
||||||
|
s3_key=s3_key,
|
||||||
|
checksum=body.checksum,
|
||||||
|
size_bytes=len(body.blob),
|
||||||
|
)
|
||||||
|
db.add(record)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(record)
|
||||||
|
|
||||||
|
created_at_ms = int(record.created_at.timestamp() * 1000)
|
||||||
|
return _CreateResponse(id=record_id, created_at=created_at_ms)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records", response_model=list[_RecordMeta])
|
||||||
|
async def list_records(
|
||||||
|
table: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[_RecordMeta]:
|
||||||
|
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
||||||
|
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
|
||||||
|
if table is not None:
|
||||||
|
query = query.where(StorageRecord.table_name == table)
|
||||||
|
query = query.offset((page - 1) * limit).limit(limit)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
_RecordMeta(
|
||||||
|
id=r.id,
|
||||||
|
table=r.table_name,
|
||||||
|
checksum=r.checksum,
|
||||||
|
created_at=int(r.created_at.timestamp() * 1000),
|
||||||
|
updated_at=int(r.updated_at.timestamp() * 1000),
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records/{record_id}")
|
||||||
|
async def download_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> Response:
|
||||||
|
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
blob = await _blob_store.download(current_user.id, record.s3_key)
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"X-Checksum": record.checksum},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/records/{record_id}", response_model=dict)
|
||||||
|
async def update_record(
|
||||||
|
record_id: str,
|
||||||
|
body: StorageRecordUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
|
||||||
|
delta = len(body.blob) - record.size_bytes
|
||||||
|
if delta > 0:
|
||||||
|
await _check_quota(current_user, delta, db)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, record.table_name, record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
record.s3_key = s3_key
|
||||||
|
record.checksum = body.checksum
|
||||||
|
record.size_bytes = len(body.blob)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/records/{record_id}", response_model=dict)
|
||||||
|
async def delete_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a record and its S3 blob."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
await _blob_store.delete(current_user.id, record.s3_key)
|
||||||
|
await db.delete(record)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
79
app/api/routes/vectors.py
Normal file
79
app/api/routes/vectors.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.llm import embed
|
||||||
|
from app.schemas import (
|
||||||
|
UserProfile,
|
||||||
|
VectorSearchRequest,
|
||||||
|
VectorSearchResponse,
|
||||||
|
VectorUpsertRequest,
|
||||||
|
)
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
from app.storage.vector_store import VectorStore
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["vectors"])
|
||||||
|
|
||||||
|
_vector_store = VectorStore()
|
||||||
|
|
||||||
|
|
||||||
|
class _VectorDeleteRequest(BaseModel):
|
||||||
|
ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedResponse(BaseModel):
|
||||||
|
vector: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/upsert", response_model=dict)
|
||||||
|
async def upsert_vectors(
|
||||||
|
body: VectorUpsertRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
||||||
|
for item in body.vectors:
|
||||||
|
reject_if_tampered(item.blob, item.checksum)
|
||||||
|
await _vector_store.upsert(current_user.id, body.vectors)
|
||||||
|
return {"upserted": len(body.vectors)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
||||||
|
async def search_vectors(
|
||||||
|
body: VectorSearchRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> VectorSearchResponse:
|
||||||
|
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
||||||
|
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
||||||
|
return VectorSearchResponse(results=results)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/vectors", response_model=dict)
|
||||||
|
async def delete_vectors(
|
||||||
|
body: _VectorDeleteRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete vectors by ID, scoped to the authenticated user."""
|
||||||
|
await _vector_store.delete(current_user.id, body.ids)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/embed", response_model=_EmbedResponse)
|
||||||
|
async def embed_text(
|
||||||
|
body: _EmbedRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _EmbedResponse:
|
||||||
|
"""Generate a 1536-dim embedding vector for the given text.
|
||||||
|
|
||||||
|
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||||
|
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
|
||||||
|
"""
|
||||||
|
vector = await embed(body.text)
|
||||||
|
return _EmbedResponse(vector=vector)
|
||||||
@@ -1 +0,0 @@
|
|||||||
"OAuth provider abstractions and utilities."
|
|
||||||
@@ -1,135 +0,0 @@
|
|||||||
"""OAuth 2.0 + PKCE provider abstractions.
|
|
||||||
|
|
||||||
Each provider implements a three-step flow designed for a desktop (public) client:
|
|
||||||
|
|
||||||
1. get_authorization_url(state, code_challenge) → str
|
|
||||||
Build the provider's consent-screen URL. State and code_challenge are
|
|
||||||
generated server-side; the client opens this URL in the system browser.
|
|
||||||
|
|
||||||
2. exchange_code(code, code_verifier, redirect_uri) → dict
|
|
||||||
Exchange the short-lived authorization code for an access token.
|
|
||||||
The code_verifier proves ownership of the PKCE challenge.
|
|
||||||
|
|
||||||
3. get_userinfo(access_token) → OAuthUserInfo
|
|
||||||
Fetch the canonical user identity from the provider.
|
|
||||||
|
|
||||||
Currently supported providers:
|
|
||||||
- GoogleOAuthProvider (scope: openid email profile)
|
|
||||||
|
|
||||||
Adding a new provider:
|
|
||||||
- Implement the three methods above.
|
|
||||||
- Register in _PROVIDERS inside routes/auth.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
import os
|
|
||||||
import urllib.parse
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
import httpx
|
|
||||||
|
|
||||||
|
|
||||||
# ── Data transfer objects ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class OAuthUserInfo:
|
|
||||||
"""Normalized user identity returned by any provider."""
|
|
||||||
|
|
||||||
provider_user_id: str
|
|
||||||
email: str
|
|
||||||
email_verified: bool
|
|
||||||
avatar_url: str | None
|
|
||||||
name: str | None
|
|
||||||
|
|
||||||
|
|
||||||
# ── PKCE helpers ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def generate_pkce_pair() -> tuple[str, str]:
|
|
||||||
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
|
|
||||||
|
|
||||||
The code_verifier is a random 32-byte URL-safe base64 string.
|
|
||||||
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
|
|
||||||
"""
|
|
||||||
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
|
|
||||||
digest = hashlib.sha256(code_verifier.encode()).digest()
|
|
||||||
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
|
||||||
return code_verifier, code_challenge
|
|
||||||
|
|
||||||
|
|
||||||
# ── Google provider ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class GoogleOAuthProvider:
|
|
||||||
"""Google OAuth 2.0 provider (openid email profile scope).
|
|
||||||
|
|
||||||
Uses Google's standard authorization endpoint with PKCE S256.
|
|
||||||
Does NOT use google-auth-oauthlib to keep the flow generic and async.
|
|
||||||
"""
|
|
||||||
|
|
||||||
name = "google"
|
|
||||||
|
|
||||||
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
|
||||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
|
||||||
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
|
||||||
|
|
||||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
|
|
||||||
self.client_id = client_id
|
|
||||||
self.client_secret = client_secret
|
|
||||||
self.redirect_uri = redirect_uri
|
|
||||||
|
|
||||||
def get_authorization_url(self, state: str, code_challenge: str) -> str:
|
|
||||||
"""Build the Google consent-screen URL."""
|
|
||||||
params = {
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"redirect_uri": self.redirect_uri,
|
|
||||||
"response_type": "code",
|
|
||||||
"scope": "openid email profile",
|
|
||||||
"state": state,
|
|
||||||
"code_challenge": code_challenge,
|
|
||||||
"code_challenge_method": "S256",
|
|
||||||
"access_type": "offline",
|
|
||||||
"prompt": "select_account",
|
|
||||||
}
|
|
||||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
|
||||||
|
|
||||||
async def exchange_code(
|
|
||||||
self, code: str, code_verifier: str, redirect_uri: str
|
|
||||||
) -> dict:
|
|
||||||
"""Exchange authorization code for an access token."""
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.post(
|
|
||||||
self._TOKEN_URL,
|
|
||||||
data={
|
|
||||||
"client_id": self.client_id,
|
|
||||||
"client_secret": self.client_secret,
|
|
||||||
"code": code,
|
|
||||||
"code_verifier": code_verifier,
|
|
||||||
"grant_type": "authorization_code",
|
|
||||||
"redirect_uri": redirect_uri,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
return response.json()
|
|
||||||
|
|
||||||
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
|
|
||||||
"""Fetch the authenticated user's identity from Google."""
|
|
||||||
async with httpx.AsyncClient() as client:
|
|
||||||
response = await client.get(
|
|
||||||
self._USERINFO_URL,
|
|
||||||
headers={"Authorization": f"Bearer {access_token}"},
|
|
||||||
)
|
|
||||||
response.raise_for_status()
|
|
||||||
data = response.json()
|
|
||||||
|
|
||||||
return OAuthUserInfo(
|
|
||||||
provider_user_id=data["sub"],
|
|
||||||
email=data["email"],
|
|
||||||
email_verified=data.get("email_verified", False),
|
|
||||||
avatar_url=data.get("picture"),
|
|
||||||
name=data.get("name"),
|
|
||||||
)
|
|
||||||
@@ -1,139 +0,0 @@
|
|||||||
"""Quota checks and atomic token-usage accounting for folder integration."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from sqlalchemy import select, update
|
|
||||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.billing.tier_manager import TierManager
|
|
||||||
from app.models import MonthlyTokenUsage
|
|
||||||
from app.schemas import BillingTier
|
|
||||||
|
|
||||||
|
|
||||||
class QuotaExceeded(Exception):
|
|
||||||
"""Raised when a folder operation cannot proceed under the user's tier."""
|
|
||||||
|
|
||||||
def __init__(self, reason: str, message: str) -> None:
|
|
||||||
super().__init__(message)
|
|
||||||
self.reason = reason # "max_files" | "monthly_tokens"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class TokenUsageResult:
|
|
||||||
tokens_used: int
|
|
||||||
exhausted: bool
|
|
||||||
|
|
||||||
|
|
||||||
def _current_year_month() -> str:
|
|
||||||
return datetime.now(timezone.utc).strftime("%Y-%m")
|
|
||||||
|
|
||||||
|
|
||||||
_tier_manager = TierManager()
|
|
||||||
|
|
||||||
|
|
||||||
async def check_folder_quota(
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
tier: BillingTier,
|
|
||||||
estimated_files: int,
|
|
||||||
db: AsyncSession,
|
|
||||||
) -> None:
|
|
||||||
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
|
|
||||||
would be violated. -1 in either feature means unlimited."""
|
|
||||||
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
|
|
||||||
if max_files != -1 and estimated_files > max_files:
|
|
||||||
raise QuotaExceeded(
|
|
||||||
"max_files",
|
|
||||||
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
|
|
||||||
)
|
|
||||||
|
|
||||||
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
|
|
||||||
if cap == -1:
|
|
||||||
return
|
|
||||||
ym = _current_year_month()
|
|
||||||
row = (
|
|
||||||
await db.execute(
|
|
||||||
select(MonthlyTokenUsage).where(
|
|
||||||
MonthlyTokenUsage.user_id == user_id,
|
|
||||||
MonthlyTokenUsage.year_month == ym,
|
|
||||||
MonthlyTokenUsage.feature == "folder_index",
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one_or_none()
|
|
||||||
used = row.tokens_used if row else 0
|
|
||||||
if used >= cap:
|
|
||||||
raise QuotaExceeded(
|
|
||||||
"monthly_tokens",
|
|
||||||
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def add_token_usage(
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
feature: str,
|
|
||||||
tokens: int,
|
|
||||||
db: AsyncSession,
|
|
||||||
cap: int | None = None,
|
|
||||||
) -> TokenUsageResult:
|
|
||||||
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
|
|
||||||
|
|
||||||
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
|
|
||||||
back to a read-then-write on other engines (e.g. aiosqlite in tests).
|
|
||||||
Returns post-update total and whether cap is exhausted.
|
|
||||||
"""
|
|
||||||
ym = _current_year_month()
|
|
||||||
|
|
||||||
# Detect dialect to choose between native upsert and portable fallback.
|
|
||||||
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
|
|
||||||
|
|
||||||
if dialect_name == "postgresql":
|
|
||||||
# Native atomic upsert — production path.
|
|
||||||
stmt = (
|
|
||||||
pg_insert(MonthlyTokenUsage)
|
|
||||||
.values(
|
|
||||||
user_id=user_id,
|
|
||||||
year_month=ym,
|
|
||||||
feature=feature,
|
|
||||||
tokens_used=tokens,
|
|
||||||
)
|
|
||||||
.on_conflict_do_update(
|
|
||||||
index_elements=["user_id", "year_month", "feature"],
|
|
||||||
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
|
|
||||||
)
|
|
||||||
.returning(MonthlyTokenUsage.tokens_used)
|
|
||||||
)
|
|
||||||
used: int = (await db.execute(stmt)).scalar_one()
|
|
||||||
await db.commit()
|
|
||||||
else:
|
|
||||||
# Portable fallback — used in tests (SQLite) and any non-PG engine.
|
|
||||||
row = (
|
|
||||||
await db.execute(
|
|
||||||
select(MonthlyTokenUsage).where(
|
|
||||||
MonthlyTokenUsage.user_id == user_id,
|
|
||||||
MonthlyTokenUsage.year_month == ym,
|
|
||||||
MonthlyTokenUsage.feature == feature,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).scalar_one_or_none()
|
|
||||||
|
|
||||||
if row is None:
|
|
||||||
row = MonthlyTokenUsage(
|
|
||||||
user_id=user_id,
|
|
||||||
year_month=ym,
|
|
||||||
feature=feature,
|
|
||||||
tokens_used=tokens,
|
|
||||||
)
|
|
||||||
db.add(row)
|
|
||||||
else:
|
|
||||||
row.tokens_used += tokens
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(row)
|
|
||||||
used = row.tokens_used
|
|
||||||
|
|
||||||
exhausted = cap is not None and cap != -1 and used >= cap
|
|
||||||
return TokenUsageResult(tokens_used=used, exhausted=exhausted)
|
|
||||||
@@ -43,8 +43,8 @@ class StripeService:
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tier: str,
|
tier: str,
|
||||||
success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
cancel_url: str = "https://app.adiuvai.app/billing/cancel",
|
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a Stripe checkout session and return the URL.
|
"""Create a Stripe checkout session and return the URL.
|
||||||
|
|
||||||
@@ -200,45 +200,6 @@ class StripeService:
|
|||||||
sub.status = "canceled"
|
sub.status = "canceled"
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
async def list_invoices(
|
|
||||||
self, user_id: str, db: AsyncSession, limit: int = 24
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Return recent invoices for the user from Stripe.
|
|
||||||
|
|
||||||
Returns an empty list when Stripe is not configured or the user has
|
|
||||||
no ``stripe_customer_id``.
|
|
||||||
"""
|
|
||||||
if not self._configured():
|
|
||||||
return []
|
|
||||||
|
|
||||||
from app.models import User # noqa: PLC0415
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(User.stripe_customer_id).where(User.id == user_id)
|
|
||||||
)
|
|
||||||
customer_id = result.scalar_one_or_none()
|
|
||||||
if not customer_id:
|
|
||||||
return []
|
|
||||||
|
|
||||||
try:
|
|
||||||
s = self._client()
|
|
||||||
invoices = s.Invoice.list(customer=customer_id, limit=limit)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"id": inv.id,
|
|
||||||
"amount_due": inv.amount_due,
|
|
||||||
"amount_paid": inv.amount_paid,
|
|
||||||
"currency": inv.currency,
|
|
||||||
"status": inv.status,
|
|
||||||
"created": inv.created * 1000, # epoch ms
|
|
||||||
"invoice_url": inv.hosted_invoice_url,
|
|
||||||
"invoice_pdf": inv.invoice_pdf,
|
|
||||||
}
|
|
||||||
for inv in invoices.auto_paging_iter()
|
|
||||||
]
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
# ── Private DB helpers ───────────────────────────────────────────────
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
async def _upsert_subscription(
|
async def _upsert_subscription(
|
||||||
|
|||||||
@@ -21,58 +21,42 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"free": {
|
"free": {
|
||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
"batch_runs_per_day": 5,
|
"cloud_storage_gb": 0,
|
||||||
|
"backup_gb": 0,
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
"real_embeddings": False, # keyword fallback only
|
|
||||||
"realtime_extraction": False, # batch queue (Phase 2)
|
|
||||||
"relational_memory": False, # relational tier (Phase 3) — Pro+
|
|
||||||
"proactive_mining": False, # Power+ only (Phase 5)
|
|
||||||
"folder_max_files": 200,
|
|
||||||
"folder_monthly_tokens": 100_000,
|
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
"batch_runs_per_day": 50,
|
"cloud_storage_gb": 5,
|
||||||
|
"backup_gb": 5,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
|
"plugin_marketplace": False,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
"real_embeddings": True, # pgvector cosine search
|
|
||||||
"realtime_extraction": True, # fire-and-forget asyncio.create_task
|
|
||||||
"relational_memory": True, # person/project predicates
|
|
||||||
"proactive_mining": False, # Power+ only (Phase 5)
|
|
||||||
"folder_max_files": 5000,
|
|
||||||
"folder_monthly_tokens": 2_000_000,
|
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1, # unlimited
|
"batch_active": -1, # unlimited
|
||||||
"batch_runs_per_day": -1, # unlimited
|
"cloud_storage_gb": 25,
|
||||||
|
"backup_gb": 25,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
"sso": False,
|
"sso": False,
|
||||||
"real_embeddings": True,
|
|
||||||
"realtime_extraction": True,
|
|
||||||
"relational_memory": True, # all predicates incl. custom
|
|
||||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
|
||||||
"folder_max_files": -1, # unlimited
|
|
||||||
"folder_monthly_tokens": -1, # unlimited
|
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
"batch_runs_per_day": -1, # unlimited
|
"cloud_storage_gb": -1, # unlimited
|
||||||
|
"backup_gb": -1, # unlimited
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
|
"plugin_marketplace": True,
|
||||||
"sso": True,
|
"sso": True,
|
||||||
"real_embeddings": True,
|
|
||||||
"realtime_extraction": True,
|
|
||||||
"relational_memory": True, # all predicates incl. custom
|
|
||||||
"proactive_mining": True, # scheduled pattern mining (Phase 5)
|
|
||||||
"folder_max_files": -1, # unlimited
|
|
||||||
"folder_monthly_tokens": -1, # unlimited
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -93,18 +77,16 @@ class TierManager:
|
|||||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for ``user_id`` from the DB.
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
|
Falls back to ``'free'`` when no subscription row exists.
|
||||||
when no subscription row exists.
|
|
||||||
"""
|
"""
|
||||||
from app.models import Subscription # noqa: PLC0415
|
from app.models import Subscription # noqa: PLC0415
|
||||||
from app.config.settings import settings # noqa: PLC0415
|
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str | None = result.scalar_one_or_none()
|
tier: str | None = result.scalar_one_or_none()
|
||||||
if tier is None or tier not in FEATURES:
|
if tier is None or tier not in FEATURES:
|
||||||
return "power" if settings.ENV == "dev" else "free"
|
return "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
# ── Feature access ───────────────────────────────────────────────────
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
@@ -131,19 +113,77 @@ class TierManager:
|
|||||||
)
|
)
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
|
|
||||||
"""Return integer feature value for tier. -1 means unlimited."""
|
|
||||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
|
||||||
if not isinstance(value, int):
|
|
||||||
return 0
|
|
||||||
return value
|
|
||||||
|
|
||||||
# ── Rate limiting ────────────────────────────────────────────────────
|
# ── Rate limiting ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||||
"""Return the requests-per-minute limit for ``tier``."""
|
"""Return the requests-per-minute limit for ``tier``."""
|
||||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
|
# ── Storage quota ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def enforce_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||||
|
|
||||||
|
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||||
|
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||||
|
"""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Storage quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def enforce_backup_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> None:
|
||||||
|
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||||
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup is not available on the '{tier}' tier",
|
||||||
|
)
|
||||||
|
if limit_gb == -1:
|
||||||
|
return # unlimited
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
||||||
|
detail=f"Backup quota exceeded for tier '{tier}'",
|
||||||
|
)
|
||||||
|
|
||||||
|
def check_quota(
|
||||||
|
self,
|
||||||
|
tier: BillingTier,
|
||||||
|
current_bytes: int = 0,
|
||||||
|
additional_bytes: int = 0,
|
||||||
|
) -> bool:
|
||||||
|
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||||
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
|
if limit_gb == 0:
|
||||||
|
return False
|
||||||
|
if limit_gb == -1:
|
||||||
|
return True
|
||||||
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
|
return current_bytes + additional_bytes <= limit_bytes
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton shared across the app.
|
# Module-level singleton shared across the app.
|
||||||
tier_manager = TierManager()
|
tier_manager = TierManager()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
|
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
||||||
JWT_SECRET: str = "change-me-in-production"
|
JWT_SECRET: str = "change-me-in-production"
|
||||||
JWT_ALGORITHM: str = "HS256"
|
JWT_ALGORITHM: str = "HS256"
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
@@ -12,29 +12,26 @@ class Settings(BaseSettings):
|
|||||||
STRIPE_SECRET_KEY: str = ""
|
STRIPE_SECRET_KEY: str = ""
|
||||||
STRIPE_WEBHOOK_SECRET: str = ""
|
STRIPE_WEBHOOK_SECRET: str = ""
|
||||||
|
|
||||||
|
S3_BUCKET: str = ""
|
||||||
|
S3_REGION: str = "us-east-1"
|
||||||
|
S3_ENDPOINT_URL: str = ""
|
||||||
|
AWS_ACCESS_KEY_ID: str = ""
|
||||||
|
AWS_SECRET_ACCESS_KEY: str = ""
|
||||||
|
|
||||||
|
PINECONE_API_KEY: str = ""
|
||||||
|
PINECONE_INDEX: str = "adiuva"
|
||||||
|
QDRANT_URL: str = ""
|
||||||
|
QDRANT_API_KEY: str = ""
|
||||||
|
|
||||||
OPENAI_API_KEY: str = ""
|
OPENAI_API_KEY: str = ""
|
||||||
ANTHROPIC_API_KEY: str = ""
|
ANTHROPIC_API_KEY: str = ""
|
||||||
GOOGLE_API_KEY: str = ""
|
GOOGLE_API_KEY: str = ""
|
||||||
CEREBRAS_API_KEY: str = ""
|
CEREBRAS_API_KEY: str = ""
|
||||||
GROQ_API_KEY: str = ""
|
|
||||||
DEEPSEEK_API_KEY: str = ""
|
|
||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
|
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
||||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
|
||||||
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
|
|
||||||
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
|
||||||
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
|
|
||||||
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
|
||||||
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
|
||||||
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
|
|
||||||
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
|
|
||||||
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
|
||||||
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
|
|
||||||
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
|
|
||||||
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
|
|
||||||
|
|
||||||
# GitHub Copilot OAuth token storage directory.
|
# GitHub Copilot OAuth token storage directory.
|
||||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||||
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||||
@@ -48,39 +45,16 @@ class Settings(BaseSettings):
|
|||||||
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||||
MS_TENANT_ID: str = "common"
|
MS_TENANT_ID: str = "common"
|
||||||
|
|
||||||
# Google Login OAuth credentials — scope: openid email profile.
|
|
||||||
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
|
|
||||||
GOOGLE_AUTH_CLIENT_ID: str = ""
|
|
||||||
GOOGLE_AUTH_CLIENT_SECRET: str = ""
|
|
||||||
# The redirect URI registered in Google Cloud Console.
|
|
||||||
# Google redirects here after consent; this backend route then bounces to
|
|
||||||
# the adiuvai:// deep link so the Electron app receives the code.
|
|
||||||
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
|
|
||||||
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
|
||||||
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
|
||||||
|
|
||||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
OAUTH_ENCRYPTION_KEY: str = ""
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = [
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
"app://.",
|
|
||||||
"http://localhost:3000",
|
|
||||||
"http://localhost:5173",
|
|
||||||
"http://localhost:4173", # Vite preview (web SPA)
|
|
||||||
"https://app.adiuvai.com", # Production web portal
|
|
||||||
]
|
|
||||||
|
|
||||||
LANGFUSE_SECRET_KEY: str = ""
|
|
||||||
LANGFUSE_PUBLIC_KEY: str = ""
|
|
||||||
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
|
||||||
|
|
||||||
SCHEDULER_ENABLED: bool = True
|
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -1,30 +0,0 @@
|
|||||||
"""Minimal agent base types retained for compatibility with batch runners."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
|
|
||||||
class BaseAgent(ABC):
|
|
||||||
"""Common base for non-chat agents still using the old base contract."""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
user_id: str = "",
|
|
||||||
shared_memory: dict[str, Any] | None = None,
|
|
||||||
vector_store_context: list[str] | None = None,
|
|
||||||
) -> None:
|
|
||||||
self.user_id = user_id
|
|
||||||
self.shared_memory: dict[str, Any] = shared_memory or {}
|
|
||||||
self.vector_store_context: list[str] = vector_store_context or []
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_name(self) -> str: ...
|
|
||||||
|
|
||||||
@abstractmethod
|
|
||||||
def get_description(self) -> str: ...
|
|
||||||
|
|
||||||
@property
|
|
||||||
def skills(self) -> list[str]:
|
|
||||||
return []
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -1,59 +0,0 @@
|
|||||||
"""In-process TTL buffer for per-session LangChain message history.
|
|
||||||
|
|
||||||
Stores the full message list (including AIMessage with tool_calls and ToolMessage)
|
|
||||||
keyed by (user_id, session_id), so agents can reconstruct tool-call context across
|
|
||||||
conversation turns without it being lossy through the wire.
|
|
||||||
|
|
||||||
Single-process only. For multi-worker deployments, replace the _SessionBuffer
|
|
||||||
implementation with one backed by Redis (serialize LangChain messages to dicts via
|
|
||||||
message_to_dict / messages_from_dict from langchain_core.messages).
|
|
||||||
"""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import time
|
|
||||||
from threading import Lock
|
|
||||||
|
|
||||||
from langchain_core.messages import BaseMessage
|
|
||||||
|
|
||||||
SESSION_TTL_SECONDS = 1800 # 30-minute idle expiry
|
|
||||||
MAX_MESSAGES_PER_SESSION = 80 # cap to avoid unbounded memory growth
|
|
||||||
|
|
||||||
|
|
||||||
class _SessionBuffer:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._store: dict[tuple[str, str], tuple[float, list[BaseMessage]]] = {}
|
|
||||||
self._lock = Lock()
|
|
||||||
|
|
||||||
def _evict_stale(self) -> None:
|
|
||||||
now = time.monotonic()
|
|
||||||
stale = [k for k, (ts, _) in self._store.items() if now - ts > SESSION_TTL_SECONDS]
|
|
||||||
for k in stale:
|
|
||||||
del self._store[k]
|
|
||||||
|
|
||||||
def get(self, user_id: str, session_id: str) -> list[BaseMessage] | None:
|
|
||||||
key = (user_id, session_id)
|
|
||||||
with self._lock:
|
|
||||||
entry = self._store.get(key)
|
|
||||||
if entry is None:
|
|
||||||
return None
|
|
||||||
ts, msgs = entry
|
|
||||||
if time.monotonic() - ts > SESSION_TTL_SECONDS:
|
|
||||||
del self._store[key]
|
|
||||||
return None
|
|
||||||
self._store[key] = (time.monotonic(), msgs)
|
|
||||||
return list(msgs)
|
|
||||||
|
|
||||||
def set(self, user_id: str, session_id: str, messages: list[BaseMessage]) -> None:
|
|
||||||
key = (user_id, session_id)
|
|
||||||
capped = messages[-MAX_MESSAGES_PER_SESSION:]
|
|
||||||
with self._lock:
|
|
||||||
self._evict_stale()
|
|
||||||
self._store[key] = (time.monotonic(), capped)
|
|
||||||
|
|
||||||
def clear(self, user_id: str, session_id: str) -> None:
|
|
||||||
with self._lock:
|
|
||||||
self._store.pop((user_id, session_id), None)
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
|
|
||||||
session_buffer = _SessionBuffer()
|
|
||||||
@@ -1,228 +0,0 @@
|
|||||||
"""Brief agent — produces plain-text home and project status briefs.
|
|
||||||
|
|
||||||
Read-only tool subset only. Never calls _normalize_tagged_list_lines —
|
|
||||||
the brief prompt forbids XML tags, so skipping post-processing is intentional.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from datetime import date
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from app.agents.note_agent import NOTE_READ_TOOLS
|
|
||||||
from app.agents.project_agent import PROJECT_READ_TOOLS
|
|
||||||
from app.agents.task_agent import TASK_READ_TOOLS
|
|
||||||
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
|
|
||||||
from app.core.deep_agent import (
|
|
||||||
_language_instruction,
|
|
||||||
_proactive_hints_injection,
|
|
||||||
_read_only_memory_tools,
|
|
||||||
_relational_memory_injection,
|
|
||||||
_run_single_agent_stream,
|
|
||||||
_trace_id_from_context,
|
|
||||||
build_brief_multi_project_manifest,
|
|
||||||
)
|
|
||||||
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
|
|
||||||
|
|
||||||
_LANGUAGE_NAMES: dict[str, str] = {
|
|
||||||
"en": "English", "it": "Italian", "es": "Spanish",
|
|
||||||
"fr": "French", "de": "German",
|
|
||||||
"english": "English", "italian": "Italian", "italiano": "Italian",
|
|
||||||
"spanish": "Spanish", "español": "Spanish",
|
|
||||||
"french": "French", "français": "French",
|
|
||||||
"german": "German", "deutsch": "German",
|
|
||||||
}
|
|
||||||
|
|
||||||
_HOME_BRIEF_FALLBACK = """\
|
|
||||||
You are the user's personal assistant producing a short daily brief.
|
|
||||||
|
|
||||||
ROLE
|
|
||||||
Act like a calm, attentive secretary writing a stand-up note for your boss.
|
|
||||||
Warm and human, never breezy. Never cheerful filler, never emojis, never
|
|
||||||
"here is your brief" meta-text. The user is opening the app mid-workday and
|
|
||||||
is probably stressed — your job is to lower cognitive load, not add noise.
|
|
||||||
|
|
||||||
TOOLS — always call before writing
|
|
||||||
Pull fresh data every run. Do not invent counts or titles. Use at minimum:
|
|
||||||
- list_tasks_due_today — tasks the user owes today
|
|
||||||
- list_timelines_today — events starting or ending today
|
|
||||||
- list_all_projects — projects currently in progress or at risk
|
|
||||||
- memory_list_blocks / memory_get — personal context about people, clients,
|
|
||||||
payment habits, working preferences
|
|
||||||
If a tool returns nothing, simply omit that topic. Never report zeros.
|
|
||||||
|
|
||||||
WHAT TO INCLUDE
|
|
||||||
1. Tasks due today (title + priority; group the 1-2 most important).
|
|
||||||
2. Timeline events starting or ending today (and anything that starts/ends
|
|
||||||
tomorrow if the user has a very light day).
|
|
||||||
3. Active projects that need a nudge — stalled, blocked, or awaiting input.
|
|
||||||
4. Memory-aware colour where it sharpens the brief. Examples:
|
|
||||||
- "Client Rossi tends to pay late — the Acme invoice is 6 days out."
|
|
||||||
- "You usually dislike meetings before 10:00 — the call at 09:30 is unusual."
|
|
||||||
Only add a memory line when it changes what the user does. Do not pad.
|
|
||||||
|
|
||||||
WHAT TO OMIT
|
|
||||||
- Zero-counts ("no overdue items", "0 meetings today").
|
|
||||||
- Statistics ("2 active projects, 3 completed tasks").
|
|
||||||
- Headers, titles, greetings, sign-offs, dates, emojis, slang.
|
|
||||||
- Meta-phrases ("here is", "let me know if", "hope this helps").
|
|
||||||
- XML/HTML tags of any kind. Plain prose only.
|
|
||||||
|
|
||||||
LIGHT-DAY CLAUSE
|
|
||||||
If tasks + events + active-project-nudges together produce fewer than two
|
|
||||||
sentences of content, also list 1-2 projects in status on_hold or waiting
|
|
||||||
and ask a single, specific question about them — e.g. "Is the Bianchi
|
|
||||||
redesign still paused, or ready to pick back up?" One question max, grounded
|
|
||||||
in a real project name.
|
|
||||||
|
|
||||||
VOICE
|
|
||||||
- Calm. Concise. Human. Short sentences.
|
|
||||||
- Use **bold** sparingly for task titles, project names, and people's names.
|
|
||||||
- No bullet lists. Flow as 2-4 sentences of prose.
|
|
||||||
|
|
||||||
LENGTH
|
|
||||||
2-4 sentences total. Hard cap 4. If the day is truly empty, one sentence.
|
|
||||||
|
|
||||||
Respond in the user's language ({language}). Today is {today}.\
|
|
||||||
"""
|
|
||||||
|
|
||||||
_PROJECT_BRIEF_FALLBACK = """\
|
|
||||||
You are the project assistant producing a short status brief for ONE project.
|
|
||||||
|
|
||||||
ROLE
|
|
||||||
A senior project manager summarising state-of-play for the owner. Factual,
|
|
||||||
sharp, forward-looking. Never reassuring filler, never emojis.
|
|
||||||
|
|
||||||
SCOPE
|
|
||||||
Work only with project_id = {project_id}. Do not mention or pull data from
|
|
||||||
other projects. Use tools to fetch fresh data:
|
|
||||||
- get_project — current status, dates, description
|
|
||||||
- list_tasks(project_id) — open work, split by status
|
|
||||||
- list_timelines(project_id) — milestones hit, upcoming, overdue
|
|
||||||
- list_notes(project_id) — any recent decisions or blockers
|
|
||||||
- memory_get — relevant context about the client, collaborators, constraints
|
|
||||||
|
|
||||||
STRUCTURE — follow exactly, one short paragraph per section, no headers
|
|
||||||
1. **State.** One sentence: current phase, health (on track / at risk / blocked),
|
|
||||||
and why. Cite the concrete signal (overdue milestone, stalled tasks, recent
|
|
||||||
blocker note).
|
|
||||||
2. **What's moving.** What was completed or progressed recently. Name specific
|
|
||||||
tasks or milestones.
|
|
||||||
3. **Next steps.** The 1-3 most important things the user should do next, in
|
|
||||||
priority order. Be concrete — task name, who owns it, when due if known.
|
|
||||||
If waiting on someone else, name them and what the ask is.
|
|
||||||
4. **Risks / memory-flagged items.** One line max. Only include when there is
|
|
||||||
a real risk or a relevant memory (e.g. late-paying client, tight deadline,
|
|
||||||
scope change). Omit the section entirely if nothing to say.
|
|
||||||
|
|
||||||
WHAT TO OMIT
|
|
||||||
- Zero-counts ("no overdue tasks").
|
|
||||||
- Generic advice ("keep up the good work").
|
|
||||||
- Greetings, headers, bullet lists, emojis, sign-offs, meta-phrases.
|
|
||||||
- XML/HTML tags or bracketed id lists. Plain prose only.
|
|
||||||
|
|
||||||
VOICE
|
|
||||||
- Direct. Factual. No fluff.
|
|
||||||
- Use **bold** sparingly for task titles, milestone names, and the owner's name.
|
|
||||||
- Short sentences. Prefer verbs over nouns ("Client review is blocking release"
|
|
||||||
not "There is a blocker which is the client review").
|
|
||||||
|
|
||||||
LENGTH
|
|
||||||
4-8 sentences total across the 3-4 sections. Hard cap 8.
|
|
||||||
|
|
||||||
Respond in the user's language ({language}). Today is {today}.\
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_language(context: dict[str, Any]) -> str:
|
|
||||||
core = context.get("core_memory") or {}
|
|
||||||
raw = (core.get("language") or "en").strip().lower()
|
|
||||||
return _LANGUAGE_NAMES.get(raw, raw.title()) or "English"
|
|
||||||
|
|
||||||
|
|
||||||
def _build_read_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
|
||||||
return [
|
|
||||||
*TASK_READ_TOOLS,
|
|
||||||
*PROJECT_READ_TOOLS,
|
|
||||||
*TIMELINE_READ_TOOLS,
|
|
||||||
*NOTE_READ_TOOLS,
|
|
||||||
*_read_only_memory_tools(user_id, trace_id),
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
async def run_home_brief(
|
|
||||||
user_id: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
||||||
"""Stream a plain-text daily home brief.
|
|
||||||
|
|
||||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
|
||||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
|
||||||
"""
|
|
||||||
from app.agents.folder_agent import FOLDER_TOOLS
|
|
||||||
|
|
||||||
trace_id = _trace_id_from_context(context)
|
|
||||||
today = date.today().isoformat()
|
|
||||||
language = _resolve_language(context)
|
|
||||||
|
|
||||||
raw_template, langfuse_prompt = get_prompt_or_fallback("home_brief", _HOME_BRIEF_FALLBACK)
|
|
||||||
system_prompt = compile_prompt(raw_template, langfuse_prompt, language=language, today=today)
|
|
||||||
system_prompt += _relational_memory_injection(context)
|
|
||||||
system_prompt += _proactive_hints_injection(context)
|
|
||||||
system_prompt += _language_instruction(context)
|
|
||||||
if today not in system_prompt:
|
|
||||||
system_prompt += f"\nToday is {today}."
|
|
||||||
|
|
||||||
brief_manifest = await build_brief_multi_project_manifest()
|
|
||||||
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
|
|
||||||
|
|
||||||
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
|
|
||||||
async for event in _run_single_agent_stream(
|
|
||||||
user_id=user_id,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
message="Generate the daily brief.",
|
|
||||||
context=context,
|
|
||||||
langfuse_prompt=langfuse_prompt,
|
|
||||||
agent_name="brief-agent",
|
|
||||||
tools=tools,
|
|
||||||
):
|
|
||||||
yield event
|
|
||||||
|
|
||||||
|
|
||||||
async def run_project_brief(
|
|
||||||
user_id: str,
|
|
||||||
project_id: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
||||||
"""Stream a plain-text project status brief for project_id.
|
|
||||||
|
|
||||||
Yields (event_type, data) tuples identical to _run_single_agent_stream.
|
|
||||||
Do NOT post-process output through _normalize_tagged_list_lines.
|
|
||||||
"""
|
|
||||||
trace_id = _trace_id_from_context(context)
|
|
||||||
today = date.today().isoformat()
|
|
||||||
language = _resolve_language(context)
|
|
||||||
|
|
||||||
raw_template, langfuse_prompt = get_prompt_or_fallback("project_brief", _PROJECT_BRIEF_FALLBACK)
|
|
||||||
system_prompt = compile_prompt(
|
|
||||||
raw_template, langfuse_prompt,
|
|
||||||
language=language, today=today, project_id=project_id,
|
|
||||||
)
|
|
||||||
system_prompt += _relational_memory_injection(context)
|
|
||||||
system_prompt += _proactive_hints_injection(context)
|
|
||||||
system_prompt += _language_instruction(context)
|
|
||||||
if today not in system_prompt:
|
|
||||||
system_prompt += f"\nToday is {today}."
|
|
||||||
|
|
||||||
tools = _build_read_tools(user_id, trace_id)
|
|
||||||
async for event in _run_single_agent_stream(
|
|
||||||
user_id=user_id,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
message=f"Generate the project status brief for project {project_id}.",
|
|
||||||
context=context,
|
|
||||||
langfuse_prompt=langfuse_prompt,
|
|
||||||
agent_name="brief-agent",
|
|
||||||
tools=tools,
|
|
||||||
):
|
|
||||||
yield event
|
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -3,15 +3,20 @@
|
|||||||
Maintains in-memory state for all active Electron → backend WebSocket
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
connections. One connection per user (latest replaces previous).
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
The manager handles the **tool-call round-trip** pattern:
|
The manager participates in two interaction patterns:
|
||||||
- Backend sends ``tool_call`` frame → Electron executes the action →
|
|
||||||
returns ``tool_result`` frame.
|
|
||||||
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
|
||||||
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
|
||||||
receive the result dict from Electron.
|
|
||||||
|
|
||||||
This pattern is used by all tools (CRUD, file-system, etc.) via
|
1. **Tool-call round-trip** (bidirectional CRUD):
|
||||||
``execute_on_client()`` in ``ws_context.py``.
|
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
||||||
|
``tool_result`` frame.
|
||||||
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
|
receive the result dict from Electron.
|
||||||
|
|
||||||
|
2. **Agent-data streaming** (local directory agent runs):
|
||||||
|
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
||||||
|
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
||||||
|
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
||||||
|
a specific ``run_id`` so the agent runner can iterate frames.
|
||||||
|
|
||||||
The ``device_manager`` module-level singleton is imported by both the
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
device WS route and the agent runner.
|
device WS route and the agent runner.
|
||||||
@@ -37,6 +42,8 @@ class DeviceConnection:
|
|||||||
device_id: str
|
device_id: str
|
||||||
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
# Per-run queues for agent_data / agent_complete frames.
|
||||||
|
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class DeviceConnectionManager:
|
class DeviceConnectionManager:
|
||||||
@@ -146,6 +153,31 @@ class DeviceConnectionManager:
|
|||||||
if fut is not None and not fut.done():
|
if fut is not None and not fut.done():
|
||||||
fut.set_result(result)
|
fut.set_result(result)
|
||||||
|
|
||||||
|
# ── Agent-data queue ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_agent_data_queue(
|
||||||
|
self, user_id: str, run_id: str
|
||||||
|
) -> asyncio.Queue[dict | None]:
|
||||||
|
"""Return (creating if absent) the queue for *run_id* agent frames.
|
||||||
|
|
||||||
|
The agent runner reads from this queue. The device WS handler writes
|
||||||
|
to it. ``None`` is the sentinel that signals the stream is finished.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"get_agent_data_queue: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
if run_id not in conn.agent_data_queues:
|
||||||
|
conn.agent_data_queues[run_id] = asyncio.Queue()
|
||||||
|
return conn.agent_data_queues[run_id]
|
||||||
|
|
||||||
|
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
||||||
|
"""Remove the queue for *run_id* once a run has completed."""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.agent_data_queues.pop(run_id, None)
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton — import this everywhere.
|
# Module-level singleton — import this everywhere.
|
||||||
device_manager = DeviceConnectionManager()
|
device_manager = DeviceConnectionManager()
|
||||||
|
|||||||
@@ -1,34 +0,0 @@
|
|||||||
"""OpenAI embedding helper for associative memory tier.
|
|
||||||
|
|
||||||
Single public function: ``embed_text(text) -> list[float] | None``.
|
|
||||||
Returns None on any failure — callers must implement a keyword fallback.
|
|
||||||
Never raises; all exceptions are logged as warnings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_MAX_INPUT_CHARS = 8000
|
|
||||||
_EMBEDDING_MODEL = "text-embedding-3-small"
|
|
||||||
|
|
||||||
|
|
||||||
async def embed_text(text: str) -> list[float] | None:
|
|
||||||
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
|
|
||||||
try:
|
|
||||||
client = AsyncOpenAI()
|
|
||||||
truncated = text[:_MAX_INPUT_CHARS]
|
|
||||||
response = await client.embeddings.create(
|
|
||||||
input=truncated,
|
|
||||||
model=_EMBEDDING_MODEL,
|
|
||||||
)
|
|
||||||
result: list[float] = response.data[0].embedding
|
|
||||||
logger.debug("embeddings: embed_text dims=%d", len(result))
|
|
||||||
return result
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("embeddings: embed_text failed: %s", exc)
|
|
||||||
return None
|
|
||||||
@@ -1,183 +0,0 @@
|
|||||||
"""Per-file summarisation for project folder integration."""
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import io
|
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
from pypdf import PdfReader
|
|
||||||
from docx import Document as DocxDocument
|
|
||||||
|
|
||||||
from app.core.langfuse_client import (
|
|
||||||
compile_prompt,
|
|
||||||
extract_usage,
|
|
||||||
get_langfuse,
|
|
||||||
get_prompt_or_fallback,
|
|
||||||
)
|
|
||||||
from app.core.llm import get_llm
|
|
||||||
|
|
||||||
_TEXT_FALLBACK = (
|
|
||||||
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
|
|
||||||
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
|
|
||||||
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
|
|
||||||
)
|
|
||||||
_IMAGE_FALLBACK = (
|
|
||||||
"You are summarising an image attached to a project folder.\n"
|
|
||||||
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
|
|
||||||
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
|
|
||||||
)
|
|
||||||
_MAX_INPUT_CHARS = 6000
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class IndexResult:
|
|
||||||
summary: str
|
|
||||||
tokens_used: int
|
|
||||||
|
|
||||||
|
|
||||||
async def _llm_text(messages: list) -> object:
|
|
||||||
"""Make the LLM call for text summarisation.
|
|
||||||
|
|
||||||
Defined as a standalone async function so tests can patch it cleanly
|
|
||||||
without needing to mock the LLM object itself.
|
|
||||||
"""
|
|
||||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
|
||||||
return await llm.ainvoke(messages)
|
|
||||||
|
|
||||||
|
|
||||||
async def _llm_vision(messages: list) -> object:
|
|
||||||
"""Make the LLM call for vision (image) summarisation.
|
|
||||||
|
|
||||||
Accepts the message list and returns the response directly, mirroring
|
|
||||||
the ``_llm_text`` caller pattern so tests can patch it at the module level.
|
|
||||||
"""
|
|
||||||
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
|
|
||||||
return await llm.ainvoke(messages)
|
|
||||||
|
|
||||||
|
|
||||||
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
|
|
||||||
"""Return a compact summary of an image file using vision.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
image_b64:
|
|
||||||
Base64-encoded image bytes.
|
|
||||||
mime:
|
|
||||||
MIME type of the image, e.g. ``"image/png"``.
|
|
||||||
file_name:
|
|
||||||
Optional file name, attached to the Langfuse trace as input metadata.
|
|
||||||
"""
|
|
||||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=template),
|
|
||||||
HumanMessage(content=[
|
|
||||||
{"type": "text", "text": "Summarise this image."},
|
|
||||||
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
|
|
||||||
]),
|
|
||||||
]
|
|
||||||
lf = get_langfuse()
|
|
||||||
if lf is not None:
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="folder-summarize-image",
|
|
||||||
model="gpt-4o-mini",
|
|
||||||
prompt=prompt_obj,
|
|
||||||
input={"file_name": file_name, "mime": mime},
|
|
||||||
) as gen:
|
|
||||||
response = await _llm_vision(messages)
|
|
||||||
usage = extract_usage(response)
|
|
||||||
gen.update(output=response.content, usage_details=usage)
|
|
||||||
else:
|
|
||||||
response = await _llm_vision(messages)
|
|
||||||
usage = extract_usage(response)
|
|
||||||
summary = (response.content or "").strip()[:500]
|
|
||||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
|
||||||
|
|
||||||
|
|
||||||
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
|
|
||||||
"""Return a compact summary of a text file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
content:
|
|
||||||
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
|
|
||||||
ext:
|
|
||||||
File extension including the leading dot, e.g. ``".md"``.
|
|
||||||
name:
|
|
||||||
File name, e.g. ``"kickoff.md"``.
|
|
||||||
"""
|
|
||||||
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
|
|
||||||
truncated = content[:_MAX_INPUT_CHARS]
|
|
||||||
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=compiled),
|
|
||||||
HumanMessage(content="Summarise this file."),
|
|
||||||
]
|
|
||||||
lf = get_langfuse()
|
|
||||||
if lf is not None:
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="folder-summarize-text",
|
|
||||||
model="gpt-4o-mini",
|
|
||||||
prompt=prompt_obj,
|
|
||||||
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
|
|
||||||
) as gen:
|
|
||||||
response = await _llm_text(messages)
|
|
||||||
usage = extract_usage(response)
|
|
||||||
gen.update(output=response.content, usage_details=usage)
|
|
||||||
else:
|
|
||||||
response = await _llm_text(messages)
|
|
||||||
usage = extract_usage(response)
|
|
||||||
summary = (response.content or "").strip()[:500]
|
|
||||||
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_pdf_text(pdf_b64: str) -> str:
|
|
||||||
buf = io.BytesIO(base64.b64decode(pdf_b64))
|
|
||||||
reader = PdfReader(buf)
|
|
||||||
parts: list[str] = []
|
|
||||||
for page in reader.pages:
|
|
||||||
try:
|
|
||||||
parts.append(page.extract_text() or "")
|
|
||||||
except Exception:
|
|
||||||
continue
|
|
||||||
return "\n".join(parts).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_docx_text(docx_b64: str) -> str:
|
|
||||||
buf = io.BytesIO(base64.b64decode(docx_b64))
|
|
||||||
doc = DocxDocument(buf)
|
|
||||||
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
|
|
||||||
|
|
||||||
|
|
||||||
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
|
|
||||||
"""Return a compact summary of a PDF file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
pdf_b64:
|
|
||||||
Base64-encoded PDF bytes.
|
|
||||||
name:
|
|
||||||
File name, e.g. ``"report.pdf"``.
|
|
||||||
"""
|
|
||||||
text = _extract_pdf_text(pdf_b64)
|
|
||||||
if not text:
|
|
||||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
|
||||||
return await summarize_text(content=text, ext=".pdf", name=name)
|
|
||||||
|
|
||||||
|
|
||||||
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
|
|
||||||
"""Return a compact summary of a DOCX file.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
docx_b64:
|
|
||||||
Base64-encoded DOCX bytes.
|
|
||||||
name:
|
|
||||||
File name, e.g. ``"spec.docx"``.
|
|
||||||
"""
|
|
||||||
text = _extract_docx_text(docx_b64)
|
|
||||||
if not text:
|
|
||||||
return IndexResult(summary="Could not extract text", tokens_used=0)
|
|
||||||
return await summarize_text(content=text, ext=".docx", name=name)
|
|
||||||
@@ -1,190 +0,0 @@
|
|||||||
"""Langfuse observability — singleton client and prompt helpers.
|
|
||||||
|
|
||||||
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
|
|
||||||
all helpers are no-ops so the app works without Langfuse configured.
|
|
||||||
|
|
||||||
Usage
|
|
||||||
-----
|
|
||||||
Tracing::
|
|
||||||
|
|
||||||
from app.core.langfuse_client import get_langfuse
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
if lf:
|
|
||||||
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
|
|
||||||
span.update(input=user_message)
|
|
||||||
# ... do work ...
|
|
||||||
span.update(output=result)
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
Prompt management::
|
|
||||||
|
|
||||||
from app.core.langfuse_client import get_prompt_or_fallback
|
|
||||||
|
|
||||||
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
|
|
||||||
# Use text as the system prompt; pass prompt_obj to generations for linking.
|
|
||||||
|
|
||||||
Linking a prompt to a generation::
|
|
||||||
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="llm-call",
|
|
||||||
model="gpt-4o",
|
|
||||||
prompt=prompt_obj, # links generation → prompt version in the UI
|
|
||||||
input=messages,
|
|
||||||
) as gen:
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
gen.update(output=response.content, usage=_usage(response))
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import logging
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Any, Generator
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_client: Any = None
|
|
||||||
_initialized: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def get_langfuse() -> Any | None:
|
|
||||||
"""Return the Langfuse singleton, or ``None`` when not configured."""
|
|
||||||
global _client, _initialized
|
|
||||||
if _initialized:
|
|
||||||
return _client
|
|
||||||
_initialized = True
|
|
||||||
|
|
||||||
from app.config.settings import settings # local import to avoid circular deps
|
|
||||||
|
|
||||||
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
|
||||||
logger.debug("langfuse: not configured — observability disabled")
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import Langfuse
|
|
||||||
|
|
||||||
_client = Langfuse(
|
|
||||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
|
||||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
|
||||||
host=settings.LANGFUSE_BASE_URL,
|
|
||||||
)
|
|
||||||
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_BASE_URL)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse: failed to initialize: %s", exc)
|
|
||||||
_client = None
|
|
||||||
|
|
||||||
return _client
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
|
||||||
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
|
|
||||||
|
|
||||||
Returns ``(raw_template, prompt_obj_or_None)``.
|
|
||||||
|
|
||||||
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
|
|
||||||
on it directly; use :func:`compile_prompt` instead so the correct variable
|
|
||||||
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
|
|
||||||
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
|
|
||||||
unavailable / the fetch failed. Pass this to generation observations so
|
|
||||||
Langfuse links the generation to the exact prompt version in the UI.
|
|
||||||
"""
|
|
||||||
lf = get_langfuse()
|
|
||||||
if lf is None:
|
|
||||||
return fallback, None
|
|
||||||
|
|
||||||
try:
|
|
||||||
prompt = lf.get_prompt(name, label="production", fallback=fallback)
|
|
||||||
# For text-type prompts .prompt holds the raw template string.
|
|
||||||
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
|
|
||||||
return raw, prompt
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
|
|
||||||
return fallback, None
|
|
||||||
|
|
||||||
|
|
||||||
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
|
|
||||||
"""Compile *template* with *variables*, choosing the right syntax.
|
|
||||||
|
|
||||||
* When *prompt_obj* is a real Langfuse prompt object, calls
|
|
||||||
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
|
|
||||||
substitution as defined in the Langfuse UI.
|
|
||||||
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
|
|
||||||
falls back to ``template.format(**variables)`` which handles the
|
|
||||||
``{variable}`` syntax used in the hardcoded fallback strings.
|
|
||||||
|
|
||||||
This keeps callers oblivious to which syntax is in use.
|
|
||||||
"""
|
|
||||||
if prompt_obj is not None:
|
|
||||||
try:
|
|
||||||
compiled = prompt_obj.compile(**variables)
|
|
||||||
# compile() returns a string for text prompts.
|
|
||||||
if isinstance(compiled, str):
|
|
||||||
return compiled
|
|
||||||
# Chat prompts return a list of dicts — join text parts.
|
|
||||||
if isinstance(compiled, list):
|
|
||||||
return "\n".join(
|
|
||||||
m.get("content", "") for m in compiled if isinstance(m, dict)
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
|
|
||||||
getattr(prompt_obj, "name", "?"),
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
return template.format(**variables)
|
|
||||||
|
|
||||||
|
|
||||||
def extract_usage(response: Any) -> dict[str, int]:
|
|
||||||
"""Extract token usage from a LangChain AI message into Langfuse format."""
|
|
||||||
meta = getattr(response, "usage_metadata", None)
|
|
||||||
if not meta:
|
|
||||||
return {}
|
|
||||||
return {
|
|
||||||
"input": int(meta.get("input_tokens", 0)),
|
|
||||||
"output": int(meta.get("output_tokens", 0)),
|
|
||||||
"total": int(meta.get("total_tokens", 0)),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def hash_user_id(user_id: str) -> str:
|
|
||||||
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
|
|
||||||
|
|
||||||
This avoids sending raw database UUIDs to external observability services
|
|
||||||
while still providing a stable, deterministic identifier for per-user
|
|
||||||
metrics in the Langfuse dashboard.
|
|
||||||
"""
|
|
||||||
return hashlib.sha256(user_id.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def langfuse_context(
|
|
||||||
user_id: str | None = None,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> Generator[None, None, None]:
|
|
||||||
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
|
|
||||||
|
|
||||||
No-op when Langfuse is not configured or parameters are empty.
|
|
||||||
"""
|
|
||||||
lf = get_langfuse()
|
|
||||||
if lf is None or (not user_id and not session_id):
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import propagate_attributes
|
|
||||||
except ImportError:
|
|
||||||
logger.debug("langfuse: propagate_attributes not available — skipping context")
|
|
||||||
yield
|
|
||||||
return
|
|
||||||
|
|
||||||
attrs: dict[str, str] = {}
|
|
||||||
if user_id:
|
|
||||||
attrs["user_id"] = hash_user_id(user_id)
|
|
||||||
if session_id:
|
|
||||||
attrs["session_id"] = session_id
|
|
||||||
|
|
||||||
with propagate_attributes(**attrs):
|
|
||||||
yield
|
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
Every agent and the orchestrator call ``get_llm()``
|
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
|
||||||
instead of directly constructing a provider-specific class. The model string
|
instead of directly constructing a provider-specific class. The model string
|
||||||
follows the `LiteLLM model naming convention
|
follows the `LiteLLM model naming convention
|
||||||
<https://docs.litellm.ai/docs/providers>`_:
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
@@ -11,15 +11,13 @@ follows the `LiteLLM model naming convention
|
|||||||
* Ollama: ``ollama/llama3``
|
* Ollama: ``ollama/llama3``
|
||||||
* Bedrock: ``bedrock/anthropic.claude-v2``
|
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||||
|
|
||||||
Switch providers by changing **LLM_MODEL** in ``.env``
|
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
||||||
— no code changes required.
|
— no code changes required.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
|
||||||
from collections.abc import Callable
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -34,14 +32,6 @@ from app.config.settings import settings
|
|||||||
# Drop them silently instead of raising UnsupportedParamsError.
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
# Some provider responses include a plain dict in the `usage` field where a
|
|
||||||
# richer Pydantic model is expected. This warning is noisy but non-fatal.
|
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore",
|
|
||||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
|
||||||
category=UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
"""Return the most appropriate API key for the given LiteLLM model string."""
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
@@ -51,10 +41,6 @@ def _api_key_for_model(model: str) -> str | None:
|
|||||||
return settings.GOOGLE_API_KEY or None
|
return settings.GOOGLE_API_KEY or None
|
||||||
if model.startswith("cerebras/"):
|
if model.startswith("cerebras/"):
|
||||||
return settings.CEREBRAS_API_KEY or None
|
return settings.CEREBRAS_API_KEY or None
|
||||||
if model.startswith("groq/"):
|
|
||||||
return settings.GROQ_API_KEY or None
|
|
||||||
if model.startswith("deepseek/"):
|
|
||||||
return settings.DEEPSEEK_API_KEY or None
|
|
||||||
if model.startswith("github_copilot/"):
|
if model.startswith("github_copilot/"):
|
||||||
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||||
# No API key is required; returning None lets LiteLLM handle auth.
|
# No API key is required; returning None lets LiteLLM handle auth.
|
||||||
@@ -100,39 +86,12 @@ def get_llm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
def get_router_llm(
|
||||||
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
|
||||||
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
|
||||||
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
|
||||||
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
|
||||||
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
|
||||||
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
|
|
||||||
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
|
|
||||||
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
|
||||||
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
|
|
||||||
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
|
|
||||||
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
|
|
||||||
"note-summarizer": lambda: "gpt-4o-mini",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def model_for_agent(agent_name: str) -> str:
|
|
||||||
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
|
|
||||||
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
|
|
||||||
|
|
||||||
|
|
||||||
def get_agent_llm(
|
|
||||||
agent_name: str,
|
|
||||||
*,
|
*,
|
||||||
temperature: float = 0,
|
temperature: float = 0,
|
||||||
) -> ChatOpenAI | ChatLiteLLM:
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
|
"""Return the lighter model used for intent classification / routing."""
|
||||||
|
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
||||||
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
|
|
||||||
per-agent override is left empty in ``.env``.
|
|
||||||
"""
|
|
||||||
model = model_for_agent(agent_name)
|
|
||||||
return get_llm(model=model, temperature=temperature)
|
|
||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
async def embed(text: str) -> list[float]:
|
||||||
|
|||||||
@@ -1,450 +0,0 @@
|
|||||||
"""Mem0-style Extract/Update pipeline — Phase 2.
|
|
||||||
|
|
||||||
Runs after every ``store_episode`` call to distil durable facts, preferences,
|
|
||||||
routines, and relations from the latest conversation turn.
|
|
||||||
|
|
||||||
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
|
|
||||||
|
|
||||||
Design notes
|
|
||||||
------------
|
|
||||||
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
|
|
||||||
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
|
|
||||||
- Zero-trust: never logs decrypted user content; relation subject/object labels are
|
|
||||||
treated as identifiers (safe to log per spec).
|
|
||||||
- Must not raise into the request path — caller wraps in asyncio.create_task().
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
|
|
||||||
from app.core.llm import get_agent_llm, model_for_agent
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
|
|
||||||
|
|
||||||
_EXTRACTION_FALLBACK = (
|
|
||||||
"You are a memory extractor for a personal AI secretary. Given the last conversation "
|
|
||||||
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
|
|
||||||
"preferences, routines, and person/project relations worth remembering.\n\n"
|
|
||||||
"Output JSON matching this schema exactly:\n"
|
|
||||||
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
|
|
||||||
'"content": "<short canonical statement>", '
|
|
||||||
'"target_tier": "<core|associative|relational|proactive>", '
|
|
||||||
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
|
|
||||||
"Rules:\n"
|
|
||||||
"- Skip small talk, greetings, one-off questions.\n"
|
|
||||||
"- Max 5 candidates per call.\n"
|
|
||||||
"- Only extract durable information (still true next week).\n"
|
|
||||||
"- For type=relation: subject/predicate/object required.\n"
|
|
||||||
"- Default confidence=0.7.\n\n"
|
|
||||||
"## Last turn\n{last_turn}\n\n"
|
|
||||||
"## Core memory (current)\n{core_memory}\n\n"
|
|
||||||
"## Recent episodes\n{recent_episodes}"
|
|
||||||
)
|
|
||||||
|
|
||||||
_DECIDE_FALLBACK = (
|
|
||||||
"You are a memory update decision engine. Given a new memory candidate and a list of "
|
|
||||||
"existing memories from the same tier, decide what action to take.\n\n"
|
|
||||||
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
|
|
||||||
"- ADD: new information not in existing memories.\n"
|
|
||||||
"- UPDATE: contradicts or supersedes an existing memory.\n"
|
|
||||||
"- DELETE: states something is no longer true.\n"
|
|
||||||
"- NOOP: already captured accurately.\n\n"
|
|
||||||
"## New candidate\n{candidate}\n\n"
|
|
||||||
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Pydantic schemas ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class MemoryCandidate(BaseModel):
|
|
||||||
type: Literal["fact", "preference", "relation", "routine"]
|
|
||||||
content: str
|
|
||||||
target_tier: Literal["core", "associative", "relational", "proactive"]
|
|
||||||
subject: str | None = None
|
|
||||||
predicate: str | None = None
|
|
||||||
object: str | None = None
|
|
||||||
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
|
|
||||||
|
|
||||||
|
|
||||||
class ExtractionResult(BaseModel):
|
|
||||||
candidates: list[MemoryCandidate] = Field(default_factory=list)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def extract_candidates(
|
|
||||||
last_turn: str,
|
|
||||||
core_memory: dict[str, str],
|
|
||||||
recent_episodes: list[str],
|
|
||||||
) -> ExtractionResult:
|
|
||||||
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
|
|
||||||
|
|
||||||
Returns an ExtractionResult (may be empty on failure — never raises).
|
|
||||||
"""
|
|
||||||
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
|
|
||||||
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
|
|
||||||
|
|
||||||
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
|
|
||||||
|
|
||||||
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
|
|
||||||
if prompt_obj is not None:
|
|
||||||
try:
|
|
||||||
system_text = prompt_obj.compile(
|
|
||||||
last_turn=last_turn,
|
|
||||||
core_memory=core_str,
|
|
||||||
recent_episodes=episodes_str,
|
|
||||||
)
|
|
||||||
if isinstance(system_text, list):
|
|
||||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_extraction: compile failed: %s", exc)
|
|
||||||
system_text = template.format(
|
|
||||||
last_turn=last_turn,
|
|
||||||
core_memory=core_str,
|
|
||||||
recent_episodes=episodes_str,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
system_text = template.format(
|
|
||||||
last_turn=last_turn,
|
|
||||||
core_memory=core_str,
|
|
||||||
recent_episodes=episodes_str,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
|
||||||
# Bind JSON mode so the model always returns parseable output.
|
|
||||||
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
try:
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=system_text),
|
|
||||||
HumanMessage(content="Extract memory candidates as JSON."),
|
|
||||||
]
|
|
||||||
|
|
||||||
if lf:
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="memory-extraction",
|
|
||||||
model=model_for_agent("memory-extractor"),
|
|
||||||
prompt=prompt_obj,
|
|
||||||
input=messages,
|
|
||||||
) as gen:
|
|
||||||
response = await llm_json.ainvoke(messages)
|
|
||||||
gen.update(output=response.content, usage=extract_usage(response))
|
|
||||||
else:
|
|
||||||
response = await llm_json.ainvoke(messages)
|
|
||||||
|
|
||||||
raw = json.loads(response.content)
|
|
||||||
result = ExtractionResult.model_validate(raw)
|
|
||||||
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
|
|
||||||
return result
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
|
|
||||||
return ExtractionResult(candidates=[])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def decide_action(
|
|
||||||
candidate: MemoryCandidate,
|
|
||||||
existing: list[str],
|
|
||||||
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
|
|
||||||
"""Decide what to do with a candidate given existing memories in the same tier.
|
|
||||||
|
|
||||||
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
|
|
||||||
Never raises.
|
|
||||||
"""
|
|
||||||
if not existing:
|
|
||||||
return "ADD"
|
|
||||||
|
|
||||||
candidate_str = f"[{candidate.type}] {candidate.content}"
|
|
||||||
existing_str = "\n".join(f"- {m}" for m in existing)
|
|
||||||
|
|
||||||
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
|
|
||||||
|
|
||||||
if prompt_obj is not None:
|
|
||||||
try:
|
|
||||||
system_text = prompt_obj.compile(
|
|
||||||
candidate=candidate_str,
|
|
||||||
existing_memories=existing_str,
|
|
||||||
)
|
|
||||||
if isinstance(system_text, list):
|
|
||||||
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_extraction: decide compile failed: %s", exc)
|
|
||||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
|
||||||
else:
|
|
||||||
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
|
|
||||||
|
|
||||||
llm = get_agent_llm("memory-extractor", temperature=0)
|
|
||||||
lf = get_langfuse()
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=system_text),
|
|
||||||
HumanMessage(content="Decide action."),
|
|
||||||
]
|
|
||||||
|
|
||||||
if lf:
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="memory-decide-action",
|
|
||||||
model=model_for_agent("memory-extractor"),
|
|
||||||
prompt=prompt_obj,
|
|
||||||
input=messages,
|
|
||||||
) as gen:
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
gen.update(output=response.content, usage=extract_usage(response))
|
|
||||||
else:
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
|
|
||||||
verb = response.content.strip().upper()
|
|
||||||
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
|
|
||||||
return verb # type: ignore[return-value]
|
|
||||||
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
|
|
||||||
return "ADD"
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_extraction: decide_action failed: %s", exc)
|
|
||||||
return "ADD"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
|
|
||||||
|
|
||||||
async def run_extraction(
|
|
||||||
db: AsyncSession,
|
|
||||||
user_id: str,
|
|
||||||
last_user_msg: str,
|
|
||||||
last_assistant_msg: str,
|
|
||||||
session_id: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Full Mem0-style extract/update pipeline for one conversation turn.
|
|
||||||
|
|
||||||
Steps:
|
|
||||||
1. Load core memory + last 5 episodes.
|
|
||||||
2. extract_candidates() → up to 5 MemoryCandidate objects.
|
|
||||||
3. For each candidate: find top-3 neighbours → decide_action() → apply.
|
|
||||||
4. Trace via Langfuse.
|
|
||||||
|
|
||||||
Never raises — wraps everything in try/except.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_extraction_inner(
|
|
||||||
db: AsyncSession,
|
|
||||||
user_id: str,
|
|
||||||
last_user_msg: str,
|
|
||||||
last_assistant_msg: str,
|
|
||||||
session_id: str | None,
|
|
||||||
) -> None:
|
|
||||||
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
|
||||||
|
|
||||||
middleware = MemoryMiddleware(db)
|
|
||||||
fernet = await middleware._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
# 1. Load context
|
|
||||||
core: dict[str, str] = await middleware._load_core(user_id, fernet)
|
|
||||||
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
|
|
||||||
|
|
||||||
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
|
|
||||||
async def _run(trace_id: str | None) -> dict[str, Any]:
|
|
||||||
# 2. Extract candidates
|
|
||||||
result = await extract_candidates(last_turn, core, episodes)
|
|
||||||
if not result.candidates:
|
|
||||||
logger.info("memory_extraction: no candidates user=%s", user_id)
|
|
||||||
return {"candidates": 0, "applied": 0}
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"memory_extraction: processing %d candidates user=%s trace=%s",
|
|
||||||
len(result.candidates),
|
|
||||||
user_id,
|
|
||||||
trace_id or "-",
|
|
||||||
)
|
|
||||||
|
|
||||||
# 3. Apply each candidate
|
|
||||||
applied = 0
|
|
||||||
actions: list[str] = []
|
|
||||||
for candidate in result.candidates:
|
|
||||||
try:
|
|
||||||
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
|
|
||||||
applied += 1
|
|
||||||
actions.append(f"{candidate.type}:{candidate.target_tier}")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"memory_extraction: apply failed candidate=%r user=%s: %s",
|
|
||||||
candidate.content[:80],
|
|
||||||
user_id,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"memory_extraction: applied %d/%d candidates user=%s",
|
|
||||||
applied,
|
|
||||||
len(result.candidates),
|
|
||||||
user_id,
|
|
||||||
)
|
|
||||||
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
|
|
||||||
|
|
||||||
with langfuse_context(user_id=user_id, session_id=session_id):
|
|
||||||
if lf:
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="span",
|
|
||||||
name="memory-extraction-pipeline",
|
|
||||||
input={"last_turn_preview": last_turn[:200]},
|
|
||||||
) as span:
|
|
||||||
summary = await _run(trace_id=span.id)
|
|
||||||
span.update(output=summary)
|
|
||||||
try:
|
|
||||||
lf.flush()
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
await _run(trace_id=None)
|
|
||||||
|
|
||||||
|
|
||||||
async def _apply_candidate(
|
|
||||||
middleware: Any,
|
|
||||||
db: AsyncSession,
|
|
||||||
user_id: str,
|
|
||||||
fernet: Any,
|
|
||||||
candidate: MemoryCandidate,
|
|
||||||
trace_id: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Fetch neighbours, decide action, apply to the appropriate tier."""
|
|
||||||
|
|
||||||
neighbours: list[str] = []
|
|
||||||
|
|
||||||
if candidate.target_tier == "core":
|
|
||||||
# For core tier: neighbours are existing core block values for similar keys.
|
|
||||||
blocks = await middleware.list_core_blocks(user_id)
|
|
||||||
neighbours = [b["value"] for b in blocks[:3]]
|
|
||||||
|
|
||||||
elif candidate.target_tier == "associative":
|
|
||||||
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
|
|
||||||
|
|
||||||
elif candidate.target_tier == "relational":
|
|
||||||
# Relation candidates handled specially — passed to upsert_relation directly.
|
|
||||||
# Neighbours: search by subject label if available.
|
|
||||||
neighbours = []
|
|
||||||
|
|
||||||
elif candidate.target_tier == "proactive":
|
|
||||||
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
|
|
||||||
|
|
||||||
action = await decide_action(candidate, neighbours)
|
|
||||||
logger.info(
|
|
||||||
"memory_extraction: candidate type=%s tier=%s action=%s",
|
|
||||||
candidate.type,
|
|
||||||
candidate.target_tier,
|
|
||||||
action,
|
|
||||||
)
|
|
||||||
|
|
||||||
if action == "NOOP":
|
|
||||||
return
|
|
||||||
|
|
||||||
if candidate.target_tier == "relational":
|
|
||||||
# Always upsert relations — decide_action skipped (no neighbour search).
|
|
||||||
if candidate.subject and candidate.predicate and candidate.object:
|
|
||||||
await _upsert_relation(
|
|
||||||
middleware, db, user_id, candidate, trace_id
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
if action in ("ADD", "UPDATE"):
|
|
||||||
if candidate.target_tier == "core":
|
|
||||||
# Derive a short key from the content (first 40 chars, snake_cased).
|
|
||||||
key = _content_to_key(candidate.content)
|
|
||||||
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
|
|
||||||
|
|
||||||
elif candidate.target_tier == "associative":
|
|
||||||
await middleware.store_associative(user_id, candidate.content)
|
|
||||||
|
|
||||||
elif candidate.target_tier == "proactive":
|
|
||||||
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
|
|
||||||
|
|
||||||
elif action == "DELETE":
|
|
||||||
if candidate.target_tier == "core":
|
|
||||||
key = _content_to_key(candidate.content)
|
|
||||||
await middleware.delete_core(user_id, key)
|
|
||||||
|
|
||||||
|
|
||||||
def _content_to_key(content: str) -> str:
|
|
||||||
"""Derive a short snake_case key from a content string (first 40 chars)."""
|
|
||||||
import re # noqa: PLC0415
|
|
||||||
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
|
|
||||||
return slug or "memory"
|
|
||||||
|
|
||||||
|
|
||||||
async def _upsert_relation(
|
|
||||||
middleware: Any,
|
|
||||||
db: AsyncSession,
|
|
||||||
user_id: str,
|
|
||||||
candidate: MemoryCandidate,
|
|
||||||
trace_id: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
|
|
||||||
await middleware.upsert_relation(
|
|
||||||
user_id=user_id,
|
|
||||||
subject=candidate.subject or "unknown",
|
|
||||||
subject_type="unknown",
|
|
||||||
predicate=candidate.predicate or "related_to",
|
|
||||||
object_=candidate.object or "unknown",
|
|
||||||
object_type="unknown",
|
|
||||||
confidence=candidate.confidence,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
|
|
||||||
candidate.subject,
|
|
||||||
candidate.predicate,
|
|
||||||
candidate.object,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _store_proactive_stub(
|
|
||||||
middleware: Any,
|
|
||||||
db: AsyncSession,
|
|
||||||
user_id: str,
|
|
||||||
candidate: MemoryCandidate,
|
|
||||||
fernet: Any,
|
|
||||||
) -> None:
|
|
||||||
"""Store a proactive pattern row directly (MemoryProactive model)."""
|
|
||||||
import uuid # noqa: PLC0415
|
|
||||||
from app.models import MemoryProactive # noqa: PLC0415
|
|
||||||
from app.core.memory_middleware import _encrypt # noqa: PLC0415
|
|
||||||
|
|
||||||
encrypted = _encrypt(fernet, candidate.content)
|
|
||||||
row = MemoryProactive(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
pattern_encrypted=encrypted,
|
|
||||||
confidence=candidate.confidence,
|
|
||||||
source="inferred",
|
|
||||||
)
|
|
||||||
db.add(row)
|
|
||||||
try:
|
|
||||||
await db.commit()
|
|
||||||
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_extraction: store proactive failed: %s", exc)
|
|
||||||
await db.rollback()
|
|
||||||
@@ -1,581 +0,0 @@
|
|||||||
"""Memory maintenance jobs — Phase 3/5.
|
|
||||||
|
|
||||||
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
|
|
||||||
|
|
||||||
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
|
|
||||||
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
|
|
||||||
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
|
|
||||||
|
|
||||||
All are safe to call manually or from tests; they never raise.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
|
|
||||||
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Decay parameters for relations
|
|
||||||
_DECAY_FACTOR = 0.95
|
|
||||||
_DECAY_PERIOD_DAYS = 30
|
|
||||||
_PRUNE_THRESHOLD = 0.2
|
|
||||||
|
|
||||||
# Proactive pattern decay: 10 % per 7 days since last sighting
|
|
||||||
_PROACTIVE_DECAY_FACTOR = 0.9
|
|
||||||
_PROACTIVE_DECAY_PERIOD_DAYS = 7
|
|
||||||
_PROACTIVE_PRUNE_THRESHOLD = 0.2
|
|
||||||
|
|
||||||
# Mining: require at least this many episodes to attempt pattern extraction
|
|
||||||
_MIN_EPISODES_FOR_MINING = 3
|
|
||||||
_MINING_LOOKBACK_DAYS = 30
|
|
||||||
|
|
||||||
# Audit: caps to control token cost
|
|
||||||
_AUDIT_MAX_FACTS = 50
|
|
||||||
_AUDIT_MAX_LABELS = 100
|
|
||||||
|
|
||||||
|
|
||||||
async def decay_relations(db: AsyncSession, user_id: str) -> None:
|
|
||||||
"""Apply confidence decay to all relation rows for a user.
|
|
||||||
|
|
||||||
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
|
|
||||||
Rows whose confidence falls below 0.2 are deleted.
|
|
||||||
|
|
||||||
Never raises — wraps in try/except.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await _decay_relations_inner(db, user_id)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
|
|
||||||
|
|
||||||
|
|
||||||
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
|
|
||||||
result = await db.execute(
|
|
||||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
deleted = 0
|
|
||||||
decayed = 0
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
reference = row.last_confirmed_at or row.created_at
|
|
||||||
if reference is None:
|
|
||||||
continue
|
|
||||||
if reference.tzinfo is None:
|
|
||||||
reference = reference.replace(tzinfo=timezone.utc)
|
|
||||||
|
|
||||||
days_elapsed = (now - reference).days
|
|
||||||
if days_elapsed < _DECAY_PERIOD_DAYS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
periods = days_elapsed // _DECAY_PERIOD_DAYS
|
|
||||||
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
|
|
||||||
|
|
||||||
if new_confidence < _PRUNE_THRESHOLD:
|
|
||||||
await db.delete(row)
|
|
||||||
deleted += 1
|
|
||||||
logger.info(
|
|
||||||
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
|
|
||||||
"confidence=%.3f (below threshold)",
|
|
||||||
row.id, user_id, row.subject_label, row.predicate, new_confidence,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
row.confidence = new_confidence
|
|
||||||
decayed += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
await db.commit()
|
|
||||||
logger.info(
|
|
||||||
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
|
|
||||||
user_id, decayed, deleted,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
|
|
||||||
await db.rollback()
|
|
||||||
|
|
||||||
|
|
||||||
async def drain_extraction_queue(db: AsyncSession) -> None:
|
|
||||||
"""Process pending ExtractionQueue rows for Free-tier users.
|
|
||||||
|
|
||||||
Each row corresponds to a stored episode that should be fed through the
|
|
||||||
Mem0-style extraction pipeline. Rows are deleted after successful processing.
|
|
||||||
Never raises — wraps in try/except.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await _drain_extraction_queue_inner(db)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
|
|
||||||
from app.models import ExtractionQueue # noqa: PLC0415
|
|
||||||
|
|
||||||
result = await db.execute(select(ExtractionQueue))
|
|
||||||
rows = result.scalars().all()
|
|
||||||
|
|
||||||
if not rows:
|
|
||||||
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
|
|
||||||
|
|
||||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
|
||||||
|
|
||||||
processed = 0
|
|
||||||
for row in rows:
|
|
||||||
try:
|
|
||||||
await run_extraction(
|
|
||||||
db=db,
|
|
||||||
user_id=row.user_id,
|
|
||||||
last_user_msg="",
|
|
||||||
last_assistant_msg="",
|
|
||||||
session_id=None,
|
|
||||||
)
|
|
||||||
await db.delete(row)
|
|
||||||
await db.commit()
|
|
||||||
processed += 1
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"memory_maintenance: drain failed row=%s user=%s: %s",
|
|
||||||
row.id, row.user_id, exc,
|
|
||||||
)
|
|
||||||
await db.rollback()
|
|
||||||
|
|
||||||
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
|
|
||||||
|
|
||||||
|
|
||||||
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
|
|
||||||
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
|
|
||||||
|
|
||||||
Steps:
|
|
||||||
1. Gate on proactive_mining tier feature.
|
|
||||||
2. Load + decrypt last 30 days of episodic summaries.
|
|
||||||
3. Call gpt-4o-mini to identify recurring patterns.
|
|
||||||
4. Encrypt and store each pattern in memory_proactive.
|
|
||||||
5. Apply decay to existing proactive rows.
|
|
||||||
|
|
||||||
Never raises — wraps in try/except.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await _mine_proactive_patterns_inner(db, user_id)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
|
|
||||||
|
|
||||||
|
|
||||||
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
|
|
||||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
||||||
|
|
||||||
tier = await tier_manager.get_tier(user_id, db)
|
|
||||||
if not tier_manager.check_feature(tier, "proactive_mining"):
|
|
||||||
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Load user Fernet key
|
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if user is None or not user.encryption_key:
|
|
||||||
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
fernet = Fernet(user.encryption_key.encode())
|
|
||||||
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
|
|
||||||
|
|
||||||
episodes_result = await db.execute(
|
|
||||||
select(MemoryEpisodic)
|
|
||||||
.where(
|
|
||||||
MemoryEpisodic.user_id == user_id,
|
|
||||||
MemoryEpisodic.created_at >= cutoff,
|
|
||||||
)
|
|
||||||
.order_by(MemoryEpisodic.created_at.asc())
|
|
||||||
)
|
|
||||||
episode_rows = episodes_result.scalars().all()
|
|
||||||
|
|
||||||
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
|
|
||||||
logger.info(
|
|
||||||
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
|
|
||||||
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
summaries: list[str] = []
|
|
||||||
for ep in episode_rows:
|
|
||||||
try:
|
|
||||||
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
|
|
||||||
summaries.append(plaintext)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if not summaries:
|
|
||||||
return
|
|
||||||
|
|
||||||
patterns = await _extract_proactive_patterns(summaries)
|
|
||||||
if not patterns:
|
|
||||||
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
stored = 0
|
|
||||||
for pattern_text in patterns:
|
|
||||||
try:
|
|
||||||
encrypted = fernet.encrypt(pattern_text.encode()).decode()
|
|
||||||
row = MemoryProactive(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
pattern_encrypted=encrypted,
|
|
||||||
confidence=0.7,
|
|
||||||
source="inferred",
|
|
||||||
)
|
|
||||||
db.add(row)
|
|
||||||
stored += 1
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await db.commit()
|
|
||||||
logger.info(
|
|
||||||
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
|
|
||||||
user_id, stored,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
|
|
||||||
await db.rollback()
|
|
||||||
return
|
|
||||||
|
|
||||||
await _decay_proactive_patterns(db, user_id, fernet)
|
|
||||||
|
|
||||||
|
|
||||||
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
|
|
||||||
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
|
|
||||||
from app.core.llm import get_agent_llm # noqa: PLC0415
|
|
||||||
|
|
||||||
llm = get_agent_llm("memory-miner", temperature=0)
|
|
||||||
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
|
|
||||||
prompt = (
|
|
||||||
"You are analyzing conversation history for a personal AI secretary. "
|
|
||||||
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
|
|
||||||
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
|
|
||||||
"Return each pattern as a plain, short English sentence on its own line. "
|
|
||||||
"No numbering, no bullet points, no extra text.\n\n"
|
|
||||||
f"Conversation history:\n{combined}"
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
response = await llm.ainvoke(prompt)
|
|
||||||
text = response.content if hasattr(response, "content") else str(response)
|
|
||||||
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
|
|
||||||
return lines[:5]
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
|
|
||||||
"""Decay confidence of existing proactive patterns; prune below threshold."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
deleted = 0
|
|
||||||
decayed = 0
|
|
||||||
|
|
||||||
for row in rows:
|
|
||||||
reference = row.created_at
|
|
||||||
if reference is None:
|
|
||||||
continue
|
|
||||||
if reference.tzinfo is None:
|
|
||||||
reference = reference.replace(tzinfo=timezone.utc)
|
|
||||||
|
|
||||||
days_elapsed = (now - reference).days
|
|
||||||
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
|
|
||||||
continue
|
|
||||||
|
|
||||||
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
|
|
||||||
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
|
|
||||||
|
|
||||||
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
|
|
||||||
await db.delete(row)
|
|
||||||
deleted += 1
|
|
||||||
else:
|
|
||||||
row.confidence = new_confidence
|
|
||||||
decayed += 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
await db.commit()
|
|
||||||
logger.info(
|
|
||||||
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
|
|
||||||
user_id, decayed, deleted,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
|
|
||||||
await db.rollback()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
_AUDIT_CONTRADICTIONS_FALLBACK = (
|
|
||||||
"You are auditing a personal AI assistant's memory bank. "
|
|
||||||
"Each fact has an ID in brackets. "
|
|
||||||
"Find pairs that directly contradict each other "
|
|
||||||
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
|
|
||||||
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
|
|
||||||
'Return ONLY a valid JSON array, no markdown fences: '
|
|
||||||
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
|
|
||||||
"If no contradictions, return [].\n\n"
|
|
||||||
"Facts:\n{facts}"
|
|
||||||
)
|
|
||||||
|
|
||||||
_AUDIT_CANONICALIZE_FALLBACK = (
|
|
||||||
"You are auditing entity labels in a personal AI assistant's relational memory. "
|
|
||||||
"These are names of people, companies, projects, or topics. "
|
|
||||||
"Group labels that clearly refer to the same real-world entity "
|
|
||||||
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
|
|
||||||
"Return ONLY a valid JSON array, no markdown fences: "
|
|
||||||
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
|
|
||||||
"Only include groups with at least one variant. Singletons: omit.\n\n"
|
|
||||||
"Labels:\n{labels}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def audit_memory(db: AsyncSession, user_id: str) -> None:
|
|
||||||
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
|
|
||||||
|
|
||||||
Steps:
|
|
||||||
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
|
|
||||||
2. LLM flags rows to delete (direct contradictions); hard-delete them.
|
|
||||||
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
|
|
||||||
4. Rewrite variant labels to their canonical form in-place.
|
|
||||||
|
|
||||||
Never raises — wraps in try/except.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
await _audit_memory_inner(db, user_id)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
|
|
||||||
|
|
||||||
|
|
||||||
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
|
|
||||||
result = await db.execute(select(User).where(User.id == user_id))
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if user is None or not user.encryption_key:
|
|
||||||
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
fernet = Fernet(user.encryption_key.encode())
|
|
||||||
await _scan_associative_contradictions(db, user_id, fernet)
|
|
||||||
await _canonicalize_relation_labels(db, user_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def _scan_associative_contradictions(
|
|
||||||
db: AsyncSession,
|
|
||||||
user_id: str,
|
|
||||||
fernet: Fernet,
|
|
||||||
) -> None:
|
|
||||||
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(MemoryAssociative)
|
|
||||||
.where(MemoryAssociative.user_id == user_id)
|
|
||||||
.order_by(MemoryAssociative.updated_at.desc())
|
|
||||||
.limit(_AUDIT_MAX_FACTS)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
if len(rows) < 2:
|
|
||||||
return
|
|
||||||
|
|
||||||
id_to_text: dict[str, str] = {}
|
|
||||||
for row in rows:
|
|
||||||
try:
|
|
||||||
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
|
|
||||||
id_to_text[row.id] = plaintext
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
if len(id_to_text) < 2:
|
|
||||||
return
|
|
||||||
|
|
||||||
id_list = list(id_to_text.keys())
|
|
||||||
numbered = "\n".join(
|
|
||||||
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
|
|
||||||
)
|
|
||||||
|
|
||||||
template, prompt_obj = get_prompt_or_fallback(
|
|
||||||
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
|
|
||||||
)
|
|
||||||
system_text = compile_prompt(template, prompt_obj, facts=numbered)
|
|
||||||
|
|
||||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
|
||||||
|
|
||||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
|
||||||
lf = get_langfuse()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=system_text),
|
|
||||||
HumanMessage(content="Audit facts for contradictions."),
|
|
||||||
]
|
|
||||||
try:
|
|
||||||
if lf:
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="memory-audit-contradictions",
|
|
||||||
model=model_for_agent("memory-auditor"),
|
|
||||||
prompt=prompt_obj,
|
|
||||||
input=messages,
|
|
||||||
) as gen:
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
gen.update(output=response.content, usage=extract_usage(response))
|
|
||||||
else:
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
|
|
||||||
text = response.content if hasattr(response, "content") else str(response)
|
|
||||||
deletions = json.loads(text.strip())
|
|
||||||
if not isinstance(deletions, list):
|
|
||||||
return
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
|
|
||||||
user_id, exc,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
deleted = 0
|
|
||||||
for item in deletions:
|
|
||||||
if not isinstance(item, dict):
|
|
||||||
continue
|
|
||||||
rid = item.get("delete")
|
|
||||||
if not rid or rid not in id_to_text:
|
|
||||||
continue
|
|
||||||
result2 = await db.execute(
|
|
||||||
select(MemoryAssociative).where(
|
|
||||||
MemoryAssociative.id == rid,
|
|
||||||
MemoryAssociative.user_id == user_id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
target = result2.scalar_one_or_none()
|
|
||||||
if target:
|
|
||||||
await db.delete(target)
|
|
||||||
deleted += 1
|
|
||||||
logger.info(
|
|
||||||
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
|
|
||||||
rid, user_id, item.get("reason", ""),
|
|
||||||
)
|
|
||||||
|
|
||||||
if deleted:
|
|
||||||
try:
|
|
||||||
await db.commit()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
|
|
||||||
)
|
|
||||||
await db.rollback()
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
|
|
||||||
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
if not rows:
|
|
||||||
return
|
|
||||||
|
|
||||||
all_labels: set[str] = set()
|
|
||||||
for row in rows:
|
|
||||||
all_labels.add(row.subject_label)
|
|
||||||
all_labels.add(row.object_label)
|
|
||||||
|
|
||||||
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
|
|
||||||
if len(labels_list) < 2:
|
|
||||||
return
|
|
||||||
|
|
||||||
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
|
|
||||||
template, prompt_obj = get_prompt_or_fallback(
|
|
||||||
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
|
|
||||||
)
|
|
||||||
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
|
|
||||||
|
|
||||||
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
|
||||||
|
|
||||||
llm = get_agent_llm("memory-auditor", temperature=0)
|
|
||||||
lf = get_langfuse()
|
|
||||||
messages = [
|
|
||||||
SystemMessage(content=system_text),
|
|
||||||
HumanMessage(content="Canonicalize entity labels."),
|
|
||||||
]
|
|
||||||
try:
|
|
||||||
if lf:
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="generation",
|
|
||||||
name="memory-audit-canonicalize",
|
|
||||||
model=model_for_agent("memory-auditor"),
|
|
||||||
prompt=prompt_obj,
|
|
||||||
input=messages,
|
|
||||||
) as gen:
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
gen.update(output=response.content, usage=extract_usage(response))
|
|
||||||
else:
|
|
||||||
response = await llm.ainvoke(messages)
|
|
||||||
|
|
||||||
text = response.content if hasattr(response, "content") else str(response)
|
|
||||||
groups = json.loads(text.strip())
|
|
||||||
if not isinstance(groups, list):
|
|
||||||
return
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
|
|
||||||
user_id, exc,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Build variant → canonical map
|
|
||||||
remap: dict[str, str] = {}
|
|
||||||
for group in groups:
|
|
||||||
if not isinstance(group, dict):
|
|
||||||
continue
|
|
||||||
canonical = group.get("canonical", "")
|
|
||||||
variants = group.get("variants") or []
|
|
||||||
if not canonical:
|
|
||||||
continue
|
|
||||||
for v in variants:
|
|
||||||
if isinstance(v, str) and v != canonical:
|
|
||||||
remap[v] = canonical
|
|
||||||
|
|
||||||
if not remap:
|
|
||||||
return
|
|
||||||
|
|
||||||
updated = 0
|
|
||||||
for row in rows:
|
|
||||||
changed = False
|
|
||||||
if row.subject_label in remap:
|
|
||||||
row.subject_label = remap[row.subject_label]
|
|
||||||
changed = True
|
|
||||||
if row.object_label in remap:
|
|
||||||
row.object_label = remap[row.object_label]
|
|
||||||
changed = True
|
|
||||||
if changed:
|
|
||||||
updated += 1
|
|
||||||
|
|
||||||
if updated:
|
|
||||||
try:
|
|
||||||
await db.commit()
|
|
||||||
logger.info(
|
|
||||||
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
|
|
||||||
user_id, updated,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
|
|
||||||
)
|
|
||||||
await db.rollback()
|
|
||||||
@@ -18,10 +18,8 @@ Usage:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from cryptography.fernet import Fernet, InvalidToken
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
@@ -29,22 +27,15 @@ from sqlalchemy import select
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.models import (
|
from app.models import (
|
||||||
ExtractionQueue,
|
|
||||||
MemoryAssociative,
|
MemoryAssociative,
|
||||||
MemoryCore,
|
MemoryCore,
|
||||||
MemoryEpisodic,
|
MemoryEpisodic,
|
||||||
MemoryProactive,
|
MemoryProactive,
|
||||||
MemoryRelation,
|
|
||||||
User,
|
User,
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _now() -> datetime:
|
|
||||||
return datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
|
|
||||||
# Tuning constants
|
# Tuning constants
|
||||||
_ASSOCIATIVE_TOP_K = 5
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
_EPISODIC_RECENT_N = 10
|
_EPISODIC_RECENT_N = 10
|
||||||
@@ -52,60 +43,36 @@ _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware:
|
class MemoryMiddleware:
|
||||||
"""Enrich orchestrator context with memory and persist interactions after."""
|
"""Enrich agent context with memory and persist interactions after."""
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession) -> None:
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
self._db = db
|
self._db = db
|
||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────────────
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def enrich_context(
|
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
||||||
self,
|
"""Build memory context dict to inject into the agent before LLM call.
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
trace_id: str | None = None,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Build memory context dict to inject into the orchestrator before LLM call.
|
|
||||||
|
|
||||||
Returns a dict with keys:
|
Returns a dict with keys:
|
||||||
core_memory — {key: plaintext_value, ...}
|
core_memory — {key: plaintext_value, ...}
|
||||||
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||||
episodic_memory — [plaintext_summary, ...] (most recent N)
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||||
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||||
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
|
|
||||||
"""
|
"""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
user_tier: str = user_dbg.get("tier") or "free"
|
|
||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
core = await self._load_core(user_id, fernet)
|
||||||
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
episodic = await self._load_episodic(user_id, fernet)
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
relational = await self._load_relational(user_id, user_tier=user_tier)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
user_tier,
|
|
||||||
len(core),
|
|
||||||
len(associative),
|
|
||||||
len(episodic),
|
|
||||||
len(proactive),
|
|
||||||
len(relational),
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
return {
|
||||||
"core_memory": core,
|
"core_memory": core,
|
||||||
"associative_memory": associative,
|
"associative_memory": associative,
|
||||||
"episodic_memory": episodic,
|
"episodic_memory": episodic,
|
||||||
"proactive_hints": proactive,
|
"proactive_hints": proactive,
|
||||||
"relational_memory": relational,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async def store_episode(
|
async def store_episode(
|
||||||
@@ -114,15 +81,11 @@ class MemoryMiddleware:
|
|||||||
session_id: str,
|
session_id: str,
|
||||||
message: str,
|
message: str,
|
||||||
response: str,
|
response: str,
|
||||||
trace_id: str | None = None,
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Summarise and store a completed interaction in episodic memory.
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
The summary is a simple heuristic concatenation (no LLM call) to keep
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||||
latency low. After committing the episode row, dispatches the Mem0-style
|
latency low. Full LLM summarisation can be added in a later step.
|
||||||
extraction pipeline:
|
|
||||||
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
|
|
||||||
- Free → enqueue an ExtractionQueue row for the daily cron.
|
|
||||||
"""
|
"""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -131,97 +94,20 @@ class MemoryMiddleware:
|
|||||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||||
encrypted = _encrypt(fernet, summary)
|
encrypted = _encrypt(fernet, summary)
|
||||||
|
|
||||||
episode = MemoryEpisodic(
|
row = MemoryEpisodic(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
summary_encrypted=encrypted,
|
summary_encrypted=encrypted,
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
self._db.add(episode)
|
self._db.add(row)
|
||||||
episode_id: str = episode.id
|
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
tier = user_dbg.get("tier") or "free"
|
|
||||||
logger.info(
|
|
||||||
"memory: store_episode trace=%s user=%s tier=%s session=%s",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
tier,
|
|
||||||
session_id,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
return
|
|
||||||
|
|
||||||
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
|
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
||||||
await self._dispatch_extraction(
|
|
||||||
user_id=user_id,
|
|
||||||
episode_id=episode_id,
|
|
||||||
last_user_msg=message,
|
|
||||||
last_assistant_msg=response,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def _dispatch_extraction(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
episode_id: str,
|
|
||||||
last_user_msg: str,
|
|
||||||
last_assistant_msg: str,
|
|
||||||
session_id: str | None,
|
|
||||||
) -> None:
|
|
||||||
"""Route extraction to realtime task or batch queue based on user tier."""
|
|
||||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
||||||
|
|
||||||
tier = await tier_manager.get_tier(user_id, self._db)
|
|
||||||
|
|
||||||
if tier_manager.check_feature(tier, "realtime_extraction"):
|
|
||||||
# Pro/Power/Team: fire-and-forget in the background.
|
|
||||||
# Must open a fresh session — request session closes after handler returns.
|
|
||||||
from app.core.memory_extraction import run_extraction # noqa: PLC0415
|
|
||||||
from app.db import async_session # noqa: PLC0415
|
|
||||||
|
|
||||||
async def _task() -> None:
|
|
||||||
try:
|
|
||||||
async with async_session() as fresh_db:
|
|
||||||
await run_extraction(
|
|
||||||
db=fresh_db,
|
|
||||||
user_id=user_id,
|
|
||||||
last_user_msg=last_user_msg,
|
|
||||||
last_assistant_msg=last_assistant_msg,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"memory: extraction task failed user=%s: %s", user_id, exc
|
|
||||||
)
|
|
||||||
|
|
||||||
asyncio.create_task(_task())
|
|
||||||
logger.info("memory: realtime extraction dispatched user=%s", user_id)
|
|
||||||
else:
|
|
||||||
# Free tier: enqueue for daily batch cron.
|
|
||||||
queue_row = ExtractionQueue(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
episode_id=episode_id,
|
|
||||||
)
|
|
||||||
self._db.add(queue_row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
logger.info(
|
|
||||||
"memory: extraction enqueued (batch) user=%s episode=%s",
|
|
||||||
user_id,
|
|
||||||
episode_id,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"memory: extraction queue insert failed user=%s: %s", user_id, exc
|
|
||||||
)
|
|
||||||
await self._db.rollback()
|
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
|
||||||
"""Upsert a core memory key/value for a user."""
|
"""Upsert a core memory key/value for a user."""
|
||||||
fernet = await self._get_fernet(user_id)
|
fernet = await self._get_fernet(user_id)
|
||||||
if fernet is None:
|
if fernet is None:
|
||||||
@@ -247,313 +133,10 @@ class MemoryMiddleware:
|
|||||||
))
|
))
|
||||||
try:
|
try:
|
||||||
await self._db.commit()
|
await self._db.commit()
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
logger.info(
|
|
||||||
"memory: update_core trace=%s user=%s tier=%s key=%s",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
user_dbg.get("tier") or "-",
|
|
||||||
key,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
await self._db.rollback()
|
await self._db.rollback()
|
||||||
|
|
||||||
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
|
||||||
"""Return core memory as editable blocks (label/value)."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore)
|
|
||||||
.where(MemoryCore.user_id == user_id)
|
|
||||||
.order_by(MemoryCore.key.asc())
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
out: list[dict[str, str]] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
||||||
if plaintext is not None:
|
|
||||||
out.append({"label": row.key, "value": plaintext})
|
|
||||||
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
|
||||||
"""Return a single core memory block value by label."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(
|
|
||||||
MemoryCore.user_id == user_id,
|
|
||||||
MemoryCore.key == label,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
|
|
||||||
return None
|
|
||||||
value = _safe_decrypt(fernet, row.value_encrypted)
|
|
||||||
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
|
|
||||||
return value
|
|
||||||
|
|
||||||
async def delete_core(self, user_id: str, label: str) -> bool:
|
|
||||||
"""Delete a core memory block by label. Returns True if deleted."""
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(
|
|
||||||
MemoryCore.user_id == user_id,
|
|
||||||
MemoryCore.key == label,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
|
|
||||||
return False
|
|
||||||
|
|
||||||
await self._db.delete(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
logger.info("memory: delete_core user=%s label=%s", user_id, label)
|
|
||||||
return True
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
|
||||||
"""Append content to a core block, creating it if missing."""
|
|
||||||
current = await self.get_core_block(user_id, label)
|
|
||||||
if current is None:
|
|
||||||
await self.update_core(user_id, label, content)
|
|
||||||
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
|
|
||||||
return
|
|
||||||
await self.update_core(user_id, label, f"{current}\n{content}")
|
|
||||||
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
|
|
||||||
|
|
||||||
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
|
||||||
"""Replace one exact string inside a core block. Returns False if not found."""
|
|
||||||
current = await self.get_core_block(user_id, label)
|
|
||||||
if current is None or old not in current:
|
|
||||||
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
|
|
||||||
return False
|
|
||||||
await self.update_core(user_id, label, current.replace(old, new, 1))
|
|
||||||
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def store_associative(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
content: str,
|
|
||||||
entity_type: str | None = None,
|
|
||||||
entity_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Store associative memory; embed if user tier has real_embeddings."""
|
|
||||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
||||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
|
||||||
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
encrypted = _encrypt(fernet, content)
|
|
||||||
|
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
user_tier = user_dbg.get("tier") or "free"
|
|
||||||
|
|
||||||
embedding: list[float] | None = None
|
|
||||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
|
||||||
embedding = await embed_text(content)
|
|
||||||
|
|
||||||
row = MemoryAssociative(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
content_encrypted=encrypted,
|
|
||||||
embedding=embedding,
|
|
||||||
entity_type=entity_type,
|
|
||||||
entity_id=entity_id,
|
|
||||||
)
|
|
||||||
self._db.add(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
logger.info(
|
|
||||||
"memory: store_associative user=%s embedded=%s",
|
|
||||||
user_id,
|
|
||||||
embedding is not None,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
|
|
||||||
async def upsert_relation(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
subject: str,
|
|
||||||
subject_type: str,
|
|
||||||
predicate: str,
|
|
||||||
object_: str,
|
|
||||||
object_type: str,
|
|
||||||
*,
|
|
||||||
confidence: float = 0.7,
|
|
||||||
source_episode_id: str | None = None,
|
|
||||||
notes: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
|
|
||||||
|
|
||||||
subject_label / object_label are plaintext entity identifiers — not encrypted.
|
|
||||||
notes is optional; encrypted with user Fernet if provided.
|
|
||||||
"""
|
|
||||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
||||||
|
|
||||||
user_dbg = await self._get_user_debug(user_id)
|
|
||||||
user_tier = user_dbg.get("tier") or "free"
|
|
||||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
|
||||||
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
|
|
||||||
return
|
|
||||||
|
|
||||||
notes_encrypted: bytes | None = None
|
|
||||||
if notes:
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet:
|
|
||||||
notes_encrypted = fernet.encrypt(notes.encode())
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryRelation).where(
|
|
||||||
MemoryRelation.user_id == user_id,
|
|
||||||
MemoryRelation.subject_label == subject,
|
|
||||||
MemoryRelation.predicate == predicate,
|
|
||||||
MemoryRelation.object_label == object_,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
existing = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
if existing is not None:
|
|
||||||
existing.subject_type = subject_type
|
|
||||||
existing.object_type = object_type
|
|
||||||
existing.confidence = confidence
|
|
||||||
existing.last_confirmed_at = _now()
|
|
||||||
if notes_encrypted is not None:
|
|
||||||
existing.notes_encrypted = notes_encrypted
|
|
||||||
else:
|
|
||||||
self._db.add(MemoryRelation(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
subject_label=subject,
|
|
||||||
subject_type=subject_type,
|
|
||||||
predicate=predicate,
|
|
||||||
object_label=object_,
|
|
||||||
object_type=object_type,
|
|
||||||
confidence=confidence,
|
|
||||||
source_episode_id=source_episode_id,
|
|
||||||
notes_encrypted=notes_encrypted,
|
|
||||||
))
|
|
||||||
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
logger.info(
|
|
||||||
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
|
|
||||||
user_id, subject, predicate, object_,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
|
|
||||||
async def query_relations(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
subject: str | None = None,
|
|
||||||
predicate: str | None = None,
|
|
||||||
object_: str | None = None,
|
|
||||||
limit: int = 20,
|
|
||||||
) -> list[MemoryRelation]:
|
|
||||||
"""Query relation rows for a user with optional filters."""
|
|
||||||
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
|
|
||||||
if subject is not None:
|
|
||||||
q = q.where(MemoryRelation.subject_label == subject)
|
|
||||||
if predicate is not None:
|
|
||||||
q = q.where(MemoryRelation.predicate == predicate)
|
|
||||||
if object_ is not None:
|
|
||||||
q = q.where(MemoryRelation.object_label == object_)
|
|
||||||
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
|
|
||||||
result = await self._db.execute(q)
|
|
||||||
return list(result.scalars().all())
|
|
||||||
|
|
||||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
|
||||||
"""Insert a long-term archival memory entry."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
encrypted = _encrypt(fernet, content)
|
|
||||||
row = MemoryAssociative(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
content_encrypted=encrypted,
|
|
||||||
embedding=None,
|
|
||||||
entity_type=source,
|
|
||||||
entity_id=None,
|
|
||||||
)
|
|
||||||
self._db.add(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
|
|
||||||
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
||||||
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryAssociative)
|
|
||||||
.where(MemoryAssociative.user_id == user_id)
|
|
||||||
.order_by(MemoryAssociative.updated_at.desc())
|
|
||||||
.limit(100)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
needle = query.strip().lower()
|
|
||||||
out: list[str] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
||||||
if plaintext is None:
|
|
||||||
continue
|
|
||||||
if not needle or needle in plaintext.lower():
|
|
||||||
out.append(plaintext)
|
|
||||||
if len(out) >= max(top_k, 1):
|
|
||||||
break
|
|
||||||
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
||||||
"""Search recall memory (episodic summaries) by keyword."""
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryEpisodic)
|
|
||||||
.where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
|
||||||
.limit(100)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
needle = query.strip().lower()
|
|
||||||
out: list[str] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
||||||
if plaintext is None:
|
|
||||||
continue
|
|
||||||
if not needle or needle in plaintext.lower():
|
|
||||||
out.append(plaintext)
|
|
||||||
if len(out) >= max(top_k, 1):
|
|
||||||
break
|
|
||||||
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
|
|
||||||
return out
|
|
||||||
|
|
||||||
# ── Private helpers ───────────────────────────────────────────────────────
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
@@ -565,29 +148,6 @@ class MemoryMiddleware:
|
|||||||
return None
|
return None
|
||||||
return Fernet(user.encryption_key.encode())
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
|
|
||||||
"""Load lightweight user debug fields for trace logs."""
|
|
||||||
from app.config.settings import settings # noqa: PLC0415
|
|
||||||
from app.models import Subscription # noqa: PLC0415
|
|
||||||
|
|
||||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if user is None:
|
|
||||||
return {"tier": None}
|
|
||||||
|
|
||||||
sub_result = await self._db.execute(
|
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
||||||
)
|
|
||||||
sub_tier: str | None = sub_result.scalar_one_or_none()
|
|
||||||
if sub_tier:
|
|
||||||
tier = sub_tier
|
|
||||||
elif settings.ENV == "dev":
|
|
||||||
tier = "power"
|
|
||||||
else:
|
|
||||||
tier = user.tier or "free"
|
|
||||||
|
|
||||||
return {"tier": tier}
|
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
@@ -601,49 +161,14 @@ class MemoryMiddleware:
|
|||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_associative(
|
async def _load_associative(
|
||||||
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
|
self, user_id: str, message: str, fernet: Fernet
|
||||||
) -> list[str]:
|
) -> list[str]:
|
||||||
"""Load top-k associative memories.
|
"""Load top-k associative memories.
|
||||||
|
|
||||||
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
|
Production: uses pgvector cosine similarity on the message embedding.
|
||||||
Free / embedding failure: keyword-ordered fallback (most recent rows).
|
Current implementation: keyword-based fallback (no external embedding call)
|
||||||
|
so tests pass without a live OpenAI key.
|
||||||
"""
|
"""
|
||||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
||||||
from app.core.embeddings import embed_text # noqa: PLC0415
|
|
||||||
|
|
||||||
if tier_manager.check_feature(user_tier, "real_embeddings"):
|
|
||||||
vec = await embed_text(message)
|
|
||||||
if vec is not None:
|
|
||||||
try:
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryAssociative)
|
|
||||||
.where(
|
|
||||||
MemoryAssociative.user_id == user_id,
|
|
||||||
MemoryAssociative.embedding.isnot(None),
|
|
||||||
)
|
|
||||||
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
|
|
||||||
.limit(_ASSOCIATIVE_TOP_K)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
out: list[str] = []
|
|
||||||
for row in rows:
|
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
||||||
if plaintext is not None:
|
|
||||||
out.append(plaintext)
|
|
||||||
logger.info(
|
|
||||||
"memory: _load_associative user=%s mode=vector hits=%d",
|
|
||||||
user_id,
|
|
||||||
len(out),
|
|
||||||
)
|
|
||||||
return out
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"memory: vector search failed user=%s, falling back to keyword: %s",
|
|
||||||
user_id,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Keyword fallback: most recent rows
|
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryAssociative)
|
select(MemoryAssociative)
|
||||||
.where(MemoryAssociative.user_id == user_id)
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
@@ -651,24 +176,17 @@ class MemoryMiddleware:
|
|||||||
.limit(_ASSOCIATIVE_TOP_K)
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
)
|
)
|
||||||
rows = result.scalars().all()
|
rows = result.scalars().all()
|
||||||
out = []
|
out: list[str] = []
|
||||||
for row in rows:
|
for row in rows:
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
if plaintext is not None:
|
if plaintext is not None:
|
||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_episodic(
|
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
fernet: Fernet,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> list[str]:
|
|
||||||
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
|
||||||
if session_id:
|
|
||||||
query = query.where(MemoryEpisodic.session_id == session_id)
|
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
query
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
.order_by(MemoryEpisodic.created_at.desc())
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
.limit(_EPISODIC_RECENT_N)
|
.limit(_EPISODIC_RECENT_N)
|
||||||
)
|
)
|
||||||
@@ -680,26 +198,6 @@ class MemoryMiddleware:
|
|||||||
out.append(plaintext)
|
out.append(plaintext)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
|
|
||||||
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
|
|
||||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
||||||
|
|
||||||
if not tier_manager.check_feature(user_tier, "relational_memory"):
|
|
||||||
return []
|
|
||||||
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryRelation)
|
|
||||||
.where(MemoryRelation.user_id == user_id)
|
|
||||||
.order_by(MemoryRelation.confidence.desc())
|
|
||||||
.limit(10)
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
out = [
|
|
||||||
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
result = await self._db.execute(
|
result = await self._db.execute(
|
||||||
select(MemoryProactive)
|
select(MemoryProactive)
|
||||||
|
|||||||
@@ -1,51 +0,0 @@
|
|||||||
"""Note summarizer — generates a compact AI summary for a note.
|
|
||||||
|
|
||||||
Called fire-and-forget from create_note / update_note tools so the
|
|
||||||
``notes.ai_summary`` column stays current without blocking the agent loop.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
from app.core.langfuse_client import get_prompt_or_fallback
|
|
||||||
from app.core.llm import get_agent_llm
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_FALLBACK_PROMPT = """\
|
|
||||||
Summarize this note in <=250 characters. Be terse and dense.
|
|
||||||
Keep proper nouns, dates, decisions, and action items.
|
|
||||||
Do not start with "This note".
|
|
||||||
Respond with the summary text only — no intro, no labels.
|
|
||||||
|
|
||||||
Title: {title}
|
|
||||||
Content: {content}"""
|
|
||||||
|
|
||||||
_MAX_CONTENT_CHARS = 4000
|
|
||||||
|
|
||||||
|
|
||||||
async def generate_note_summary(title: str, content: str) -> str:
|
|
||||||
"""Return a <=250-char summary of *title* + *content*.
|
|
||||||
|
|
||||||
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
|
|
||||||
fallback. Truncates *content* to 4000 chars before sending to avoid
|
|
||||||
token waste on large notes.
|
|
||||||
"""
|
|
||||||
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
|
|
||||||
trimmed = content[:_MAX_CONTENT_CHARS]
|
|
||||||
system_prompt = template.format(title=title, content=trimmed)
|
|
||||||
|
|
||||||
try:
|
|
||||||
llm = get_agent_llm("note-summarizer")
|
|
||||||
response = await llm.ainvoke([
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
HumanMessage(content="Generate the summary."),
|
|
||||||
])
|
|
||||||
text = response.content if isinstance(response.content, str) else ""
|
|
||||||
return text.strip()[:250]
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("note_summarizer: failed to generate summary: %s", exc)
|
|
||||||
return ""
|
|
||||||
@@ -1,71 +1,141 @@
|
|||||||
"""Output formatter for deep-agent stream events."""
|
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
||||||
|
|
||||||
|
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
||||||
|
* ``("token", str)`` — supervisor text token
|
||||||
|
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
||||||
|
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
||||||
|
|
||||||
|
HomeFormatter:
|
||||||
|
* Streams text tokens as-is → emits ``WsStreamText``
|
||||||
|
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
||||||
|
for the frontend to parse and render as interactive components)
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
|
||||||
|
FloatingFormatter:
|
||||||
|
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
||||||
|
* Streams text tokens → emits ``WsStreamText``
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import logging
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
|
WsStreamEnd,
|
||||||
_CANVAS_BLOCK_RE = re.compile(
|
WsStreamStart,
|
||||||
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
|
WsStreamText,
|
||||||
re.DOTALL | re.IGNORECASE,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
|
# Map sub-agent tool name → floating domain / entity type
|
||||||
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
|
_AGENT_DOMAIN: dict[str, str] = {
|
||||||
|
"task_agent": "tasks",
|
||||||
Returns ``(visible_text, canvas_content, canvas_kind)``.
|
"timeline_agent": "timelines",
|
||||||
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
|
"note_agent": "notes",
|
||||||
"""
|
"project_agent": "projects",
|
||||||
match = _CANVAS_BLOCK_RE.search(text)
|
}
|
||||||
if not match:
|
|
||||||
return text, None, None
|
|
||||||
|
|
||||||
canvas_kind = match.group(1).strip()
|
|
||||||
canvas_content = match.group(2).strip()
|
|
||||||
visible = text[: match.start()] + text[match.end() :]
|
|
||||||
visible = visible.strip()
|
|
||||||
return visible, canvas_content, canvas_kind
|
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
class StreamFormatter:
|
class HomeFormatter:
|
||||||
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
||||||
|
|
||||||
|
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
||||||
|
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
||||||
|
is responsible for parsing those and rendering interactive components.
|
||||||
|
Mutations are attached to ``WsStreamEnd``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
def __init__(self, request_id: str) -> None:
|
||||||
self.request_id = request_id
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
async def format(
|
async def format(
|
||||||
self,
|
self,
|
||||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
started = False
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
async for event_type, data in event_stream:
|
async for event_type, data in event_stream:
|
||||||
if event_type == "floating_domain":
|
if event_type == "token":
|
||||||
if isinstance(data, dict):
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FloatingFormatter:
|
||||||
|
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
||||||
|
|
||||||
|
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
||||||
|
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
||||||
|
``WsStreamText``. No block parsing for floating context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
domain_sent = False
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "tool_end" and not domain_sent:
|
||||||
|
# Sniff domain from the first sub-agent that completes
|
||||||
|
name = data.get("name", "")
|
||||||
|
domain = _AGENT_DOMAIN.get(name, "tasks")
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain=domain, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
|
||||||
|
elif event_type == "token":
|
||||||
|
if not domain_sent:
|
||||||
|
# First token arrived before any tool_end — default domain
|
||||||
yield WsFloatingDomain(
|
yield WsFloatingDomain(
|
||||||
request_id=self.request_id,
|
request_id=self.request_id,
|
||||||
domain=data,
|
domain="tasks", # type: ignore[arg-type]
|
||||||
)
|
)
|
||||||
continue
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
if event_type != "token":
|
elif event_type == "mutations":
|
||||||
continue
|
self._mutations = data or []
|
||||||
|
|
||||||
if not started:
|
# If no events triggered domain_sent (edge case), still emit structure
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
if not domain_sent:
|
||||||
started = True
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
text = str(data or "")
|
domain="tasks", # type: ignore[arg-type]
|
||||||
if text:
|
)
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=text)
|
|
||||||
|
|
||||||
if not started:
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|||||||
@@ -1,104 +0,0 @@
|
|||||||
"""Preprocessor registry: detect content type and dispatch to handlers.
|
|
||||||
|
|
||||||
Public API
|
|
||||||
----------
|
|
||||||
detect_content_type(filename, raw_content) -> str
|
|
||||||
Heuristic detection based on file extension and content patterns.
|
|
||||||
|
|
||||||
preprocess(content_type, raw_content) -> PreprocessResult
|
|
||||||
Dispatch to the appropriate handler.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
|
|
||||||
from app.core.preprocessors.base import PreprocessResult
|
|
||||||
|
|
||||||
# ── Heuristics ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
# Patterns that strongly suggest an email HTML file
|
|
||||||
_EMAIL_SIGNALS = re.compile(
|
|
||||||
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Patterns that suggest a generic HTML page (not an email)
|
|
||||||
_GENERIC_HTML_SIGNALS = re.compile(
|
|
||||||
r"<(nav|main|header|footer|article|section)\b",
|
|
||||||
re.IGNORECASE,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def detect_content_type(filename: str, raw_content: str) -> str:
|
|
||||||
"""Return a content-type string for the given file.
|
|
||||||
|
|
||||||
Supported types: ``"email_html"``, ``"generic_html"``,
|
|
||||||
``"plain_text"``, ``"unknown"``.
|
|
||||||
"""
|
|
||||||
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
|
||||||
|
|
||||||
if ext == "txt":
|
|
||||||
return "plain_text"
|
|
||||||
|
|
||||||
if ext in ("html", "htm", "eml", "mhtml", "mht"):
|
|
||||||
# Prefer email detection over generic HTML
|
|
||||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
|
||||||
return "email_html"
|
|
||||||
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
|
|
||||||
return "generic_html"
|
|
||||||
# .html without clear signals — check for any email header
|
|
||||||
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
|
|
||||||
return "email_html"
|
|
||||||
return "generic_html"
|
|
||||||
|
|
||||||
# Plain text files with email headers
|
|
||||||
if ext in ("", "txt") or not ext:
|
|
||||||
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
|
||||||
return "email_html"
|
|
||||||
|
|
||||||
# Detect binary content
|
|
||||||
try:
|
|
||||||
raw_content.encode("utf-8")
|
|
||||||
except (UnicodeEncodeError, AttributeError):
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
# Non-text bytes heuristic: high ratio of non-printable chars
|
|
||||||
sample = raw_content[:512]
|
|
||||||
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
|
|
||||||
if len(sample) > 0 and non_printable / len(sample) > 0.1:
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
return "unknown"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Generic fallback handler ──────────────────────────────────────────
|
|
||||||
|
|
||||||
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
|
|
||||||
"""Strip HTML tags if present, return text as-is."""
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup
|
|
||||||
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
|
|
||||||
except ImportError:
|
|
||||||
# No BeautifulSoup — strip tags with a simple regex
|
|
||||||
text = re.sub(r"<[^>]+>", "", raw_content)
|
|
||||||
|
|
||||||
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
|
||||||
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
|
|
||||||
|
|
||||||
|
|
||||||
# ── Dispatch ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
|
|
||||||
"""Dispatch *raw_content* to the handler registered for *content_type*.
|
|
||||||
|
|
||||||
Falls back to the generic handler for unknown types.
|
|
||||||
"""
|
|
||||||
if content_type == "email_html":
|
|
||||||
from app.core.preprocessors.email_html import preprocess_email_html
|
|
||||||
return preprocess_email_html(raw_content)
|
|
||||||
|
|
||||||
return _preprocess_generic(raw_content, content_type)
|
|
||||||
|
|
||||||
|
|
||||||
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
"""Base types for the preprocessor system."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class PreprocessResult:
|
|
||||||
"""Output of a preprocessor handler.
|
|
||||||
|
|
||||||
Attributes
|
|
||||||
----------
|
|
||||||
content_type:
|
|
||||||
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
|
|
||||||
clean_text:
|
|
||||||
Human-readable text stripped of markup/binary noise.
|
|
||||||
metadata:
|
|
||||||
Dict of extracted metadata (keys vary by handler).
|
|
||||||
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
content_type: str
|
|
||||||
clean_text: str
|
|
||||||
metadata: dict = field(default_factory=dict)
|
|
||||||
@@ -1,111 +0,0 @@
|
|||||||
"""Preprocessor for email HTML files.
|
|
||||||
|
|
||||||
Handles:
|
|
||||||
- HTML stripping via BeautifulSoup
|
|
||||||
- Metadata extraction (Subject, From, To, Date)
|
|
||||||
- Thread splitting — isolates the latest reply
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import TYPE_CHECKING
|
|
||||||
|
|
||||||
from app.core.preprocessors.base import PreprocessResult
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# ── Thread split markers ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
# Matches patterns like:
|
|
||||||
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
|
|
||||||
# "-----Original Message-----"
|
|
||||||
# "> " (plain-text quote prefix)
|
|
||||||
_THREAD_PATTERNS = [
|
|
||||||
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
|
|
||||||
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
|
|
||||||
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
|
|
||||||
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
|
|
||||||
]
|
|
||||||
|
|
||||||
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
|
|
||||||
|
|
||||||
_META_PATTERNS: dict[str, list[re.Pattern]] = {
|
|
||||||
"subject": [
|
|
||||||
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
|
|
||||||
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
"from": [
|
|
||||||
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
|
||||||
re.compile(r"From:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
"to": [
|
|
||||||
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
|
||||||
re.compile(r"To:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
"date": [
|
|
||||||
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
|
||||||
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
|
|
||||||
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
|
|
||||||
],
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_metadata(raw_html: str, text: str) -> dict:
|
|
||||||
"""Extract Subject/From/To/Date from raw HTML or plain text."""
|
|
||||||
metadata: dict[str, str] = {}
|
|
||||||
for field, patterns in _META_PATTERNS.items():
|
|
||||||
for pat in patterns:
|
|
||||||
m = pat.search(raw_html) or pat.search(text)
|
|
||||||
if m:
|
|
||||||
metadata[field] = m.group(1).strip()
|
|
||||||
break
|
|
||||||
return metadata
|
|
||||||
|
|
||||||
|
|
||||||
def _split_thread(text: str) -> str:
|
|
||||||
"""Return only the latest message in a threaded email."""
|
|
||||||
earliest_pos: int | None = None
|
|
||||||
for pat in _THREAD_PATTERNS:
|
|
||||||
m = pat.search(text)
|
|
||||||
if m and (earliest_pos is None or m.start() < earliest_pos):
|
|
||||||
earliest_pos = m.start()
|
|
||||||
|
|
||||||
if earliest_pos is not None and earliest_pos > 0:
|
|
||||||
return text[:earliest_pos].strip()
|
|
||||||
return text.strip()
|
|
||||||
|
|
||||||
|
|
||||||
def preprocess_email_html(raw_content: str) -> PreprocessResult:
|
|
||||||
"""Strip HTML, extract metadata, split thread from an email HTML file."""
|
|
||||||
try:
|
|
||||||
from bs4 import BeautifulSoup # lazy import — optional dep
|
|
||||||
except ImportError as exc:
|
|
||||||
raise ImportError(
|
|
||||||
"beautifulsoup4 is required for email_html preprocessing. "
|
|
||||||
"Install it with: pip install beautifulsoup4"
|
|
||||||
) from exc
|
|
||||||
|
|
||||||
# Parse with lxml if available, fall back to html.parser
|
|
||||||
try:
|
|
||||||
soup = BeautifulSoup(raw_content, "lxml")
|
|
||||||
except Exception:
|
|
||||||
soup = BeautifulSoup(raw_content, "html.parser")
|
|
||||||
|
|
||||||
# Remove noise tags
|
|
||||||
for tag in soup(["style", "script", "head", "noscript"]):
|
|
||||||
tag.decompose()
|
|
||||||
|
|
||||||
clean_text = soup.get_text(separator="\n")
|
|
||||||
# Collapse excessive blank lines
|
|
||||||
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
|
|
||||||
|
|
||||||
metadata = _extract_metadata(raw_content, clean_text)
|
|
||||||
latest_message = _split_thread(clean_text)
|
|
||||||
|
|
||||||
return PreprocessResult(
|
|
||||||
content_type="email_html",
|
|
||||||
clean_text=latest_message,
|
|
||||||
metadata=metadata,
|
|
||||||
)
|
|
||||||
@@ -7,40 +7,21 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
import logging
|
||||||
from contextvars import ContextVar
|
from contextvars import ContextVar
|
||||||
from typing import Any, Callable, Coroutine
|
from typing import Any, Callable, Coroutine
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
_SNAKE_TO_CAMEL_RE = re.compile(r"_([a-z])")
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _key_to_camel(key: str) -> str:
|
|
||||||
return _SNAKE_TO_CAMEL_RE.sub(lambda m: m.group(1).upper(), key)
|
|
||||||
|
|
||||||
|
|
||||||
def _keys_to_camel(obj: Any) -> Any:
|
|
||||||
"""Recursively convert dict keys from snake_case to camelCase.
|
|
||||||
|
|
||||||
Mirrors the JS-side ``toCamelCase`` applied to incoming WS frames in
|
|
||||||
``adiuvAI/src/main/api/backend-client.ts``. The Electron executor wraps
|
|
||||||
tool_result payloads in ``toSnakeCase`` before sending; this restores the
|
|
||||||
camelCase schema property names that the tool code expects to read.
|
|
||||||
"""
|
|
||||||
if isinstance(obj, dict):
|
|
||||||
return {_key_to_camel(k): _keys_to_camel(v) for k, v in obj.items()}
|
|
||||||
if isinstance(obj, list):
|
|
||||||
return [_keys_to_camel(v) for v in obj]
|
|
||||||
return obj
|
|
||||||
|
|
||||||
# Holds the execute callback for the current WS session.
|
# Holds the execute callback for the current WS session.
|
||||||
# Set by the chat WS handler before the orchestrator runs; cleared after.
|
# Set by the chat WS handler before the deep agent runs; cleared after.
|
||||||
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||||
"_client_executor"
|
"_client_executor"
|
||||||
)
|
)
|
||||||
|
|
||||||
# Optional collector that captures raw execute_on_client results.
|
# Optional collector that captures raw execute_on_client results.
|
||||||
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
|
# Set by the deep agent tool loop to capture CRUD mutations.
|
||||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
"_tool_result_collector", default=None
|
"_tool_result_collector", default=None
|
||||||
)
|
)
|
||||||
@@ -103,13 +84,17 @@ async def execute_on_client(
|
|||||||
if limit is not None:
|
if limit is not None:
|
||||||
payload["limit"] = limit
|
payload["limit"] = limit
|
||||||
|
|
||||||
|
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
result = await callback(payload)
|
result = await callback(payload)
|
||||||
result = _keys_to_camel(result)
|
if result is None:
|
||||||
|
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
|
else:
|
||||||
|
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
collector = _tool_result_collector.get(None)
|
collector = _tool_result_collector.get(None)
|
||||||
if collector is not None:
|
if collector is not None and action in ("insert", "update", "delete"):
|
||||||
collector.append({
|
collector.append({
|
||||||
"action": action,
|
"action": action,
|
||||||
"table": table,
|
"table": table,
|
||||||
"data": result,
|
"data": data or {},
|
||||||
})
|
})
|
||||||
return result
|
return result
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|||||||
94
app/main.py
94
app/main.py
@@ -4,10 +4,6 @@ import logging
|
|||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
|
||||||
from app.api.middleware.sanitizer import SanitizerMiddleware
|
|
||||||
from app.config.settings import settings
|
|
||||||
|
|
||||||
logging.basicConfig(
|
logging.basicConfig(
|
||||||
level=logging.INFO,
|
level=logging.INFO,
|
||||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
@@ -15,88 +11,16 @@ logging.basicConfig(
|
|||||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||||
async def _memory_audit_cron_tick() -> None:
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
|
from app.config.settings import settings
|
||||||
import logging # noqa: PLC0415
|
|
||||||
_log = logging.getLogger(__name__)
|
|
||||||
_log.info("memory audit cron tick: starting")
|
|
||||||
try:
|
|
||||||
from app.db import async_session # noqa: PLC0415
|
|
||||||
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
|
|
||||||
from app.models import User # noqa: PLC0415
|
|
||||||
from sqlalchemy import select # noqa: PLC0415
|
|
||||||
|
|
||||||
async with async_session() as db:
|
|
||||||
result = await db.execute(select(User.id))
|
|
||||||
user_ids: list[str] = list(result.scalars().all())
|
|
||||||
|
|
||||||
for uid in user_ids:
|
|
||||||
try:
|
|
||||||
async with async_session() as db:
|
|
||||||
await audit_memory(db, uid)
|
|
||||||
except Exception as exc:
|
|
||||||
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
|
|
||||||
|
|
||||||
_log.info("memory audit cron tick: done users=%d", len(user_ids))
|
|
||||||
except Exception as exc:
|
|
||||||
_log.warning("memory audit cron tick: failed: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
async def _memory_cron_tick() -> None:
|
|
||||||
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
|
|
||||||
import logging # noqa: PLC0415
|
|
||||||
_log = logging.getLogger(__name__)
|
|
||||||
_log.info("memory cron tick: starting")
|
|
||||||
try:
|
|
||||||
from app.db import async_session # noqa: PLC0415
|
|
||||||
from app.core.memory_maintenance import drain_extraction_queue, mine_proactive_patterns # noqa: PLC0415
|
|
||||||
from app.billing.tier_manager import tier_manager # noqa: PLC0415
|
|
||||||
from app.models import User # noqa: PLC0415
|
|
||||||
from sqlalchemy import select # noqa: PLC0415
|
|
||||||
|
|
||||||
async with async_session() as db:
|
|
||||||
await drain_extraction_queue(db)
|
|
||||||
|
|
||||||
# mine proactive patterns for every Power+ user
|
|
||||||
async with async_session() as db:
|
|
||||||
result = await db.execute(select(User.id))
|
|
||||||
user_ids: list[str] = list(result.scalars().all())
|
|
||||||
|
|
||||||
for uid in user_ids:
|
|
||||||
try:
|
|
||||||
async with async_session() as db:
|
|
||||||
tier = await tier_manager.get_tier(uid, db)
|
|
||||||
if tier_manager.check_feature(tier, "proactive_mining"):
|
|
||||||
await mine_proactive_patterns(db, uid)
|
|
||||||
except Exception as exc:
|
|
||||||
_log.warning("memory cron tick: mine_proactive_patterns failed user=%s: %s", uid, exc)
|
|
||||||
|
|
||||||
_log.info("memory cron tick: done users=%d", len(user_ids))
|
|
||||||
except Exception as exc:
|
|
||||||
_log.warning("memory cron tick: failed: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: FastAPI):
|
async def lifespan(app: FastAPI):
|
||||||
# Startup: ensure agent tool modules are loaded.
|
# Startup: initialise DB connection pool
|
||||||
import app.agents # noqa: F401
|
|
||||||
|
|
||||||
scheduler = None
|
|
||||||
if settings.SCHEDULER_ENABLED:
|
|
||||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler # noqa: PLC0415
|
|
||||||
|
|
||||||
scheduler = AsyncIOScheduler()
|
|
||||||
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
|
||||||
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
|
||||||
scheduler.start()
|
|
||||||
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
|
||||||
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
if scheduler is not None:
|
|
||||||
scheduler.shutdown(wait=False)
|
|
||||||
|
|
||||||
# Shutdown: dispose SQLAlchemy connection pool
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
from app.db import engine
|
from app.db import engine
|
||||||
await engine.dispose()
|
await engine.dispose()
|
||||||
@@ -104,7 +28,7 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="AdiuvAI Cloud API",
|
title="Adiuva Cloud API",
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
redoc_url=None,
|
redoc_url=None,
|
||||||
@@ -124,14 +48,18 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agents, auth, billing, chat, device_ws, memory
|
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
|
app.include_router(storage.router, prefix="/api/v1")
|
||||||
|
app.include_router(vectors.router, prefix="/api/v1")
|
||||||
|
app.include_router(backup.router, prefix="/api/v1")
|
||||||
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
|
app.include_router(agent_setup.router, prefix="/api/v1")
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
app.include_router(memory.router, prefix="/api/v1")
|
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
async def health() -> dict:
|
async def health() -> dict:
|
||||||
|
|||||||
7
app/marketplace/__init__.py
Normal file
7
app/marketplace/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Plugin marketplace package.
|
||||||
|
|
||||||
|
Three service classes introduced in Step 10:
|
||||||
|
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
||||||
|
- ``ReviewQueue`` — approval workflow + security checklist
|
||||||
|
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
||||||
|
"""
|
||||||
212
app/marketplace/plugin_registry.py
Normal file
212
app/marketplace/plugin_registry.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""Plugin catalog registry backed by PostgreSQL.
|
||||||
|
|
||||||
|
Maintains the authoritative list of plugins, their review status, and
|
||||||
|
aggregate install counts. All data is persisted in the ``plugins`` table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import Plugin
|
||||||
|
from app.schemas import PluginListResponse, PluginManifest
|
||||||
|
|
||||||
|
_PAGE_SIZE = 20
|
||||||
|
|
||||||
|
|
||||||
|
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
|
||||||
|
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
|
||||||
|
try:
|
||||||
|
permissions = json.loads(p.permissions) if p.permissions else []
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
permissions = []
|
||||||
|
return PluginManifest(
|
||||||
|
id=p.id,
|
||||||
|
name=p.name,
|
||||||
|
description=p.description,
|
||||||
|
version=p.version,
|
||||||
|
author=p.author_name,
|
||||||
|
permissions=permissions,
|
||||||
|
category=p.category,
|
||||||
|
price_cents=p.price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginRegistry:
|
||||||
|
"""PostgreSQL-backed plugin catalog.
|
||||||
|
|
||||||
|
All methods accept an ``AsyncSession`` parameter so the calling route
|
||||||
|
controls the session lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Queries ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def list_plugins(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
category: str | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
sort: Literal["rating", "installs", "newest"] = "newest",
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Return a page of approved plugins, optionally filtered and sorted."""
|
||||||
|
base = select(Plugin).where(Plugin.status == "approved")
|
||||||
|
|
||||||
|
if category:
|
||||||
|
base = base.where(Plugin.category == category)
|
||||||
|
if query:
|
||||||
|
pattern = f"%{query}%"
|
||||||
|
base = base.where(
|
||||||
|
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count
|
||||||
|
count_q = select(func.count()).select_from(base.subquery())
|
||||||
|
total = (await db.execute(count_q)).scalar_one()
|
||||||
|
|
||||||
|
# Sort
|
||||||
|
if sort == "installs":
|
||||||
|
base = base.order_by(Plugin.install_count.desc())
|
||||||
|
elif sort == "rating":
|
||||||
|
base = base.order_by(Plugin.avg_rating.desc())
|
||||||
|
else: # newest
|
||||||
|
base = base.order_by(Plugin.created_at.desc())
|
||||||
|
|
||||||
|
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
|
||||||
|
rows = (await db.execute(base)).scalars().all()
|
||||||
|
|
||||||
|
return PluginListResponse(
|
||||||
|
plugins=[_plugin_to_manifest(r) for r in rows],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
p = result.scalar_one_or_none()
|
||||||
|
if p is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"manifest": _plugin_to_manifest(p),
|
||||||
|
"status": p.status,
|
||||||
|
"install_count": p.install_count,
|
||||||
|
"avg_rating": p.avg_rating,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Mutations ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def submit_plugin(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
manifest: PluginManifest,
|
||||||
|
package_s3_key: str,
|
||||||
|
) -> str:
|
||||||
|
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
||||||
|
|
||||||
|
Returns the plugin_id. If a plugin with the same id already exists
|
||||||
|
it is overwritten (re-submission after rejection).
|
||||||
|
"""
|
||||||
|
plugin_id = manifest.id
|
||||||
|
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = existing.scalar_one_or_none()
|
||||||
|
|
||||||
|
if row is not None:
|
||||||
|
row.name = manifest.name
|
||||||
|
row.description = manifest.description
|
||||||
|
row.version = manifest.version
|
||||||
|
row.author_name = manifest.author
|
||||||
|
row.category = manifest.category
|
||||||
|
row.price_cents = manifest.price_cents
|
||||||
|
row.permissions = json.dumps(manifest.permissions)
|
||||||
|
row.status = "pending_review"
|
||||||
|
row.s3_package_key = package_s3_key
|
||||||
|
row.rejection_reason = None
|
||||||
|
else:
|
||||||
|
row = Plugin(
|
||||||
|
id=plugin_id,
|
||||||
|
name=manifest.name,
|
||||||
|
description=manifest.description,
|
||||||
|
version=manifest.version,
|
||||||
|
author_name=manifest.author,
|
||||||
|
category=manifest.category,
|
||||||
|
price_cents=manifest.price_cents,
|
||||||
|
permissions=json.dumps(manifest.permissions),
|
||||||
|
status="pending_review",
|
||||||
|
s3_package_key=package_s3_key,
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
await db.commit()
|
||||||
|
return plugin_id
|
||||||
|
|
||||||
|
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'approved'``.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
row.status = "approved"
|
||||||
|
row.rejection_reason = None
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
row.status = "rejected"
|
||||||
|
row.rejection_reason = reason
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.install_count = row.install_count + 1
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Decrement the install count for *plugin_id*, floored at 0."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.install_count = max(0, row.install_count - 1)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
||||||
|
|
||||||
|
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
|
"""Return all entries with status='pending_review'."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Plugin).where(Plugin.status == "pending_review")
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"manifest": _plugin_to_manifest(r),
|
||||||
|
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
registry = PluginRegistry()
|
||||||
125
app/marketplace/plugin_review.py
Normal file
125
app/marketplace/plugin_review.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Plugin review workflow backed by PostgreSQL.
|
||||||
|
|
||||||
|
Manages the approval queue for newly submitted plugins and enforces a
|
||||||
|
security checklist before any plugin is made visible in the marketplace.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_review import review_queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import PluginReview as PluginReviewModel
|
||||||
|
from app.schemas import PluginManifest
|
||||||
|
|
||||||
|
# ── Security policy ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"read:tasks",
|
||||||
|
"write:tasks",
|
||||||
|
"read:projects",
|
||||||
|
"write:projects",
|
||||||
|
"read:notes",
|
||||||
|
"write:notes",
|
||||||
|
"read:timelines",
|
||||||
|
"write:timelines",
|
||||||
|
"read:calendar",
|
||||||
|
"write:calendar",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_manifest(manifest: PluginManifest) -> None:
|
||||||
|
"""Enforce the plugin security checklist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``ValueError`` on the first violation found. Callers should catch
|
||||||
|
this and return HTTP 422 / reject the submission.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. Plugin id matches ``^[a-z0-9-]+$``
|
||||||
|
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
||||||
|
3. No manifest field contains raw binary data
|
||||||
|
"""
|
||||||
|
if not _PLUGIN_ID_RE.match(manifest.id):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid plugin id format: '{manifest.id}'. "
|
||||||
|
"Only lowercase letters, digits, and hyphens are allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
for perm in manifest.permissions:
|
||||||
|
if perm not in ALLOWED_PERMISSIONS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown permission: '{perm}'. "
|
||||||
|
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name, value in manifest.model_dump().items():
|
||||||
|
if isinstance(value, (bytes, bytearray)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Binary content is not allowed in manifest field '{field_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewQueue:
|
||||||
|
"""Approval queue for pending plugin submissions.
|
||||||
|
|
||||||
|
Delegates status changes to the shared ``PluginRegistry`` singleton.
|
||||||
|
Review records are persisted in the ``plugin_reviews`` table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
|
"""Return all plugins currently awaiting review.
|
||||||
|
|
||||||
|
Each item is ``{plugin_id, manifest, submitted_at}``.
|
||||||
|
"""
|
||||||
|
entries = await registry.get_pending_entries(db)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"plugin_id": e["manifest"].id,
|
||||||
|
"manifest": e["manifest"],
|
||||||
|
"submitted_at": e["submitted_at"],
|
||||||
|
}
|
||||||
|
for e in entries
|
||||||
|
]
|
||||||
|
|
||||||
|
async def submit_review(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
plugin_id: str,
|
||||||
|
reviewer_id: str,
|
||||||
|
decision: Literal["approved", "rejected"],
|
||||||
|
notes: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Record a review decision and update the plugin's status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``KeyError`` if *plugin_id* is not found in the registry.
|
||||||
|
"""
|
||||||
|
if decision == "approved":
|
||||||
|
await registry.approve_plugin(db, plugin_id)
|
||||||
|
else:
|
||||||
|
await registry.reject_plugin(db, plugin_id, reason=notes)
|
||||||
|
|
||||||
|
review = PluginReviewModel(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
reviewer_id=reviewer_id,
|
||||||
|
decision=decision,
|
||||||
|
notes=notes,
|
||||||
|
)
|
||||||
|
db.add(review)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
review_queue = ReviewQueue()
|
||||||
233
app/marketplace/revenue_share.py
Normal file
233
app/marketplace/revenue_share.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
|
||||||
|
|
||||||
|
Records every plugin installation as a revenue event and facilitates
|
||||||
|
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
|
||||||
|
in the ``revenue_events`` table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from sqlalchemy import extract, func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import Plugin, RevenueEvent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Revenue split constants ───────────────────────────────────────────
|
||||||
|
|
||||||
|
DEVELOPER_SHARE: float = 0.70
|
||||||
|
PLATFORM_SHARE: float = 0.30
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueShare:
|
||||||
|
"""Records installation revenue events and coordinates developer payouts.
|
||||||
|
|
||||||
|
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
||||||
|
is not configured, consistent with the rest of the billing layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe_configured() -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe() -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Core operations ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def record_install(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
plugin_id: str,
|
||||||
|
user_id: str,
|
||||||
|
amount_cents: int,
|
||||||
|
) -> None:
|
||||||
|
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
||||||
|
|
||||||
|
For free plugins (``amount_cents == 0``) no payment is initiated but
|
||||||
|
the event is still recorded for analytics.
|
||||||
|
|
||||||
|
For paid plugins the developer receives 70 % via a Stripe Connect
|
||||||
|
destination charge. If Stripe is not configured or the charge fails
|
||||||
|
the installation still succeeds (the event is recorded and the install
|
||||||
|
count is incremented) — a warning is logged for monitoring.
|
||||||
|
"""
|
||||||
|
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
||||||
|
stripe_transfer_id: str | None = None
|
||||||
|
|
||||||
|
if amount_cents > 0 and self._stripe_configured():
|
||||||
|
# Look up the plugin's author Stripe account from the DB
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
plugin_row = result.scalar_one_or_none()
|
||||||
|
developer_stripe_account: str | None = None
|
||||||
|
if plugin_row and plugin_row.author_id:
|
||||||
|
# Future: look up user.stripe_connect_account_id
|
||||||
|
developer_stripe_account = None # no real account yet
|
||||||
|
|
||||||
|
if developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
transfer = s.Transfer.create(
|
||||||
|
amount=developer_share_cents,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Revenue share for plugin {plugin_id}",
|
||||||
|
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
stripe_transfer_id = transfer["id"]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Stripe Connect transfer failed for plugin %s: %s",
|
||||||
|
plugin_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"No Stripe account on file for plugin %s developer; "
|
||||||
|
"skipping transfer.",
|
||||||
|
plugin_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = RevenueEvent(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=user_id,
|
||||||
|
amount_cents=amount_cents,
|
||||||
|
developer_share_cents=developer_share_cents,
|
||||||
|
stripe_transfer_id=stripe_transfer_id,
|
||||||
|
)
|
||||||
|
db.add(event)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await registry.record_install(db, plugin_id)
|
||||||
|
|
||||||
|
async def get_earnings(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
developer_id: str,
|
||||||
|
period: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return aggregated earnings for *developer_id*.
|
||||||
|
|
||||||
|
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
||||||
|
|
||||||
|
Returns::
|
||||||
|
|
||||||
|
{
|
||||||
|
"developer_id": str,
|
||||||
|
"period": str | None,
|
||||||
|
"total_installs": int,
|
||||||
|
"total_revenue_cents": int,
|
||||||
|
"developer_share_cents": int,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Find plugin ids belonging to this developer (by author_name match)
|
||||||
|
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
|
||||||
|
plugin_result = await db.execute(plugin_q)
|
||||||
|
developer_plugin_ids = [row[0] for row in plugin_result.all()]
|
||||||
|
|
||||||
|
if not developer_plugin_ids:
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": 0,
|
||||||
|
"total_revenue_cents": 0,
|
||||||
|
"developer_share_cents": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
query = select(
|
||||||
|
func.count().label("total_installs"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
|
||||||
|
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
|
||||||
|
|
||||||
|
if period:
|
||||||
|
# Filter by YYYY-MM: extract year and month from created_at
|
||||||
|
try:
|
||||||
|
year, month = period.split("-")
|
||||||
|
query = query.where(
|
||||||
|
extract("year", RevenueEvent.created_at) == int(year),
|
||||||
|
extract("month", RevenueEvent.created_at) == int(month),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass # invalid period format — return all
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
row = result.one()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": row.total_installs,
|
||||||
|
"total_revenue_cents": row.total_revenue,
|
||||||
|
"developer_share_cents": row.dev_share,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
|
||||||
|
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
||||||
|
|
||||||
|
Marks processed events with ``paid_at`` timestamp.
|
||||||
|
Stubs gracefully when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
year, month = period.split("-")
|
||||||
|
year_int, month_int = int(year), int(month)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid period format: %s", period)
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(RevenueEvent).where(
|
||||||
|
RevenueEvent.plugin_id == plugin_id,
|
||||||
|
RevenueEvent.paid_at.is_(None),
|
||||||
|
extract("year", RevenueEvent.created_at) == year_int,
|
||||||
|
extract("month", RevenueEvent.created_at) == month_int,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
unpaid = list(result.scalars().all())
|
||||||
|
|
||||||
|
total_dev_share = sum(e.developer_share_cents for e in unpaid)
|
||||||
|
if total_dev_share <= 0 or not unpaid:
|
||||||
|
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._stripe_configured():
|
||||||
|
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
plugin_row = plugin_result.scalar_one_or_none()
|
||||||
|
developer_stripe_account: str | None = None # Future: fetch from DB
|
||||||
|
if plugin_row and developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
s.Transfer.create(
|
||||||
|
amount=total_dev_share,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Payout for plugin {plugin_id} period {period}",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
paid_ts = datetime.now(timezone.utc)
|
||||||
|
for event in unpaid:
|
||||||
|
event.paid_at = paid_ts
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
revenue_share = RevenueShare()
|
||||||
293
app/models.py
293
app/models.py
@@ -1,20 +1,23 @@
|
|||||||
"""SQLAlchemy ORM models for all persistent tables.
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
Only auth, billing, agent config, and memory data live here.
|
Only auth, billing, storage metadata, and marketplace data live here.
|
||||||
User content (notes, tasks, etc.) lives exclusively on the client.
|
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||||
|
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||||
|
|
||||||
Table inventory:
|
Table inventory:
|
||||||
users — account credentials + tier
|
users — account credentials + tier
|
||||||
refresh_tokens — hashed refresh token store
|
refresh_tokens — hashed refresh token store
|
||||||
subscriptions — Stripe subscription records
|
subscriptions — Stripe subscription records
|
||||||
local_agent_configs — per-device batch agent configs
|
storage_records — S3 blob metadata (no plaintext)
|
||||||
cloud_agent_configs — OAuth-backed cloud agent configs
|
backup_metadata — encrypted backup manifests
|
||||||
agent_run_logs — execution history for all agents
|
plugins — marketplace plugin catalog
|
||||||
|
plugin_installations — per-user install records
|
||||||
|
plugin_reviews — admin review decisions
|
||||||
|
revenue_events — Stripe Connect 70/30 split ledger
|
||||||
memory_core — per-user persistent key/value preferences (encrypted)
|
memory_core — per-user persistent key/value preferences (encrypted)
|
||||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
memory_episodic — per-user session summaries (encrypted)
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
memory_proactive — per-user behavioral patterns (encrypted)
|
memory_proactive — per-user behavioral patterns (encrypted)
|
||||||
memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3)
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -22,8 +25,8 @@ from __future__ import annotations
|
|||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from pgvector.sqlalchemy import Vector
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
|
BigInteger,
|
||||||
Boolean,
|
Boolean,
|
||||||
DateTime,
|
DateTime,
|
||||||
Enum,
|
Enum,
|
||||||
@@ -31,9 +34,9 @@ from sqlalchemy import (
|
|||||||
ForeignKey,
|
ForeignKey,
|
||||||
Integer,
|
Integer,
|
||||||
JSON,
|
JSON,
|
||||||
LargeBinary,
|
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
Uuid,
|
Uuid,
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
@@ -55,6 +58,8 @@ def _now() -> datetime:
|
|||||||
# ── Enum types ────────────────────────────────────────────────────────────
|
# ── Enum types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||||
|
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
||||||
|
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
||||||
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
||||||
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
||||||
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
||||||
@@ -72,8 +77,7 @@ class User(Base):
|
|||||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||||
@@ -82,9 +86,6 @@ class User(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
onboarding_completed_at: Mapped[datetime | None] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=True, default=None
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
)
|
)
|
||||||
@@ -95,9 +96,6 @@ class User(Base):
|
|||||||
subscription: Mapped[Subscription | None] = relationship(
|
subscription: Mapped[Subscription | None] = relationship(
|
||||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
|
||||||
back_populates="user", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(Base):
|
class RefreshToken(Base):
|
||||||
@@ -118,25 +116,6 @@ class RefreshToken(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|
||||||
class OAuthAccount(Base):
|
|
||||||
__tablename__ = "oauth_accounts"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
|
||||||
provider_user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
provider_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
user: Mapped[User] = relationship(back_populates="oauth_accounts")
|
|
||||||
|
|
||||||
|
|
||||||
class Subscription(Base):
|
class Subscription(Base):
|
||||||
__tablename__ = "subscriptions"
|
__tablename__ = "subscriptions"
|
||||||
|
|
||||||
@@ -158,6 +137,151 @@ class Subscription(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="subscription")
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecord(Base):
|
||||||
|
__tablename__ = "storage_records"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class BackupMetadata(Base):
|
||||||
|
__tablename__ = "backup_metadata"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
||||||
|
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
||||||
|
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
||||||
|
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class Plugin(Base):
|
||||||
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
||||||
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||||
|
# nullable until developer account system is built
|
||||||
|
author_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||||
|
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||||
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||||
|
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||||
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
||||||
|
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
submitted_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
installations: Mapped[list[PluginInstallation]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
reviews: Mapped[list[PluginReview]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
||||||
|
back_populates="plugin", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallation(Base):
|
||||||
|
__tablename__ = "plugin_installations"
|
||||||
|
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
installed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
||||||
|
|
||||||
|
|
||||||
|
class PluginReview(Base):
|
||||||
|
__tablename__ = "plugin_reviews"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
reviewer_id: Mapped[str | None] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
|
)
|
||||||
|
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
||||||
|
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
reviewed_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueEvent(Base):
|
||||||
|
__tablename__ = "revenue_events"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
plugin_id: Mapped[str] = mapped_column(
|
||||||
|
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
|
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfig(Base):
|
class LocalAgentConfig(Base):
|
||||||
__tablename__ = "local_agent_configs"
|
__tablename__ = "local_agent_configs"
|
||||||
|
|
||||||
@@ -172,7 +296,6 @@ class LocalAgentConfig(Base):
|
|||||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
|
||||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
@@ -243,7 +366,6 @@ class AgentRunLog(Base):
|
|||||||
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
||||||
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
|
||||||
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
started_at: Mapped[datetime] = mapped_column(
|
started_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
@@ -264,17 +386,6 @@ class AgentRunLog(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class MonthlyTokenUsage(Base):
|
|
||||||
__tablename__ = "monthly_token_usage"
|
|
||||||
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
|
|
||||||
)
|
|
||||||
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
|
|
||||||
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
|
|
||||||
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
|
|
||||||
|
|
||||||
|
|
||||||
# ── Memory models ─────────────────────────────────────────────────────────────
|
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -314,8 +425,8 @@ class MemoryAssociative(Base):
|
|||||||
nullable=False, index=True,
|
nullable=False, index=True,
|
||||||
)
|
)
|
||||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
# vector(1536) via pgvector; SQLite tests use NULL embeddings so no dialect issue.
|
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
||||||
embedding: Mapped[list | None] = mapped_column(Vector(1536), nullable=True)
|
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
@@ -363,85 +474,3 @@ class MemoryProactive(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ExtractionQueue(Base):
|
|
||||||
"""Batch extraction queue for Free-tier users (Phase 2).
|
|
||||||
|
|
||||||
Pro/Power/Team users get realtime asyncio.create_task() extraction.
|
|
||||||
Free users get a queue row here; a daily cron (Phase 5) drains it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "extraction_queue"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
|
||||||
nullable=False, index=True,
|
|
||||||
)
|
|
||||||
episode_id: Mapped[str | None] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), nullable=True,
|
|
||||||
)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryRelation(Base):
|
|
||||||
"""Per-user entity/relation graph row (Mem0g-light, Phase 3).
|
|
||||||
|
|
||||||
subject_label/object_label are plaintext entity identifiers (not user content).
|
|
||||||
notes_encrypted is optional Fernet-encrypted per-user commentary.
|
|
||||||
confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at.
|
|
||||||
"""
|
|
||||||
|
|
||||||
__tablename__ = "memory_relations"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
|
||||||
nullable=False, index=True,
|
|
||||||
)
|
|
||||||
subject_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
|
||||||
subject_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
|
||||||
predicate: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
||||||
object_label: Mapped[str] = mapped_column(String(128), nullable=False)
|
|
||||||
object_type: Mapped[str] = mapped_column(String(32), nullable=False)
|
|
||||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7)
|
|
||||||
source_episode_id: Mapped[str | None] = mapped_column(
|
|
||||||
Uuid(as_uuid=False),
|
|
||||||
ForeignKey("memory_episodic.id", ondelete="SET NULL"),
|
|
||||||
nullable=True,
|
|
||||||
)
|
|
||||||
notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
||||||
)
|
|
||||||
last_confirmed_at: Mapped[datetime | None] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=True
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Plugin(Base):
|
|
||||||
"""Plugin marketplace catalog entry."""
|
|
||||||
|
|
||||||
__tablename__ = "plugins"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
description: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
version: Mapped[str] = mapped_column(String(50), nullable=False)
|
|
||||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
category: Mapped[str] = mapped_column(String(100), nullable=False)
|
|
||||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
|
||||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="pending")
|
|
||||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
|
||||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|||||||
296
app/schemas.py
296
app/schemas.py
@@ -30,16 +30,6 @@ class UserProfile(BaseModel):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
surname: str | None = None
|
surname: str | None = None
|
||||||
tier: BillingTier
|
tier: BillingTier
|
||||||
avatar_url: str | None = None
|
|
||||||
has_password: bool = True
|
|
||||||
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
|
|
||||||
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
|
|
||||||
|
|
||||||
|
|
||||||
class OAuthAccountInfo(BaseModel):
|
|
||||||
provider: str
|
|
||||||
provider_email: str | None = None
|
|
||||||
created_at: int # epoch ms
|
|
||||||
|
|
||||||
|
|
||||||
# ── Chat ─────────────────────────────────────────────────────────────
|
# ── Chat ─────────────────────────────────────────────────────────────
|
||||||
@@ -60,6 +50,88 @@ class ChatResponse(BaseModel):
|
|||||||
response: str
|
response: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Backup ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class BackupMetadata(BaseModel):
|
||||||
|
version: int
|
||||||
|
timestamp: int
|
||||||
|
checksum: str
|
||||||
|
chunk_count: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
||||||
|
|
||||||
|
class StorageRecord(BaseModel):
|
||||||
|
id: str
|
||||||
|
user_id: str
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordCreate(BaseModel):
|
||||||
|
table: str
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class StorageRecordUpdate(BaseModel):
|
||||||
|
blob: bytes
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
||||||
|
|
||||||
|
class VectorItem(BaseModel):
|
||||||
|
id: str
|
||||||
|
blob: bytes # encrypted vector + metadata — backend never decrypts
|
||||||
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
|
class VectorUpsertRequest(BaseModel):
|
||||||
|
vectors: list[VectorItem]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchRequest(BaseModel):
|
||||||
|
query_blob: bytes # encrypted query — backend never decrypts
|
||||||
|
top_k: int = 10
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResult(BaseModel):
|
||||||
|
id: str
|
||||||
|
score: float
|
||||||
|
blob: bytes
|
||||||
|
|
||||||
|
|
||||||
|
class VectorSearchResponse(BaseModel):
|
||||||
|
results: list[VectorSearchResult]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Plugin Marketplace ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class PluginManifest(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
description: str
|
||||||
|
version: str
|
||||||
|
author: str
|
||||||
|
permissions: list[str]
|
||||||
|
category: str
|
||||||
|
price_cents: int = 0
|
||||||
|
|
||||||
|
|
||||||
|
class PluginListResponse(BaseModel):
|
||||||
|
plugins: list[PluginManifest]
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
|
||||||
|
|
||||||
|
class PluginInstallRequest(BaseModel):
|
||||||
|
plugin_id: str
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||||
|
|
||||||
class WsFrameType(str, Enum):
|
class WsFrameType(str, Enum):
|
||||||
@@ -70,6 +142,9 @@ class WsFrameType(str, Enum):
|
|||||||
tool_result = "tool_result"
|
tool_result = "tool_result"
|
||||||
final = "final"
|
final = "final"
|
||||||
ping = "ping"
|
ping = "ping"
|
||||||
|
agent_run = "agent_run"
|
||||||
|
agent_data = "agent_data"
|
||||||
|
agent_complete = "agent_complete"
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
@@ -81,21 +156,6 @@ class WsFrameType(str, Enum):
|
|||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
# ── v4 journey frame types ────────────────────────────────────────
|
|
||||||
journey_start = "journey_start"
|
|
||||||
journey_message = "journey_message"
|
|
||||||
journey_reply = "journey_reply"
|
|
||||||
# ── v5 brief frame types ──────────────────────────────────────────
|
|
||||||
brief_request = "brief_request"
|
|
||||||
# ── v6 task brief frame types ─────────────────────────────────────
|
|
||||||
task_brief_request = "task_brief_request"
|
|
||||||
# ── v7 folder index frame types ───────────────────────────────────
|
|
||||||
index_session_start = "index_session_start"
|
|
||||||
index_file_batch = "index_file_batch"
|
|
||||||
index_session_cancel = "index_session_cancel"
|
|
||||||
index_file_result = "index_file_result"
|
|
||||||
index_session_progress = "index_session_progress"
|
|
||||||
index_session_done = "index_session_done"
|
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -148,19 +208,34 @@ class WsDeviceHello(BaseModel):
|
|||||||
agent_ids: list[str] = Field(default_factory=list)
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentRun(BaseModel):
|
||||||
|
"""Server → Client: trigger an agent run on the connected device."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
||||||
|
run_id: str
|
||||||
|
agent_id: str
|
||||||
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentData(BaseModel):
|
||||||
|
"""Client → Server: files read by the local agent."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
||||||
|
run_id: str
|
||||||
|
files: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentComplete(BaseModel):
|
||||||
|
"""Client → Server: Electron signals it has finished reading files."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
||||||
|
run_id: str
|
||||||
|
files_read: int
|
||||||
|
errors: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
class FormatPrefsModel(BaseModel):
|
|
||||||
"""User display preferences sent by Electron on each request."""
|
|
||||||
|
|
||||||
timezone: str = "UTC"
|
|
||||||
date_format: str = "dd/MM/yyyy"
|
|
||||||
time_format: str = "24h"
|
|
||||||
locale: str = "en-US"
|
|
||||||
now_iso: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingScope(BaseModel):
|
class WsFloatingScope(BaseModel):
|
||||||
"""Scope for a floating request — narrows the agent to a specific entity."""
|
"""Scope for a floating request — narrows the agent to a specific entity."""
|
||||||
|
|
||||||
@@ -174,7 +249,6 @@ class WsHomeRequest(BaseModel):
|
|||||||
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
type: Literal[WsFrameType.home_request] = WsFrameType.home_request
|
||||||
message: str
|
message: str
|
||||||
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
conversation_history: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
format_prefs: FormatPrefsModel | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingRequest(BaseModel):
|
class WsFloatingRequest(BaseModel):
|
||||||
@@ -183,18 +257,6 @@ class WsFloatingRequest(BaseModel):
|
|||||||
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
|
||||||
message: str
|
message: str
|
||||||
scope: WsFloatingScope
|
scope: WsFloatingScope
|
||||||
format_prefs: FormatPrefsModel | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsBriefRequest(BaseModel):
|
|
||||||
"""Client → Server: Request a plain-text brief (home or project)."""
|
|
||||||
|
|
||||||
type: Literal[WsFrameType.brief_request] = WsFrameType.brief_request
|
|
||||||
request_id: str | None = None
|
|
||||||
session_id: str | None = None
|
|
||||||
mode: Literal["home", "project"]
|
|
||||||
project_id: str | None = None
|
|
||||||
format_prefs: FormatPrefsModel | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsStreamStart(BaseModel):
|
class WsStreamStart(BaseModel):
|
||||||
@@ -217,16 +279,7 @@ class WsStreamEnd(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
error: str | None = None
|
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
mutations: list[dict[str, Any]] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsDomain(BaseModel):
|
|
||||||
"""Structured floating domain payload for UI routing decisions."""
|
|
||||||
|
|
||||||
type: Literal["task", "timeline", "project", "node"]
|
|
||||||
id: str | None = None
|
|
||||||
section: Literal["task", "timeline", "note"] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class WsFloatingDomain(BaseModel):
|
||||||
@@ -234,28 +287,7 @@ class WsFloatingDomain(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
request_id: str
|
request_id: str
|
||||||
domain: WsDomain
|
domain: Literal["tasks", "timelines", "notes", "projects"]
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Config V2 ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class ContentTypeConfig(BaseModel):
|
|
||||||
"""Per-type extraction config produced by the journey chatbot."""
|
|
||||||
|
|
||||||
id: str
|
|
||||||
label: str = ""
|
|
||||||
detection_hint: str = ""
|
|
||||||
preprocessing: str = "generic" # handler name: "email_html", "plain_text", ...
|
|
||||||
extraction_prompt: str
|
|
||||||
|
|
||||||
|
|
||||||
class AgentConfig(BaseModel):
|
|
||||||
"""Structured agent configuration (replaces freeform prompt_template)."""
|
|
||||||
|
|
||||||
content_types: list[ContentTypeConfig] = []
|
|
||||||
global_rules: list[str] = []
|
|
||||||
data_types: list[str] = []
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
@@ -264,29 +296,84 @@ class AgentCatalogItem(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
config_schema: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckRequest(BaseModel):
|
# ── Local Agent Config ────────────────────────────────────────────────
|
||||||
active_agents: int = Field(ge=0, default=0)
|
|
||||||
|
class LocalAgentConfigCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
device_id: str
|
||||||
|
directory_paths: list[str]
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
file_extensions: list[str]
|
||||||
|
schedule_cron: str
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckResponse(BaseModel):
|
class LocalAgentConfigUpdate(BaseModel):
|
||||||
allowed: bool
|
name: str | None = None
|
||||||
tier: BillingTier
|
device_id: str | None = None
|
||||||
active_agents: int
|
directory_paths: list[str] | None = None
|
||||||
limit: int
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
file_extensions: list[str] | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentTriggerRequest(BaseModel):
|
class LocalAgentConfigResponse(BaseModel):
|
||||||
directory: str = Field(min_length=1)
|
id: str
|
||||||
device_id: str = Field(default="")
|
name: str
|
||||||
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
device_id: str
|
||||||
what_to_extract: list[str] = Field(min_length=1)
|
directory_paths: list[str]
|
||||||
batch_interval: str = Field(min_length=1)
|
data_types: list[str]
|
||||||
custom_agent_prompt: str | None = None
|
prompt_template: str
|
||||||
agent_config: dict | None = None
|
file_extensions: list[str]
|
||||||
active_agents: int = Field(ge=0, default=0)
|
schedule_cron: str
|
||||||
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Agent Config ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CloudAgentConfigCreate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
oauth_token_encrypted: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigUpdate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"] | None = None
|
||||||
|
name: str | None = None
|
||||||
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
oauth_token_encrypted: str | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigResponse(BaseModel):
|
||||||
|
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None
|
||||||
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
@@ -305,3 +392,18 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class JourneyStartRequest(BaseModel):
|
||||||
|
agent_type: Literal["local", "cloud"]
|
||||||
|
agent_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyMessageRequest(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyResponse(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
done: bool
|
||||||
|
prompt_template: str | None = None
|
||||||
|
|||||||
1
app/storage/__init__.py
Normal file
1
app/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
||||||
106
app/storage/blob_store.py
Normal file
106
app/storage/blob_store.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""S3-backed store for E2E-encrypted blobs.
|
||||||
|
|
||||||
|
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
||||||
|
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class BlobStore:
|
||||||
|
"""Thin wrapper around boto3 S3.
|
||||||
|
|
||||||
|
All blobs must be E2E encrypted by the client before upload.
|
||||||
|
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
||||||
|
but cannot decrypt the inner client-side payload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"region_name": settings.S3_REGION,
|
||||||
|
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
|
||||||
|
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
|
||||||
|
}
|
||||||
|
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
|
||||||
|
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
|
||||||
|
return boto3.client("s3", **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _key(user_id: str, table: str, record_id: str) -> str:
|
||||||
|
return f"{user_id}/{table}/{record_id}"
|
||||||
|
|
||||||
|
async def upload(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
table: str,
|
||||||
|
record_id: str,
|
||||||
|
blob: bytes,
|
||||||
|
checksum: str,
|
||||||
|
) -> str:
|
||||||
|
"""Store *blob* in S3 and return the S3 key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Owner of the blob (used as key prefix).
|
||||||
|
table: Logical table name (e.g. ``"tasks"``).
|
||||||
|
record_id: Record UUID.
|
||||||
|
blob: Raw bytes (pre-encrypted by client).
|
||||||
|
checksum: SHA-256 hex digest supplied by the client; stored as
|
||||||
|
object metadata for download-time verification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The S3 key under which the blob was stored.
|
||||||
|
"""
|
||||||
|
key = self._key(user_id, table, record_id)
|
||||||
|
self._client().put_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=key,
|
||||||
|
Body=blob,
|
||||||
|
ServerSideEncryption="AES256", # SSE-S3 at rest
|
||||||
|
Metadata={"checksum": checksum},
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
async def download(self, user_id: str, s3_key: str) -> bytes:
|
||||||
|
"""Retrieve the blob stored at *s3_key*.
|
||||||
|
|
||||||
|
*user_id* is retained in the signature so higher-level code can
|
||||||
|
enforce ownership without re-parsing the key.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
||||||
|
object does not exist.
|
||||||
|
"""
|
||||||
|
response = self._client().get_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
return response["Body"].read()
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, s3_key: str) -> None:
|
||||||
|
"""Delete the object at *s3_key*.
|
||||||
|
|
||||||
|
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
||||||
|
not exist.
|
||||||
|
"""
|
||||||
|
self._client().delete_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
||||||
|
"""Return all S3 keys for a given user + table combination.
|
||||||
|
|
||||||
|
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
||||||
|
"""
|
||||||
|
prefix = f"{user_id}/{table}/"
|
||||||
|
response = self._client().list_objects_v2(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Prefix=prefix,
|
||||||
|
)
|
||||||
|
return [obj["Key"] for obj in response.get("Contents", [])]
|
||||||
32
app/storage/encryption.py
Normal file
32
app/storage/encryption.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Integrity verification only — the backend NEVER decrypts user data."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
||||||
|
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
||||||
|
|
||||||
|
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
||||||
|
timing-based side-channel attacks.
|
||||||
|
"""
|
||||||
|
computed = hashlib.sha256(blob).hexdigest()
|
||||||
|
return hmac.compare_digest(computed, checksum)
|
||||||
|
|
||||||
|
|
||||||
|
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
||||||
|
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
||||||
|
|
||||||
|
Call this before storing or forwarding any client-provided blob.
|
||||||
|
The backend never holds decryption keys — this check only verifies
|
||||||
|
that the opaque bytes arrived intact.
|
||||||
|
"""
|
||||||
|
if not verify_checksum(blob, checksum):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Checksum mismatch: blob integrity check failed",
|
||||||
|
)
|
||||||
205
app/storage/vector_store.py
Normal file
205
app/storage/vector_store.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
||||||
|
|
||||||
|
Vectors are pre-encrypted blobs from the client. The backend stores them
|
||||||
|
alongside a deterministic 32-dim float representation derived from the blob's
|
||||||
|
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
||||||
|
is a known trade-off documented in the backend plan.
|
||||||
|
|
||||||
|
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
||||||
|
``user_id`` payload field on a shared collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pinecone import Pinecone
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.schemas import VectorItem, VectorSearchResult
|
||||||
|
|
||||||
|
_QDRANT_COLLECTION = "adiuva_vectors"
|
||||||
|
|
||||||
|
|
||||||
|
def _blob_to_vector(blob: bytes) -> list[float]:
|
||||||
|
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
||||||
|
|
||||||
|
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
||||||
|
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
||||||
|
semantic meaning on encrypted data.
|
||||||
|
"""
|
||||||
|
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStore:
|
||||||
|
"""Thin wrapper around Pinecone or Qdrant.
|
||||||
|
|
||||||
|
The backend to use is selected at runtime:
|
||||||
|
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
||||||
|
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _use_pinecone(self) -> bool:
|
||||||
|
return bool(settings.PINECONE_API_KEY)
|
||||||
|
|
||||||
|
# ── Pinecone helpers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _pinecone_index(self) -> Any:
|
||||||
|
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
||||||
|
return pc.Index(settings.PINECONE_INDEX)
|
||||||
|
|
||||||
|
# ── Qdrant helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _qdrant_client(self) -> Any:
|
||||||
|
return QdrantClient(
|
||||||
|
url=settings.QDRANT_URL,
|
||||||
|
api_key=settings.QDRANT_API_KEY or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
"""Store encrypted vectors in the backend.
|
||||||
|
|
||||||
|
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
||||||
|
so it can be returned verbatim during search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Used as Pinecone namespace or Qdrant payload field.
|
||||||
|
vectors: List of encrypted vector items from the client.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_upsert(user_id, vectors)
|
||||||
|
else:
|
||||||
|
await self._qdrant_upsert(user_id, vectors)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
query_blob: bytes,
|
||||||
|
top_k: int,
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
"""Query the vector store and return encrypted result blobs.
|
||||||
|
|
||||||
|
The query vector is derived from *query_blob* using the same
|
||||||
|
deterministic mapping as upsert.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Scopes the search to this user's namespace.
|
||||||
|
query_blob: Encrypted query from the client.
|
||||||
|
top_k: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
return await self._pinecone_search(user_id, query_blob, top_k)
|
||||||
|
return await self._qdrant_search(user_id, query_blob, top_k)
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
"""Remove vectors by ID, scoped to *user_id*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Namespace / payload filter to prevent cross-user deletion.
|
||||||
|
vector_ids: List of vector IDs to remove.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_delete(user_id, vector_ids)
|
||||||
|
else:
|
||||||
|
await self._qdrant_delete(user_id, vector_ids)
|
||||||
|
|
||||||
|
# ── Pinecone implementation ───────────────────────────────────────
|
||||||
|
|
||||||
|
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"id": v.id,
|
||||||
|
"values": _blob_to_vector(v.blob),
|
||||||
|
"metadata": {
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
index.upsert(vectors=records, namespace=user_id)
|
||||||
|
|
||||||
|
async def _pinecone_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
response = index.query(
|
||||||
|
vector=query_vector,
|
||||||
|
top_k=top_k,
|
||||||
|
namespace=user_id,
|
||||||
|
include_metadata=True,
|
||||||
|
)
|
||||||
|
results: list[VectorSearchResult] = []
|
||||||
|
for match in response.get("matches", []):
|
||||||
|
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
||||||
|
results.append(
|
||||||
|
VectorSearchResult(
|
||||||
|
id=match["id"],
|
||||||
|
score=match["score"],
|
||||||
|
blob=blob_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
index.delete(ids=vector_ids, namespace=user_id)
|
||||||
|
|
||||||
|
# ── Qdrant implementation ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
points = [
|
||||||
|
PointStruct(
|
||||||
|
id=v.id,
|
||||||
|
vector=_blob_to_vector(v.blob),
|
||||||
|
payload={
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
||||||
|
|
||||||
|
async def _qdrant_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
hits = client.search(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=Filter(
|
||||||
|
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
||||||
|
),
|
||||||
|
limit=top_k,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
VectorSearchResult(
|
||||||
|
id=str(hit.id),
|
||||||
|
score=hit.score,
|
||||||
|
blob=base64.b64decode(hit.payload["blob"]),
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
client.delete(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
points_selector=PointIdsList(points=vector_ids),
|
||||||
|
)
|
||||||
@@ -7,7 +7,7 @@ services:
|
|||||||
- path: .env
|
- path: .env
|
||||||
required: false
|
required: false
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
||||||
volumes:
|
volumes:
|
||||||
- copilot_tokens:/root/.config/litellm/github_copilot
|
- copilot_tokens:/root/.config/litellm/github_copilot
|
||||||
@@ -21,7 +21,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: postgres
|
||||||
POSTGRES_PASSWORD: postgres
|
POSTGRES_PASSWORD: postgres
|
||||||
POSTGRES_DB: adiuvai
|
POSTGRES_DB: adiuva
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
@@ -36,6 +36,37 @@ services:
|
|||||||
# image: redis:7-alpine
|
# image: redis:7-alpine
|
||||||
# restart: unless-stopped
|
# restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Local S3-compatible storage (MinIO) ──
|
||||||
|
minio:
|
||||||
|
image: minio/minio:latest
|
||||||
|
command: server /data --console-address ":9001"
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
- "9001:9001"
|
||||||
|
environment:
|
||||||
|
MINIO_ROOT_USER: minioadmin
|
||||||
|
MINIO_ROOT_PASSWORD: minioadmin
|
||||||
|
volumes:
|
||||||
|
- minio_data:/data
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "mc", "ready", "local"]
|
||||||
|
interval: 5s
|
||||||
|
timeout: 5s
|
||||||
|
retries: 5
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Local vector store (Qdrant) ──
|
||||||
|
qdrant:
|
||||||
|
image: qdrant/qdrant:latest
|
||||||
|
ports:
|
||||||
|
- "6333:6333"
|
||||||
|
- "6334:6334"
|
||||||
|
volumes:
|
||||||
|
- qdrant_data:/qdrant/storage
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
|
minio_data:
|
||||||
|
qdrant_data:
|
||||||
copilot_tokens:
|
copilot_tokens:
|
||||||
|
|||||||
@@ -4,6 +4,8 @@ gunicorn>=22.0.0
|
|||||||
langchain>=0.3.0
|
langchain>=0.3.0
|
||||||
langchain-openai>=0.3.0
|
langchain-openai>=0.3.0
|
||||||
langchain-litellm>=0.1.0
|
langchain-litellm>=0.1.0
|
||||||
|
langgraph>=0.3.0
|
||||||
|
deepagents>=0.4.10
|
||||||
litellm>=1.50.0
|
litellm>=1.50.0
|
||||||
pydantic>=2.10.0
|
pydantic>=2.10.0
|
||||||
pydantic-settings>=2.7.0
|
pydantic-settings>=2.7.0
|
||||||
@@ -32,12 +34,4 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
pgvector>=0.2.5
|
|
||||||
langfuse>=3.3.1
|
|
||||||
beautifulsoup4>=4.12.0
|
|
||||||
lxml>=5.0.0
|
|
||||||
PyYAML>=6.0.0
|
|
||||||
apscheduler>=3.10.0
|
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
pypdf>=4.0
|
|
||||||
python-docx>=1.1
|
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
@@ -6,23 +6,26 @@ a per-test session, and a FastAPI ``TestClient`` wired to use it.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from collections.abc import AsyncGenerator, Generator
|
from collections.abc import AsyncGenerator, Generator
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import boto3
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
from fastapi.testclient import TestClient
|
from fastapi.testclient import TestClient
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
|
from moto import mock_aws
|
||||||
from sqlalchemy import StaticPool, event
|
from sqlalchemy import StaticPool, event
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
from app.db import Base, get_session
|
from app.db import Base, get_session
|
||||||
from app.main import app
|
from app.main import app
|
||||||
from app.models import Subscription, User
|
from app.models import Plugin, Subscription, User
|
||||||
|
|
||||||
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
# ── Fixed test user IDs (one per tier) ───────────────────────────────
|
||||||
|
|
||||||
@@ -106,6 +109,79 @@ def client(db_session: AsyncSession) -> Generator[TestClient, None, None]: # n
|
|||||||
app.dependency_overrides.pop(get_session, None)
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Seed data helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SEED_PLUGINS = [
|
||||||
|
Plugin(
|
||||||
|
id="plugin-github-sync",
|
||||||
|
name="GitHub Sync",
|
||||||
|
description="Sync tasks with GitHub Issues and pull requests.",
|
||||||
|
version="1.0.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=0,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-github-sync/1.0.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-slack-notify",
|
||||||
|
name="Slack Notifier",
|
||||||
|
description="Post task and timeline updates to Slack channels.",
|
||||||
|
version="1.2.0",
|
||||||
|
author_name="Adiuva",
|
||||||
|
category="communication",
|
||||||
|
price_cents=499,
|
||||||
|
permissions=json.dumps(["read:tasks", "read:timelines"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-slack-notify/1.2.0/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
Plugin(
|
||||||
|
id="plugin-time-tracker",
|
||||||
|
name="Time Tracker",
|
||||||
|
description="Track time spent on tasks with automatic reporting.",
|
||||||
|
version="0.9.1",
|
||||||
|
author_name="Third Party",
|
||||||
|
category="productivity",
|
||||||
|
price_cents=999,
|
||||||
|
permissions=json.dumps(["read:tasks", "write:tasks"]),
|
||||||
|
status="approved",
|
||||||
|
s3_package_key="plugins/plugin-time-tracker/0.9.1/package.zip",
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture
|
||||||
|
async def seed_plugins(db_session: AsyncSession) -> list[Plugin]:
|
||||||
|
"""Insert the 3 default approved plugins and return them."""
|
||||||
|
plugins = []
|
||||||
|
for template in _SEED_PLUGINS:
|
||||||
|
p = Plugin(
|
||||||
|
id=template.id,
|
||||||
|
name=template.name,
|
||||||
|
description=template.description,
|
||||||
|
version=template.version,
|
||||||
|
author_name=template.author_name,
|
||||||
|
category=template.category,
|
||||||
|
price_cents=template.price_cents,
|
||||||
|
permissions=template.permissions,
|
||||||
|
status=template.status,
|
||||||
|
s3_package_key=template.s3_package_key,
|
||||||
|
install_count=template.install_count,
|
||||||
|
avg_rating=template.avg_rating,
|
||||||
|
)
|
||||||
|
db_session.add(p)
|
||||||
|
plugins.append(p)
|
||||||
|
await db_session.commit()
|
||||||
|
return plugins
|
||||||
|
|
||||||
|
|
||||||
# ── JWT helpers ──────────────────────────────────────────────────────
|
# ── JWT helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -136,53 +212,24 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
|
|||||||
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
|
||||||
|
|
||||||
|
|
||||||
# ── Convenience aliases and per-tier user fixtures ────────────────────
|
# ── S3 mock fixture ──────────────────────────────────────────────────
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
S3_TEST_BUCKET = "test-bucket"
|
||||||
async def db(db_session: AsyncSession) -> AsyncSession:
|
S3_TEST_REGION = "us-east-1"
|
||||||
"""Alias for db_session — used by folder quota tests."""
|
|
||||||
return db_session
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def test_user_free(db_session: AsyncSession):
|
|
||||||
"""Return the seeded free-tier User row."""
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(User).where(User.id == TEST_USER_IDS["free"])
|
|
||||||
)
|
|
||||||
return result.scalar_one()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture
|
|
||||||
async def test_user_power(db_session: AsyncSession):
|
|
||||||
"""Return the seeded power-tier User row."""
|
|
||||||
result = await db_session.execute(
|
|
||||||
select(User).where(User.id == TEST_USER_IDS["power"])
|
|
||||||
)
|
|
||||||
return result.scalar_one()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def auth_headers_free() -> dict[str, str]:
|
def s3_bucket():
|
||||||
"""Authorization header for the seeded free-tier user."""
|
"""Create a mocked S3 bucket via moto and patch BlobStore settings."""
|
||||||
return auth_header("free")
|
with mock_aws():
|
||||||
|
os.environ.setdefault("AWS_ACCESS_KEY_ID", "testing")
|
||||||
|
os.environ.setdefault("AWS_SECRET_ACCESS_KEY", "testing")
|
||||||
# ── CLI options ───────────────────────────────────────────────────────
|
os.environ.setdefault("AWS_DEFAULT_REGION", S3_TEST_REGION)
|
||||||
|
client = boto3.client("s3", region_name=S3_TEST_REGION)
|
||||||
def pytest_addoption(parser):
|
client.create_bucket(Bucket=S3_TEST_BUCKET)
|
||||||
parser.addoption(
|
with patch("app.storage.blob_store.settings") as mock_settings:
|
||||||
"--preprocess-dir",
|
mock_settings.S3_BUCKET = S3_TEST_BUCKET
|
||||||
default=None,
|
mock_settings.S3_REGION = S3_TEST_REGION
|
||||||
help="Override fixture folder for preprocessor tests (must contain cases.yaml + data/)",
|
mock_settings.AWS_ACCESS_KEY_ID = "testing"
|
||||||
)
|
mock_settings.AWS_SECRET_ACCESS_KEY = "testing"
|
||||||
parser.addoption(
|
yield S3_TEST_BUCKET
|
||||||
"--runner-dir",
|
|
||||||
default=None,
|
|
||||||
help="Override fixture folder for agent_runner_v2 eval tests (must contain cases.yaml + data/)",
|
|
||||||
)
|
|
||||||
parser.addoption(
|
|
||||||
"--journey-dir",
|
|
||||||
default=None,
|
|
||||||
help="Override fixture folder for journey_v2 eval tests (must contain cases.yaml + data/)",
|
|
||||||
)
|
|
||||||
|
|||||||
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
86
tests/fixtures/agent_runner_v2/cases.yaml
vendored
@@ -1,86 +0,0 @@
|
|||||||
# Agent Runner V2 — eval test cases (Step 2, requires real LLM)
|
|
||||||
#
|
|
||||||
# Each case drives one parametrized `test_eval_runner` invocation.
|
|
||||||
#
|
|
||||||
# Keys
|
|
||||||
# ----
|
|
||||||
# id: str unique identifier shown in pytest output
|
|
||||||
# description: str human-readable label
|
|
||||||
# file: str filename inside data/
|
|
||||||
# file_path: str path reported to the executor (affects project-matching via filename)
|
|
||||||
# projects: [alpha|beta] symbolic project names resolved by the test helper
|
|
||||||
#
|
|
||||||
# Optional pre-existing records (dedup tests)
|
|
||||||
# existing_tasks: list of {id, title, status, priority}
|
|
||||||
# existing_notes: list of {id, title, content}
|
|
||||||
# existing_timelines: list of {id, title, date}
|
|
||||||
#
|
|
||||||
# Assertions (one or more)
|
|
||||||
# expect_insert: <table> at least 1 insert row in this table (tasks|notes|timelines)
|
|
||||||
# expect_no_insert: true zero inserts in any table
|
|
||||||
# expect_project_id: <id> any insert must carry this projectId
|
|
||||||
# expect_dedup: true task inserts == 0 OR task updates >= 1 (dedup check)
|
|
||||||
#
|
|
||||||
# Langfuse
|
|
||||||
# score_name: str observation score name
|
|
||||||
|
|
||||||
- id: "2.1"
|
|
||||||
description: "Action email → create_task"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/ProjectAlpha_action.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_insert: tasks
|
|
||||||
score_name: runner.email_to_task
|
|
||||||
|
|
||||||
- id: "2.2"
|
|
||||||
description: "Informational email → create_note"
|
|
||||||
file: email_info.html
|
|
||||||
file_path: /emails/ProjectAlpha_info.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_insert: notes
|
|
||||||
score_name: runner.email_to_note
|
|
||||||
|
|
||||||
- id: "2.3"
|
|
||||||
description: "Email with meeting date → create_timeline"
|
|
||||||
file: email_date.html
|
|
||||||
file_path: /emails/ProjectAlpha_kickoff.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_insert: timelines
|
|
||||||
score_name: runner.email_to_timeline
|
|
||||||
|
|
||||||
- id: "2.4"
|
|
||||||
description: "Filename contains project name → correct project assigned"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/ProjectAlpha_report.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_project_id: proj-alpha
|
|
||||||
score_name: runner.project_filename
|
|
||||||
|
|
||||||
- id: "2.5"
|
|
||||||
description: "Email body mentions project → correct project assigned"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/email_001.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_project_id: proj-alpha
|
|
||||||
score_name: runner.project_content
|
|
||||||
|
|
||||||
- id: "2.6"
|
|
||||||
description: "Newsletter + global rule no-project → no creates"
|
|
||||||
file: email_no_project.html
|
|
||||||
file_path: /emails/newsletter.html
|
|
||||||
projects: [alpha, beta]
|
|
||||||
expect_no_insert: true
|
|
||||||
score_name: runner.no_project
|
|
||||||
|
|
||||||
- id: "2.7"
|
|
||||||
description: "Existing task with same title → dedup (update not create)"
|
|
||||||
file: email_action.html
|
|
||||||
file_path: /emails/ProjectAlpha_followup.html
|
|
||||||
projects: [alpha]
|
|
||||||
existing_tasks:
|
|
||||||
- id: task-existing
|
|
||||||
title: Fix the login bug
|
|
||||||
status: todo
|
|
||||||
priority: medium
|
|
||||||
expect_dedup: true
|
|
||||||
score_name: runner.dedup
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> boss@company.com</p>
|
|
||||||
<p><b>To:</b> dev@company.com</p>
|
|
||||||
<p><b>Subject:</b> Fix the login bug</p>
|
|
||||||
<p><b>Date:</b> 2026-04-07</p>
|
|
||||||
<p>Hi,<br>Please fix the login bug in Project Alpha by Friday. High priority!</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> pm@company.com</p>
|
|
||||||
<p><b>Subject:</b> Project Alpha kick-off meeting</p>
|
|
||||||
<p>The kick-off meeting for Project Alpha is scheduled for 2026-04-15 at 10:00.</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> pm@company.com</p>
|
|
||||||
<p><b>To:</b> team@company.com</p>
|
|
||||||
<p><b>Subject:</b> FYI: New policy for Project Alpha</p>
|
|
||||||
<p>Just a heads-up that starting next week all code reviews must be done
|
|
||||||
within 24 hours for Project Alpha. No action needed from you now.</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
<html><head></head><body>
|
|
||||||
<p><b>From:</b> newsletter@ads.com</p>
|
|
||||||
<p><b>Subject:</b> Weekly newsletter</p>
|
|
||||||
<p>Check out our latest deals on electronics!</p>
|
|
||||||
</body></html>
|
|
||||||
19
tests/fixtures/journey_v2/cases.yaml
vendored
19
tests/fixtures/journey_v2/cases.yaml
vendored
@@ -1,19 +0,0 @@
|
|||||||
# Journey V2 eval test cases — Step 4
|
|
||||||
#
|
|
||||||
# Only case 4.1 is kept as an automated eval. Cases 4.2–4.5 (multi-turn
|
|
||||||
# conversations that expect the LLM to produce a complete AgentConfig)
|
|
||||||
# are non-deterministic and tested manually — results tracked in Langfuse.
|
|
||||||
#
|
|
||||||
# Assertion keys:
|
|
||||||
# expect_question: true → first reply must contain "?"
|
|
||||||
|
|
||||||
- id: "4.1"
|
|
||||||
description: "Journey start explores directory, first reply contains a question"
|
|
||||||
directory: "/test/emails"
|
|
||||||
data_types: ["tasks", "notes", "timelines"]
|
|
||||||
directory_files:
|
|
||||||
- path: "/test/emails/outlook_export_2024.html"
|
|
||||||
content_file: "email_action.html"
|
|
||||||
user_messages: []
|
|
||||||
score_name: "journey.start"
|
|
||||||
expect_question: true
|
|
||||||
23
tests/fixtures/journey_v2/data/email_action.html
vendored
23
tests/fixtures/journey_v2/data/email_action.html
vendored
@@ -1,23 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>Email: Fix the login bug</title>
|
|
||||||
<style>body { font-family: Arial; } .header { color: #666; }</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="header">
|
|
||||||
<p><strong>From:</strong> boss@company.com</p>
|
|
||||||
<p><strong>To:</strong> dev@company.com</p>
|
|
||||||
<p><strong>Subject:</strong> Fix the login bug</p>
|
|
||||||
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:15:00 +0000</p>
|
|
||||||
</div>
|
|
||||||
<div class="body">
|
|
||||||
<p>Hi,</p>
|
|
||||||
<p>Please fix the login bug in Project Alpha as soon as possible.
|
|
||||||
Users are reporting that they can't log in with their Google accounts.
|
|
||||||
This is blocking the whole team. Please resolve it by Friday.</p>
|
|
||||||
<p>Thanks,<br>Boss</p>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
23
tests/fixtures/journey_v2/data/email_info.html
vendored
23
tests/fixtures/journey_v2/data/email_info.html
vendored
@@ -1,23 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>Email: New policy update</title>
|
|
||||||
<style>body { font-family: Arial; }</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="header">
|
|
||||||
<p><strong>From:</strong> hr@company.com</p>
|
|
||||||
<p><strong>To:</strong> all@company.com</p>
|
|
||||||
<p><strong>Subject:</strong> FYI: New remote work policy effective May 1</p>
|
|
||||||
<p><strong>Date:</strong> Tue, 8 Apr 2026 10:00:00 +0000</p>
|
|
||||||
</div>
|
|
||||||
<div class="body">
|
|
||||||
<p>Hi everyone,</p>
|
|
||||||
<p>Just a heads-up that starting May 1, 2026 the company will be moving to
|
|
||||||
a hybrid work model. You will be expected to come into the office at least
|
|
||||||
two days per week. More details will follow in the employee handbook.</p>
|
|
||||||
<p>Best,<br>HR Team</p>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
68
tests/fixtures/preprocessors/cases.yaml
vendored
68
tests/fixtures/preprocessors/cases.yaml
vendored
@@ -1,68 +0,0 @@
|
|||||||
# Preprocessor test cases
|
|
||||||
#
|
|
||||||
# detect: <expected_type> → chiama detect_content_type(filename, content)
|
|
||||||
# process: <content_type> → chiama preprocess(content_type, content)
|
|
||||||
#
|
|
||||||
# Sorgente: file: <nome in data/> oppure generate: binary_noise
|
|
||||||
#
|
|
||||||
# Assertions piatte (solo per process):
|
|
||||||
# no_html: true clean_text senza tag HTML
|
|
||||||
# min_chars: N len(clean_text) >= N
|
|
||||||
# ratio_lt: F len(clean) / len(raw) < F
|
|
||||||
# has_meta: [k, ...] chiavi presenti in metadata
|
|
||||||
# contains: str | [str] substring(s) presenti in clean_text
|
|
||||||
# excludes: str | [str] substring(s) assenti da clean_text
|
|
||||||
# content_type: str result.content_type == questo valore
|
|
||||||
|
|
||||||
- id: "1.1"
|
|
||||||
file: email_action.html
|
|
||||||
detect: email_html
|
|
||||||
|
|
||||||
- id: "1.2"
|
|
||||||
file: generic_page.html
|
|
||||||
detect: generic_html
|
|
||||||
|
|
||||||
- id: "1.3"
|
|
||||||
file: notes.txt
|
|
||||||
detect: plain_text
|
|
||||||
|
|
||||||
- id: "1.4"
|
|
||||||
file: archive.xyz
|
|
||||||
generate: binary_noise
|
|
||||||
detect: unknown
|
|
||||||
|
|
||||||
- id: "1.5"
|
|
||||||
file: email_action.html
|
|
||||||
process: email_html
|
|
||||||
no_html: true
|
|
||||||
min_chars: 50
|
|
||||||
ratio_lt: 0.8
|
|
||||||
|
|
||||||
- id: "1.6"
|
|
||||||
file: email_action.html
|
|
||||||
process: email_html
|
|
||||||
has_meta: [subject, from]
|
|
||||||
|
|
||||||
- id: "1.7"
|
|
||||||
file: email_thread.html
|
|
||||||
process: email_html
|
|
||||||
contains: "Sure, I'll handle the deploy"
|
|
||||||
excludes: "Let's plan the deploy"
|
|
||||||
|
|
||||||
- id: "1.8"
|
|
||||||
file: email_single.html
|
|
||||||
process: email_html
|
|
||||||
contains: "deploy is done"
|
|
||||||
|
|
||||||
- id: "1.9"
|
|
||||||
file: email_heavy.html
|
|
||||||
process: email_html
|
|
||||||
no_html: true
|
|
||||||
min_chars: 30
|
|
||||||
excludes: [border-collapse, font-size]
|
|
||||||
|
|
||||||
- id: "1.10"
|
|
||||||
file: fallback.txt
|
|
||||||
process: unknown
|
|
||||||
min_chars: 1
|
|
||||||
content_type: unknown
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<title>Fix the login bug</title>
|
|
||||||
<style>
|
|
||||||
body { font-family: Arial, sans-serif; color: #333; margin: 0; padding: 20px; }
|
|
||||||
.header { background: #f5f5f5; padding: 10px; border-bottom: 1px solid #ddd; }
|
|
||||||
.body { padding: 20px; }
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<div class="header">
|
|
||||||
<p><strong>From:</strong> boss@company.com</p>
|
|
||||||
<p><strong>To:</strong> dev@company.com</p>
|
|
||||||
<p><strong>Subject:</strong> Fix the login bug</p>
|
|
||||||
<p><strong>Date:</strong> Mon, 7 Apr 2026 09:00:00 +0200</p>
|
|
||||||
</div>
|
|
||||||
<div class="body">
|
|
||||||
<p>Hi,</p>
|
|
||||||
<p>Please fix the login bug by Friday. It is blocking the release.</p>
|
|
||||||
<p>Priority: high. Let me know if you need anything.</p>
|
|
||||||
<p>Thanks,<br>Boss</p>
|
|
||||||
</div>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,49 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html>
|
|
||||||
<head>
|
|
||||||
<style>
|
|
||||||
table { border-collapse: collapse; width: 100%; max-width: 600px; margin: 0 auto; }
|
|
||||||
td { padding: 8px 12px; border: 1px solid #dddddd; font-size: 12px; color: #444444; }
|
|
||||||
.header-row { background-color: #003366; color: #ffffff; font-weight: bold; }
|
|
||||||
.label-col { background-color: #f0f0f0; width: 80px; font-weight: bold; }
|
|
||||||
.footer-row { font-size: 10px; color: #999999; text-align: center; }
|
|
||||||
</style>
|
|
||||||
</head>
|
|
||||||
<body bgcolor="#eeeeee">
|
|
||||||
<center>
|
|
||||||
<table cellpadding="0" cellspacing="0">
|
|
||||||
<tr class="header-row">
|
|
||||||
<td colspan="2">Company Internal Update</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td class="label-col">From:</td>
|
|
||||||
<td>newsletter@corp.com</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td class="label-col">Subject:</td>
|
|
||||||
<td>Q1 Results Update</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td class="label-col">Date:</td>
|
|
||||||
<td>Apr 7, 2026</td>
|
|
||||||
</tr>
|
|
||||||
<tr>
|
|
||||||
<td colspan="2">
|
|
||||||
<table width="100%" cellpadding="10">
|
|
||||||
<tr>
|
|
||||||
<td>
|
|
||||||
<p style="font-size:14px; font-weight:bold;">Dear Team,</p>
|
|
||||||
<p>Q1 results are in. Revenue up 15% year-over-year.</p>
|
|
||||||
<p>Please review the attached report and share any feedback by EOW.</p>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
</td>
|
|
||||||
</tr>
|
|
||||||
<tr class="footer-row">
|
|
||||||
<td colspan="2">Confidential — do not forward outside the company.</td>
|
|
||||||
</tr>
|
|
||||||
</table>
|
|
||||||
</center>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html><body>
|
|
||||||
<p><strong>From:</strong> alice@co.com</p>
|
|
||||||
<p><strong>To:</strong> team@co.com</p>
|
|
||||||
<p><strong>Subject:</strong> Quick update</p>
|
|
||||||
<p><strong>Date:</strong> Tue, 7 Apr 2026 10:30:00 +0200</p>
|
|
||||||
<p>The deploy is done. Everything looks good. No issues so far.</p>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,24 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html><body>
|
|
||||||
<div class="message-latest">
|
|
||||||
<p><strong>From:</strong> alice@co.com</p>
|
|
||||||
<p><strong>Subject:</strong> Re: Re: Deploy plan</p>
|
|
||||||
<p>Sure, I'll handle the deploy.</p>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<p>On Mon, Apr 6, 2026 at 3:00 PM, Bob <bob@co.com> wrote:</p>
|
|
||||||
<blockquote>
|
|
||||||
<p>From: bob@co.com</p>
|
|
||||||
<p>Can you handle the deploy?</p>
|
|
||||||
<p>On Sun, Apr 5, 2026 at 1:00 PM, Alice <alice@co.com> wrote:</p>
|
|
||||||
<blockquote>
|
|
||||||
<p>From: alice@co.com</p>
|
|
||||||
<p>Let's plan the deploy for Monday.</p>
|
|
||||||
<p>On Sat, Apr 4, 2026 at 11:00 AM, Charlie <charlie@co.com> wrote:</p>
|
|
||||||
<blockquote>
|
|
||||||
<p>From: charlie@co.com</p>
|
|
||||||
<p>We need to schedule the deploy. What day works?</p>
|
|
||||||
</blockquote>
|
|
||||||
</blockquote>
|
|
||||||
</blockquote>
|
|
||||||
</body></html>
|
|
||||||
@@ -1,3 +0,0 @@
|
|||||||
random text content without any structure
|
|
||||||
line two with some words
|
|
||||||
line three and more content here
|
|
||||||
@@ -1,35 +0,0 @@
|
|||||||
<!DOCTYPE html>
|
|
||||||
<html lang="en">
|
|
||||||
<head>
|
|
||||||
<meta charset="UTF-8">
|
|
||||||
<title>My Web App</title>
|
|
||||||
<link rel="stylesheet" href="styles.css">
|
|
||||||
</head>
|
|
||||||
<body>
|
|
||||||
<nav>
|
|
||||||
<a href="/">Home</a>
|
|
||||||
<a href="/about">About</a>
|
|
||||||
<a href="/contact">Contact</a>
|
|
||||||
</nav>
|
|
||||||
<main>
|
|
||||||
<header>
|
|
||||||
<h1>Welcome to My App</h1>
|
|
||||||
</header>
|
|
||||||
<article>
|
|
||||||
<p>This is a generic web page with no email headers.</p>
|
|
||||||
<p>It has navigation, main content, and a footer.</p>
|
|
||||||
</article>
|
|
||||||
<section>
|
|
||||||
<h2>Features</h2>
|
|
||||||
<ul>
|
|
||||||
<li>Fast</li>
|
|
||||||
<li>Reliable</li>
|
|
||||||
<li>Secure</li>
|
|
||||||
</ul>
|
|
||||||
</section>
|
|
||||||
</main>
|
|
||||||
<footer>
|
|
||||||
<p>© 2026 My App</p>
|
|
||||||
</footer>
|
|
||||||
</body>
|
|
||||||
</html>
|
|
||||||
15
tests/fixtures/preprocessors/data/notes.txt
vendored
15
tests/fixtures/preprocessors/data/notes.txt
vendored
@@ -1,15 +0,0 @@
|
|||||||
Meeting notes - April 7, 2026
|
|
||||||
|
|
||||||
Attendees: Alice, Bob, Charlie
|
|
||||||
|
|
||||||
Discussion points:
|
|
||||||
- Deploy scheduled for Friday
|
|
||||||
- Bug fix for login must be completed by Thursday
|
|
||||||
- Review Q1 numbers before EOW
|
|
||||||
|
|
||||||
Action items:
|
|
||||||
- Alice: fix login bug
|
|
||||||
- Bob: prepare deploy checklist
|
|
||||||
- Charlie: send Q1 report
|
|
||||||
|
|
||||||
Next meeting: April 14, 2026
|
|
||||||
871
tests/test_agent_runner.py
Normal file
871
tests/test_agent_runner.py
Normal file
@@ -0,0 +1,871 @@
|
|||||||
|
"""Tests for Step 3.4: agent_runner module.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
Unit:
|
||||||
|
- _is_overdue — cron schedule overdue detection
|
||||||
|
- _extract_items_from_content — LLM extraction + JSON parsing + validation
|
||||||
|
- _send_insert_to_client — tool_call frame construction + timeout
|
||||||
|
- run_local_agent — end-to-end local agent happy path
|
||||||
|
- run_local_agent — device offline path
|
||||||
|
- run_local_agent — file-read timeout path
|
||||||
|
- run_local_agent — LLM extraction error path
|
||||||
|
- run_cloud_agent — stub returns error immediately
|
||||||
|
- trigger_pending_runs — overdue local + cloud dispatched
|
||||||
|
- trigger_pending_runs — non-overdue skipped
|
||||||
|
- trigger_pending_runs — device_id filter for local agents
|
||||||
|
|
||||||
|
Integration:
|
||||||
|
- POST /agents/{id}/run — 404 on unknown agent
|
||||||
|
- POST /agents/{id}/run — creates run log + dispatches background task
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.core.agent_runner import (
|
||||||
|
_extract_items_from_content,
|
||||||
|
_is_overdue,
|
||||||
|
_send_insert_to_client,
|
||||||
|
run_cloud_agent,
|
||||||
|
run_local_agent,
|
||||||
|
trigger_pending_runs,
|
||||||
|
)
|
||||||
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_FREE_UID = TEST_USER_IDS["free"]
|
||||||
|
_PRO_UID = TEST_USER_IDS["pro"]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
|
||||||
|
return LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
name="Test Local Agent",
|
||||||
|
directory_paths=["/home/user/emails"],
|
||||||
|
data_types=["tasks", "notes"],
|
||||||
|
prompt_template="Extract tasks and notes from this document.",
|
||||||
|
file_extensions=[".txt", ".eml"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
|
||||||
|
return CloudAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
provider="gmail",
|
||||||
|
name="Test Gmail Agent",
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks from email.",
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
|
||||||
|
return AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
|
||||||
|
mgr = DeviceConnectionManager()
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
mgr.register(user_id, device_id, ws)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _is_overdue
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_is_overdue_never_run():
|
||||||
|
"""An agent that has never run is always overdue."""
|
||||||
|
assert _is_overdue("0 */6 * * *", None) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_very_recently_run():
|
||||||
|
"""An agent that just ran is not overdue."""
|
||||||
|
last = datetime.now(timezone.utc)
|
||||||
|
assert _is_overdue("0 */6 * * *", last) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_long_ago():
|
||||||
|
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
|
||||||
|
from datetime import timedelta
|
||||||
|
last = datetime.now(timezone.utc) - timedelta(days=2)
|
||||||
|
assert _is_overdue("0 */6 * * *", last) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_invalid_cron_returns_false():
|
||||||
|
"""Unparseable cron must not raise and should return False (fail-safe)."""
|
||||||
|
assert _is_overdue("not a cron", None) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_naive_datetime():
|
||||||
|
"""Naive datetime objects are handled without raising."""
|
||||||
|
from datetime import timedelta
|
||||||
|
last = datetime.utcnow() - timedelta(days=1) # naive
|
||||||
|
# Should not raise.
|
||||||
|
result = _is_overdue("0 */6 * * *", last)
|
||||||
|
assert isinstance(result, bool)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _extract_items_from_content
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_happy_path():
|
||||||
|
"""LLM returns valid JSON array; items with allowed tables are returned."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
|
||||||
|
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content(
|
||||||
|
"Extract tasks and notes.",
|
||||||
|
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
|
||||||
|
["tasks", "notes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(items) == 2
|
||||||
|
assert items[0]["table"] == "tasks"
|
||||||
|
assert items[0]["data"]["title"] == "Buy milk"
|
||||||
|
assert items[1]["table"] == "notes"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_strips_forbidden_fields():
|
||||||
|
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{
|
||||||
|
"table": "tasks",
|
||||||
|
"data": {
|
||||||
|
"title": "Review PR",
|
||||||
|
"id": "should-be-removed",
|
||||||
|
"createdAt": 99999,
|
||||||
|
"isAiSuggested": 0,
|
||||||
|
"isApproved": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
data = items[0]["data"]
|
||||||
|
assert "id" not in data
|
||||||
|
assert "createdAt" not in data
|
||||||
|
assert "isAiSuggested" not in data
|
||||||
|
assert "isApproved" not in data
|
||||||
|
assert data["title"] == "Review PR"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_invalid_json_returns_empty():
|
||||||
|
"""LLM returning invalid JSON must return empty list without raising."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = "Sorry, I cannot extract anything."
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||||
|
|
||||||
|
assert items == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_disallowed_table_filtered():
|
||||||
|
"""Items whose table is not in data_types are discarded."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{"table": "tasks", "data": {"title": "Valid task"}},
|
||||||
|
{"table": "projects", "data": {"name": "Should be filtered"}},
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
# Only "tasks" is in data_types — "projects" should be filtered.
|
||||||
|
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0]["table"] == "tasks"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_empty_data_types_returns_empty():
|
||||||
|
"""If no allowed data_types match, skip LLM call and return immediately."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content("Extract.", "content", [])
|
||||||
|
|
||||||
|
mock_llm.ainvoke.assert_not_called()
|
||||||
|
assert items == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_llm_error_propagates():
|
||||||
|
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
with pytest.raises(RuntimeError, match="API unavailable"):
|
||||||
|
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _send_insert_to_client
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_insert_to_client_happy_path():
|
||||||
|
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
sent_payloads: list[dict] = []
|
||||||
|
original_send = mgr.send_frame
|
||||||
|
|
||||||
|
async def _capture_send(uid: str, frame: dict) -> None:
|
||||||
|
sent_payloads.append(frame)
|
||||||
|
# Immediately resolve the pending call with a success result.
|
||||||
|
call_id = frame["id"]
|
||||||
|
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
|
||||||
|
|
||||||
|
mgr.send_frame = _capture_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(sent_payloads) == 1
|
||||||
|
payload = sent_payloads[0]
|
||||||
|
assert payload["action"] == "insert"
|
||||||
|
assert payload["table"] == "tasks"
|
||||||
|
assert payload["data"]["title"] == "Buy milk"
|
||||||
|
assert payload["data"]["isAiSuggested"] == 1
|
||||||
|
assert payload["data"]["isApproved"] == 0
|
||||||
|
assert result["row"]["title"] == "Buy milk"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_insert_to_client_timeout():
|
||||||
|
"""asyncio.TimeoutError is raised when Electron does not respond."""
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
async def _slow_send(uid: str, frame: dict) -> None:
|
||||||
|
# Never resolve the pending call.
|
||||||
|
pass
|
||||||
|
|
||||||
|
mgr.send_frame = _slow_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_local_agent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_device_offline():
|
||||||
|
"""run_local_agent marks run as error when device is offline."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = DeviceConnectionManager() # Empty — no device registered.
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("not connected" in e for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_happy_path():
|
||||||
|
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
# Build a fake agent_data frame (will be queued after send).
|
||||||
|
file_frame = {
|
||||||
|
"type": "agent_data",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
|
||||||
|
}
|
||||||
|
agent_complete_frame = None # sentinel
|
||||||
|
|
||||||
|
sent_frames: list[dict] = []
|
||||||
|
|
||||||
|
async def _mock_send(uid: str, frame: dict) -> None:
|
||||||
|
sent_frames.append(frame)
|
||||||
|
if frame.get("type") == "agent_run":
|
||||||
|
# Simulate Electron responding with file data then agent_complete.
|
||||||
|
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||||
|
await q.put(file_frame)
|
||||||
|
await q.put(agent_complete_frame)
|
||||||
|
elif frame.get("type") == "tool_call":
|
||||||
|
# Resolve the pending insert immediately.
|
||||||
|
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
|
||||||
|
|
||||||
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "success"
|
||||||
|
assert kwargs["items_processed"] == 1
|
||||||
|
assert kwargs["items_created"] == 1
|
||||||
|
assert kwargs["errors"] == []
|
||||||
|
assert kwargs["update_config_last_run"] is True
|
||||||
|
|
||||||
|
# Verify agent_run frame was sent.
|
||||||
|
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||||
|
assert len(agent_run_frames) == 1
|
||||||
|
assert agent_run_frames[0]["agent_id"] == config.id
|
||||||
|
assert "paths" in agent_run_frames[0]["config"]
|
||||||
|
|
||||||
|
# Verify insert frame was sent with AI flags.
|
||||||
|
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
|
||||||
|
assert len(insert_frames) == 1
|
||||||
|
assert insert_frames[0]["data"]["isAiSuggested"] == 1
|
||||||
|
assert insert_frames[0]["data"]["isApproved"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_file_read_timeout():
|
||||||
|
"""run_local_agent marks run as partial/error when device stops sending files."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
async def _mock_send(uid: str, frame: dict) -> None:
|
||||||
|
# Don't put anything in the queue — simulate stalled device.
|
||||||
|
pass
|
||||||
|
|
||||||
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error" # No items created, so error (not partial).
|
||||||
|
assert any("timed out" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_llm_extraction_error():
|
||||||
|
"""LLM errors per-file are recorded; run continues for remaining files."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
file_frame = {
|
||||||
|
"type": "agent_data",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"files": [
|
||||||
|
{"path": "/file1.eml", "content": "Email one."},
|
||||||
|
{"path": "/file2.eml", "content": "Email two."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _mock_send(uid: str, frame: dict) -> None:
|
||||||
|
if frame.get("type") == "agent_run":
|
||||||
|
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||||
|
await q.put(file_frame)
|
||||||
|
await q.put(None) # agent_complete sentinel
|
||||||
|
|
||||||
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert kwargs["items_processed"] == 2 # Both files attempted.
|
||||||
|
assert kwargs["items_created"] == 0
|
||||||
|
assert len(kwargs["errors"]) == 2 # One error per file.
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_cloud_agent (stub)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_device_offline():
|
||||||
|
"""Cloud agent aborts immediately when no device is connected."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = DeviceConnectionManager() # empty — no devices registered
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_no_oauth_token():
|
||||||
|
"""Cloud agent errors when no OAuth token is stored."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = None
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_token_decrypt_failure():
|
||||||
|
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
valid_key = _Fernet.generate_key().decode()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.integrations.settings") as mock_settings:
|
||||||
|
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_happy_path_gmail():
|
||||||
|
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
||||||
|
from app.integrations import EmailMessage, encrypt_token
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
|
||||||
|
fernet_key = _Fernet.generate_key().decode()
|
||||||
|
credentials = {
|
||||||
|
"token": "access_abc",
|
||||||
|
"refresh_token": "refresh_xyz",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "cid",
|
||||||
|
"client_secret": "csec",
|
||||||
|
}
|
||||||
|
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.provider = "gmail"
|
||||||
|
config.prompt_template = "Extract tasks from this email."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as ms:
|
||||||
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||||
|
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
sample_email = EmailMessage(
|
||||||
|
id="msg001",
|
||||||
|
subject="Action required",
|
||||||
|
sender="boss@company.com",
|
||||||
|
body_text="Please fix the bug by Friday.",
|
||||||
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as mock_int_settings, \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
||||||
|
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
||||||
|
patch("app.core.agent_runner.async_session"):
|
||||||
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
|
||||||
|
mock_gmail = AsyncMock()
|
||||||
|
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
||||||
|
mock_gmail.refreshed_credentials = None
|
||||||
|
|
||||||
|
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_gmail):
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_extract.assert_called_once()
|
||||||
|
mock_insert.assert_called_once()
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "success"
|
||||||
|
assert kwargs["items_processed"] == 1
|
||||||
|
assert kwargs["items_created"] == 1
|
||||||
|
assert kwargs["config_type"] == "cloud"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_provider_fetch_error():
|
||||||
|
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
||||||
|
credentials = {"token": "abc"}
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
||||||
|
config.prompt_template = "Extract tasks."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
||||||
|
mock_provider.refreshed_credentials = None
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||||
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||||
|
patch("app.core.agent_runner.async_session"):
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||||
|
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||||
|
from app.integrations import EmailMessage, encrypt_token
|
||||||
|
from cryptography.fernet import Fernet as _Fernet
|
||||||
|
|
||||||
|
fernet_key = _Fernet.generate_key().decode()
|
||||||
|
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
||||||
|
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
||||||
|
|
||||||
|
config = _make_cloud_config()
|
||||||
|
config.prompt_template = "Extract tasks."
|
||||||
|
config.data_types = ["tasks"]
|
||||||
|
|
||||||
|
with patch("app.integrations.settings") as ms:
|
||||||
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||||
|
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
mock_provider = AsyncMock()
|
||||||
|
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
||||||
|
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
||||||
|
|
||||||
|
# Track DB writes via mock async_session.
|
||||||
|
mock_cfg_row = MagicMock()
|
||||||
|
mock_cfg_row.oauth_token_encrypted = None
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||||
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
||||||
|
cfg_result = MagicMock()
|
||||||
|
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
||||||
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||||
|
mock_db.commit = AsyncMock()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
||||||
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||||
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||||
|
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
||||||
|
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
||||||
|
patch("app.integrations.settings") as mock_int_settings:
|
||||||
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
# The new encrypted token should have been written to the config row.
|
||||||
|
mock_encrypt.assert_called_once_with(fresh_credentials)
|
||||||
|
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_finalize_run_updates_cloud_config_last_run_at():
|
||||||
|
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
||||||
|
from app.core.agent_runner import _finalize_run
|
||||||
|
|
||||||
|
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
||||||
|
run_log.id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
mock_cfg = MagicMock()
|
||||||
|
mock_cfg.last_run_at = None
|
||||||
|
|
||||||
|
cfg_result = MagicMock()
|
||||||
|
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
||||||
|
|
||||||
|
mock_db = AsyncMock()
|
||||||
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||||
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_db.merge = AsyncMock(return_value=run_log)
|
||||||
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||||
|
mock_db.commit = AsyncMock()
|
||||||
|
|
||||||
|
config_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="success",
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config_id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
|
||||||
|
# CloudAgentConfig.last_run_at should have been set.
|
||||||
|
assert mock_cfg.last_run_at is not None
|
||||||
|
mock_db.commit.assert_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# trigger_pending_runs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_pending_runs_no_overdue():
|
||||||
|
"""If no agents are overdue trigger_pending_runs does nothing."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
config = _make_local_config()
|
||||||
|
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
||||||
|
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_pending_runs_device_id_filter():
|
||||||
|
"""Local agents are only triggered for the matching device_id."""
|
||||||
|
# The DB query already filters by device_id, so we verify the SELECT
|
||||||
|
# includes the device_id filter by checking that a config bound to a
|
||||||
|
# different device is never dispatched.
|
||||||
|
#
|
||||||
|
# Since trigger_pending_runs queries with device_id == "dev-001",
|
||||||
|
# simulate the DB returning an empty list (as it would for a mismatch).
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
|
mgr = _make_manager(device_id="dev-001")
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_pending_runs_dispatches_overdue():
|
||||||
|
"""Overdue local agent triggers run_local_agent sequentially."""
|
||||||
|
config = _make_local_config() # last_run_at=None → always overdue
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
call_order: list[str] = []
|
||||||
|
|
||||||
|
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
||||||
|
call_order.append("run_local")
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
||||||
|
# First call: query configs. Subsequent calls: create run_log.
|
||||||
|
mock_query_ctx = AsyncMock()
|
||||||
|
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
||||||
|
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_query_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
|
||||||
|
run_log_obj = AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=config.id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=_FREE_UID,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
mock_insert_ctx = AsyncMock()
|
||||||
|
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
||||||
|
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_insert_ctx.add = MagicMock()
|
||||||
|
mock_insert_ctx.commit = AsyncMock()
|
||||||
|
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
||||||
|
|
||||||
|
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
||||||
|
|
||||||
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
|
assert call_order == ["run_local"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration: POST /agents/{id}/run
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
"""Route all get_session calls to the test SQLite session."""
|
||||||
|
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_run_unknown_agent(client):
|
||||||
|
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/v1/agents/{uuid.uuid4()}/run",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||||
|
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
||||||
|
# Create the local agent config in the DB.
|
||||||
|
config = LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=TEST_USER_IDS["power"],
|
||||||
|
device_id="dev-001",
|
||||||
|
name="My Agent",
|
||||||
|
directory_paths=["/home/user/docs"],
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks.",
|
||||||
|
file_extensions=[".txt"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
db_session.add(config)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
dispatched: list = []
|
||||||
|
|
||||||
|
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||||
|
dispatched.append((user_id, cfg.id))
|
||||||
|
|
||||||
|
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||||
|
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
||||||
|
patch("asyncio.create_task") as mock_create_task:
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/v1/agents/{config.id}/run",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 202
|
||||||
|
data = resp.json()
|
||||||
|
assert data["agent_id"] == config.id
|
||||||
|
assert data["status"] == "running"
|
||||||
|
assert data["agent_type"] == "local"
|
||||||
|
|
||||||
|
# Verify create_task was called (dispatching background run).
|
||||||
|
mock_create_task.assert_called_once()
|
||||||
@@ -1,430 +0,0 @@
|
|||||||
"""Tests for Local Agent V2 runner (Step 2).
|
|
||||||
|
|
||||||
Covers the unified per-file flow:
|
|
||||||
Phase A — detect + preprocess (Python, zero LLM)
|
|
||||||
Phase B — single LLM call with tools (classify + extract + create)
|
|
||||||
|
|
||||||
Fixture-based eval tests (2.1–2.7)
|
|
||||||
-----------------------------------
|
|
||||||
Cases are defined in tests/fixtures/agent_runner_v2/cases.yaml.
|
|
||||||
Email HTML files live in tests/fixtures/agent_runner_v2/data/.
|
|
||||||
Use --runner-dir to point at a custom folder (same structure required).
|
|
||||||
|
|
||||||
Unit tests (no LLM)
|
|
||||||
--------------------
|
|
||||||
2.8 items_created count → items_created == N create_* calls
|
|
||||||
2.9 Device offline → status=error
|
|
||||||
2.10 Empty file → items_processed=0, status=success
|
|
||||||
|
|
||||||
Run:
|
|
||||||
pytest tests/test_agent_runner_v2.py -v
|
|
||||||
pytest tests/test_agent_runner_v2.py -v -k "2_9 or 2_10 or 2_8" # unit only
|
|
||||||
pytest tests/test_agent_runner_v2.py -v -k "eval" # LLM evals only
|
|
||||||
pytest tests/test_agent_runner_v2.py -v --runner-dir /path/to/dir # custom fixtures
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from contextlib import nullcontext
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
from app.core.agent_runner import (
|
|
||||||
_format_metadata,
|
|
||||||
_format_projects,
|
|
||||||
_get_extraction_rules,
|
|
||||||
_get_no_match_behavior,
|
|
||||||
run_local_agent,
|
|
||||||
)
|
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
|
||||||
from app.core.langfuse_client import get_langfuse
|
|
||||||
from app.models import AgentRunLog, LocalAgentConfig
|
|
||||||
from tests.conftest import TEST_USER_IDS
|
|
||||||
|
|
||||||
# ── Constants ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_USER_ID = TEST_USER_IDS["power"]
|
|
||||||
|
|
||||||
_DEFAULT_FIXTURE_DIR = Path(__file__).parent / "fixtures" / "agent_runner_v2"
|
|
||||||
|
|
||||||
_AGENT_CONFIG = {
|
|
||||||
"content_types": [
|
|
||||||
{
|
|
||||||
"id": "email_html",
|
|
||||||
"label": "Email HTML",
|
|
||||||
"detection_hint": "HTML file with From/To/Subject headers",
|
|
||||||
"preprocessing": "email_html",
|
|
||||||
"extraction_prompt": (
|
|
||||||
"If the email contains a direct action request or task assignment → create a task. "
|
|
||||||
"If the email contains informational content, updates, or FYI → create a note. "
|
|
||||||
"If the email mentions a specific date for a meeting or deadline → create a timeline entry."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
],
|
|
||||||
"global_rules": [
|
|
||||||
"Se il file non è riconducibile a nessun progetto, non creare alcuna entità."
|
|
||||||
],
|
|
||||||
"data_types": ["tasks", "notes", "timelines"],
|
|
||||||
}
|
|
||||||
|
|
||||||
# Canonical project definitions, referenced symbolically in cases.yaml.
|
|
||||||
_PROJECTS: dict[str, dict] = {
|
|
||||||
"alpha": {"id": "proj-alpha", "name": "Project Alpha", "status": "active"},
|
|
||||||
"beta": {"id": "proj-beta", "name": "Project Beta", "status": "active"},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fixture loading ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _fixtures_dir(config) -> Path:
|
|
||||||
override = config.getoption("--runner-dir")
|
|
||||||
return Path(override) if override else _DEFAULT_FIXTURE_DIR
|
|
||||||
|
|
||||||
|
|
||||||
def _load_cases(config) -> list[dict]:
|
|
||||||
return yaml.safe_load(
|
|
||||||
(_fixtures_dir(config) / "cases.yaml").read_text(encoding="utf-8")
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _read_case_file(case: dict, data_dir: Path) -> str:
|
|
||||||
return (data_dir / case["file"]).read_text(encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
def _resolve_projects(entries: list[str | dict]) -> list[dict]:
|
|
||||||
"""Resolve project list from YAML: symbolic names and/or inline dicts."""
|
|
||||||
result = []
|
|
||||||
for entry in entries:
|
|
||||||
if isinstance(entry, str):
|
|
||||||
if entry in _PROJECTS:
|
|
||||||
result.append(_PROJECTS[entry])
|
|
||||||
elif isinstance(entry, dict):
|
|
||||||
result.append(entry)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
# ── pytest_generate_tests — parametrize eval tests from YAML ─────────────
|
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
|
||||||
if "runner_case" not in metafunc.fixturenames:
|
|
||||||
return
|
|
||||||
cases = _load_cases(metafunc.config)
|
|
||||||
metafunc.parametrize("runner_case", cases, ids=[c["id"] for c in cases])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Test helpers ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _make_config(
|
|
||||||
agent_config: dict | None = None,
|
|
||||||
directory: str = "/emails",
|
|
||||||
device_id: str = "dev-001",
|
|
||||||
) -> LocalAgentConfig:
|
|
||||||
return LocalAgentConfig(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=_USER_ID,
|
|
||||||
device_id=device_id,
|
|
||||||
name="Test V2 Agent",
|
|
||||||
directory_paths=[directory],
|
|
||||||
data_types=["tasks", "notes", "timelines"],
|
|
||||||
prompt_template="",
|
|
||||||
agent_config=agent_config or _AGENT_CONFIG,
|
|
||||||
file_extensions=[".html", ".eml"],
|
|
||||||
schedule_cron="0 */6 * * *",
|
|
||||||
enabled=True,
|
|
||||||
last_run_at=None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_run_log(agent_id: str) -> AgentRunLog:
|
|
||||||
return AgentRunLog(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=_USER_ID,
|
|
||||||
status="running",
|
|
||||||
started_at=datetime.now(timezone.utc),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _make_manager(online: bool = True) -> DeviceConnectionManager:
|
|
||||||
mgr = DeviceConnectionManager()
|
|
||||||
if online:
|
|
||||||
ws = MagicMock()
|
|
||||||
ws.send_text = AsyncMock()
|
|
||||||
mgr.register(_USER_ID, "dev-001", ws)
|
|
||||||
return mgr
|
|
||||||
|
|
||||||
|
|
||||||
def _make_executor(
|
|
||||||
file_path: str,
|
|
||||||
file_content: str,
|
|
||||||
projects: list[dict] | None = None,
|
|
||||||
existing_tasks: list[dict] | None = None,
|
|
||||||
existing_notes: list[dict] | None = None,
|
|
||||||
existing_timelines: list[dict] | None = None,
|
|
||||||
) -> tuple[Any, list[dict]]:
|
|
||||||
"""Return (async_executor, captured_calls).
|
|
||||||
|
|
||||||
The executor handles all ``execute_on_client`` payloads:
|
|
||||||
directory listing, file reading, project/entity fetching, and CRUD.
|
|
||||||
"""
|
|
||||||
calls: list[dict] = []
|
|
||||||
_projects = projects if projects is not None else list(_PROJECTS.values())
|
|
||||||
|
|
||||||
async def _executor(payload: dict) -> dict:
|
|
||||||
action = payload.get("action", "")
|
|
||||||
table = payload.get("table", "")
|
|
||||||
data = payload.get("data") or {}
|
|
||||||
calls.append({"action": action, "table": table, "data": data})
|
|
||||||
|
|
||||||
if action == "list_directory":
|
|
||||||
return {"entries": [{"type": "file", "path": file_path}]}
|
|
||||||
|
|
||||||
if action == "get_file_metadata":
|
|
||||||
return {"modifiedAt": None}
|
|
||||||
|
|
||||||
if action == "read_file_content":
|
|
||||||
return {"content": file_content}
|
|
||||||
|
|
||||||
if action == "select":
|
|
||||||
if table == "projects":
|
|
||||||
return {"rows": _projects}
|
|
||||||
if table == "tasks":
|
|
||||||
return {"rows": existing_tasks or []}
|
|
||||||
if table == "notes":
|
|
||||||
return {"rows": existing_notes or []}
|
|
||||||
if table == "timelines":
|
|
||||||
return {"rows": existing_timelines or []}
|
|
||||||
return {"rows": []}
|
|
||||||
|
|
||||||
if action == "insert":
|
|
||||||
return {"row": {"id": str(uuid.uuid4()), **data}}
|
|
||||||
|
|
||||||
if action == "update":
|
|
||||||
return {"success": True}
|
|
||||||
|
|
||||||
return {}
|
|
||||||
|
|
||||||
return _executor, calls
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: helper functions ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_projects_empty():
|
|
||||||
assert "(no projects" in _format_projects([])
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_projects_with_data():
|
|
||||||
result = _format_projects([_PROJECTS["alpha"]])
|
|
||||||
assert "proj-alpha" in result
|
|
||||||
assert "Project Alpha" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_metadata_empty():
|
|
||||||
assert _format_metadata({}) == ""
|
|
||||||
|
|
||||||
|
|
||||||
def test_format_metadata_email():
|
|
||||||
meta = {"subject": "Fix bug", "from": "boss@co.com", "date": "2026-04-07"}
|
|
||||||
result = _format_metadata(meta)
|
|
||||||
assert "Fix bug" in result
|
|
||||||
assert "boss@co.com" in result
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_extraction_rules_match():
|
|
||||||
rules = _get_extraction_rules(_AGENT_CONFIG, "email_html")
|
|
||||||
assert "task" in rules.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_extraction_rules_fallback():
|
|
||||||
rules = _get_extraction_rules(_AGENT_CONFIG, "plain_text")
|
|
||||||
assert "extract" in rules.lower()
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_no_match_behavior_from_global_rules():
|
|
||||||
behavior = _get_no_match_behavior(_AGENT_CONFIG)
|
|
||||||
assert behavior # non-empty
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_no_match_behavior_default():
|
|
||||||
behavior = _get_no_match_behavior({})
|
|
||||||
assert "project" in behavior.lower()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: 2.9 — device offline ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_2_9_device_offline():
|
|
||||||
"""2.9 No device online → status=error, no executor created."""
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager(online=False)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
assert kwargs["status"] == "error"
|
|
||||||
assert any("not connected" in e for e in kwargs.get("errors", []))
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: 2.10 — empty file ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_2_10_empty_file():
|
|
||||||
"""2.10 File with empty content → skipped, items_processed=0, success."""
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
executor, calls = _make_executor(
|
|
||||||
file_path="/emails/empty.html",
|
|
||||||
file_content="",
|
|
||||||
projects=[_PROJECTS["alpha"]],
|
|
||||||
)
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
assert kwargs["items_processed"] == 0
|
|
||||||
assert kwargs["status"] == "success"
|
|
||||||
assert kwargs["items_created"] == 0
|
|
||||||
|
|
||||||
|
|
||||||
# ── Unit: 2.8 — items_created count ─────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_2_8_items_created_count():
|
|
||||||
"""2.8 items_created == number of create_* tool calls per run."""
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
executor, _calls = _make_executor(
|
|
||||||
file_path="/emails/action.html",
|
|
||||||
file_content="<html><body><p>Fix the login bug in Project Alpha.</p></body></html>",
|
|
||||||
projects=[_PROJECTS["alpha"]],
|
|
||||||
)
|
|
||||||
|
|
||||||
async def mock_run_agent(*, _tool_calls_out=None, **kw) -> str:
|
|
||||||
if _tool_calls_out is not None:
|
|
||||||
_tool_calls_out.extend(["create_task", "create_note", "update_task"])
|
|
||||||
return "Done."
|
|
||||||
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
|
||||||
patch("app.core.agent_runner._run_agent_with_tools", side_effect=mock_run_agent), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
# Only create_task + create_note count (not update_task).
|
|
||||||
assert kwargs["items_created"] == 2
|
|
||||||
assert kwargs["items_processed"] == 1
|
|
||||||
|
|
||||||
|
|
||||||
# ── Eval: 2.1–2.7 — fixture-driven, real LLM + Langfuse scoring ──────────
|
|
||||||
#
|
|
||||||
# Cases loaded from tests/fixtures/agent_runner_v2/cases.yaml.
|
|
||||||
# Supported assertions (from YAML):
|
|
||||||
# expect_insert: <table> → at least 1 insert in that table
|
|
||||||
# expect_no_insert: true → zero inserts in any table
|
|
||||||
# expect_project_id: <id> → any insert carries this projectId
|
|
||||||
# expect_dedup: true → task inserts == 0 OR task updates >= 1
|
|
||||||
# ─────────────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.eval
|
|
||||||
async def test_eval_runner(runner_case, pytestconfig):
|
|
||||||
"""Parametrized eval test — one invocation per YAML case."""
|
|
||||||
case: dict = runner_case
|
|
||||||
data_dir = _fixtures_dir(pytestconfig) / "data"
|
|
||||||
file_content = _read_case_file(case, data_dir)
|
|
||||||
projects = _resolve_projects(case.get("projects", []))
|
|
||||||
|
|
||||||
config = _make_config()
|
|
||||||
run_log = _make_run_log(config.id)
|
|
||||||
mgr = _make_manager()
|
|
||||||
|
|
||||||
executor, calls = _make_executor(
|
|
||||||
file_path=case["file_path"],
|
|
||||||
file_content=file_content,
|
|
||||||
projects=projects,
|
|
||||||
existing_tasks=case.get("existing_tasks"),
|
|
||||||
existing_notes=case.get("existing_notes"),
|
|
||||||
existing_timelines=case.get("existing_timelines"),
|
|
||||||
)
|
|
||||||
|
|
||||||
lf = get_langfuse()
|
|
||||||
obs_ctx = lf.start_as_current_observation(
|
|
||||||
name=f"eval-runner-{case['id']}-{case.get('score_name', 'unknown').replace('.', '-')}",
|
|
||||||
metadata={"step": "2", "case_id": case["id"]},
|
|
||||||
) if lf else nullcontext()
|
|
||||||
|
|
||||||
with obs_ctx as obs:
|
|
||||||
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
|
|
||||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
|
|
||||||
await run_local_agent(_USER_ID, config, run_log, mgr)
|
|
||||||
|
|
||||||
_, kwargs = mock_fin.call_args
|
|
||||||
score, comment = _evaluate_case(case, calls, kwargs)
|
|
||||||
|
|
||||||
if obs is not None:
|
|
||||||
obs.score(
|
|
||||||
name=case.get("score_name", f"runner.case_{case['id']}"),
|
|
||||||
value=score,
|
|
||||||
comment=comment,
|
|
||||||
)
|
|
||||||
|
|
||||||
if lf:
|
|
||||||
lf.flush()
|
|
||||||
|
|
||||||
assert score == 1.0, f"[{case['id']}] {case.get('description', '')} — {comment}"
|
|
||||||
|
|
||||||
|
|
||||||
def _evaluate_case(case: dict, calls: list[dict], finalize_kwargs: dict) -> tuple[float, str]:
|
|
||||||
"""Return (score, comment) for a YAML case given the captured executor calls."""
|
|
||||||
inserts = [c for c in calls if c["action"] == "insert"]
|
|
||||||
|
|
||||||
if case.get("expect_no_insert"):
|
|
||||||
score = 1.0 if len(inserts) == 0 else 0.0
|
|
||||||
return score, f"inserts={len(inserts)} (expected 0)"
|
|
||||||
|
|
||||||
if "expect_insert" in case:
|
|
||||||
tables = case["expect_insert"]
|
|
||||||
if isinstance(tables, str):
|
|
||||||
tables = [tables]
|
|
||||||
missing = [t for t in tables if not any(c["table"] == t for c in inserts)]
|
|
||||||
score = 1.0 if not missing else 0.0
|
|
||||||
counts = {t: sum(1 for c in inserts if c["table"] == t) for t in tables}
|
|
||||||
return score, f"inserts={counts}" + (f" missing={missing}" if missing else "")
|
|
||||||
|
|
||||||
if "expect_project_id" in case:
|
|
||||||
expected_pid = case["expect_project_id"]
|
|
||||||
correct = any(c.get("data", {}).get("projectId") == expected_pid for c in inserts)
|
|
||||||
score = 1.0 if correct else 0.0
|
|
||||||
all_pids = [c.get("data", {}).get("projectId") for c in inserts]
|
|
||||||
return score, f"projectIds={all_pids} (expected {expected_pid!r})"
|
|
||||||
|
|
||||||
if case.get("expect_dedup"):
|
|
||||||
task_creates = [c for c in inserts if c["table"] == "tasks"]
|
|
||||||
task_updates = [c for c in calls if c["action"] == "update" and c["table"] == "tasks"]
|
|
||||||
score = 1.0 if len(task_creates) == 0 or len(task_updates) >= 1 else 0.0
|
|
||||||
return score, f"task_creates={len(task_creates)} task_updates={len(task_updates)}"
|
|
||||||
|
|
||||||
return 0.0, "no assertion defined in case"
|
|
||||||
243
tests/test_agent_setup.py
Normal file
243
tests/test_agent_setup.py
Normal file
@@ -0,0 +1,243 @@
|
|||||||
|
"""Tests for the Chatbot Journey endpoints.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
1. Start journey for local agent → session_id + first question, done=False
|
||||||
|
2. Start journey for cloud agent → contextual email-focused question
|
||||||
|
3. Start journey with existing agent_id → session seeded, first question returned
|
||||||
|
4. Start journey with non-existent agent_id → still succeeds (graceful fallback)
|
||||||
|
5. Message: continue conversation → done=False, follow-up question returned
|
||||||
|
6. Message: LLM wraps up → done=True + prompt_template extracted correctly
|
||||||
|
7. Message with max-turns nudge → no crash, returns response
|
||||||
|
8. Invalid session_id → 404
|
||||||
|
9. Expired session → 404
|
||||||
|
10. Session ownership: user B cannot access user A's session
|
||||||
|
11. No JWT on /start → 401
|
||||||
|
12. No JWT on /message → 401
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.routes.agent_setup import (
|
||||||
|
_SESSION_TTL_SECONDS,
|
||||||
|
_TEMPLATE_END,
|
||||||
|
_TEMPLATE_START,
|
||||||
|
_extract_template,
|
||||||
|
_sessions,
|
||||||
|
)
|
||||||
|
from app.models import LocalAgentConfig
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _start(client: TestClient, agent_type: str = "local", agent_id: str | None = None, tier: str = "power") -> dict:
|
||||||
|
body: dict = {"agent_type": agent_type}
|
||||||
|
if agent_id:
|
||||||
|
body["agent_id"] = agent_id
|
||||||
|
resp = client.post("/api/v1/agents/journey/start", json=body, headers=auth_header(tier))
|
||||||
|
return resp
|
||||||
|
|
||||||
|
|
||||||
|
def _message(client: TestClient, session_id: str, message: str, tier: str = "power") -> dict:
|
||||||
|
return client.post(
|
||||||
|
"/api/v1/agents/journey/message",
|
||||||
|
json={"session_id": session_id, "message": message},
|
||||||
|
headers=auth_header(tier),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Unit: _extract_template ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_template_present():
|
||||||
|
text = f"Some preamble.\n{_TEMPLATE_START}\nExtract tasks from emails.\n{_TEMPLATE_END}\nTrailing text."
|
||||||
|
result = _extract_template(text)
|
||||||
|
assert result == "Extract tasks from emails."
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_template_absent():
|
||||||
|
assert _extract_template("No markers here.") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_extract_template_empty_content():
|
||||||
|
text = f"{_TEMPLATE_START}\n{_TEMPLATE_END}"
|
||||||
|
assert _extract_template(text) is None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Start journey ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_local(client: TestClient):
|
||||||
|
resp = _start(client, agent_type="local")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert "session_id" in body
|
||||||
|
assert body["done"] is False
|
||||||
|
assert body["prompt_template"] is None
|
||||||
|
assert len(body["message"]) > 0
|
||||||
|
# Local question should be about files/directories
|
||||||
|
assert any(w in body["message"].lower() for w in ("file", "director", "document", "monitor"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_cloud(client: TestClient):
|
||||||
|
resp = _start(client, agent_type="cloud")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["done"] is False
|
||||||
|
# Cloud question should mention emails or messages
|
||||||
|
assert any(w in body["message"].lower() for w in ("email", "message", "communication"))
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_with_agent_id(client: TestClient, db_session: AsyncSession):
|
||||||
|
"""When agent_id is provided, session should be created even if agent doesn't exist."""
|
||||||
|
fake_agent_id = str(uuid.uuid4())
|
||||||
|
resp = _start(client, agent_type="local", agent_id=fake_agent_id)
|
||||||
|
# Should succeed gracefully even if the agent_id doesn't exist
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["done"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_with_existing_agent(client: TestClient, db_session: AsyncSession):
|
||||||
|
"""When a real local agent is provided, session is seeded with its prompt_template."""
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
user_id = TEST_USER_IDS["power"]
|
||||||
|
agent = LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
name="Test Agent",
|
||||||
|
device_id="device-1",
|
||||||
|
directory_paths=["/home/user/emails"],
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks from .eml files.",
|
||||||
|
file_extensions=[".eml"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _seed():
|
||||||
|
db_session.add(agent)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
asyncio.get_event_loop().run_until_complete(_seed())
|
||||||
|
|
||||||
|
resp = _start(client, agent_type="local", agent_id=agent.id)
|
||||||
|
assert resp.status_code == 200
|
||||||
|
body = resp.json()
|
||||||
|
assert body["done"] is False
|
||||||
|
# The session should be stored
|
||||||
|
assert body["session_id"] in _sessions
|
||||||
|
|
||||||
|
|
||||||
|
def test_start_journey_requires_auth(client: TestClient):
|
||||||
|
resp = client.post("/api/v1/agents/journey/start", json={"agent_type": "local"})
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
# ── Message ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_continues_conversation(client: TestClient):
|
||||||
|
"""A mid-journey reply (no template markers) returns done=False."""
|
||||||
|
follow_up = "That looks good. Can you tell me more about priority rules?"
|
||||||
|
|
||||||
|
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
||||||
|
start_resp = _start(client, agent_type="local")
|
||||||
|
assert start_resp.status_code == 200
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
msg_resp = _message(client, session_id, "I have .eml and .txt files")
|
||||||
|
assert msg_resp.status_code == 200
|
||||||
|
body = msg_resp.json()
|
||||||
|
assert body["done"] is False
|
||||||
|
assert body["prompt_template"] is None
|
||||||
|
assert body["message"] == follow_up
|
||||||
|
assert body["session_id"] == session_id
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_produces_template(client: TestClient):
|
||||||
|
"""When the LLM includes PROMPT_TEMPLATE markers, done=True and prompt_template is set."""
|
||||||
|
final_template = "Extract tasks from email. Subject → title. 'urgent' → high priority."
|
||||||
|
llm_response = (
|
||||||
|
"Great, I have all the information I need.\n"
|
||||||
|
f"{_TEMPLATE_START}\n{final_template}\n{_TEMPLATE_END}\n"
|
||||||
|
)
|
||||||
|
|
||||||
|
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=llm_response)):
|
||||||
|
start_resp = _start(client, agent_type="cloud")
|
||||||
|
assert start_resp.status_code == 200
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
msg_resp = _message(client, session_id, "Only invoices from clients")
|
||||||
|
assert msg_resp.status_code == 200
|
||||||
|
body = msg_resp.json()
|
||||||
|
assert body["done"] is True
|
||||||
|
assert body["prompt_template"] == final_template
|
||||||
|
# Session should be cleaned up
|
||||||
|
assert session_id not in _sessions
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_invalid_session(client: TestClient):
|
||||||
|
resp = _message(client, "nonexistent-session-id", "hello")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_wrong_owner(client: TestClient):
|
||||||
|
"""User B cannot access user A's session."""
|
||||||
|
start_resp = _start(client, agent_type="local", tier="power")
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
# user with "pro" tier (different user_id) tries to send a message
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/agents/journey/message",
|
||||||
|
json={"session_id": session_id, "message": "hello"},
|
||||||
|
headers=auth_header("pro"), # different user
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_expired_session(client: TestClient):
|
||||||
|
"""Expired sessions return 404."""
|
||||||
|
start_resp = _start(client, agent_type="local")
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
# Manually expire the session
|
||||||
|
_sessions[session_id].created_at = time.monotonic() - _SESSION_TTL_SECONDS - 1
|
||||||
|
|
||||||
|
resp = _message(client, session_id, "hello")
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_requires_auth(client: TestClient):
|
||||||
|
resp = client.post(
|
||||||
|
"/api/v1/agents/journey/message",
|
||||||
|
json={"session_id": "any", "message": "hello"},
|
||||||
|
)
|
||||||
|
assert resp.status_code == 401
|
||||||
|
|
||||||
|
|
||||||
|
def test_message_max_turns_nudge(client: TestClient):
|
||||||
|
"""After _MAX_TURNS user messages, a system nudge is appended but no crash occurs."""
|
||||||
|
from app.api.routes.agent_setup import _MAX_TURNS
|
||||||
|
|
||||||
|
follow_up = "Tell me more about priority rules."
|
||||||
|
|
||||||
|
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
||||||
|
start_resp = _start(client, agent_type="local")
|
||||||
|
session_id = start_resp.json()["session_id"]
|
||||||
|
|
||||||
|
for i in range(_MAX_TURNS):
|
||||||
|
resp = _message(client, session_id, f"Answer {i + 1}")
|
||||||
|
assert resp.status_code == 200
|
||||||
|
# While no template produced, session must still exist
|
||||||
|
if resp.json()["done"]:
|
||||||
|
break # LLM decided to wrap up early — also fine
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user