Compare commits
31 Commits
229e20d073
...
feature/ba
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e668e3fd20 | ||
|
|
7ccdad431f | ||
|
|
4073863dc6 | ||
|
|
a85f8fde29 | ||
|
|
90500a3462 | ||
|
|
c1a8ac7669 | ||
|
|
c510cbaae5 | ||
|
|
ce139bbac3 | ||
|
|
3cf067faea | ||
|
|
7253f6fe72 | ||
|
|
41db3a7089 | ||
|
|
cc94194fd1 | ||
|
|
96c91e386d | ||
|
|
c0aef71141 | ||
|
|
467abc8d42 | ||
|
|
5753f8def9 | ||
|
|
e672b58b6f | ||
|
|
d8add7e8cb | ||
|
|
c6c4578f9a | ||
|
|
3aa0b36a6c | ||
|
|
fa231a3642 | ||
|
|
d91c98f86d | ||
|
|
c0619f5c4d | ||
|
|
da282229ff | ||
|
|
7fa6ad5760 | ||
|
|
dcd14220ca | ||
|
|
3cc32569d9 | ||
|
|
bf445ac2ce | ||
|
|
a2d6d689e4 | ||
|
|
aa8bcbf0d8 | ||
|
|
1ce1d492b0 |
78
.env.example
78
.env.example
@@ -2,55 +2,69 @@
|
|||||||
ENV=dev
|
ENV=dev
|
||||||
|
|
||||||
# ── Database ──────────────────────────────────────────────────────────────────
|
# ── Database ──────────────────────────────────────────────────────────────────
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
|
||||||
|
|
||||||
# ── Redis ─────────────────────────────────────────────────────────────────────
|
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||||
REDIS_URL=redis://localhost:6379/0
|
JWT_SECRET=replace-with-a-long-random-secret
|
||||||
|
JWT_ALGORITHM=HS256
|
||||||
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
|
|
||||||
# Public key for optional local JWT verification (Traefik ForwardAuth handles
|
|
||||||
# this in production — services trust X-User-* headers from Traefik).
|
|
||||||
# Generate keypair:
|
|
||||||
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
|
||||||
# openssl rsa -in private.pem -pubout -out public.pem
|
|
||||||
# Paste PEM content with literal \n for newlines.
|
|
||||||
JWT_PUBLIC_KEY=
|
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
# ── LLM ───────────────────────────────────────────────────────────────────────
|
# ── LLM ───────────────────────────────────────────────────────────────────────
|
||||||
# LiteLLM model identifiers — change to swap providers without code changes.
|
# LiteLLM model identifiers — change to swap providers without code changes.
|
||||||
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
|
||||||
|
#
|
||||||
|
# API keys — only the key(s) matching your chosen provider(s) are required.
|
||||||
|
# The correct key is picked automatically from the model prefix (e.g.
|
||||||
|
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
|
||||||
OPENAI_API_KEY=
|
OPENAI_API_KEY=
|
||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
LLM_MODEL=gpt-4o
|
CEREBRAS_API_KEY=
|
||||||
|
|
||||||
|
# Default model used by any agent that does not have a specific override below.
|
||||||
|
LLM_MODEL=gpt-5-mini
|
||||||
|
LLM_EMBED_MODEL=text-embedding-3-small
|
||||||
|
|
||||||
|
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
|
||||||
|
# In Docker, point this to a named-volume path so tokens survive restarts.
|
||||||
|
# GITHUB_COPILOT_TOKEN_DIR=
|
||||||
|
|
||||||
|
# ── Per-agent model overrides ─────────────────────────────────────────────────
|
||||||
|
# Leave a value empty to fall back to LLM_MODEL.
|
||||||
|
# Each agent resolves its API key from the model prefix automatically.
|
||||||
|
#
|
||||||
|
# Intent classifier — routes user messages to the right domain agent.
|
||||||
|
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
|
||||||
|
LLM_MODEL_CLASSIFIER=
|
||||||
|
|
||||||
|
# Home-agent — handles chat from the home screen (all tools available).
|
||||||
|
LLM_MODEL_HOME_AGENT=
|
||||||
|
|
||||||
|
# Floating-agent — handles contextual chat triggered from a task/project/note.
|
||||||
|
LLM_MODEL_FLOATING_AGENT=
|
||||||
|
|
||||||
|
# Unified-processor — processes local directory files (local agent runner).
|
||||||
|
LLM_MODEL_UNIFIED_PROCESSOR=
|
||||||
|
|
||||||
|
# Cloud-processor — fetches and processes data from cloud connectors.
|
||||||
|
LLM_MODEL_CLOUD_PROCESSOR=
|
||||||
|
|
||||||
|
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
|
||||||
|
LLM_MODEL_SETUP_AGENT=
|
||||||
|
|
||||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
STRIPE_WEBHOOK_SECRET=
|
STRIPE_WEBHOOK_SECRET=
|
||||||
|
|
||||||
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
|
|
||||||
S3_BUCKET=adiuva
|
|
||||||
S3_REGION=us-east-1
|
|
||||||
S3_ENDPOINT_URL=
|
|
||||||
AWS_ACCESS_KEY_ID=
|
|
||||||
AWS_SECRET_ACCESS_KEY=
|
|
||||||
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
|
|
||||||
|
|
||||||
# ── Vector Store ──────────────────────────────────────────────────────────────
|
# ── Langfuse (leave empty to disable observability) ───────────────────────────
|
||||||
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
|
LANGFUSE_SECRET_KEY=
|
||||||
PINECONE_API_KEY=
|
LANGFUSE_PUBLIC_KEY=
|
||||||
PINECONE_INDEX=adiuva
|
# LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default)
|
||||||
QDRANT_URL=
|
# LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US
|
||||||
QDRANT_API_KEY=
|
# LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted
|
||||||
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
|
|
||||||
|
|
||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||||
|
|
||||||
# ── Langfuse (observability) ─────────────────────────────────────────────────
|
|
||||||
LANGFUSE_SECRET_KEY=sk-lf-...
|
|
||||||
LANGFUSE_PUBLIC_KEY=pk-lf-...
|
|
||||||
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL
|
|
||||||
@@ -48,23 +48,23 @@ jobs:
|
|||||||
key: ${{ secrets.SSH_KEY }}
|
key: ${{ secrets.SSH_KEY }}
|
||||||
script: |
|
script: |
|
||||||
set -e
|
set -e
|
||||||
DEPLOY_DIR="/opt/adiuva-api"
|
DEPLOY_DIR="/opt/adiuvai-api"
|
||||||
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
|
||||||
TAG="${{ gitea.ref_name }}"
|
TAG="${{ gitea.ref_name }}"
|
||||||
|
|
||||||
# ── Pull latest code ──
|
# ── Pull latest code ──
|
||||||
cd /tmp && rm -rf adiuva-api-deploy
|
cd /tmp && rm -rf adiuvai-api-deploy
|
||||||
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
|
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
|
||||||
|
|
||||||
# ── Sync source (preserve .env) ──
|
# ── Sync source (preserve .env) ──
|
||||||
cp -rf /tmp/adiuva-api-deploy/app/ \
|
cp -rf /tmp/adiuvai-api-deploy/app/ \
|
||||||
/tmp/adiuva-api-deploy/alembic/ \
|
/tmp/adiuvai-api-deploy/alembic/ \
|
||||||
/tmp/adiuva-api-deploy/alembic.ini \
|
/tmp/adiuvai-api-deploy/alembic.ini \
|
||||||
/tmp/adiuva-api-deploy/Dockerfile \
|
/tmp/adiuvai-api-deploy/Dockerfile \
|
||||||
/tmp/adiuva-api-deploy/docker-compose.yml \
|
/tmp/adiuvai-api-deploy/docker-compose.yml \
|
||||||
/tmp/adiuva-api-deploy/requirements.txt \
|
/tmp/adiuvai-api-deploy/requirements.txt \
|
||||||
"$DEPLOY_DIR/"
|
"$DEPLOY_DIR/"
|
||||||
rm -rf /tmp/adiuva-api-deploy
|
rm -rf /tmp/adiuvai-api-deploy
|
||||||
|
|
||||||
# ── Verify .env ──
|
# ── Verify .env ──
|
||||||
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
if [ ! -f "$DEPLOY_DIR/.env" ]; then
|
||||||
|
|||||||
4
.github/workflows/ci.yml
vendored
4
.github/workflows/ci.yml
vendored
@@ -58,7 +58,7 @@ jobs:
|
|||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
|
|
||||||
- name: Build image
|
- name: Build image
|
||||||
run: docker build -t adiuva-api:ci .
|
run: docker build -t adiuvai-api:ci .
|
||||||
|
|
||||||
- name: Verify gunicorn installed
|
- name: Verify gunicorn installed
|
||||||
run: docker run --rm adiuva-api:ci gunicorn --version
|
run: docker run --rm adiuvai-api:ci gunicorn --version
|
||||||
|
|||||||
4
.gitignore
vendored
4
.gitignore
vendored
@@ -13,9 +13,6 @@ env/
|
|||||||
# Environment variables
|
# Environment variables
|
||||||
.env
|
.env
|
||||||
|
|
||||||
# Cryptographic keys
|
|
||||||
*.pem
|
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
@@ -24,6 +21,7 @@ env/
|
|||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.coverage
|
.coverage
|
||||||
|
tests/fixtures/private*/
|
||||||
|
|
||||||
# Docker
|
# Docker
|
||||||
*.log
|
*.log
|
||||||
|
|||||||
793
README.md
793
README.md
@@ -1,793 +0,0 @@
|
|||||||
# Adiuva Cloud API
|
|
||||||
|
|
||||||
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
|
|
||||||
|
|
||||||
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Table of Contents
|
|
||||||
|
|
||||||
- [Overview](#overview)
|
|
||||||
- [Architecture](#architecture)
|
|
||||||
- [Key Features](#key-features)
|
|
||||||
- [Tech Stack](#tech-stack)
|
|
||||||
- [Getting Started](#getting-started)
|
|
||||||
- [Docker Deployment](#docker-deployment)
|
|
||||||
- [Environment Variables](#environment-variables)
|
|
||||||
- [API Reference](#api-reference)
|
|
||||||
- [Data Model](#data-model)
|
|
||||||
- [AI Agent System](#ai-agent-system)
|
|
||||||
- [Orchestration & Execution Plans](#orchestration--execution-plans)
|
|
||||||
- [Middleware](#middleware)
|
|
||||||
- [Storage Layer](#storage-layer)
|
|
||||||
- [Billing & Tiers](#billing--tiers)
|
|
||||||
- [Plugin Marketplace](#plugin-marketplace)
|
|
||||||
- [Testing](#testing)
|
|
||||||
- [Project Structure](#project-structure)
|
|
||||||
- [License](#license)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Overview
|
|
||||||
|
|
||||||
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
|
|
||||||
|
|
||||||
### Design Principles
|
|
||||||
|
|
||||||
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
|
|
||||||
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
|
|
||||||
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
|
|
||||||
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
|
|
||||||
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Architecture
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
|
|
||||||
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
|
|
||||||
│ Desktop App │────▶│ │
|
|
||||||
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
|
|
||||||
└──────────────┘ │ │
|
|
||||||
│ ┌──────────────────┐ ┌────────────────────────────┐ │
|
|
||||||
│ │ Auth Routes │ │ Chat Routes │ │
|
|
||||||
│ │ Billing Routes │ │ ↓ │ │
|
|
||||||
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
|
|
||||||
│ │ Backup Routes │ │ ↓ classify intent │ │
|
|
||||||
│ │ Plugin Routes │ │ Agent Registry │ │
|
|
||||||
│ │ Vector Routes │ │ ↓ │ │
|
|
||||||
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
|
|
||||||
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
|
|
||||||
│ │ (GPT-4o + LangChain) │ │
|
|
||||||
│ └────────────────────────────┘ │
|
|
||||||
└────────────────────────────────────────────────────────┘
|
|
||||||
│ │ │
|
|
||||||
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
|
|
||||||
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
|
|
||||||
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
|
|
||||||
│ Billing, │ │ backups) │ │ (Vectors) │
|
|
||||||
│ Metadata) │ └───────────────┘ └────────────────┘
|
|
||||||
└────────────┘
|
|
||||||
│
|
|
||||||
┌────────▼───┐
|
|
||||||
│ Stripe │
|
|
||||||
│ (Billing, │
|
|
||||||
│ Connect) │
|
|
||||||
└────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Key Features
|
|
||||||
|
|
||||||
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
|
|
||||||
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
|
|
||||||
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
|
|
||||||
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
|
|
||||||
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
|
|
||||||
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
|
|
||||||
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
|
|
||||||
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
|
|
||||||
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
|
|
||||||
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
|
|
||||||
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
|
|
||||||
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
|
|
||||||
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
|
|
||||||
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
|
|
||||||
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Tech Stack
|
|
||||||
|
|
||||||
| Package | Version | Purpose |
|
|
||||||
|---|---|---|
|
|
||||||
| `fastapi` | ≥ 0.115.0 | Web framework |
|
|
||||||
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
|
|
||||||
| `gunicorn` | ≥ 22.0.0 | Production process manager |
|
|
||||||
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
|
|
||||||
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
|
|
||||||
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
|
|
||||||
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
|
|
||||||
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
|
|
||||||
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
|
|
||||||
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
|
|
||||||
| `boto3` | ≥ 1.35.0 | AWS S3 client |
|
|
||||||
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
|
|
||||||
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
|
|
||||||
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
|
|
||||||
| `alembic` | ≥ 1.14.0 | Database migration management |
|
|
||||||
| `bcrypt` | ≥ 4.2.0 | Password hashing |
|
|
||||||
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
|
|
||||||
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
|
|
||||||
| `websockets` | ≥ 14.0 | WebSocket protocol support |
|
|
||||||
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
|
|
||||||
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
|
|
||||||
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
|
|
||||||
| `pytest` | ≥ 8.0.0 | Test framework |
|
|
||||||
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
|
|
||||||
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
|
|
||||||
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
|
|
||||||
| `ruff` | ≥ 0.8.0 | Linter and formatter |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Getting Started
|
|
||||||
|
|
||||||
### Prerequisites
|
|
||||||
|
|
||||||
- Python 3.12+
|
|
||||||
- PostgreSQL 16+
|
|
||||||
- An OpenAI API key (for LLM features)
|
|
||||||
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
|
|
||||||
- AWS credentials (optional — needed for S3 storage in production)
|
|
||||||
|
|
||||||
### Installation
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Clone the repository
|
|
||||||
git clone <repo-url> && cd adiuva-api
|
|
||||||
|
|
||||||
# Create a virtual environment
|
|
||||||
python -m venv .venv && source .venv/bin/activate
|
|
||||||
|
|
||||||
# Install dependencies
|
|
||||||
pip install -r requirements.txt
|
|
||||||
|
|
||||||
# Configure environment
|
|
||||||
cp .env.example .env
|
|
||||||
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
|
|
||||||
```
|
|
||||||
|
|
||||||
### Database Setup
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Start PostgreSQL (or use the Docker Compose database)
|
|
||||||
docker compose up db -d
|
|
||||||
|
|
||||||
# Run migrations
|
|
||||||
alembic upgrade head
|
|
||||||
```
|
|
||||||
|
|
||||||
### Run the Development Server
|
|
||||||
|
|
||||||
```bash
|
|
||||||
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
|
|
||||||
```
|
|
||||||
|
|
||||||
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Docker Deployment
|
|
||||||
|
|
||||||
### Quick Start
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose up --build
|
|
||||||
```
|
|
||||||
|
|
||||||
This starts two services:
|
|
||||||
|
|
||||||
- **app** — FastAPI server on port `8000`
|
|
||||||
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
|
|
||||||
|
|
||||||
The compose file also includes optional services for fully local deployments:
|
|
||||||
|
|
||||||
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
|
|
||||||
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
|
|
||||||
|
|
||||||
### Dockerfile Details
|
|
||||||
|
|
||||||
The Dockerfile uses a multi-stage build:
|
|
||||||
|
|
||||||
1. **Builder stage** — Installs Python dependencies into a virtual environment.
|
|
||||||
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
|
|
||||||
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Production command (run by the container)
|
|
||||||
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Homelab / Self-Hosted Deployment
|
|
||||||
|
|
||||||
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
|
|
||||||
|
|
||||||
### 1. Start all services
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose up -d
|
|
||||||
```
|
|
||||||
|
|
||||||
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
|
|
||||||
|
|
||||||
### 2. Create the MinIO bucket
|
|
||||||
|
|
||||||
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
|
|
||||||
docker compose exec minio mc mb local/adiuva
|
|
||||||
```
|
|
||||||
|
|
||||||
### 3. Configure your `.env`
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Database (uses the compose PostgreSQL)
|
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
|
|
||||||
# S3 → MinIO
|
|
||||||
S3_BUCKET=adiuva
|
|
||||||
S3_REGION=us-east-1
|
|
||||||
S3_ENDPOINT_URL=http://minio:9000
|
|
||||||
AWS_ACCESS_KEY_ID=minioadmin
|
|
||||||
AWS_SECRET_ACCESS_KEY=minioadmin
|
|
||||||
|
|
||||||
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
|
|
||||||
QDRANT_URL=http://qdrant:6333
|
|
||||||
QDRANT_API_KEY=
|
|
||||||
PINECONE_API_KEY=
|
|
||||||
|
|
||||||
# Billing — leave empty to stub (no Stripe needed)
|
|
||||||
STRIPE_SECRET_KEY=
|
|
||||||
STRIPE_WEBHOOK_SECRET=
|
|
||||||
|
|
||||||
# LLM — the only external service
|
|
||||||
OPENAI_API_KEY=sk-...
|
|
||||||
LLM_MODEL=gpt-4o
|
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
|
||||||
# Auth
|
|
||||||
JWT_SECRET=your-secret-here
|
|
||||||
ENV=dev
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Run migrations
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker compose exec app alembic upgrade head
|
|
||||||
```
|
|
||||||
|
|
||||||
### What runs where
|
|
||||||
|
|
||||||
| Service | Runs on | Port | Notes |
|
|
||||||
|---|---|---|---|
|
|
||||||
| FastAPI app | Docker | 8000 | API server |
|
|
||||||
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
|
|
||||||
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
|
|
||||||
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
|
|
||||||
| Stripe | — | — | Stubbed when keys are empty |
|
|
||||||
| OpenAI / LLM | Cloud | — | Only external dependency |
|
|
||||||
|
|
||||||
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Environment Variables
|
|
||||||
|
|
||||||
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
|
|
||||||
|
|
||||||
| Variable | Type | Default | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
|
|
||||||
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
|
|
||||||
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
|
|
||||||
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
|
|
||||||
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
|
|
||||||
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
|
|
||||||
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
|
|
||||||
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
|
|
||||||
| `S3_REGION` | `str` | `us-east-1` | AWS region |
|
|
||||||
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
|
|
||||||
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
|
|
||||||
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
|
|
||||||
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
|
|
||||||
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
|
|
||||||
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
|
|
||||||
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
|
|
||||||
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
|
|
||||||
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
|
|
||||||
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
|
|
||||||
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
|
|
||||||
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## API Reference
|
|
||||||
|
|
||||||
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
|
|
||||||
|
|
||||||
### Health
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
|
|
||||||
|
|
||||||
### Auth
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
|
|
||||||
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
|
|
||||||
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
|
|
||||||
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
|
|
||||||
|
|
||||||
### Chat
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
|
|
||||||
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
|
|
||||||
|
|
||||||
### Plans
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
|
|
||||||
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
|
|
||||||
|
|
||||||
### Storage (Cloud Records)
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
|
|
||||||
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
|
|
||||||
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
|
|
||||||
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
|
|
||||||
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
|
|
||||||
|
|
||||||
### Vectors (Cloud Vector Store)
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
|
|
||||||
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
|
|
||||||
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
|
|
||||||
|
|
||||||
### Backup
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
|
|
||||||
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
|
|
||||||
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
|
|
||||||
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
|
|
||||||
|
|
||||||
### Plugins (Marketplace)
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
|
|
||||||
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
|
|
||||||
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
|
|
||||||
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
|
|
||||||
|
|
||||||
### Billing
|
|
||||||
|
|
||||||
| Method | Path | Auth | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
|
|
||||||
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
|
|
||||||
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
|
|
||||||
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Data Model
|
|
||||||
|
|
||||||
9 tables managed by Alembic migrations. Source: `app/models.py`
|
|
||||||
|
|
||||||
### Tables
|
|
||||||
|
|
||||||
| Table | Primary Key | Key Columns | Purpose |
|
|
||||||
|---|---|---|---|
|
|
||||||
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
|
|
||||||
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
|
|
||||||
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
|
|
||||||
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
|
|
||||||
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
|
|
||||||
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
|
|
||||||
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
|
|
||||||
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
|
|
||||||
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
|
|
||||||
|
|
||||||
### Enum Types
|
|
||||||
|
|
||||||
| Enum | Values |
|
|
||||||
|---|---|
|
|
||||||
| `billing_tier` | `free`, `pro`, `power`, `team` |
|
|
||||||
| `plugin_status` | `pending_review`, `approved`, `rejected` |
|
|
||||||
| `review_decision` | `approved`, `rejected` |
|
|
||||||
|
|
||||||
### Migrations
|
|
||||||
|
|
||||||
| Version | Description |
|
|
||||||
|---|---|
|
|
||||||
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
|
|
||||||
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## AI Agent System
|
|
||||||
|
|
||||||
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
|
|
||||||
|
|
||||||
### Architecture
|
|
||||||
|
|
||||||
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
|
|
||||||
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
|
|
||||||
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
|
|
||||||
|
|
||||||
### Registered Agents
|
|
||||||
|
|
||||||
| Agent | Registry Name | Tools | Description |
|
|
||||||
|---|---|---|---|
|
|
||||||
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
|
|
||||||
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
|
|
||||||
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
|
|
||||||
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
|
|
||||||
|
|
||||||
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
|
|
||||||
|
|
||||||
### Switching LLM Providers
|
|
||||||
|
|
||||||
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# OpenAI (default)
|
|
||||||
LLM_MODEL=gpt-4o
|
|
||||||
LLM_ROUTER_MODEL=gpt-4o-mini
|
|
||||||
|
|
||||||
# Anthropic
|
|
||||||
LLM_MODEL=anthropic/claude-3.5-sonnet
|
|
||||||
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
|
|
||||||
|
|
||||||
# Google Gemini
|
|
||||||
LLM_MODEL=gemini/gemini-pro
|
|
||||||
LLM_ROUTER_MODEL=gemini/gemini-flash
|
|
||||||
|
|
||||||
# Local Ollama
|
|
||||||
LLM_MODEL=ollama/llama3
|
|
||||||
LLM_ROUTER_MODEL=ollama/llama3
|
|
||||||
|
|
||||||
# AWS Bedrock
|
|
||||||
LLM_MODEL=bedrock/anthropic.claude-v2
|
|
||||||
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
|
|
||||||
```
|
|
||||||
|
|
||||||
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Orchestration & Execution Plans
|
|
||||||
|
|
||||||
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
|
|
||||||
|
|
||||||
### Orchestrator
|
|
||||||
|
|
||||||
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
|
|
||||||
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
|
|
||||||
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
|
|
||||||
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
|
|
||||||
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
|
|
||||||
|
|
||||||
### Execution Plans
|
|
||||||
|
|
||||||
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
|
|
||||||
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
|
|
||||||
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
|
|
||||||
|
|
||||||
### Built-in Templates (6)
|
|
||||||
|
|
||||||
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
|
|
||||||
|
|
||||||
### Built-in Playbooks (2)
|
|
||||||
|
|
||||||
| Playbook | Description |
|
|
||||||
|---|---|
|
|
||||||
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
|
|
||||||
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Middleware
|
|
||||||
|
|
||||||
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
|
|
||||||
|
|
||||||
### JWT Authentication
|
|
||||||
|
|
||||||
Source: `app/api/middleware/auth.py`
|
|
||||||
|
|
||||||
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
|
|
||||||
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
|
|
||||||
- Falls back to `free` when no subscription row exists.
|
|
||||||
- Raises `401 Unauthorized` on invalid or expired tokens.
|
|
||||||
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
|
|
||||||
|
|
||||||
### Tier-Based Rate Limiter
|
|
||||||
|
|
||||||
Source: `app/api/middleware/rate_limit.py`
|
|
||||||
|
|
||||||
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
|
|
||||||
- Per-user 60-second window sized by subscription tier:
|
|
||||||
|
|
||||||
| Tier | Requests / Minute |
|
|
||||||
|---|---|
|
|
||||||
| Free | 20 |
|
|
||||||
| Pro | 60 |
|
|
||||||
| Power | 120 |
|
|
||||||
| Team | 200 |
|
|
||||||
|
|
||||||
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
|
|
||||||
- **Exempt paths:** register, login, webhook, health
|
|
||||||
|
|
||||||
### Response Sanitizer
|
|
||||||
|
|
||||||
Source: `app/api/middleware/sanitizer.py`
|
|
||||||
|
|
||||||
- Runs only on `/api/v1/chat` endpoints.
|
|
||||||
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
|
|
||||||
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
|
|
||||||
- Logs sanitization events as `WARNING`.
|
|
||||||
- Binary responses (storage, backup) are never touched.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Storage Layer
|
|
||||||
|
|
||||||
### Blob Store
|
|
||||||
|
|
||||||
Source: `app/storage/blob_store.py`
|
|
||||||
|
|
||||||
- S3-backed storage for E2E encrypted blobs.
|
|
||||||
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
|
|
||||||
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
|
|
||||||
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
|
|
||||||
- The backend **never inspects or decrypts blob content**.
|
|
||||||
|
|
||||||
### Vector Store
|
|
||||||
|
|
||||||
Source: `app/storage/vector_store.py`
|
|
||||||
|
|
||||||
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
|
|
||||||
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
|
|
||||||
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
|
|
||||||
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
|
|
||||||
- Methods: `upsert()`, `search()`, `delete()`
|
|
||||||
|
|
||||||
### Encryption Utilities
|
|
||||||
|
|
||||||
Source: `app/storage/encryption.py`
|
|
||||||
|
|
||||||
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
|
|
||||||
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
|
|
||||||
- **No decryption key ever reaches the backend.**
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Billing & Tiers
|
|
||||||
|
|
||||||
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
|
|
||||||
|
|
||||||
### Feature Matrix
|
|
||||||
|
|
||||||
| Feature | Free | Pro | Power | Team |
|
|
||||||
|---|---|---|---|---|
|
|
||||||
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
|
|
||||||
| Batch Active | 2 | 10 | Unlimited | Unlimited |
|
|
||||||
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
|
||||||
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
|
|
||||||
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
|
|
||||||
| Batch Builder | — | — | ✓ | ✓ |
|
|
||||||
| Plugin Marketplace | — | — | ✓ | ✓ |
|
|
||||||
| SSO | — | — | — | ✓ |
|
|
||||||
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
|
|
||||||
|
|
||||||
### Stripe Integration
|
|
||||||
|
|
||||||
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
|
|
||||||
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
|
|
||||||
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
|
|
||||||
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
|
|
||||||
|
|
||||||
### Tier Manager
|
|
||||||
|
|
||||||
- `get_tier(user_id)` — Returns the user's current billing tier.
|
|
||||||
- `check_feature(tier, feature)` — Boolean feature gate check.
|
|
||||||
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
|
|
||||||
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Plugin Marketplace
|
|
||||||
|
|
||||||
Source: `app/marketplace/`
|
|
||||||
|
|
||||||
### Plugin Registry
|
|
||||||
|
|
||||||
- PostgreSQL-backed catalog of submitted and approved plugins.
|
|
||||||
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
|
|
||||||
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
|
|
||||||
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
|
|
||||||
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
|
|
||||||
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
|
|
||||||
|
|
||||||
### Review Queue
|
|
||||||
|
|
||||||
- Automated security checklist before human review:
|
|
||||||
- Plugin ID must match `^[a-z0-9-]+$`
|
|
||||||
- Permissions must be from the allowed set only
|
|
||||||
- No binary blobs in the manifest
|
|
||||||
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
|
|
||||||
- `get_pending(db)` — Lists plugins awaiting review.
|
|
||||||
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
|
|
||||||
|
|
||||||
### Revenue Sharing
|
|
||||||
|
|
||||||
- **70% developer / 30% platform** split on all paid plugin sales.
|
|
||||||
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
|
|
||||||
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
|
|
||||||
- Gracefully stubs transfers when Stripe is not configured.
|
|
||||||
|
|
||||||
### Seed Plugins
|
|
||||||
|
|
||||||
| Plugin | Category | Price |
|
|
||||||
|---|---|---|
|
|
||||||
| GitHub Sync | Productivity | Free |
|
|
||||||
| Slack Notifier | Communication | €4.99 |
|
|
||||||
| Time Tracker | Productivity | €9.99 |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Testing
|
|
||||||
|
|
||||||
### Running Tests
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# Run all tests
|
|
||||||
pytest
|
|
||||||
|
|
||||||
# Run a specific test file
|
|
||||||
pytest tests/test_auth.py
|
|
||||||
|
|
||||||
# Run with verbose output
|
|
||||||
pytest -v
|
|
||||||
```
|
|
||||||
|
|
||||||
### Test Infrastructure
|
|
||||||
|
|
||||||
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
|
|
||||||
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
|
|
||||||
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
|
|
||||||
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
|
|
||||||
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
|
|
||||||
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
|
|
||||||
- **No external dependencies** — all tests run fully offline.
|
|
||||||
|
|
||||||
### Test Coverage
|
|
||||||
|
|
||||||
| File | Coverage |
|
|
||||||
|---|---|
|
|
||||||
| `test_auth.py` | Register, login, token access, refresh, expiration |
|
|
||||||
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
|
|
||||||
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
|
|
||||||
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
|
|
||||||
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
|
|
||||||
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
|
|
||||||
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
|
|
||||||
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
|
|
||||||
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Project Structure
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── alembic.ini # Alembic configuration
|
|
||||||
├── BACKEND_PLAN.md # Architecture & design decisions
|
|
||||||
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
|
|
||||||
├── Dockerfile # Multi-stage production build
|
|
||||||
├── requirements.txt # Python dependencies
|
|
||||||
│
|
|
||||||
├── alembic/ # Database migrations
|
|
||||||
│ ├── env.py # Alembic environment config
|
|
||||||
│ ├── script.py.mako # Migration template
|
|
||||||
│ └── versions/
|
|
||||||
│ ├── 001_initial_schema.py # Tables, indexes, FKs
|
|
||||||
│ └── 002_seed_plugins.py # Seed marketplace plugins
|
|
||||||
│
|
|
||||||
├── app/ # Application source
|
|
||||||
│ ├── main.py # FastAPI app factory, middleware, routes
|
|
||||||
│ ├── db.py # Async SQLAlchemy engine & session
|
|
||||||
│ ├── models.py # SQLAlchemy ORM models (9 tables)
|
|
||||||
│ ├── schemas.py # Pydantic request/response schemas
|
|
||||||
│ │
|
|
||||||
│ ├── config/
|
|
||||||
│ │ └── settings.py # Pydantic Settings (env vars)
|
|
||||||
│ │
|
|
||||||
│ ├── agents/ # LLM-powered domain agents
|
|
||||||
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
|
|
||||||
│ │ ├── project_agent.py # Project lifecycle (6 tools)
|
|
||||||
│ │ ├── timeline_agent.py # Milestones (4 tools)
|
|
||||||
│ │ └── note_agent.py # Markdown notes (5 tools)
|
|
||||||
│ │
|
|
||||||
│ ├── core/ # Orchestration engine
|
|
||||||
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
|
||||||
│ │ ├── llm.py # LiteLLM factory (get_llm)
|
|
||||||
│ │ ├── orchestrator.py # Intent classification & routing
|
|
||||||
│ │ └── execution_plan.py # Plan builder, templates, cache
|
|
||||||
│ │
|
|
||||||
│ ├── api/ # HTTP layer
|
|
||||||
│ │ ├── deps.py # Shared FastAPI dependencies
|
|
||||||
│ │ ├── middleware/
|
|
||||||
│ │ │ ├── auth.py # JWT validation, live tier lookup
|
|
||||||
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
|
|
||||||
│ │ │ └── sanitizer.py # Prompt IP leak protection
|
|
||||||
│ │ └── routes/
|
|
||||||
│ │ ├── auth.py # Register, login, refresh, me
|
|
||||||
│ │ ├── chat.py # Chat + WebSocket streaming
|
|
||||||
│ │ ├── plans.py # Execution plan playbooks
|
|
||||||
│ │ ├── storage.py # E2E encrypted record CRUD
|
|
||||||
│ │ ├── vectors.py # Vector upsert, search, delete
|
|
||||||
│ │ ├── backup.py # Encrypted backup management
|
|
||||||
│ │ ├── plugins.py # Marketplace browse & install
|
|
||||||
│ │ └── billing.py # Stripe checkout & webhooks
|
|
||||||
│ │
|
|
||||||
│ ├── storage/ # Storage backends
|
|
||||||
│ │ ├── blob_store.py # S3 blob storage
|
|
||||||
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
|
|
||||||
│ │ └── encryption.py # Checksum verification utilities
|
|
||||||
│ │
|
|
||||||
│ ├── billing/ # Subscription management
|
|
||||||
│ │ ├── stripe_service.py # Stripe API integration
|
|
||||||
│ │ └── tier_manager.py # Feature matrix & quota enforcement
|
|
||||||
│ │
|
|
||||||
│ └── marketplace/ # Plugin ecosystem
|
|
||||||
│ ├── plugin_registry.py # Catalog CRUD & search
|
|
||||||
│ ├── plugin_review.py # Security checklist & review queue
|
|
||||||
│ └── revenue_share.py # 70/30 split & Stripe Connect
|
|
||||||
│
|
|
||||||
└── tests/ # Test suite
|
|
||||||
├── conftest.py # Fixtures: DB, S3, auth, seeds
|
|
||||||
├── test_auth.py
|
|
||||||
├── test_orchestrator.py
|
|
||||||
├── test_agents.py
|
|
||||||
├── test_storage.py
|
|
||||||
├── test_backup.py
|
|
||||||
├── test_plugins.py
|
|
||||||
├── test_agent_registry.py
|
|
||||||
├── test_execution_plan.py
|
|
||||||
└── test_middleware.py
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## License
|
|
||||||
|
|
||||||
*To be determined.*
|
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import re
|
|||||||
from logging.config import fileConfig
|
from logging.config import fileConfig
|
||||||
|
|
||||||
from alembic import context
|
from alembic import context
|
||||||
from sqlalchemy import engine_from_config, pool
|
from sqlalchemy import pool
|
||||||
from sqlalchemy.ext.asyncio import create_async_engine
|
from sqlalchemy.ext.asyncio import create_async_engine
|
||||||
|
|
||||||
# Alembic Config object (gives access to alembic.ini values).
|
# Alembic Config object (gives access to alembic.ini values).
|
||||||
|
|||||||
@@ -1,5 +1,4 @@
|
|||||||
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
|
"""Initial schema: users, refresh_tokens, subscriptions.
|
||||||
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
|
|
||||||
|
|
||||||
Revision ID: 001
|
Revision ID: 001
|
||||||
Revises:
|
Revises:
|
||||||
@@ -28,18 +27,6 @@ def upgrade() -> None:
|
|||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
END $$;
|
END $$;
|
||||||
""")
|
""")
|
||||||
op.execute("""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
op.execute("""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
""")
|
|
||||||
|
|
||||||
# ── users ─────────────────────────────────────────────────────────────
|
# ── users ─────────────────────────────────────────────────────────────
|
||||||
op.create_table(
|
op.create_table(
|
||||||
@@ -88,122 +75,10 @@ def upgrade() -> None:
|
|||||||
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
|
||||||
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
|
||||||
|
|
||||||
# ── storage_records ───────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"storage_records",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("table_name", sa.String(100), nullable=False),
|
|
||||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
|
||||||
sa.Column("checksum", sa.String(64), nullable=False),
|
|
||||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
|
|
||||||
|
|
||||||
# ── backup_metadata ───────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"backup_metadata",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("s3_key", sa.String(500), nullable=False),
|
|
||||||
sa.Column("version", sa.Integer, nullable=False),
|
|
||||||
sa.Column("timestamp", sa.BigInteger, nullable=False),
|
|
||||||
sa.Column("checksum", sa.String(64), nullable=False),
|
|
||||||
sa.Column("size_bytes", sa.Integer, nullable=False),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
|
|
||||||
|
|
||||||
# ── plugins ───────────────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"plugins",
|
|
||||||
sa.Column("id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("description", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
|
|
||||||
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
|
|
||||||
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
|
|
||||||
sa.Column("category", sa.String(100), nullable=False, server_default=""),
|
|
||||||
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
|
|
||||||
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
|
|
||||||
sa.Column("s3_package_key", sa.String(500), nullable=True),
|
|
||||||
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
|
|
||||||
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
|
|
||||||
sa.Column("rejection_reason", sa.Text, nullable=True),
|
|
||||||
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── plugin_installations ──────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"plugin_installations",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
|
|
||||||
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
|
|
||||||
|
|
||||||
# ── plugin_reviews ────────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"plugin_reviews",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
|
|
||||||
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
|
|
||||||
sa.Column("notes", sa.Text, nullable=True),
|
|
||||||
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
|
||||||
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
|
|
||||||
|
|
||||||
# ── revenue_events ────────────────────────────────────────────────────
|
|
||||||
op.create_table(
|
|
||||||
"revenue_events",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("plugin_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
|
|
||||||
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
|
|
||||||
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
|
|
||||||
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
|
|
||||||
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
def downgrade() -> None:
|
||||||
op.drop_table("revenue_events")
|
|
||||||
op.drop_table("plugin_reviews")
|
|
||||||
op.drop_table("plugin_installations")
|
|
||||||
op.drop_table("plugins")
|
|
||||||
op.drop_table("backup_metadata")
|
|
||||||
op.drop_table("storage_records")
|
|
||||||
op.drop_table("subscriptions")
|
op.drop_table("subscriptions")
|
||||||
op.drop_table("refresh_tokens")
|
op.drop_table("refresh_tokens")
|
||||||
op.drop_table("users")
|
op.drop_table("users")
|
||||||
|
|
||||||
op.execute("DROP TYPE IF EXISTS review_decision")
|
|
||||||
op.execute("DROP TYPE IF EXISTS plugin_status")
|
|
||||||
op.execute("DROP TYPE IF EXISTS billing_tier")
|
op.execute("DROP TYPE IF EXISTS billing_tier")
|
||||||
|
|||||||
@@ -1,92 +0,0 @@
|
|||||||
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
|
|
||||||
|
|
||||||
Revision ID: 002
|
|
||||||
Revises: 001
|
|
||||||
Create Date: 2026-03-03
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
|
|
||||||
revision: str = "002"
|
|
||||||
down_revision: Union[str, None] = "001"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
_SEED_PLUGINS = [
|
|
||||||
{
|
|
||||||
"id": "plugin-github-sync",
|
|
||||||
"name": "GitHub Sync",
|
|
||||||
"description": "Sync tasks with GitHub Issues and pull requests.",
|
|
||||||
"version": "1.0.0",
|
|
||||||
"author_name": "Adiuva",
|
|
||||||
"category": "productivity",
|
|
||||||
"price_cents": 0,
|
|
||||||
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "plugin-slack-notify",
|
|
||||||
"name": "Slack Notifier",
|
|
||||||
"description": "Post task and timeline updates to Slack channels.",
|
|
||||||
"version": "1.2.0",
|
|
||||||
"author_name": "Adiuva",
|
|
||||||
"category": "communication",
|
|
||||||
"price_cents": 499,
|
|
||||||
"permissions": json.dumps(["read:tasks", "read:timelines"]),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "plugin-time-tracker",
|
|
||||||
"name": "Time Tracker",
|
|
||||||
"description": "Track time spent on tasks with automatic reporting.",
|
|
||||||
"version": "0.9.1",
|
|
||||||
"author_name": "Third Party",
|
|
||||||
"category": "productivity",
|
|
||||||
"price_cents": 999,
|
|
||||||
"permissions": json.dumps(["read:tasks", "write:tasks"]),
|
|
||||||
"status": "approved",
|
|
||||||
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
|
|
||||||
"install_count": 0,
|
|
||||||
"avg_rating": 0.0,
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
plugins = sa.table(
|
|
||||||
"plugins",
|
|
||||||
sa.column("id", sa.String),
|
|
||||||
sa.column("name", sa.String),
|
|
||||||
sa.column("description", sa.Text),
|
|
||||||
sa.column("version", sa.String),
|
|
||||||
sa.column("author_name", sa.String),
|
|
||||||
sa.column("category", sa.String),
|
|
||||||
sa.column("price_cents", sa.Integer),
|
|
||||||
sa.column("permissions", sa.Text),
|
|
||||||
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
|
|
||||||
sa.column("s3_package_key", sa.String),
|
|
||||||
sa.column("install_count", sa.Integer),
|
|
||||||
sa.column("avg_rating", sa.Float),
|
|
||||||
)
|
|
||||||
op.bulk_insert(plugins, _SEED_PLUGINS)
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.execute(
|
|
||||||
"DELETE FROM plugins WHERE id IN ("
|
|
||||||
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
|
|
||||||
")"
|
|
||||||
)
|
|
||||||
@@ -14,7 +14,7 @@ from alembic import op
|
|||||||
from sqlalchemy.dialects import postgresql
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
revision: str = "003"
|
revision: str = "003"
|
||||||
down_revision: Union[str, None] = "002"
|
down_revision: Union[str, None] = "001"
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,107 @@
|
|||||||
|
"""Restore agent config tables and add agent_config column.
|
||||||
|
|
||||||
|
9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both
|
||||||
|
ORM models are still active. This migration recreates them with agent_config
|
||||||
|
added to local_agent_configs.
|
||||||
|
|
||||||
|
Revision ID: a3b9c0d1e2f3
|
||||||
|
Revises: 9a1f2d0b6c7e
|
||||||
|
Create Date: 2026-04-07 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "a3b9c0d1e2f3"
|
||||||
|
down_revision: Union[str, None] = "9a1f2d0b6c7e"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# Recreate enum types (idempotent — they may already exist from migration 003)
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
op.execute("""
|
||||||
|
DO $$ BEGIN
|
||||||
|
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||||
|
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||||
|
END $$;
|
||||||
|
""")
|
||||||
|
|
||||||
|
bind = op.get_bind()
|
||||||
|
inspector = sa.inspect(bind)
|
||||||
|
existing = set(inspector.get_table_names())
|
||||||
|
|
||||||
|
# ── local_agent_configs (with agent_config column) ────────────────────
|
||||||
|
if "local_agent_configs" not in existing:
|
||||||
|
op.create_table(
|
||||||
|
"local_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("device_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("agent_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
# ── cloud_agent_configs ───────────────────────────────────────────────
|
||||||
|
if "cloud_agent_configs" not in existing:
|
||||||
|
op.create_table(
|
||||||
|
"cloud_agent_configs",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column(
|
||||||
|
"provider",
|
||||||
|
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||||
|
nullable=False,
|
||||||
|
),
|
||||||
|
sa.Column("name", sa.String(255), nullable=False),
|
||||||
|
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||||
|
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||||
|
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||||
|
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||||
|
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||||
|
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||||
|
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||||
|
op.drop_table("cloud_agent_configs")
|
||||||
|
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||||
|
op.drop_table("local_agent_configs")
|
||||||
56
alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
56
alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
|
||||||
|
|
||||||
|
Revision ID: b4c0d1e2f3a4
|
||||||
|
Revises: a3b9c0d1e2f3
|
||||||
|
Create Date: 2026-04-10 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
from sqlalchemy.dialects import postgresql
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "b4c0d1e2f3a4"
|
||||||
|
down_revision: Union[str, None] = "a3b9c0d1e2f3"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
# ── users: make password_hash nullable (social users have no password) ──
|
||||||
|
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
|
||||||
|
|
||||||
|
# ── users: add avatar_url ─────────────────────────────────────────────
|
||||||
|
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
|
||||||
|
|
||||||
|
# ── oauth_accounts ────────────────────────────────────────────────────
|
||||||
|
op.create_table(
|
||||||
|
"oauth_accounts",
|
||||||
|
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||||
|
sa.Column("provider", sa.String(50), nullable=False),
|
||||||
|
sa.Column("provider_user_id", sa.String(255), nullable=False),
|
||||||
|
sa.Column("provider_email", sa.String(255), nullable=True),
|
||||||
|
sa.Column(
|
||||||
|
"created_at",
|
||||||
|
sa.DateTime(timezone=True),
|
||||||
|
nullable=False,
|
||||||
|
server_default=sa.text("now()"),
|
||||||
|
),
|
||||||
|
sa.PrimaryKeyConstraint("id"),
|
||||||
|
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||||
|
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
|
||||||
|
op.drop_table("oauth_accounts")
|
||||||
|
op.drop_column("users", "avatar_url")
|
||||||
|
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)
|
||||||
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
31
alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
"""Add onboarding_completed_at column to users table.
|
||||||
|
|
||||||
|
Revision ID: c5d1e2f3a4b5
|
||||||
|
Revises: b4c0d1e2f3a4
|
||||||
|
Create Date: 2026-04-11 00:00:00.000000
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = "c5d1e2f3a4b5"
|
||||||
|
down_revision: Union[str, None] = "b4c0d1e2f3a4"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column(
|
||||||
|
"users",
|
||||||
|
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("users", "onboarding_completed_at")
|
||||||
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
34
alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
"""avatar_url_varchar_to_text
|
||||||
|
|
||||||
|
Revision ID: e04100e88ace
|
||||||
|
Revises: c5d1e2f3a4b5
|
||||||
|
Create Date: 2026-04-13 09:13:06.733674
|
||||||
|
|
||||||
|
"""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
|
||||||
|
# revision identifiers, used by Alembic.
|
||||||
|
revision: str = 'e04100e88ace'
|
||||||
|
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.alter_column('users', 'avatar_url',
|
||||||
|
existing_type=sa.VARCHAR(length=2048),
|
||||||
|
type_=sa.Text(),
|
||||||
|
existing_nullable=True)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.alter_column('users', 'avatar_url',
|
||||||
|
existing_type=sa.Text(),
|
||||||
|
type_=sa.VARCHAR(length=2048),
|
||||||
|
existing_nullable=True)
|
||||||
@@ -7,12 +7,31 @@ handles actual disk I/O and responds with ``tool_result`` frames.
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
|
# Max characters returned by read_file_content in journey (exploration) tools.
|
||||||
|
# The journey only needs to understand file structure, not full content.
|
||||||
|
_JOURNEY_READ_MAX_CHARS: int = 4000
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_path(path: str, base: str) -> str:
|
||||||
|
"""Resolve *path* against *base* when *path* is relative.
|
||||||
|
|
||||||
|
The LLM often passes ``"."`` meaning "the configured directory".
|
||||||
|
Without this, Electron resolves ``"."`` relative to its own CWD instead
|
||||||
|
of the user's chosen directory.
|
||||||
|
"""
|
||||||
|
if os.path.isabs(path):
|
||||||
|
return path
|
||||||
|
return str(Path(base) / path)
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_directory(path: str) -> str:
|
async def list_directory(path: str) -> str:
|
||||||
@@ -83,3 +102,93 @@ FILESYSTEM_TOOLS: list[Any] = [
|
|||||||
read_file_content,
|
read_file_content,
|
||||||
get_file_metadata,
|
get_file_metadata,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def make_directory_tools(base_directory: str) -> list[Any]:
|
||||||
|
"""Return filesystem tools that resolve relative paths against *base_directory*.
|
||||||
|
|
||||||
|
Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target
|
||||||
|
directory upfront (e.g., journey setup sessions). Relative paths like ``"."``
|
||||||
|
from the LLM are resolved to the correct absolute path before being sent to
|
||||||
|
the Electron client, preventing it from falling back to its own CWD.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _compact_for_journey(raw: str) -> str:
|
||||||
|
"""Strip HTML noise and truncate for journey exploration.
|
||||||
|
|
||||||
|
The journey LLM only needs to understand file structure (headers,
|
||||||
|
first paragraphs). Full CSS/style blocks are pure noise that eat
|
||||||
|
up context window budget.
|
||||||
|
"""
|
||||||
|
text = re.sub(r"<style[^>]*>.*?</style>", "", raw, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
text = re.sub(r"<script[^>]*>.*?</script>", "", text, flags=re.DOTALL | re.IGNORECASE)
|
||||||
|
text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)
|
||||||
|
if len(text) > _JOURNEY_READ_MAX_CHARS:
|
||||||
|
text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]"
|
||||||
|
return text
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def list_directory(path: str) -> str: # noqa: F811
|
||||||
|
"""List files and folders in a local directory on the user's device.
|
||||||
|
|
||||||
|
Returns a formatted listing of entries with name, type (file/directory),
|
||||||
|
and full path.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_path(path, base_directory)
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="list_directory",
|
||||||
|
data={"path": resolved},
|
||||||
|
)
|
||||||
|
entries: list[dict[str, Any]] = result.get("entries", [])
|
||||||
|
if not entries:
|
||||||
|
return f"Directory '{resolved}' is empty or does not exist."
|
||||||
|
lines: list[str] = []
|
||||||
|
for entry in entries:
|
||||||
|
entry_type = entry.get("type", "unknown")
|
||||||
|
entry_name = entry.get("name", "")
|
||||||
|
entry_path = entry.get("path", "")
|
||||||
|
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
||||||
|
return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def read_file_content(path: str) -> str: # noqa: F811
|
||||||
|
"""Read the text content of a local file on the user's device.
|
||||||
|
|
||||||
|
Returns the file content as a string. Large files may be truncated
|
||||||
|
by the Electron client.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_path(path, base_directory)
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="read_file_content",
|
||||||
|
data={"path": resolved},
|
||||||
|
)
|
||||||
|
content: str = result.get("content", "")
|
||||||
|
if not content:
|
||||||
|
return f"File '{resolved}' is empty or could not be read."
|
||||||
|
return _compact_for_journey(content)
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def get_file_metadata(path: str) -> str: # noqa: F811
|
||||||
|
"""Get metadata for a local file: size, creation date, modification date, extension.
|
||||||
|
|
||||||
|
Returns a formatted summary of the file's metadata.
|
||||||
|
"""
|
||||||
|
resolved = _resolve_path(path, base_directory)
|
||||||
|
result = await execute_on_client(
|
||||||
|
action="get_file_metadata",
|
||||||
|
data={"path": resolved},
|
||||||
|
)
|
||||||
|
size = result.get("size", "unknown")
|
||||||
|
created = result.get("createdAt", "unknown")
|
||||||
|
modified = result.get("modifiedAt", "unknown")
|
||||||
|
extension = result.get("extension", "unknown")
|
||||||
|
name = result.get("name", resolved)
|
||||||
|
return (
|
||||||
|
f"File: {name}\n"
|
||||||
|
f" Extension: {extension}\n"
|
||||||
|
f" Size: {size} bytes\n"
|
||||||
|
f" Created: {created}\n"
|
||||||
|
f" Modified: {modified}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return [list_directory, read_file_content, get_file_metadata]
|
||||||
|
|||||||
@@ -18,21 +18,6 @@ _UUID_RE = re.compile(
|
|||||||
def _is_uuid(value: str) -> bool:
|
def _is_uuid(value: str) -> bool:
|
||||||
return bool(_UUID_RE.match(value))
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
NOTE_SYSTEM_PROMPT = (
|
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
|
|||||||
@@ -8,22 +8,6 @@ from langchain_core.tools import tool
|
|||||||
|
|
||||||
from app.core.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
PROJECT_SYSTEM_PROMPT = (
|
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
|
||||||
"update, and archive projects in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: active, archived\n"
|
|
||||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
|
||||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
|
||||||
" derive it from context data — do not fabricate content\n"
|
|
||||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
|
||||||
" user wants a complete cross-client view including archived projects\n"
|
|
||||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
|
||||||
" list_projects if you only have a project name\n"
|
|
||||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
|
||||||
" only call delete_project when the user explicitly confirms deletion."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_projects(
|
async def list_projects(
|
||||||
|
|||||||
@@ -18,23 +18,6 @@ _UUID_RE = re.compile(
|
|||||||
def _is_uuid(value: str) -> bool:
|
def _is_uuid(value: str) -> bool:
|
||||||
return bool(_UUID_RE.match(value))
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
TASK_SYSTEM_PROMPT = (
|
|
||||||
"You are a task management assistant for a project workspace.\n"
|
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: todo, in_progress, done\n"
|
|
||||||
" - priority must be one of: high, medium, low\n"
|
|
||||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
|
||||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
|
||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
|
||||||
" did not explicitly request; 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -17,20 +17,6 @@ _UUID_RE = re.compile(
|
|||||||
def _is_uuid(value: str) -> bool:
|
def _is_uuid(value: str) -> bool:
|
||||||
return bool(_UUID_RE.match(value))
|
return bool(_UUID_RE.match(value))
|
||||||
|
|
||||||
TIMELINE_SYSTEM_PROMPT = (
|
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all timelines across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
|
|||||||
@@ -65,16 +65,39 @@ async def get_current_user(
|
|||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
# Fetch name/surname from user row.
|
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
select(User.name, User.surname).where(User.id == user_id)
|
select(
|
||||||
|
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||||
|
User.password_hash,
|
||||||
|
).where(User.id == user_id)
|
||||||
)
|
)
|
||||||
user_row = user_result.one_or_none()
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
# Convert onboarding_completed_at to epoch ms (int) or None.
|
||||||
|
onboarding_ms: int | None = None
|
||||||
|
if user_row and user_row.onboarding_completed_at is not None:
|
||||||
|
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||||
|
|
||||||
|
# Load decrypted core memory.
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
|
||||||
|
|
||||||
|
memory_dict: dict[str, str] = {}
|
||||||
|
try:
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
blocks = await mw.list_core_blocks(user_id)
|
||||||
|
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||||
|
except Exception:
|
||||||
|
pass # Non-critical — return empty memory on failure
|
||||||
|
|
||||||
return UserProfile(
|
return UserProfile(
|
||||||
id=user_id,
|
id=user_id,
|
||||||
email=email,
|
email=email,
|
||||||
name=user_row.name if user_row else None,
|
name=user_row.name if user_row else None,
|
||||||
surname=user_row.surname if user_row else None,
|
surname=user_row.surname if user_row else None,
|
||||||
|
avatar_url=user_row.avatar_url if user_row else None,
|
||||||
|
has_password=bool(user_row.password_hash) if user_row else False,
|
||||||
tier=tier,
|
tier=tier,
|
||||||
|
onboarding_completed_at=onboarding_ms,
|
||||||
|
memory=memory_dict,
|
||||||
) # type: ignore[arg-type]
|
) # type: ignore[arg-type]
|
||||||
|
|||||||
@@ -8,8 +8,7 @@ that could reveal server-side prompt IP:
|
|||||||
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||||
- Exact-match known prompt fingerprints
|
- Exact-match known prompt fingerprints
|
||||||
|
|
||||||
Binary responses (storage blobs, backup data) are never touched — the
|
The middleware only activates for paths under /api/v1/chat.
|
||||||
middleware only activates for paths under /api/v1/chat.
|
|
||||||
|
|
||||||
Any sanitisation event is logged as a WARNING with the request path and the
|
Any sanitisation event is logged as a WARNING with the request path and the
|
||||||
names of the fields that were modified.
|
names of the fields that were modified.
|
||||||
|
|||||||
@@ -1,11 +1,11 @@
|
|||||||
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
|
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
|
||||||
|
|
||||||
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
The journey is driven entirely through WebSocket frames (no REST endpoints).
|
||||||
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
The device WS handler dispatches ``journey_start`` and ``journey_message``
|
||||||
frames to the functions exported here.
|
frames to the functions exported here.
|
||||||
|
|
||||||
Journey flow:
|
Journey flow:
|
||||||
1. FE sends ``journey_start`` frame with basic agent config (directory,
|
1. FE sends ``journey_start`` frame with basic agent info (directory,
|
||||||
data_types, schedule).
|
data_types, schedule).
|
||||||
2. Server creates an in-memory session, sets up a WS executor so the
|
2. Server creates an in-memory session, sets up a WS executor so the
|
||||||
setup LLM can use file-system tools, does a first directory scrape,
|
setup LLM can use file-system tools, does a first directory scrape,
|
||||||
@@ -13,10 +13,11 @@ Journey flow:
|
|||||||
3. FE sends ``journey_message`` frames for each user reply.
|
3. FE sends ``journey_message`` frames for each user reply.
|
||||||
4. Server appends the user message, calls the LLM (which may read files
|
4. Server appends the user message, calls the LLM (which may read files
|
||||||
via tools), and sends back a ``journey_reply``.
|
via tools), and sends back a ``journey_reply``.
|
||||||
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
|
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
|
||||||
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
|
||||||
6. Server parses the block, sends ``journey_reply`` with ``done=True``
|
6. Server parses and validates the JSON with Pydantic, sends
|
||||||
and the template. FE stores it locally.
|
``journey_reply`` with ``done=True`` and the serialised config.
|
||||||
|
FE stores it locally.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -30,8 +31,10 @@ from typing import Any
|
|||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
||||||
|
|
||||||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
from app.agents.filesystem_agent import make_directory_tools
|
||||||
from app.core.llm import get_llm
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
|
from app.schemas import AgentConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
@@ -39,9 +42,9 @@ logger = logging.getLogger(__name__)
|
|||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
|
||||||
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
_CONFIG_START = "AGENT_CONFIG_START"
|
||||||
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
_CONFIG_END = "AGENT_CONFIG_END"
|
||||||
|
|
||||||
# Minimum turns before we consider nudging the LLM to wrap up.
|
# Minimum turns before we consider nudging the LLM to wrap up.
|
||||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
||||||
@@ -62,6 +65,7 @@ class JourneySession:
|
|||||||
data_types: list[str]
|
data_types: list[str]
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
system_prompt: str = ""
|
system_prompt: str = ""
|
||||||
|
langfuse_prompt: Any = None
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
def is_expired(self) -> bool:
|
||||||
@@ -83,61 +87,76 @@ def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
|||||||
return s
|
return s
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt builder ─────────────────────────────────────────────────
|
# ── System prompt ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
_JOURNEY_SYSTEM_PROMPT = """\
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
Your job is to understand exactly what data the user wants to extract from their
|
Your job is to understand what files the user has in their directory and produce a
|
||||||
local directory and produce a detailed prompt_template that a separate AI will use
|
structured AgentConfig JSON that the extraction agent will use as its instruction set.
|
||||||
as its instruction set.
|
|
||||||
|
|
||||||
The extraction agent already has this base behaviour built in:
|
|
||||||
- Reads each file using file-system tools.
|
|
||||||
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
|
|
||||||
- Sets isAiSuggested=1 on every new record.
|
|
||||||
- Only extracts data explicitly present in the files — it never invents information.
|
|
||||||
The user's custom prompt is appended AFTER this base behaviour, so focus on
|
|
||||||
what to look for and how to map it — not on the general extraction mechanics.
|
|
||||||
|
|
||||||
You have access to file-system tools to explore the user's directory:
|
You have access to file-system tools to explore the user's directory:
|
||||||
- list_directory: to see folder structure
|
- list_directory: see folder structure and file names
|
||||||
- read_file_content: to peek at file contents
|
- read_file_content: peek at a file's content
|
||||||
- get_file_metadata: to check file info
|
- get_file_metadata: check file size, extension, dates
|
||||||
|
|
||||||
The user's configured directory is: {directory}
|
The user's configured directory is: {directory}
|
||||||
Target data types: {data_types}
|
Target data types: {data_types}
|
||||||
|
|
||||||
IMPORTANT — project assignment is handled automatically by the main agent runner
|
## Your process
|
||||||
before the custom prompt is ever used. You MUST NOT ask the user about projects,
|
|
||||||
projectId, or how to link records to projects. Never include projectId logic or
|
|
||||||
project creation instructions in the generated prompt_template.
|
|
||||||
|
|
||||||
Start by exploring the directory to understand its structure. Then ask concise,
|
### Step 1 — Explore the directory
|
||||||
focused questions one at a time. Cover these topics (not necessarily in this order):
|
Use list_directory and read_file_content to understand what types of files are present
|
||||||
1. The type and format of the source content (confirmed by your exploration).
|
(HTML emails, plain-text documents, CSVs, etc.).
|
||||||
2. How fields should be mapped (e.g. filename → task title).
|
|
||||||
3. Priority or status rules (e.g. "urgent" keyword → high priority).
|
|
||||||
4. Any special handling, date extraction, or exclusions.
|
|
||||||
|
|
||||||
Once you reach 90% confidence, output the final prompt_template between these exact
|
### Step 2 — Identify content types
|
||||||
markers on their own lines:
|
For each distinct file type found, decide:
|
||||||
|
- A short id (e.g. "email_html", "plain_text", "csv")
|
||||||
|
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
|
||||||
|
- A human-readable label and optional detection_hint
|
||||||
|
|
||||||
{template_start}
|
### Step 3 — Ask focused questions (one at a time)
|
||||||
<the complete extraction prompt here>
|
Cover these topics based on what you discovered:
|
||||||
{template_end}
|
1. How to map content to entity types (task / note / timeline entry)
|
||||||
|
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
|
||||||
|
3. Priority or status rules (e.g. "urgent" in subject → high priority)
|
||||||
|
4. Date extraction (e.g. "by Friday" → dueDate)
|
||||||
|
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
|
||||||
|
|
||||||
The prompt_template must be a self-contained instruction for an AI that reads files
|
### Step 4 — Produce the AgentConfig JSON
|
||||||
and must perform CRUD operations using tools to create records. It should specify:
|
Once you are ≥ 90% confident, output the final config between these exact markers
|
||||||
- What entity types to create (tasks, notes, timelines) — never projects.
|
(each on its own line):
|
||||||
- How to map file content to record fields (camelCase: title, status, priority,
|
|
||||||
dueDate, content, etc.) — never include projectId.
|
{config_start}
|
||||||
- That isAiSuggested must be set to 1 on every new record.
|
{{
|
||||||
- Concrete examples of mappings based on what you discovered in the directory.
|
"content_types": [
|
||||||
|
{{
|
||||||
|
"id": "email_html",
|
||||||
|
"label": "Email HTML",
|
||||||
|
"detection_hint": "HTML file with From/To/Subject headers",
|
||||||
|
"preprocessing": "email_html",
|
||||||
|
"extraction_prompt": "Detailed extraction instructions for this content type..."
|
||||||
|
}}
|
||||||
|
],
|
||||||
|
"global_rules": [
|
||||||
|
"If the file cannot be matched to any project, do not create any entity."
|
||||||
|
],
|
||||||
|
"data_types": {data_types_json}
|
||||||
|
}}
|
||||||
|
{config_end}
|
||||||
|
|
||||||
|
## Rules for the extraction_prompt field
|
||||||
|
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
|
||||||
|
- Include field mapping rules based on what you found in the directory
|
||||||
|
- Include priority/status/date rules if applicable
|
||||||
|
- Do NOT include projectId logic — the runner handles project assignment automatically
|
||||||
|
- Do NOT mention isAiSuggested — the runner always sets it to 1
|
||||||
|
|
||||||
|
## Constraints
|
||||||
|
- Never ask about projects, projectId, or how to link records to projects
|
||||||
|
- Never include projectId or project creation logic in the generated config
|
||||||
|
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
|
||||||
|
|
||||||
{existing_section}\
|
{existing_section}\
|
||||||
Keep asking clarifying questions until you are at least 90% confident you have
|
|
||||||
enough information to generate an accurate prompt_template. Once you reach that
|
|
||||||
confidence level, stop asking and produce the final template immediately.
|
|
||||||
Begin by exploring the directory, then ask your first question.\
|
Begin by exploring the directory, then ask your first question.\
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -145,33 +164,53 @@ Begin by exploring the directory, then ask your first question.\
|
|||||||
def _build_system_prompt(
|
def _build_system_prompt(
|
||||||
directory: str,
|
directory: str,
|
||||||
data_types: list[str],
|
data_types: list[str],
|
||||||
existing_template: str | None = None,
|
existing_config: str | None = None,
|
||||||
) -> str:
|
) -> tuple[str, Any]:
|
||||||
|
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
|
||||||
existing_section = (
|
existing_section = (
|
||||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
|
||||||
f"---\n{existing_template}\n---\n"
|
f"```json\n{existing_config}\n```\n"
|
||||||
if existing_template
|
if existing_config
|
||||||
else ""
|
else ""
|
||||||
)
|
)
|
||||||
return _SYSTEM_PROMPT_TEMPLATE.format(
|
template, prompt_obj = get_prompt_or_fallback(
|
||||||
|
"journey_system", _JOURNEY_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
compiled = compile_prompt(
|
||||||
|
template,
|
||||||
|
prompt_obj,
|
||||||
directory=directory,
|
directory=directory,
|
||||||
data_types=", ".join(data_types),
|
data_types=", ".join(data_types),
|
||||||
template_start=_TEMPLATE_START,
|
data_types_json=json.dumps(data_types),
|
||||||
template_end=_TEMPLATE_END,
|
config_start=_CONFIG_START,
|
||||||
|
config_end=_CONFIG_END,
|
||||||
existing_section=existing_section,
|
existing_section=existing_section,
|
||||||
)
|
)
|
||||||
|
return compiled, prompt_obj
|
||||||
|
|
||||||
|
|
||||||
# ── Template extraction ───────────────────────────────────────────────────
|
# ── AgentConfig extraction ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _extract_template(text: str) -> str | None:
|
def _extract_agent_config(text: str) -> str | None:
|
||||||
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
"""Return validated AgentConfig JSON string from between markers, or None.
|
||||||
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
|
||||||
|
Parses the JSON with Pydantic to ensure it conforms to the schema before
|
||||||
|
returning. Returns None if markers are absent or JSON is invalid.
|
||||||
|
"""
|
||||||
|
if _CONFIG_START not in text or _CONFIG_END not in text:
|
||||||
|
return None
|
||||||
|
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
|
||||||
|
end_idx = text.index(_CONFIG_END)
|
||||||
|
raw = text[start_idx:end_idx].strip()
|
||||||
|
if not raw:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
parsed = AgentConfig.model_validate_json(raw)
|
||||||
|
return parsed.model_dump_json()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
|
||||||
return None
|
return None
|
||||||
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
|
||||||
end_idx = text.index(_TEMPLATE_END)
|
|
||||||
return text[start_idx:end_idx].strip() or None
|
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call with tool support ───────────────────────────────────────────
|
# ── LLM call with tool support ───────────────────────────────────────────
|
||||||
@@ -199,12 +238,17 @@ async def _call_llm_with_tools(
|
|||||||
system_prompt: str,
|
system_prompt: str,
|
||||||
history: list[dict[str, Any]],
|
history: list[dict[str, Any]],
|
||||||
tools: list[Any],
|
tools: list[Any],
|
||||||
|
*,
|
||||||
|
user_id: str = "",
|
||||||
|
session_id: str = "",
|
||||||
|
langfuse_prompt: Any = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Build LangChain messages from history and invoke the LLM with tools.
|
"""Build LangChain messages from history and invoke the LLM with tools.
|
||||||
|
|
||||||
Handles tool-calling loops: if the LLM calls tools, execute them and
|
Handles tool-calling loops: if the LLM calls tools, execute them and
|
||||||
continue until a final text response is produced.
|
continue until a final text response is produced.
|
||||||
"""
|
"""
|
||||||
|
lf = get_langfuse()
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
for turn in history:
|
for turn in history:
|
||||||
if turn["role"] == "user":
|
if turn["role"] == "user":
|
||||||
@@ -212,16 +256,59 @@ async def _call_llm_with_tools(
|
|||||||
else:
|
else:
|
||||||
messages.append(AIMessage(content=turn["content"]))
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
llm = get_llm(model=None, temperature=0.4)
|
llm = get_agent_llm("setup", temperature=0.4)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
for _ in range(_MAX_TOOL_STEPS):
|
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
||||||
|
_lf_ctx.__enter__()
|
||||||
|
|
||||||
|
_span_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name="journey-setup",
|
||||||
|
input=history[-1]["content"] if history else "",
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||||
|
|
||||||
|
try:
|
||||||
|
for step in range(_MAX_TOOL_STEPS):
|
||||||
|
_gen_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="journey-setup-llm",
|
||||||
|
model=model_for_agent("setup"),
|
||||||
|
prompt=langfuse_prompt,
|
||||||
|
input=messages,
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
if _gen_ctx:
|
||||||
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
|
resp_text = _as_text(response.content)
|
||||||
|
|
||||||
|
# Guard against empty responses (e.g. model returned finish_reason
|
||||||
|
# 'error' which LiteLLM maps to 'stop' with empty content).
|
||||||
|
if not response.tool_calls and not resp_text.strip():
|
||||||
|
logger.warning(
|
||||||
|
"agent_setup: journey LLM returned empty response at step %d — retrying",
|
||||||
|
step,
|
||||||
|
)
|
||||||
|
# Drop the empty AIMessage so we don't pollute history, and retry.
|
||||||
|
continue
|
||||||
|
|
||||||
messages.append(response)
|
messages.append(response)
|
||||||
|
|
||||||
if not response.tool_calls:
|
if not response.tool_calls:
|
||||||
return _as_text(response.content)
|
if _span:
|
||||||
|
_span.update(output=resp_text)
|
||||||
|
return resp_text
|
||||||
|
|
||||||
for call in response.tool_calls:
|
for call in response.tool_calls:
|
||||||
call_name = str(call.get("name", ""))
|
call_name = str(call.get("name", ""))
|
||||||
@@ -247,7 +334,19 @@ async def _call_llm_with_tools(
|
|||||||
|
|
||||||
# Fallback: exceeded max steps.
|
# Fallback: exceeded max steps.
|
||||||
final = await llm.ainvoke(messages)
|
final = await llm.ainvoke(messages)
|
||||||
return _as_text(final.content)
|
final_text = _as_text(final.content)
|
||||||
|
if _span:
|
||||||
|
_span.update(output=final_text)
|
||||||
|
return final_text or (
|
||||||
|
"Sorry, I had trouble processing the files. "
|
||||||
|
"Could you try again? If the issue persists, the files might be too large for me to analyse."
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
if _span_ctx:
|
||||||
|
_span_ctx.__exit__(None, None, None)
|
||||||
|
_lf_ctx.__exit__(None, None, None)
|
||||||
|
if lf:
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
|
||||||
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
# ── Journey handlers (called from device_ws.py) ──────────────────────────
|
||||||
@@ -265,12 +364,12 @@ async def handle_journey_start(
|
|||||||
agent_type = frame.get("agent_type", "local")
|
agent_type = frame.get("agent_type", "local")
|
||||||
directory = frame.get("directory", "")
|
directory = frame.get("directory", "")
|
||||||
data_types = frame.get("data_types", [])
|
data_types = frame.get("data_types", [])
|
||||||
existing_template = frame.get("existing_template")
|
existing_config = frame.get("existing_config")
|
||||||
|
|
||||||
# Use the session_id provided by the FE so the reply matches the
|
# Use the session_id provided by the FE so the reply matches the
|
||||||
# listener key; fall back to a generated one if absent.
|
# listener key; fall back to a generated one if absent.
|
||||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
session_id = frame.get("session_id") or str(uuid.uuid4())
|
||||||
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
|
||||||
|
|
||||||
session = JourneySession(
|
session = JourneySession(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
@@ -279,19 +378,21 @@ async def handle_journey_start(
|
|||||||
directory=directory,
|
directory=directory,
|
||||||
data_types=data_types,
|
data_types=data_types,
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
|
# Seed with an initial user message — some providers require at least one
|
||||||
# ws_context executor (already set by the WS handler before calling us).
|
# user/input message to be present.
|
||||||
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
|
|
||||||
# require at least one user/input message to be present.
|
|
||||||
seed_history: list[dict[str, Any]] = [
|
seed_history: list[dict[str, Any]] = [
|
||||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
||||||
]
|
]
|
||||||
ai_reply = await _call_llm_with_tools(
|
ai_reply = await _call_llm_with_tools(
|
||||||
system_prompt=system_prompt,
|
system_prompt=system_prompt,
|
||||||
history=seed_history,
|
history=seed_history,
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
tools=make_directory_tools(directory),
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
session.history.extend(seed_history)
|
session.history.extend(seed_history)
|
||||||
@@ -305,14 +406,14 @@ async def handle_journey_start(
|
|||||||
directory,
|
directory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if the LLM produced the template on the first turn (unlikely but possible).
|
# Check if the LLM produced the config on the first turn (unlikely but possible).
|
||||||
prompt_template = _extract_template(ai_reply)
|
agent_config = _extract_agent_config(ai_reply)
|
||||||
done = prompt_template is not None
|
done = agent_config is not None
|
||||||
|
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
)
|
)
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
@@ -322,7 +423,7 @@ async def handle_journey_start(
|
|||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"message": display_message,
|
"message": display_message,
|
||||||
"done": done,
|
"done": done,
|
||||||
"prompt_template": prompt_template,
|
"agent_config": agent_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -345,53 +446,59 @@ async def handle_journey_message(
|
|||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"message": "Journey session not found or expired. Please start a new setup.",
|
"message": "Journey session not found or expired. Please start a new setup.",
|
||||||
"done": True,
|
"done": True,
|
||||||
"prompt_template": None,
|
"agent_config": None,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Append user turn.
|
# Append user turn.
|
||||||
session.history.append({"role": "user", "content": message})
|
session.history.append({"role": "user", "content": message})
|
||||||
|
|
||||||
# Call the LLM with tools.
|
# Call the LLM with tools.
|
||||||
|
session_tools = make_directory_tools(session.directory)
|
||||||
ai_reply = await _call_llm_with_tools(
|
ai_reply = await _call_llm_with_tools(
|
||||||
system_prompt=session.system_prompt,
|
system_prompt=session.system_prompt,
|
||||||
history=session.history,
|
history=session.history,
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
tools=session_tools,
|
||||||
|
user_id=session.user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
langfuse_prompt=session.langfuse_prompt,
|
||||||
)
|
)
|
||||||
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
# Check if the LLM produced the final template.
|
# Check if the LLM produced the final config.
|
||||||
prompt_template = _extract_template(ai_reply)
|
agent_config = _extract_agent_config(ai_reply)
|
||||||
done = prompt_template is not None
|
done = agent_config is not None
|
||||||
|
|
||||||
# If the LLM didn't produce a template, nudge it once it has asked enough
|
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
|
||||||
# questions (>= _MIN_TURNS_BEFORE_NUDGE) or hits the hard safety cap.
|
|
||||||
if not done:
|
if not done:
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
if turns >= _MAX_TURNS:
|
if turns >= _MAX_TURNS:
|
||||||
nudge_content = (
|
nudge_content = (
|
||||||
"[System: You have enough information. Please generate the final "
|
"[System: You have enough information. Please generate the final "
|
||||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
|
||||||
)
|
)
|
||||||
session.history.append({"role": "user", "content": nudge_content})
|
session.history.append({"role": "user", "content": nudge_content})
|
||||||
|
|
||||||
nudge_reply = await _call_llm_with_tools(
|
nudge_reply = await _call_llm_with_tools(
|
||||||
system_prompt=session.system_prompt,
|
system_prompt=session.system_prompt,
|
||||||
history=session.history,
|
history=session.history,
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
tools=session_tools,
|
||||||
|
user_id=session.user_id,
|
||||||
|
session_id=session_id,
|
||||||
|
langfuse_prompt=session.langfuse_prompt,
|
||||||
)
|
)
|
||||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
session.history.append({"role": "assistant", "content": nudge_reply})
|
||||||
|
|
||||||
prompt_template = _extract_template(nudge_reply)
|
agent_config = _extract_agent_config(nudge_reply)
|
||||||
if prompt_template is not None:
|
if agent_config is not None:
|
||||||
done = True
|
done = True
|
||||||
ai_reply = nudge_reply
|
ai_reply = nudge_reply
|
||||||
|
|
||||||
display_message = ai_reply
|
display_message = ai_reply
|
||||||
if done:
|
if done:
|
||||||
display_message = (
|
display_message = (
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
|
||||||
if _TEMPLATE_START in ai_reply
|
if _CONFIG_START in ai_reply
|
||||||
else "Here is your agent configuration. You can save it or continue refining."
|
else "Here is your agent configuration. You can save it or continue refining."
|
||||||
)
|
)
|
||||||
_sessions.pop(session_id, None)
|
_sessions.pop(session_id, None)
|
||||||
@@ -402,5 +509,5 @@ async def handle_journey_message(
|
|||||||
"session_id": session_id,
|
"session_id": session_id,
|
||||||
"message": display_message,
|
"message": display_message,
|
||||||
"done": done,
|
"done": done,
|
||||||
"prompt_template": prompt_template,
|
"agent_config": agent_config,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,8 +12,11 @@ in backend agent-config tables.
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
from sqlalchemy import func, select
|
from sqlalchemy import func, select
|
||||||
@@ -177,6 +180,11 @@ async def trigger_agent_run(
|
|||||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||||
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
await _enforce_run_frequency(current_user.tier, current_user.id, db)
|
||||||
|
|
||||||
|
last_run_dt = (
|
||||||
|
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
|
||||||
|
if body.last_run_at
|
||||||
|
else None
|
||||||
|
)
|
||||||
config = LocalAgentConfig(
|
config = LocalAgentConfig(
|
||||||
id=str(uuid.uuid4()),
|
id=str(uuid.uuid4()),
|
||||||
user_id=current_user.id,
|
user_id=current_user.id,
|
||||||
@@ -184,10 +192,12 @@ async def trigger_agent_run(
|
|||||||
name="Local Directory Monitor",
|
name="Local Directory Monitor",
|
||||||
directory_paths=[body.directory],
|
directory_paths=[body.directory],
|
||||||
data_types=_to_data_types(body.what_to_extract),
|
data_types=_to_data_types(body.what_to_extract),
|
||||||
prompt_template=body.custom_agent_prompt,
|
prompt_template=body.custom_agent_prompt or "",
|
||||||
|
agent_config=body.agent_config,
|
||||||
file_extensions=[],
|
file_extensions=[],
|
||||||
schedule_cron=body.batch_interval,
|
schedule_cron=body.batch_interval,
|
||||||
enabled=True,
|
enabled=True,
|
||||||
|
last_run_at=last_run_dt,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
|
||||||
|
|||||||
@@ -1,34 +1,68 @@
|
|||||||
"""Auth routes: register, login, refresh, me.
|
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
|
||||||
|
|
||||||
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||||
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||||
SHA-256 hashes so plaintext never reaches the DB.
|
SHA-256 hashes so plaintext never reaches the DB.
|
||||||
|
|
||||||
|
OAuth (Google):
|
||||||
|
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
|
||||||
|
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
import json
|
||||||
import time
|
import time
|
||||||
|
import urllib.parse
|
||||||
import uuid
|
import uuid
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
import bcrypt
|
import bcrypt
|
||||||
from cryptography.fernet import Fernet
|
from cryptography.fernet import Fernet
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from fastapi.responses import RedirectResponse
|
||||||
from jose import jwt
|
from jose import jwt
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel, Field
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
|
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import RefreshToken, User
|
from app.models import OAuthAccount, RefreshToken, User
|
||||||
from app.schemas import AuthTokens, UserProfile
|
from app.schemas import AuthTokens, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth provider registry ───────────────────────────────────────────
|
||||||
|
|
||||||
|
def _get_google_provider() -> GoogleOAuthProvider:
|
||||||
|
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||||
|
"Google login is not configured on this server",
|
||||||
|
)
|
||||||
|
return GoogleOAuthProvider(
|
||||||
|
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
|
||||||
|
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||||
|
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_PROVIDERS = {"google": _get_google_provider}
|
||||||
|
|
||||||
|
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
|
||||||
|
# Production note: replace with Redis for multi-process deployments.
|
||||||
|
_pending_states: dict[str, tuple[str, float]] = {}
|
||||||
|
_STATE_TTL_SECONDS = 600 # 10 minutes
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ─────────────────────────────────────────────────
|
# ── Internal helpers ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -231,5 +265,531 @@ async def update_profile(
|
|||||||
email=user.email,
|
email=user.email,
|
||||||
name=user.name,
|
name=user.name,
|
||||||
surname=user.surname,
|
surname=user.surname,
|
||||||
|
avatar_url=user.avatar_url,
|
||||||
tier=current_user.tier,
|
tier=current_user.tier,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth helpers ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
|
||||||
|
"""Create a refresh token row and return (plain_token, AuthTokens)."""
|
||||||
|
plain_token = str(uuid.uuid4())
|
||||||
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
|
)
|
||||||
|
rt = RefreshToken(
|
||||||
|
user_id=user.id,
|
||||||
|
token_hash=_hash_token(plain_token),
|
||||||
|
expires_at=expires_at,
|
||||||
|
)
|
||||||
|
db.add(rt)
|
||||||
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
|
return plain_token, AuthTokens(
|
||||||
|
access_token=access_token,
|
||||||
|
refresh_token=plain_token,
|
||||||
|
expires_at=expires_at_ms,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth request/response schemas ───────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _OAuthAuthorizeResponse(BaseModel):
|
||||||
|
url: str
|
||||||
|
state: str
|
||||||
|
|
||||||
|
|
||||||
|
class _OAuthCallbackRequest(BaseModel):
|
||||||
|
code: str
|
||||||
|
state: str
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth routes ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/oauth/{provider}/web-callback",
|
||||||
|
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
|
||||||
|
include_in_schema=False,
|
||||||
|
)
|
||||||
|
async def oauth_web_callback(
|
||||||
|
provider: Literal["google"],
|
||||||
|
code: str,
|
||||||
|
state: str,
|
||||||
|
) -> RedirectResponse:
|
||||||
|
"""Google redirects here after user consent.
|
||||||
|
|
||||||
|
This endpoint immediately redirects to the Electron deep-link URI so the
|
||||||
|
desktop app receives the authorization code. It is intentionally simple —
|
||||||
|
no state validation here (the Electron app + backend callback do that).
|
||||||
|
|
||||||
|
Registered in Google Cloud Console as:
|
||||||
|
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
|
||||||
|
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
|
||||||
|
"""
|
||||||
|
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
|
||||||
|
deep_link = f"adiuvai://oauth/callback?{params}"
|
||||||
|
return RedirectResponse(url=deep_link, status_code=302)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get(
|
||||||
|
"/oauth/{provider}/authorize",
|
||||||
|
response_model=_OAuthAuthorizeResponse,
|
||||||
|
summary="Start OAuth flow — returns the provider consent-screen URL",
|
||||||
|
)
|
||||||
|
async def oauth_authorize(
|
||||||
|
provider: Literal["google"],
|
||||||
|
) -> _OAuthAuthorizeResponse:
|
||||||
|
"""Generate a PKCE state + code_challenge and return the authorization URL.
|
||||||
|
|
||||||
|
The client opens this URL in the system browser. After the user grants
|
||||||
|
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
|
||||||
|
with ``code`` and ``state`` query params. The client then calls
|
||||||
|
``POST /auth/oauth/{provider}/callback`` with those values.
|
||||||
|
"""
|
||||||
|
provider_factory = _PROVIDERS.get(provider)
|
||||||
|
if provider_factory is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
oauth_provider = provider_factory()
|
||||||
|
state = str(uuid.uuid4())
|
||||||
|
code_verifier, code_challenge = generate_pkce_pair()
|
||||||
|
|
||||||
|
# Purge expired states to prevent unbounded growth.
|
||||||
|
now = time.time()
|
||||||
|
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
|
||||||
|
for s in expired:
|
||||||
|
del _pending_states[s]
|
||||||
|
|
||||||
|
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
|
||||||
|
|
||||||
|
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
|
||||||
|
return _OAuthAuthorizeResponse(url=url, state=state)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/oauth/{provider}/callback",
|
||||||
|
response_model=AuthTokens,
|
||||||
|
summary="Complete OAuth flow — exchange code and issue JWT tokens",
|
||||||
|
)
|
||||||
|
async def oauth_callback(
|
||||||
|
provider: Literal["google"],
|
||||||
|
body: _OAuthCallbackRequest,
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AuthTokens:
|
||||||
|
"""Validate state, exchange the authorization code, and sign in (or register) the user.
|
||||||
|
|
||||||
|
Resolution order:
|
||||||
|
1. ``oauth_accounts`` row match → existing user, log in.
|
||||||
|
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
|
||||||
|
3. No match → create new user (password_hash=None, avatar from provider).
|
||||||
|
"""
|
||||||
|
provider_factory = _PROVIDERS.get(provider)
|
||||||
|
if provider_factory is None:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
|
||||||
|
|
||||||
|
# Validate state (CSRF protection).
|
||||||
|
now = time.time()
|
||||||
|
entry = _pending_states.pop(body.state, None)
|
||||||
|
if entry is None or entry[1] < now:
|
||||||
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
||||||
|
|
||||||
|
code_verifier, _ = entry
|
||||||
|
|
||||||
|
oauth_provider = provider_factory()
|
||||||
|
|
||||||
|
# Exchange code for tokens.
|
||||||
|
try:
|
||||||
|
token_data = await oauth_provider.exchange_code(
|
||||||
|
code=body.code,
|
||||||
|
code_verifier=code_verifier,
|
||||||
|
redirect_uri=settings.OAUTH_REDIRECT_URI,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
|
||||||
|
)
|
||||||
|
|
||||||
|
access_token_google = token_data.get("access_token")
|
||||||
|
if not access_token_google:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
|
||||||
|
|
||||||
|
# Fetch user identity.
|
||||||
|
try:
|
||||||
|
userinfo = await oauth_provider.get_userinfo(access_token_google)
|
||||||
|
except Exception:
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
|
||||||
|
|
||||||
|
# ── Resolution order ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
# 1. Existing OAuth link?
|
||||||
|
oauth_result = await db.execute(
|
||||||
|
select(OAuthAccount).where(
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
OAuthAccount.provider_user_id == userinfo.provider_user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
oauth_account = oauth_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if oauth_account is not None:
|
||||||
|
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
|
||||||
|
user = user_result.scalar_one()
|
||||||
|
# Backfill avatar if the user doesn't have one yet.
|
||||||
|
if user.avatar_url is None and userinfo.avatar_url:
|
||||||
|
user.avatar_url = userinfo.avatar_url
|
||||||
|
await db.commit()
|
||||||
|
plain_token, tokens = await _issue_refresh_token(user, db)
|
||||||
|
await db.commit()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# 2. Email match with a verified Google email → link accounts.
|
||||||
|
if userinfo.email_verified:
|
||||||
|
email_result = await db.execute(select(User).where(User.email == userinfo.email))
|
||||||
|
existing_user = email_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if existing_user is not None:
|
||||||
|
new_link = OAuthAccount(
|
||||||
|
user_id=existing_user.id,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=userinfo.provider_user_id,
|
||||||
|
provider_email=userinfo.email,
|
||||||
|
)
|
||||||
|
db.add(new_link)
|
||||||
|
if existing_user.avatar_url is None and userinfo.avatar_url:
|
||||||
|
existing_user.avatar_url = userinfo.avatar_url
|
||||||
|
plain_token, tokens = await _issue_refresh_token(existing_user, db)
|
||||||
|
await db.commit()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
# Guard: if the email is already taken but we couldn't auto-link (e.g.
|
||||||
|
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
|
||||||
|
if not userinfo.email_verified:
|
||||||
|
conflict = await db.execute(select(User).where(User.email == userinfo.email))
|
||||||
|
if conflict.scalar_one_or_none() is not None:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_409_CONFLICT,
|
||||||
|
"An account with this email already exists. "
|
||||||
|
"Please sign in with your password.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# 3. New user — social-only account (no password).
|
||||||
|
new_user = User(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
email=userinfo.email,
|
||||||
|
name=userinfo.name,
|
||||||
|
password_hash=None,
|
||||||
|
avatar_url=userinfo.avatar_url,
|
||||||
|
tier="free",
|
||||||
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
|
)
|
||||||
|
db.add(new_user)
|
||||||
|
await db.flush() # populate new_user.id
|
||||||
|
|
||||||
|
new_oauth = OAuthAccount(
|
||||||
|
user_id=new_user.id,
|
||||||
|
provider=provider,
|
||||||
|
provider_user_id=userinfo.provider_user_id,
|
||||||
|
provider_email=userinfo.email,
|
||||||
|
)
|
||||||
|
db.add(new_oauth)
|
||||||
|
|
||||||
|
plain_token, tokens = await _issue_refresh_token(new_user, db)
|
||||||
|
await db.commit()
|
||||||
|
return tokens
|
||||||
|
|
||||||
|
|
||||||
|
# ── Onboarding helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
|
||||||
|
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
|
||||||
|
|
||||||
|
# We can't call the FastAPI dependency directly, but we can replicate
|
||||||
|
# the core logic inline. Instead, we just re-query the same way.
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
|
)
|
||||||
|
default_tier = "power" if settings.ENV == "dev" else "free"
|
||||||
|
tier: str = result.scalar_one_or_none() or default_tier
|
||||||
|
|
||||||
|
user_result = await db.execute(
|
||||||
|
select(
|
||||||
|
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
|
||||||
|
User.password_hash,
|
||||||
|
).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
user_row = user_result.one_or_none()
|
||||||
|
|
||||||
|
onboarding_ms: int | None = None
|
||||||
|
if user_row and user_row.onboarding_completed_at is not None:
|
||||||
|
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
|
||||||
|
|
||||||
|
memory_dict: dict[str, str] = {}
|
||||||
|
try:
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
blocks = await mw.list_core_blocks(user_id)
|
||||||
|
memory_dict = {b["label"]: b["value"] for b in blocks}
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return UserProfile(
|
||||||
|
id=user_id,
|
||||||
|
email=email,
|
||||||
|
name=user_row.name if user_row else None,
|
||||||
|
surname=user_row.surname if user_row else None,
|
||||||
|
avatar_url=user_row.avatar_url if user_row else None,
|
||||||
|
has_password=bool(user_row.password_hash) if user_row else False,
|
||||||
|
tier=tier,
|
||||||
|
onboarding_completed_at=onboarding_ms,
|
||||||
|
memory=memory_dict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Onboarding routes ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateMemoryRequest(BaseModel):
|
||||||
|
memory: dict[str, str] = Field(default_factory=dict)
|
||||||
|
mark_onboarded: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/memory", response_model=UserProfile)
|
||||||
|
async def update_memory(
|
||||||
|
body: _UpdateMemoryRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update core memory key/value pairs and optionally mark onboarding complete."""
|
||||||
|
mw = MemoryMiddleware(db)
|
||||||
|
for key, value in body.memory.items():
|
||||||
|
await mw.update_core(current_user.id, key, value)
|
||||||
|
if body.mark_onboarded:
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.onboarding_completed_at = datetime.now(timezone.utc)
|
||||||
|
await db.commit()
|
||||||
|
return await _build_profile(current_user.id, current_user.email, db)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/me/onboarding/reset")
|
||||||
|
async def reset_onboarding(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
):
|
||||||
|
"""Reset onboarding so the wizard runs again on next login."""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.onboarding_completed_at = None
|
||||||
|
await db.commit()
|
||||||
|
return {"status": "reset"}
|
||||||
|
|
||||||
|
|
||||||
|
class _NormalizeRequest(BaseModel):
|
||||||
|
inputs: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
class _NormalizeResponse(BaseModel):
|
||||||
|
normalized: dict[str, str]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
|
||||||
|
async def normalize_onboarding(
|
||||||
|
body: _NormalizeRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _NormalizeResponse:
|
||||||
|
"""One-shot LLM normalization for free-text onboarding answers."""
|
||||||
|
if not body.inputs:
|
||||||
|
return _NormalizeResponse(normalized={})
|
||||||
|
try:
|
||||||
|
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||||
|
prompt = (
|
||||||
|
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
|
||||||
|
"Return a JSON object with the same keys and normalized values.\n"
|
||||||
|
"Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n"
|
||||||
|
f"Input: {json.dumps(body.inputs)}"
|
||||||
|
)
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[
|
||||||
|
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
normalized = json.loads(response.content)
|
||||||
|
return _NormalizeResponse(normalized=normalized)
|
||||||
|
except Exception:
|
||||||
|
# LLM failure must never block onboarding — return inputs unchanged
|
||||||
|
return _NormalizeResponse(normalized=body.inputs)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Password management ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _ChangePasswordRequest(BaseModel):
|
||||||
|
current_password: str = Field(min_length=1)
|
||||||
|
new_password: str = Field(min_length=8)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/password", status_code=status.HTTP_200_OK)
|
||||||
|
async def change_password(
|
||||||
|
body: _ChangePasswordRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Change the authenticated user's password.
|
||||||
|
|
||||||
|
Requires the current password for verification.
|
||||||
|
Returns 400 for social-only users (no password set).
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
if user.password_hash is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
"This account uses social login and has no password to change",
|
||||||
|
)
|
||||||
|
|
||||||
|
if not _verify_password(body.current_password, user.password_hash):
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
|
||||||
|
|
||||||
|
user.password_hash = _hash_password(body.new_password)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── OAuth account management ─────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/me/oauth-accounts", response_model=list[dict])
|
||||||
|
async def list_oauth_accounts(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[dict]:
|
||||||
|
"""List all OAuth providers linked to the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
accounts = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"provider": a.provider,
|
||||||
|
"provider_email": a.provider_email,
|
||||||
|
"created_at": int(a.created_at.timestamp() * 1000),
|
||||||
|
}
|
||||||
|
for a in accounts
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
|
||||||
|
async def unlink_oauth_account(
|
||||||
|
provider: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Unlink an OAuth provider from the authenticated user.
|
||||||
|
|
||||||
|
Refuses if the user has no password and this is their only login method.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
|
||||||
|
oauth_result = await db.execute(
|
||||||
|
select(OAuthAccount).where(
|
||||||
|
OAuthAccount.user_id == current_user.id,
|
||||||
|
OAuthAccount.provider == provider,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
account = oauth_result.scalar_one_or_none()
|
||||||
|
if account is None:
|
||||||
|
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
|
||||||
|
|
||||||
|
# Safety: don't let users lock themselves out.
|
||||||
|
all_oauth = await db.execute(
|
||||||
|
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
oauth_count = len(all_oauth.scalars().all())
|
||||||
|
|
||||||
|
if user.password_hash is None and oauth_count <= 1:
|
||||||
|
raise HTTPException(
|
||||||
|
status.HTTP_400_BAD_REQUEST,
|
||||||
|
"Cannot unlink the only login method. Set a password first.",
|
||||||
|
)
|
||||||
|
|
||||||
|
await db.delete(account)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Avatar update ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateAvatarRequest(BaseModel):
|
||||||
|
avatar_url: str = Field(min_length=1)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/me/avatar", response_model=UserProfile)
|
||||||
|
async def update_avatar(
|
||||||
|
body: _UpdateAvatarRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> UserProfile:
|
||||||
|
"""Update the authenticated user's avatar URL.
|
||||||
|
|
||||||
|
Accepts {"avatar_url": "https://..."} — the client uploads the image
|
||||||
|
to its own storage and passes the resulting URL here.
|
||||||
|
"""
|
||||||
|
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
|
||||||
|
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
|
||||||
|
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
user.avatar_url = body.avatar_url
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return await _build_profile(current_user.id, current_user.email, db)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Account deletion ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/me", status_code=status.HTTP_200_OK)
|
||||||
|
async def delete_account(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Permanently delete the authenticated user's account.
|
||||||
|
|
||||||
|
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
|
||||||
|
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
|
||||||
|
is cancelled if active.
|
||||||
|
"""
|
||||||
|
# Cancel Stripe subscription if present.
|
||||||
|
try:
|
||||||
|
from app.billing.stripe_service import stripe_service # noqa: PLC0415
|
||||||
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
|
except HTTPException:
|
||||||
|
pass # No subscription — that's fine
|
||||||
|
|
||||||
|
# Delete all memory rows (core, associative, episodic, proactive).
|
||||||
|
try:
|
||||||
|
from app.models import ( # noqa: PLC0415
|
||||||
|
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
|
||||||
|
)
|
||||||
|
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
|
||||||
|
await db.execute(
|
||||||
|
model.__table__.delete().where(model.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
pass # Non-critical — cascade on User will handle most
|
||||||
|
|
||||||
|
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
|
||||||
|
result = await db.execute(select(User).where(User.id == current_user.id))
|
||||||
|
user = result.scalar_one()
|
||||||
|
await db.delete(user)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|||||||
@@ -1,171 +0,0 @@
|
|||||||
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
|
||||||
|
|
||||||
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
|
|
||||||
PostgreSQL ``backup_metadata`` table.
|
|
||||||
|
|
||||||
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
|
||||||
treating "history" as a ``{backup_id}`` path parameter.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from email.utils import parsedate_to_datetime
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.billing.tier_manager import tier_manager
|
|
||||||
from app.db import get_session
|
|
||||||
from app.models import BackupMetadata as BackupMetadataModel
|
|
||||||
from app.schemas import BackupMetadata, UserProfile
|
|
||||||
from app.storage.blob_store import BlobStore
|
|
||||||
from app.storage.encryption import reject_if_tampered
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/backup", tags=["backup"])
|
|
||||||
|
|
||||||
_blob_store = BlobStore()
|
|
||||||
|
|
||||||
|
|
||||||
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
|
|
||||||
"""Return total backup bytes stored by *user_id*."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
|
|
||||||
BackupMetadataModel.user_id == user_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return int(result.scalar_one())
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_backup_quota(
|
|
||||||
user: UserProfile, size_bytes: int, db: AsyncSession
|
|
||||||
) -> None:
|
|
||||||
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
|
||||||
current = await _current_backup_bytes(user.id, db)
|
|
||||||
tier_manager.enforce_backup_quota(
|
|
||||||
user.tier, current_bytes=current, additional_bytes=size_bytes
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("")
|
|
||||||
async def upload_backup(
|
|
||||||
request: Request,
|
|
||||||
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
|
||||||
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
|
||||||
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Upload an E2E-encrypted backup blob.
|
|
||||||
|
|
||||||
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
|
||||||
"""
|
|
||||||
blob = await request.body()
|
|
||||||
reject_if_tampered(blob, x_backup_checksum)
|
|
||||||
await _check_backup_quota(current_user, len(blob), db)
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
|
||||||
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
|
||||||
)
|
|
||||||
|
|
||||||
row = BackupMetadataModel(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=current_user.id,
|
|
||||||
s3_key=s3_key,
|
|
||||||
version=x_backup_version,
|
|
||||||
timestamp=x_backup_timestamp,
|
|
||||||
checksum=x_backup_checksum,
|
|
||||||
size_bytes=len(blob),
|
|
||||||
)
|
|
||||||
db.add(row)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/history", response_model=list[BackupMetadata])
|
|
||||||
async def backup_history(
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[BackupMetadata]:
|
|
||||||
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(BackupMetadataModel)
|
|
||||||
.where(BackupMetadataModel.user_id == current_user.id)
|
|
||||||
.order_by(BackupMetadataModel.timestamp.desc())
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
return [
|
|
||||||
BackupMetadata(
|
|
||||||
version=r.version,
|
|
||||||
timestamp=r.timestamp,
|
|
||||||
checksum=r.checksum,
|
|
||||||
chunk_count=1,
|
|
||||||
)
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("")
|
|
||||||
async def download_backup(
|
|
||||||
request: Request,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> Response:
|
|
||||||
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(BackupMetadataModel)
|
|
||||||
.where(BackupMetadataModel.user_id == current_user.id)
|
|
||||||
.order_by(BackupMetadataModel.timestamp.desc())
|
|
||||||
.limit(1)
|
|
||||||
)
|
|
||||||
latest = result.scalar_one_or_none()
|
|
||||||
if latest is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
|
||||||
|
|
||||||
ims_header = request.headers.get("If-Modified-Since")
|
|
||||||
if ims_header:
|
|
||||||
try:
|
|
||||||
ims_dt = parsedate_to_datetime(ims_header)
|
|
||||||
ims_ms = int(ims_dt.timestamp() * 1000)
|
|
||||||
if latest.timestamp <= ims_ms:
|
|
||||||
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
|
||||||
except Exception:
|
|
||||||
pass # malformed header — ignore and serve the blob
|
|
||||||
|
|
||||||
blob = await _blob_store.download(current_user.id, latest.s3_key)
|
|
||||||
return Response(
|
|
||||||
content=blob,
|
|
||||||
media_type="application/octet-stream",
|
|
||||||
headers={
|
|
||||||
"X-Backup-Version": str(latest.version),
|
|
||||||
"X-Backup-Timestamp": str(latest.timestamp),
|
|
||||||
"X-Checksum": latest.checksum,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{backup_id}", response_model=dict)
|
|
||||||
async def delete_backup(
|
|
||||||
backup_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a specific backup by ID."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(BackupMetadataModel).where(
|
|
||||||
BackupMetadataModel.id == backup_id,
|
|
||||||
BackupMetadataModel.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
target = result.scalar_one_or_none()
|
|
||||||
if target is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
|
||||||
|
|
||||||
await _blob_store.delete(current_user.id, target.s3_key)
|
|
||||||
await db.delete(target)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -83,3 +83,16 @@ async def cancel_subscription(
|
|||||||
"""Cancel the active subscription."""
|
"""Cancel the active subscription."""
|
||||||
await stripe_service.cancel_subscription(current_user.id, db)
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
return {"ok": True}
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/invoices", response_model=list[dict])
|
||||||
|
async def list_invoices(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return billing history (invoices) from Stripe.
|
||||||
|
|
||||||
|
Returns an empty list when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
invoices = await stripe_service.list_invoices(current_user.id, db)
|
||||||
|
return invoices
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
"""Chat routes: POST /chat (REST fallback).
|
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
|
||||||
|
|
||||||
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||||
"""
|
"""
|
||||||
@@ -7,14 +7,30 @@ from __future__ import annotations
|
|||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
from fastapi import APIRouter, Depends
|
||||||
from fastapi.responses import JSONResponse
|
from fastapi.responses import JSONResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.core.deep_agent import run_home
|
from app.core.deep_agent import run_home
|
||||||
|
from app.core.llm import embed
|
||||||
from app.schemas import ChatRequest, UserProfile
|
from app.schemas import ChatRequest, UserProfile
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Embed helpers ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedResponse(BaseModel):
|
||||||
|
vector: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Endpoints ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
@router.post("")
|
||||||
async def chat(
|
async def chat(
|
||||||
body: ChatRequest,
|
body: ChatRequest,
|
||||||
@@ -27,3 +43,17 @@ async def chat(
|
|||||||
context=body.context.model_dump(),
|
context=body.context.model_dump(),
|
||||||
)
|
)
|
||||||
return JSONResponse(content={"response": response})
|
return JSONResponse(content={"response": response})
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/embed", response_model=_EmbedResponse)
|
||||||
|
async def embed_text(
|
||||||
|
body: _EmbedRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _EmbedResponse:
|
||||||
|
"""Generate a 1536-dim embedding vector for the given text.
|
||||||
|
|
||||||
|
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||||
|
Used by Electron (vectordb.ts) for local note search.
|
||||||
|
"""
|
||||||
|
vector = await embed(body.text)
|
||||||
|
return _EmbedResponse(vector=vector)
|
||||||
|
|||||||
@@ -1,148 +0,0 @@
|
|||||||
"""Plugins routes: browse and install plugins from the marketplace.
|
|
||||||
|
|
||||||
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
|
|
||||||
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.db import get_session
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
from app.marketplace.revenue_share import revenue_share
|
|
||||||
from app.models import PluginInstallation, PluginReview as PluginReviewModel
|
|
||||||
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier gate ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _require_plugin_tier(user: UserProfile) -> None:
|
|
||||||
"""Raise HTTP 403 for users below Power tier."""
|
|
||||||
if user.tier not in ("power", "team"):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="Plugin marketplace requires Power tier or above",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local detail schema ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _PluginDetail(BaseModel):
|
|
||||||
plugin: PluginManifest
|
|
||||||
install_count: int
|
|
||||||
ratings: list[Any]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("", response_model=PluginListResponse)
|
|
||||||
async def list_plugins(
|
|
||||||
category: str | None = Query(default=None),
|
|
||||||
q: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> PluginListResponse:
|
|
||||||
"""Browse the plugin marketplace. Requires Power tier or above."""
|
|
||||||
_require_plugin_tier(current_user)
|
|
||||||
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
|
||||||
async def get_plugin(
|
|
||||||
plugin_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> _PluginDetail:
|
|
||||||
"""Get full plugin details including install count. Requires Power tier or above."""
|
|
||||||
_require_plugin_tier(current_user)
|
|
||||||
entry = await registry.get_plugin(db, plugin_id)
|
|
||||||
if entry is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
|
||||||
|
|
||||||
# Fetch review ratings for this plugin
|
|
||||||
review_result = await db.execute(
|
|
||||||
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
|
|
||||||
)
|
|
||||||
reviews = review_result.scalars().all()
|
|
||||||
ratings = [
|
|
||||||
{
|
|
||||||
"reviewer_id": r.reviewer_id,
|
|
||||||
"decision": r.decision,
|
|
||||||
"notes": r.notes,
|
|
||||||
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
|
|
||||||
}
|
|
||||||
for r in reviews
|
|
||||||
]
|
|
||||||
|
|
||||||
return _PluginDetail(
|
|
||||||
plugin=entry["manifest"],
|
|
||||||
install_count=entry["install_count"],
|
|
||||||
ratings=ratings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/{plugin_id}/install", response_model=dict)
|
|
||||||
async def install_plugin(
|
|
||||||
plugin_id: str,
|
|
||||||
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
|
||||||
|
|
||||||
Requires Power tier or above.
|
|
||||||
"""
|
|
||||||
_require_plugin_tier(current_user)
|
|
||||||
entry = await registry.get_plugin(db, plugin_id)
|
|
||||||
if entry is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
|
||||||
|
|
||||||
# Record the installation in plugin_installations
|
|
||||||
installation = PluginInstallation(
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
)
|
|
||||||
db.add(installation)
|
|
||||||
await db.flush()
|
|
||||||
|
|
||||||
await revenue_share.record_install(
|
|
||||||
db,
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
amount_cents=entry["manifest"].price_cents,
|
|
||||||
)
|
|
||||||
|
|
||||||
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
|
||||||
return {"ok": True, "download_url": download_url}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/{plugin_id}/install", response_model=dict)
|
|
||||||
async def uninstall_plugin(
|
|
||||||
plugin_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Unregister a plugin installation."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(PluginInstallation).where(
|
|
||||||
PluginInstallation.plugin_id == plugin_id,
|
|
||||||
PluginInstallation.user_id == current_user.id,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
installation = result.scalar_one_or_none()
|
|
||||||
if installation is not None:
|
|
||||||
await db.delete(installation)
|
|
||||||
await db.commit()
|
|
||||||
await registry.record_uninstall(db, plugin_id)
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -1,195 +0,0 @@
|
|||||||
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
|
||||||
|
|
||||||
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
|
|
||||||
PostgreSQL ``storage_records`` table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.billing.tier_manager import tier_manager
|
|
||||||
from app.db import get_session
|
|
||||||
from app.models import StorageRecord
|
|
||||||
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
|
||||||
from app.storage.blob_store import BlobStore
|
|
||||||
from app.storage.encryption import reject_if_tampered
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/storage", tags=["storage"])
|
|
||||||
|
|
||||||
_blob_store = BlobStore()
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local response schemas ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _CreateResponse(BaseModel):
|
|
||||||
id: str
|
|
||||||
created_at: int
|
|
||||||
|
|
||||||
|
|
||||||
class _RecordMeta(BaseModel):
|
|
||||||
id: str
|
|
||||||
table: str
|
|
||||||
checksum: str
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Helpers ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
|
|
||||||
"""Return total bytes stored by *user_id*."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
|
|
||||||
StorageRecord.user_id == user_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return int(result.scalar_one())
|
|
||||||
|
|
||||||
|
|
||||||
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
|
|
||||||
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
|
|
||||||
current = await _current_usage_bytes(user.id, db)
|
|
||||||
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_record_for_user(
|
|
||||||
record_id: str, user_id: str, db: AsyncSession
|
|
||||||
) -> StorageRecord:
|
|
||||||
"""Look up a record and verify ownership. Returns 404 on mismatch
|
|
||||||
to prevent user enumeration attacks."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(StorageRecord).where(
|
|
||||||
StorageRecord.id == record_id, StorageRecord.user_id == user_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
record = result.scalar_one_or_none()
|
|
||||||
if record is None:
|
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
|
||||||
return record
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def create_record(
|
|
||||||
body: StorageRecordCreate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> _CreateResponse:
|
|
||||||
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
|
||||||
reject_if_tampered(body.blob, body.checksum)
|
|
||||||
await _check_quota(current_user, len(body.blob), db)
|
|
||||||
|
|
||||||
record_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
|
||||||
current_user.id, body.table, record_id, body.blob, body.checksum
|
|
||||||
)
|
|
||||||
|
|
||||||
record = StorageRecord(
|
|
||||||
id=record_id,
|
|
||||||
user_id=current_user.id,
|
|
||||||
table_name=body.table,
|
|
||||||
s3_key=s3_key,
|
|
||||||
checksum=body.checksum,
|
|
||||||
size_bytes=len(body.blob),
|
|
||||||
)
|
|
||||||
db.add(record)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(record)
|
|
||||||
|
|
||||||
created_at_ms = int(record.created_at.timestamp() * 1000)
|
|
||||||
return _CreateResponse(id=record_id, created_at=created_at_ms)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/records", response_model=list[_RecordMeta])
|
|
||||||
async def list_records(
|
|
||||||
table: str | None = Query(default=None),
|
|
||||||
page: int = Query(default=1, ge=1),
|
|
||||||
limit: int = Query(default=50, ge=1, le=200),
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> list[_RecordMeta]:
|
|
||||||
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
|
||||||
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
|
|
||||||
if table is not None:
|
|
||||||
query = query.where(StorageRecord.table_name == table)
|
|
||||||
query = query.offset((page - 1) * limit).limit(limit)
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
|
|
||||||
return [
|
|
||||||
_RecordMeta(
|
|
||||||
id=r.id,
|
|
||||||
table=r.table_name,
|
|
||||||
checksum=r.checksum,
|
|
||||||
created_at=int(r.created_at.timestamp() * 1000),
|
|
||||||
updated_at=int(r.updated_at.timestamp() * 1000),
|
|
||||||
)
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/records/{record_id}")
|
|
||||||
async def download_record(
|
|
||||||
record_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> Response:
|
|
||||||
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
|
||||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
|
||||||
blob = await _blob_store.download(current_user.id, record.s3_key)
|
|
||||||
return Response(
|
|
||||||
content=blob,
|
|
||||||
media_type="application/octet-stream",
|
|
||||||
headers={"X-Checksum": record.checksum},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/records/{record_id}", response_model=dict)
|
|
||||||
async def update_record(
|
|
||||||
record_id: str,
|
|
||||||
body: StorageRecordUpdate,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
|
||||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
|
||||||
reject_if_tampered(body.blob, body.checksum)
|
|
||||||
|
|
||||||
delta = len(body.blob) - record.size_bytes
|
|
||||||
if delta > 0:
|
|
||||||
await _check_quota(current_user, delta, db)
|
|
||||||
|
|
||||||
s3_key = await _blob_store.upload(
|
|
||||||
current_user.id, record.table_name, record_id, body.blob, body.checksum
|
|
||||||
)
|
|
||||||
|
|
||||||
record.s3_key = s3_key
|
|
||||||
record.checksum = body.checksum
|
|
||||||
record.size_bytes = len(body.blob)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/records/{record_id}", response_model=dict)
|
|
||||||
async def delete_record(
|
|
||||||
record_id: str,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete a record and its S3 blob."""
|
|
||||||
record = await _get_record_for_user(record_id, current_user.id, db)
|
|
||||||
await _blob_store.delete(current_user.id, record.s3_key)
|
|
||||||
await db.delete(record)
|
|
||||||
await db.commit()
|
|
||||||
return {"ok": True}
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Depends
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from app.api.deps import get_current_user
|
|
||||||
from app.core.llm import embed
|
|
||||||
from app.schemas import (
|
|
||||||
UserProfile,
|
|
||||||
VectorSearchRequest,
|
|
||||||
VectorSearchResponse,
|
|
||||||
VectorUpsertRequest,
|
|
||||||
)
|
|
||||||
from app.storage.encryption import reject_if_tampered
|
|
||||||
from app.storage.vector_store import VectorStore
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/storage", tags=["vectors"])
|
|
||||||
|
|
||||||
_vector_store = VectorStore()
|
|
||||||
|
|
||||||
|
|
||||||
class _VectorDeleteRequest(BaseModel):
|
|
||||||
ids: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class _EmbedRequest(BaseModel):
|
|
||||||
text: str
|
|
||||||
|
|
||||||
|
|
||||||
class _EmbedResponse(BaseModel):
|
|
||||||
vector: list[float]
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/vectors/upsert", response_model=dict)
|
|
||||||
async def upsert_vectors(
|
|
||||||
body: VectorUpsertRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, int]:
|
|
||||||
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
|
||||||
for item in body.vectors:
|
|
||||||
reject_if_tampered(item.blob, item.checksum)
|
|
||||||
await _vector_store.upsert(current_user.id, body.vectors)
|
|
||||||
return {"upserted": len(body.vectors)}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
|
||||||
async def search_vectors(
|
|
||||||
body: VectorSearchRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> VectorSearchResponse:
|
|
||||||
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
|
||||||
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
|
||||||
return VectorSearchResponse(results=results)
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/vectors", response_model=dict)
|
|
||||||
async def delete_vectors(
|
|
||||||
body: _VectorDeleteRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Delete vectors by ID, scoped to the authenticated user."""
|
|
||||||
await _vector_store.delete(current_user.id, body.ids)
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/vectors/embed", response_model=_EmbedResponse)
|
|
||||||
async def embed_text(
|
|
||||||
body: _EmbedRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
) -> _EmbedResponse:
|
|
||||||
"""Generate a 1536-dim embedding vector for the given text.
|
|
||||||
|
|
||||||
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
|
||||||
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
|
|
||||||
"""
|
|
||||||
vector = await embed(body.text)
|
|
||||||
return _EmbedResponse(vector=vector)
|
|
||||||
1
app/auth/__init__.py
Normal file
1
app/auth/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"OAuth provider abstractions and utilities."
|
||||||
135
app/auth/oauth_providers.py
Normal file
135
app/auth/oauth_providers.py
Normal file
@@ -0,0 +1,135 @@
|
|||||||
|
"""OAuth 2.0 + PKCE provider abstractions.
|
||||||
|
|
||||||
|
Each provider implements a three-step flow designed for a desktop (public) client:
|
||||||
|
|
||||||
|
1. get_authorization_url(state, code_challenge) → str
|
||||||
|
Build the provider's consent-screen URL. State and code_challenge are
|
||||||
|
generated server-side; the client opens this URL in the system browser.
|
||||||
|
|
||||||
|
2. exchange_code(code, code_verifier, redirect_uri) → dict
|
||||||
|
Exchange the short-lived authorization code for an access token.
|
||||||
|
The code_verifier proves ownership of the PKCE challenge.
|
||||||
|
|
||||||
|
3. get_userinfo(access_token) → OAuthUserInfo
|
||||||
|
Fetch the canonical user identity from the provider.
|
||||||
|
|
||||||
|
Currently supported providers:
|
||||||
|
- GoogleOAuthProvider (scope: openid email profile)
|
||||||
|
|
||||||
|
Adding a new provider:
|
||||||
|
- Implement the three methods above.
|
||||||
|
- Register in _PROVIDERS inside routes/auth.py.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
import os
|
||||||
|
import urllib.parse
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
|
||||||
|
# ── Data transfer objects ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class OAuthUserInfo:
|
||||||
|
"""Normalized user identity returned by any provider."""
|
||||||
|
|
||||||
|
provider_user_id: str
|
||||||
|
email: str
|
||||||
|
email_verified: bool
|
||||||
|
avatar_url: str | None
|
||||||
|
name: str | None
|
||||||
|
|
||||||
|
|
||||||
|
# ── PKCE helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def generate_pkce_pair() -> tuple[str, str]:
|
||||||
|
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
|
||||||
|
|
||||||
|
The code_verifier is a random 32-byte URL-safe base64 string.
|
||||||
|
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
|
||||||
|
"""
|
||||||
|
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
|
||||||
|
digest = hashlib.sha256(code_verifier.encode()).digest()
|
||||||
|
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
|
||||||
|
return code_verifier, code_challenge
|
||||||
|
|
||||||
|
|
||||||
|
# ── Google provider ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class GoogleOAuthProvider:
|
||||||
|
"""Google OAuth 2.0 provider (openid email profile scope).
|
||||||
|
|
||||||
|
Uses Google's standard authorization endpoint with PKCE S256.
|
||||||
|
Does NOT use google-auth-oauthlib to keep the flow generic and async.
|
||||||
|
"""
|
||||||
|
|
||||||
|
name = "google"
|
||||||
|
|
||||||
|
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||||
|
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||||
|
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||||
|
|
||||||
|
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
|
||||||
|
self.client_id = client_id
|
||||||
|
self.client_secret = client_secret
|
||||||
|
self.redirect_uri = redirect_uri
|
||||||
|
|
||||||
|
def get_authorization_url(self, state: str, code_challenge: str) -> str:
|
||||||
|
"""Build the Google consent-screen URL."""
|
||||||
|
params = {
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"redirect_uri": self.redirect_uri,
|
||||||
|
"response_type": "code",
|
||||||
|
"scope": "openid email profile",
|
||||||
|
"state": state,
|
||||||
|
"code_challenge": code_challenge,
|
||||||
|
"code_challenge_method": "S256",
|
||||||
|
"access_type": "offline",
|
||||||
|
"prompt": "select_account",
|
||||||
|
}
|
||||||
|
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||||
|
|
||||||
|
async def exchange_code(
|
||||||
|
self, code: str, code_verifier: str, redirect_uri: str
|
||||||
|
) -> dict:
|
||||||
|
"""Exchange authorization code for an access token."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(
|
||||||
|
self._TOKEN_URL,
|
||||||
|
data={
|
||||||
|
"client_id": self.client_id,
|
||||||
|
"client_secret": self.client_secret,
|
||||||
|
"code": code,
|
||||||
|
"code_verifier": code_verifier,
|
||||||
|
"grant_type": "authorization_code",
|
||||||
|
"redirect_uri": redirect_uri,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
|
||||||
|
"""Fetch the authenticated user's identity from Google."""
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.get(
|
||||||
|
self._USERINFO_URL,
|
||||||
|
headers={"Authorization": f"Bearer {access_token}"},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
data = response.json()
|
||||||
|
|
||||||
|
return OAuthUserInfo(
|
||||||
|
provider_user_id=data["sub"],
|
||||||
|
email=data["email"],
|
||||||
|
email_verified=data.get("email_verified", False),
|
||||||
|
avatar_url=data.get("picture"),
|
||||||
|
name=data.get("name"),
|
||||||
|
)
|
||||||
@@ -43,8 +43,8 @@ class StripeService:
|
|||||||
self,
|
self,
|
||||||
user_id: str,
|
user_id: str,
|
||||||
tier: str,
|
tier: str,
|
||||||
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
cancel_url: str = "https://app.adiuvai.app/billing/cancel",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a Stripe checkout session and return the URL.
|
"""Create a Stripe checkout session and return the URL.
|
||||||
|
|
||||||
@@ -200,6 +200,45 @@ class StripeService:
|
|||||||
sub.status = "canceled"
|
sub.status = "canceled"
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
|
async def list_invoices(
|
||||||
|
self, user_id: str, db: AsyncSession, limit: int = 24
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Return recent invoices for the user from Stripe.
|
||||||
|
|
||||||
|
Returns an empty list when Stripe is not configured or the user has
|
||||||
|
no ``stripe_customer_id``.
|
||||||
|
"""
|
||||||
|
if not self._configured():
|
||||||
|
return []
|
||||||
|
|
||||||
|
from app.models import User # noqa: PLC0415
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(User.stripe_customer_id).where(User.id == user_id)
|
||||||
|
)
|
||||||
|
customer_id = result.scalar_one_or_none()
|
||||||
|
if not customer_id:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
s = self._client()
|
||||||
|
invoices = s.Invoice.list(customer=customer_id, limit=limit)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"id": inv.id,
|
||||||
|
"amount_due": inv.amount_due,
|
||||||
|
"amount_paid": inv.amount_paid,
|
||||||
|
"currency": inv.currency,
|
||||||
|
"status": inv.status,
|
||||||
|
"created": inv.created * 1000, # epoch ms
|
||||||
|
"invoice_url": inv.hosted_invoice_url,
|
||||||
|
"invoice_pdf": inv.invoice_pdf,
|
||||||
|
}
|
||||||
|
for inv in invoices.auto_paging_iter()
|
||||||
|
]
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
# ── Private DB helpers ───────────────────────────────────────────────
|
# ── Private DB helpers ───────────────────────────────────────────────
|
||||||
|
|
||||||
async def _upsert_subscription(
|
async def _upsert_subscription(
|
||||||
|
|||||||
@@ -22,44 +22,32 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
"batch_runs_per_day": 5,
|
"batch_runs_per_day": 5,
|
||||||
"cloud_storage_gb": 0,
|
|
||||||
"backup_gb": 0,
|
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"plugin_marketplace": False,
|
|
||||||
"sso": False,
|
"sso": False,
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1, # unlimited
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
"batch_runs_per_day": 50,
|
"batch_runs_per_day": 50,
|
||||||
"cloud_storage_gb": 5,
|
|
||||||
"backup_gb": 5,
|
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": False,
|
"batch_builder": False,
|
||||||
"plugin_marketplace": False,
|
|
||||||
"sso": False,
|
"sso": False,
|
||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1, # unlimited
|
"batch_active": -1, # unlimited
|
||||||
"batch_runs_per_day": -1, # unlimited
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": 25,
|
|
||||||
"backup_gb": 25,
|
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"plugin_marketplace": True,
|
|
||||||
"sso": False,
|
"sso": False,
|
||||||
},
|
},
|
||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
"batch_runs_per_day": -1, # unlimited
|
"batch_runs_per_day": -1, # unlimited
|
||||||
"cloud_storage_gb": -1, # unlimited
|
|
||||||
"backup_gb": -1, # unlimited
|
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"plugin_marketplace": True,
|
|
||||||
"sso": True,
|
"sso": True,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -125,71 +113,6 @@ class TierManager:
|
|||||||
"""Return the requests-per-minute limit for ``tier``."""
|
"""Return the requests-per-minute limit for ``tier``."""
|
||||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
# ── Storage quota ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def enforce_quota(
|
|
||||||
self,
|
|
||||||
tier: BillingTier,
|
|
||||||
current_bytes: int = 0,
|
|
||||||
additional_bytes: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
|
||||||
|
|
||||||
``tier`` is the caller's current tier (from ``current_user.tier``).
|
|
||||||
``current_bytes`` is the total bytes already stored (queried by caller).
|
|
||||||
"""
|
|
||||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
|
||||||
if limit_gb == 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Cloud storage is not available on the '{tier}' tier",
|
|
||||||
)
|
|
||||||
if limit_gb == -1:
|
|
||||||
return # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024 ** 3
|
|
||||||
if current_bytes + additional_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Storage quota exceeded for tier '{tier}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
def enforce_backup_quota(
|
|
||||||
self,
|
|
||||||
tier: BillingTier,
|
|
||||||
current_bytes: int = 0,
|
|
||||||
additional_bytes: int = 0,
|
|
||||||
) -> None:
|
|
||||||
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
|
||||||
limit_gb: int = FEATURES[tier]["backup_gb"]
|
|
||||||
if limit_gb == 0:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Backup is not available on the '{tier}' tier",
|
|
||||||
)
|
|
||||||
if limit_gb == -1:
|
|
||||||
return # unlimited
|
|
||||||
limit_bytes = limit_gb * 1024 ** 3
|
|
||||||
if current_bytes + additional_bytes > limit_bytes:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Backup quota exceeded for tier '{tier}'",
|
|
||||||
)
|
|
||||||
|
|
||||||
def check_quota(
|
|
||||||
self,
|
|
||||||
tier: BillingTier,
|
|
||||||
current_bytes: int = 0,
|
|
||||||
additional_bytes: int = 0,
|
|
||||||
) -> bool:
|
|
||||||
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
|
||||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
|
||||||
if limit_gb == 0:
|
|
||||||
return False
|
|
||||||
if limit_gb == -1:
|
|
||||||
return True
|
|
||||||
limit_bytes = limit_gb * 1024 ** 3
|
|
||||||
return current_bytes + additional_bytes <= limit_bytes
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton shared across the app.
|
# Module-level singleton shared across the app.
|
||||||
tier_manager = TierManager()
|
tier_manager = TierManager()
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
|
|||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
class Settings(BaseSettings):
|
||||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
|
||||||
JWT_SECRET: str = "change-me-in-production"
|
JWT_SECRET: str = "change-me-in-production"
|
||||||
JWT_ALGORITHM: str = "HS256"
|
JWT_ALGORITHM: str = "HS256"
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
@@ -12,17 +12,6 @@ class Settings(BaseSettings):
|
|||||||
STRIPE_SECRET_KEY: str = ""
|
STRIPE_SECRET_KEY: str = ""
|
||||||
STRIPE_WEBHOOK_SECRET: str = ""
|
STRIPE_WEBHOOK_SECRET: str = ""
|
||||||
|
|
||||||
S3_BUCKET: str = ""
|
|
||||||
S3_REGION: str = "us-east-1"
|
|
||||||
S3_ENDPOINT_URL: str = ""
|
|
||||||
AWS_ACCESS_KEY_ID: str = ""
|
|
||||||
AWS_SECRET_ACCESS_KEY: str = ""
|
|
||||||
|
|
||||||
PINECONE_API_KEY: str = ""
|
|
||||||
PINECONE_INDEX: str = "adiuva"
|
|
||||||
QDRANT_URL: str = ""
|
|
||||||
QDRANT_API_KEY: str = ""
|
|
||||||
|
|
||||||
OPENAI_API_KEY: str = ""
|
OPENAI_API_KEY: str = ""
|
||||||
ANTHROPIC_API_KEY: str = ""
|
ANTHROPIC_API_KEY: str = ""
|
||||||
GOOGLE_API_KEY: str = ""
|
GOOGLE_API_KEY: str = ""
|
||||||
@@ -31,6 +20,14 @@ class Settings(BaseSettings):
|
|||||||
LLM_MODEL: str = "gpt-4o"
|
LLM_MODEL: str = "gpt-4o"
|
||||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
|
||||||
|
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
|
||||||
|
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
|
||||||
|
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
|
||||||
|
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
|
||||||
|
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
|
||||||
|
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
|
||||||
|
|
||||||
# GitHub Copilot OAuth token storage directory.
|
# GitHub Copilot OAuth token storage directory.
|
||||||
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||||
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||||
@@ -44,18 +41,37 @@ class Settings(BaseSettings):
|
|||||||
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||||
MS_TENANT_ID: str = "common"
|
MS_TENANT_ID: str = "common"
|
||||||
|
|
||||||
|
# Google Login OAuth credentials — scope: openid email profile.
|
||||||
|
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
|
||||||
|
GOOGLE_AUTH_CLIENT_ID: str = ""
|
||||||
|
GOOGLE_AUTH_CLIENT_SECRET: str = ""
|
||||||
|
# The redirect URI registered in Google Cloud Console.
|
||||||
|
# Google redirects here after consent; this backend route then bounces to
|
||||||
|
# the adiuvai:// deep link so the Electron app receives the code.
|
||||||
|
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
|
||||||
|
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
||||||
|
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
||||||
|
|
||||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
OAUTH_ENCRYPTION_KEY: str = ""
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
CORS_ORIGINS: list[str] = [
|
||||||
|
"app://.",
|
||||||
|
"http://localhost:3000",
|
||||||
|
"http://localhost:5173",
|
||||||
|
"http://localhost:4173", # Vite preview (web SPA)
|
||||||
|
"https://app.adiuvai.com", # Production web portal
|
||||||
|
]
|
||||||
|
|
||||||
|
LANGFUSE_SECRET_KEY: str = ""
|
||||||
|
LANGFUSE_PUBLIC_KEY: str = ""
|
||||||
|
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
|
||||||
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
env_file=".env", env_file_encoding="utf-8", extra="ignore"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
settings = Settings()
|
||||||
|
|||||||
@@ -2,12 +2,12 @@
|
|||||||
|
|
||||||
Drives two agent types:
|
Drives two agent types:
|
||||||
|
|
||||||
* **Local directory agent** — two-step execution per file:
|
* **Local directory agent** — V2 unified flow per file:
|
||||||
Step 1 (Classification) uses code to fetch all projects and asks the LLM
|
Phase A (Detect + Preprocess, zero LLM): Python detects the content type
|
||||||
to identify which project the file belongs to and which domains are relevant.
|
and strips markup/noise, producing clean text + metadata.
|
||||||
Step 2 (Processing) fetches existing entities for that project/domains via
|
Phase B (Single LLM call with tools): the LLM identifies the project,
|
||||||
code and runs an LLM with tools — existing data in context enforces
|
checks for duplicates via list_* tools, and creates/updates records.
|
||||||
update-first naturally.
|
``items_created`` is counted from ``create_*`` tool calls.
|
||||||
|
|
||||||
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
||||||
Teams, Outlook) and pushes extracted items to Electron.
|
Teams, Outlook) and pushes extracted items to Electron.
|
||||||
@@ -29,7 +29,7 @@ from __future__ import annotations
|
|||||||
import asyncio
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import uuid
|
import os
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -43,7 +43,9 @@ from app.agents.project_agent import PROJECT_TOOLS
|
|||||||
from app.agents.task_agent import TASK_TOOLS
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
from app.core.device_manager import DeviceConnectionManager
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
from app.core.llm import get_llm
|
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
|
from app.core.preprocessors import detect_content_type, preprocess
|
||||||
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
|
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
@@ -70,97 +72,52 @@ _MAX_PROCESSING_STEPS: int = 12
|
|||||||
_MAX_SCAN_DEPTH: int = 5
|
_MAX_SCAN_DEPTH: int = 5
|
||||||
|
|
||||||
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
||||||
# NOTE: "projects" is intentionally excluded — project creation/assignment is
|
|
||||||
# handled in code by the runner, never delegated to the Step 2 LLM.
|
|
||||||
|
|
||||||
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
||||||
"tasks": TASK_TOOLS,
|
"tasks": TASK_TOOLS,
|
||||||
"notes": NOTE_TOOLS,
|
"notes": NOTE_TOOLS,
|
||||||
"timelines": TIMELINE_TOOLS,
|
"timelines": TIMELINE_TOOLS,
|
||||||
|
"timelineEvents": TIMELINE_TOOLS,
|
||||||
|
"projects": PROJECT_TOOLS,
|
||||||
}
|
}
|
||||||
|
|
||||||
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
# ── V2: Unified processing prompt (hot-swappable via Langfuse "unified_processing") ──
|
||||||
|
|
||||||
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
_UNIFIED_PROCESSING_PROMPT = """\
|
||||||
"tasks": (
|
|
||||||
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
|
||||||
"assigned to someone, or tracked with a due date or status."
|
|
||||||
),
|
|
||||||
"notes": (
|
|
||||||
"Documentation, meeting notes, summaries, reference material — "
|
|
||||||
"written content meant to be read and referenced rather than acted on."
|
|
||||||
),
|
|
||||||
"timelines": (
|
|
||||||
"Project milestones, deadlines, scheduled events — "
|
|
||||||
"specific dates that mark a point in the progress of a project."
|
|
||||||
),
|
|
||||||
"projects": (
|
|
||||||
"High-level project entities — only relevant if the file clearly introduces "
|
|
||||||
"a new project or updates the scope of an existing one."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
_STEP1_SYSTEM_PROMPT = """\
|
|
||||||
You are a file classifier for a freelance project management tool.
|
|
||||||
|
|
||||||
Your job is to match a file to an existing project and identify which data domains to extract.
|
|
||||||
|
|
||||||
## Project matching rules (STRICT — follow in order)
|
|
||||||
|
|
||||||
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
|
||||||
that overlaps with the existing projects listed below.
|
|
||||||
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
|
||||||
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
|
||||||
when the file has zero meaningful connection to any listed project.
|
|
||||||
4. When in doubt, pick the closest match from the list.
|
|
||||||
|
|
||||||
## Response format
|
|
||||||
|
|
||||||
Respond ONLY with a JSON object — no markdown, no explanation:
|
|
||||||
|
|
||||||
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
|
||||||
|
|
||||||
## Domain definitions (only consider domains in the allowed list)
|
|
||||||
|
|
||||||
{domain_definitions}
|
|
||||||
|
|
||||||
## Existing projects
|
|
||||||
|
|
||||||
{projects_list}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
_PROCESSING_SYSTEM_PROMPT = """\
|
|
||||||
You are a data extraction assistant for a freelance project management tool.
|
You are a data extraction assistant for a freelance project management tool.
|
||||||
|
|
||||||
Your task: extract structured data from the file content and persist it using the available tools.
|
## Your process (follow this exact order)
|
||||||
|
|
||||||
## Mandatory process — follow this order for EVERY item you extract
|
### 1. Identify the project
|
||||||
|
File: {filename}
|
||||||
|
{metadata_section}
|
||||||
|
|
||||||
1. READ the existing records listed below for the relevant domain.
|
Existing projects:
|
||||||
2. SEARCH for a match by title, topic, or semantic similarity.
|
{projects_list}
|
||||||
3. If a match exists → call the update_* tool with the existing record's id.
|
|
||||||
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
|
||||||
|
|
||||||
NEVER call create_* without first checking the existing records.
|
Match this file to an existing project using the filename and content clues.
|
||||||
NEVER duplicate a record that already exists under a different wording.
|
If no project matches, {no_match_behavior}.
|
||||||
|
|
||||||
## Existing records (source of truth)
|
### 2. Check existing records
|
||||||
|
Once you identify the project, use list_tasks / list_notes / list_timelines
|
||||||
|
(filtered by projectId) to see what already exists.
|
||||||
|
NEVER create a record that already exists under the same or similar title.
|
||||||
|
|
||||||
{existing_context}
|
### 3. Extract and create / update
|
||||||
|
{extraction_rules}
|
||||||
|
|
||||||
## Context
|
### Rules
|
||||||
|
- Set isAiSuggested=1 on every new record.
|
||||||
Project: {project_context}
|
- Set projectId on every record (use the id from the project list above).
|
||||||
Domains to extract: {data_types}
|
- Update existing records when a match is found by title or topic.
|
||||||
|
- Do NOT invent data — only extract what is clearly stated in the content.
|
||||||
{custom_prompt_section}
|
- Target entity types: {data_types}.
|
||||||
|
{global_rules}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# ── Cloud processing prompt (kept separate for cloud agent) ───────────────
|
# ── Cloud processing prompt (kept separate for cloud agent) ───────────────
|
||||||
|
|
||||||
_CLOUD_PROCESSING_PROMPT = """\
|
_BATCH_CLOUD_PROCESSING_PROMPT = """\
|
||||||
You are a data extraction and management assistant for a freelance project
|
You are a data extraction and management assistant for a freelance project
|
||||||
management tool.
|
management tool.
|
||||||
|
|
||||||
@@ -268,9 +225,19 @@ async def _run_agent_with_tools(
|
|||||||
user_message: str,
|
user_message: str,
|
||||||
tools: list[Any],
|
tools: list[Any],
|
||||||
max_steps: int,
|
max_steps: int,
|
||||||
|
user_id: str = "",
|
||||||
|
session_id: str = "",
|
||||||
|
langfuse_prompt: Any = None,
|
||||||
|
agent_name: str = "batch-agent",
|
||||||
|
_tool_calls_out: list[str] | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Run an LLM agent with tool-calling, returning the final text response."""
|
"""Run an LLM agent with tool-calling, returning the final text response.
|
||||||
llm = get_llm()
|
|
||||||
|
If *_tool_calls_out* is provided, the name of every tool called during the
|
||||||
|
run is appended to it (used by the caller to count ``create_*`` calls).
|
||||||
|
"""
|
||||||
|
lf = get_langfuse()
|
||||||
|
llm = get_agent_llm(agent_name)
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
llm_with_tools = llm.bind_tools(tools)
|
||||||
messages: list[Any] = [
|
messages: list[Any] = [
|
||||||
SystemMessage(content=system_prompt),
|
SystemMessage(content=system_prompt),
|
||||||
@@ -279,12 +246,45 @@ async def _run_agent_with_tools(
|
|||||||
|
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
|
|
||||||
|
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
|
||||||
|
_lf_ctx.__enter__()
|
||||||
|
|
||||||
|
_span_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name=agent_name,
|
||||||
|
metadata={"user_id": user_id} if user_id else None,
|
||||||
|
input=user_message,
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||||
|
|
||||||
|
try:
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
|
_gen_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name=f"{agent_name}-llm",
|
||||||
|
model=model_for_agent(agent_name),
|
||||||
|
prompt=langfuse_prompt,
|
||||||
|
input=messages,
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
if _gen_ctx:
|
||||||
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
messages.append(response)
|
messages.append(response)
|
||||||
|
|
||||||
if not response.tool_calls:
|
if not response.tool_calls:
|
||||||
return _as_text(response.content)
|
final_text = _as_text(response.content)
|
||||||
|
if _span:
|
||||||
|
_span.update(output=final_text)
|
||||||
|
return final_text
|
||||||
|
|
||||||
for call in response.tool_calls:
|
for call in response.tool_calls:
|
||||||
call_id = str(call.get("id", ""))
|
call_id = str(call.get("id", ""))
|
||||||
@@ -296,6 +296,9 @@ async def _run_agent_with_tools(
|
|||||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
json.dumps(call_args, ensure_ascii=True)[:800],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if _tool_calls_out is not None:
|
||||||
|
_tool_calls_out.append(call_name)
|
||||||
|
|
||||||
tool_fn = tool_map.get(call_name)
|
tool_fn = tool_map.get(call_name)
|
||||||
if tool_fn is None:
|
if tool_fn is None:
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
tool_output = f"Unknown tool: {call_name}"
|
||||||
@@ -310,7 +313,16 @@ async def _run_agent_with_tools(
|
|||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
||||||
|
|
||||||
final = await llm.ainvoke(messages)
|
final = await llm.ainvoke(messages)
|
||||||
return _as_text(final.content)
|
final_text = _as_text(final.content)
|
||||||
|
if _span:
|
||||||
|
_span.update(output=final_text)
|
||||||
|
return final_text
|
||||||
|
finally:
|
||||||
|
if _span_ctx:
|
||||||
|
_span_ctx.__exit__(None, None, None)
|
||||||
|
_lf_ctx.__exit__(None, None, None)
|
||||||
|
if lf:
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
|
||||||
# ── Tool list builder ─────────────────────────────────────────────────────
|
# ── Tool list builder ─────────────────────────────────────────────────────
|
||||||
@@ -377,7 +389,8 @@ async def _scan_directories(
|
|||||||
for file_path in all_files:
|
for file_path in all_files:
|
||||||
try:
|
try:
|
||||||
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
||||||
modified_at = meta.get("modifiedAt")
|
# FE sends snake_case keys on the wire (toSnakeCase transform)
|
||||||
|
modified_at = meta.get("modified_at") or meta.get("modifiedAt")
|
||||||
if modified_at is None:
|
if modified_at is None:
|
||||||
filtered.append(file_path)
|
filtered.append(file_path)
|
||||||
continue
|
continue
|
||||||
@@ -479,83 +492,66 @@ def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
|||||||
return f"Existing {domain}:\n" + "\n".join(lines)
|
return f"Existing {domain}:\n" + "\n".join(lines)
|
||||||
|
|
||||||
|
|
||||||
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
# ── V2 helper functions ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def _classify_file(
|
def _format_projects(projects: list[dict]) -> str:
|
||||||
file_path: str,
|
"""Format the project list for the unified system prompt."""
|
||||||
file_content: str,
|
if not projects:
|
||||||
projects: list[dict],
|
return " (no projects yet)"
|
||||||
config_data_types: list[str],
|
lines: list[str] = []
|
||||||
) -> tuple[str, list[str], str | None]:
|
for p in projects:
|
||||||
"""Call the LLM to classify a file by project and relevant domains.
|
|
||||||
|
|
||||||
Returns ``(project_id_or_"new", domains, new_project_name_or_None)``.
|
|
||||||
- ``project_id`` is an existing project UUID, or ``"new"`` when no match found.
|
|
||||||
- ``new_project_name`` is only set when ``project_id == "new"``.
|
|
||||||
Falls back to ``("new", config_data_types, None)`` on any error.
|
|
||||||
"""
|
|
||||||
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
|
||||||
|
|
||||||
if not file_content.strip():
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
valid_project_ids = {p["id"] for p in projects}
|
|
||||||
|
|
||||||
def _fmt_project(p: dict) -> str:
|
|
||||||
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
||||||
summary_part = f" — {summary[:100]}" if summary else ""
|
summary_part = f" — {summary[:100]}" if summary else ""
|
||||||
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
lines.append(
|
||||||
|
f" - id={p['id']} | name={p.get('name', '')} | "
|
||||||
|
f"status={p.get('status', '')}{summary_part}"
|
||||||
|
)
|
||||||
|
return "\n".join(lines)
|
||||||
|
|
||||||
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
|
||||||
|
|
||||||
domain_definitions = "\n".join(
|
def _format_metadata(metadata: dict) -> str:
|
||||||
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
"""Format preprocessor metadata as a compact context block."""
|
||||||
for d in config_data_types
|
if not metadata:
|
||||||
if d in _DOMAIN_DESCRIPTIONS
|
return ""
|
||||||
|
parts: list[str] = []
|
||||||
|
for key in ("subject", "from", "to", "date"):
|
||||||
|
if metadata.get(key):
|
||||||
|
parts.append(f"{key.capitalize()}: {metadata[key]}")
|
||||||
|
# any remaining keys
|
||||||
|
for key, val in metadata.items():
|
||||||
|
if key not in ("subject", "from", "to", "date") and val:
|
||||||
|
parts.append(f"{key}: {val}")
|
||||||
|
return "\n".join(parts)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_extraction_rules(agent_config: dict, content_type: str) -> str:
|
||||||
|
"""Return the extraction_prompt for *content_type* from *agent_config*.
|
||||||
|
|
||||||
|
Falls back to a generic instruction when the type is not configured.
|
||||||
|
"""
|
||||||
|
for ct in agent_config.get("content_types", []):
|
||||||
|
if ct.get("id") == content_type:
|
||||||
|
prompt = ct.get("extraction_prompt", "").strip()
|
||||||
|
if prompt:
|
||||||
|
return prompt
|
||||||
|
return (
|
||||||
|
"Extract relevant information as tasks (action items), notes "
|
||||||
|
"(informational content), or timelines (dated events)."
|
||||||
)
|
)
|
||||||
|
|
||||||
system = _STEP1_SYSTEM_PROMPT.format(
|
|
||||||
domain_definitions=domain_definitions,
|
|
||||||
projects_list=projects_list,
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = get_llm()
|
def _get_no_match_behavior(agent_config: dict) -> str:
|
||||||
try:
|
"""Derive the 'no project match' instruction from global_rules."""
|
||||||
response = await llm.ainvoke([
|
rules = agent_config.get("global_rules", [])
|
||||||
SystemMessage(content=system),
|
for rule in rules:
|
||||||
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
lower = rule.lower()
|
||||||
])
|
if "no project" in lower or "no match" in lower or "skip" in lower:
|
||||||
raw = _as_text(response.content).strip()
|
return rule
|
||||||
# Strip markdown fences if the model wraps the JSON.
|
return "create a new project with a concise name derived from the file content"
|
||||||
if raw.startswith("```"):
|
|
||||||
raw = raw.split("```")[1]
|
|
||||||
if raw.startswith("json"):
|
|
||||||
raw = raw[4:]
|
|
||||||
parsed = json.loads(raw.strip())
|
|
||||||
raw_project_id: str = str(parsed.get("project_id") or "new")
|
|
||||||
# Reject hallucinated UUIDs — only accept ids that exist in the fetched list.
|
|
||||||
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
|
||||||
new_project_name: str | None = (
|
|
||||||
str(parsed["new_project_name"]).strip() or None
|
|
||||||
if project_id == "new" and parsed.get("new_project_name")
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
domains: list[str] = [
|
|
||||||
d for d in parsed.get("domains", [])
|
|
||||||
if d in config_data_types
|
|
||||||
]
|
|
||||||
if not domains:
|
|
||||||
domains = list(config_data_types)
|
|
||||||
return project_id, domains, new_project_name
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
|
||||||
)
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent runner (two-step per file) ────────────────────────────────
|
# ── Local agent runner (V2 — unified per-file flow) ───────────────────────
|
||||||
|
|
||||||
|
|
||||||
async def run_local_agent(
|
async def run_local_agent(
|
||||||
@@ -565,16 +561,17 @@ async def run_local_agent(
|
|||||||
device_mgr: DeviceConnectionManager,
|
device_mgr: DeviceConnectionManager,
|
||||||
run_context: dict | None = None,
|
run_context: dict | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Execute a local directory agent run using a two-step approach per file.
|
"""Execute a local directory agent run — V2 unified flow.
|
||||||
|
|
||||||
Step 1 — Classification (code + 1 LLM call per file, no tools):
|
Phase A — Detect + Preprocess (zero LLM, per file):
|
||||||
Code scans directories and fetches all projects via WS.
|
Python detects the content type from filename + content patterns and
|
||||||
For each file, LLM identifies the project and relevant domains.
|
runs the appropriate handler (e.g. email_html) to produce clean text
|
||||||
|
and structured metadata.
|
||||||
|
|
||||||
Step 2 — Processing (code + 1 LLM call per file, with tools):
|
Phase B — Single LLM call with tools (per file):
|
||||||
Code fetches existing entities for the identified project/domains.
|
One LLM call handles project identification, duplicate checking, and
|
||||||
LLM receives file content + existing entities in context and uses
|
record creation/update. ``create_*`` tool calls are counted to
|
||||||
tools to update existing records or create new ones.
|
produce the accurate ``items_created`` metric.
|
||||||
"""
|
"""
|
||||||
run_id = run_log.id
|
run_id = run_log.id
|
||||||
agent_id = (run_context or {}).get("agent_id") or config.id
|
agent_id = (run_context or {}).get("agent_id") or config.id
|
||||||
@@ -609,16 +606,11 @@ async def run_local_agent(
|
|||||||
errors: list[str] = []
|
errors: list[str] = []
|
||||||
items_processed = 0
|
items_processed = 0
|
||||||
items_created = 0
|
items_created = 0
|
||||||
|
agent_config: dict = config.agent_config or {}
|
||||||
custom_section = (
|
processing_tools = _build_processing_tools(config.data_types)
|
||||||
f"User instructions:\n{config.prompt_template}"
|
|
||||||
if config.prompt_template
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# ── Code: scan directories ───────────────────────────────────
|
# ── Code: scan directories ───────────────────────────────────
|
||||||
logger.info("agent_runner: run=%s scanning directories user=%s", run_id, user_id)
|
|
||||||
file_paths = await _scan_directories(
|
file_paths = await _scan_directories(
|
||||||
paths=config.directory_paths,
|
paths=config.directory_paths,
|
||||||
extensions=config.file_extensions or [],
|
extensions=config.file_extensions or [],
|
||||||
@@ -634,108 +626,89 @@ async def run_local_agent(
|
|||||||
|
|
||||||
# ── Code: fetch all projects once ────────────────────────────
|
# ── Code: fetch all projects once ────────────────────────────
|
||||||
projects = await _fetch_projects()
|
projects = await _fetch_projects()
|
||||||
|
projects_block = _format_projects(projects)
|
||||||
|
|
||||||
|
# Prompt template + Langfuse version linking (hot-swappable from UI).
|
||||||
|
unified_template, prompt_obj = get_prompt_or_fallback(
|
||||||
|
"unified_processing", _UNIFIED_PROCESSING_PROMPT
|
||||||
|
)
|
||||||
|
|
||||||
for file_path in file_paths:
|
for file_path in file_paths:
|
||||||
try:
|
try:
|
||||||
# Read file content via code.
|
# ── Phase A: read + detect + preprocess ─────────────
|
||||||
file_result = await execute_on_client(
|
file_result = await execute_on_client(
|
||||||
action="read_file_content", data={"path": file_path}
|
action="read_file_content", data={"path": file_path}
|
||||||
)
|
)
|
||||||
file_content: str = file_result.get("content", "")
|
raw_content: str = file_result.get("content", "")
|
||||||
if not file_content:
|
if not raw_content.strip():
|
||||||
logger.debug("agent_runner: run=%s skipping empty file %r", run_id, file_path)
|
logger.debug(
|
||||||
|
"agent_runner: run=%s skipping empty file %r", run_id, file_path
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
items_processed += 1
|
items_processed += 1
|
||||||
|
filename = os.path.basename(file_path)
|
||||||
|
content_type = detect_content_type(filename, raw_content)
|
||||||
|
preprocessed = preprocess(content_type, raw_content)
|
||||||
|
|
||||||
# Step 1 — classify file.
|
|
||||||
project_id, domains, new_project_name = await _classify_file(
|
|
||||||
file_path=file_path,
|
|
||||||
file_content=file_content,
|
|
||||||
projects=projects,
|
|
||||||
config_data_types=config.data_types,
|
|
||||||
)
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s file=%r → project=%s new_name=%r domains=%s",
|
"agent_runner: run=%s file=%r content_type=%s clean_len=%d",
|
||||||
run_id,
|
run_id, file_path, content_type, len(preprocessed.clean_text),
|
||||||
file_path,
|
|
||||||
project_id,
|
|
||||||
new_project_name,
|
|
||||||
domains,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 2 — resolve project_id via CODE, then fetch entities.
|
# ── Phase B: single LLM call ─────────────────────────
|
||||||
# Project creation is NEVER delegated to the Step 2 LLM.
|
extraction_rules = _get_extraction_rules(agent_config, content_type)
|
||||||
if project_id == "new":
|
no_match_behavior = _get_no_match_behavior(agent_config)
|
||||||
proj_name = new_project_name or "Untitled Project"
|
global_rules_lines = "\n".join(
|
||||||
try:
|
f"- {r}" for r in agent_config.get("global_rules", [])
|
||||||
proj_result = await execute_on_client(
|
|
||||||
action="insert",
|
|
||||||
table="projects",
|
|
||||||
data={"name": proj_name, "clientId": None},
|
|
||||||
)
|
)
|
||||||
created = proj_result.get("row", {})
|
metadata_section = _format_metadata(preprocessed.metadata)
|
||||||
effective_project_id = created.get("id", "standalone")
|
|
||||||
# Add to local list so subsequent files can match it.
|
system_prompt = compile_prompt(
|
||||||
if "id" in created:
|
unified_template,
|
||||||
projects.append(created)
|
prompt_obj,
|
||||||
logger.info(
|
filename=filename,
|
||||||
"agent_runner: run=%s created project %r id=%s",
|
metadata_section=metadata_section,
|
||||||
run_id, proj_name, effective_project_id,
|
projects_list=projects_block,
|
||||||
)
|
no_match_behavior=no_match_behavior,
|
||||||
except Exception as exc:
|
extraction_rules=extraction_rules,
|
||||||
logger.warning(
|
global_rules=global_rules_lines,
|
||||||
"agent_runner: run=%s failed to create project %r: %s",
|
data_types=", ".join(config.data_types),
|
||||||
run_id, proj_name, exc,
|
|
||||||
)
|
|
||||||
effective_project_id = "standalone"
|
|
||||||
proj_name = "unknown"
|
|
||||||
project_context = (
|
|
||||||
f"Project: {proj_name} (id: {effective_project_id}). "
|
|
||||||
"Always set projectId to this id on every record you create."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
effective_project_id = project_id
|
|
||||||
proj = next((p for p in projects if p["id"] == project_id), None)
|
|
||||||
proj_name = proj.get("name", project_id) if proj else project_id
|
|
||||||
project_context = (
|
|
||||||
f"Project: {proj_name} (id: {project_id}). "
|
|
||||||
"Always set projectId to this id on every record you create."
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# "projects" domain is never passed to Step 2 — handled above in code.
|
|
||||||
domains = [d for d in domains if d != "projects"]
|
|
||||||
|
|
||||||
existing_blocks: list[str] = []
|
|
||||||
for domain in domains:
|
|
||||||
rows = await _fetch_domain_entities(domain, effective_project_id)
|
|
||||||
existing_blocks.append(_format_entities_for_context(domain, rows))
|
|
||||||
|
|
||||||
existing_context = "\n\n".join(existing_blocks)
|
|
||||||
|
|
||||||
system_prompt = _PROCESSING_SYSTEM_PROMPT.format(
|
|
||||||
existing_context=existing_context,
|
|
||||||
project_context=project_context,
|
|
||||||
data_types=", ".join(domains),
|
|
||||||
custom_prompt_section=custom_section,
|
|
||||||
)
|
|
||||||
|
|
||||||
processing_tools = _build_processing_tools(domains)
|
|
||||||
|
|
||||||
result_text = await _run_agent_with_tools(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
user_message = (
|
user_message = (
|
||||||
f"Process this file and extract relevant information.\n\n"
|
f"Process this file and extract relevant information.\n\n"
|
||||||
f"File: {file_path}\n\nContent:\n{file_content}"
|
f"File: {file_path}\n\n"
|
||||||
),
|
f"Content:\n{preprocessed.clean_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
file_tool_calls: list[str] = []
|
||||||
|
result_text = await _run_agent_with_tools(
|
||||||
|
system_prompt=system_prompt,
|
||||||
|
user_message=user_message,
|
||||||
tools=processing_tools,
|
tools=processing_tools,
|
||||||
max_steps=_MAX_PROCESSING_STEPS,
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=run_id,
|
||||||
|
langfuse_prompt=prompt_obj,
|
||||||
|
agent_name="unified-processor",
|
||||||
|
_tool_calls_out=file_tool_calls,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
file_created = sum(
|
||||||
|
1 for name in file_tool_calls if name.startswith("create_")
|
||||||
|
)
|
||||||
|
items_created += file_created
|
||||||
|
|
||||||
|
# Refresh project list when a project was created so
|
||||||
|
# subsequent files see it in the prompt context.
|
||||||
|
if "create_project" in file_tool_calls:
|
||||||
|
projects = await _fetch_projects()
|
||||||
|
projects_block = _format_projects(projects)
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s file=%r result=%s",
|
"agent_runner: run=%s file=%r created=%d result=%s",
|
||||||
run_id,
|
run_id, file_path, file_created, result_text[:200],
|
||||||
file_path,
|
|
||||||
result_text[:200],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
@@ -767,10 +740,11 @@ async def run_local_agent(
|
|||||||
errors=errors,
|
errors=errors,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info(
|
||||||
"agent_runner: run=%s done status=%s processed=%d errors=%d",
|
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
run_id,
|
run_id,
|
||||||
final_status,
|
final_status,
|
||||||
items_processed,
|
items_processed,
|
||||||
|
items_created,
|
||||||
len(errors),
|
len(errors),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -928,7 +902,12 @@ async def run_cloud_agent(
|
|||||||
continue
|
continue
|
||||||
items_processed += 1
|
items_processed += 1
|
||||||
|
|
||||||
processing_prompt = _CLOUD_PROCESSING_PROMPT.format(
|
cloud_template, cloud_prompt_obj = get_prompt_or_fallback(
|
||||||
|
"batch_cloud_processing", _BATCH_CLOUD_PROCESSING_PROMPT
|
||||||
|
)
|
||||||
|
processing_prompt = compile_prompt(
|
||||||
|
cloud_template,
|
||||||
|
cloud_prompt_obj,
|
||||||
data_types=", ".join(config.data_types),
|
data_types=", ".join(config.data_types),
|
||||||
project_context="Determine the appropriate project from the message context.",
|
project_context="Determine the appropriate project from the message context.",
|
||||||
file_list=f"Message from {config.provider} (id: {msg.id})",
|
file_list=f"Message from {config.provider} (id: {msg.id})",
|
||||||
@@ -941,6 +920,10 @@ async def run_cloud_agent(
|
|||||||
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
||||||
tools=processing_tools,
|
tools=processing_tools,
|
||||||
max_steps=_MAX_PROCESSING_STEPS,
|
max_steps=_MAX_PROCESSING_STEPS,
|
||||||
|
user_id=user_id,
|
||||||
|
session_id=run_id,
|
||||||
|
langfuse_prompt=cloud_prompt_obj,
|
||||||
|
agent_name="cloud-processor",
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ from app.agents.note_agent import NOTE_TOOLS
|
|||||||
from app.agents.project_agent import PROJECT_TOOLS
|
from app.agents.project_agent import PROJECT_TOOLS
|
||||||
from app.agents.task_agent import TASK_TOOLS
|
from app.agents.task_agent import TASK_TOOLS
|
||||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
from app.agents.timeline_agent import TIMELINE_TOOLS
|
||||||
from app.core.llm import get_llm
|
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
|
||||||
|
from app.core.llm import get_agent_llm, model_for_agent
|
||||||
from app.core.memory_middleware import MemoryMiddleware
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
@@ -26,7 +27,35 @@ logger = logging.getLogger(__name__)
|
|||||||
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
||||||
FloatingDomainSection = Literal["task", "timeline", "note"]
|
FloatingDomainSection = Literal["task", "timeline", "note"]
|
||||||
|
|
||||||
_HOME_SINGLE_AGENT_SYSTEM = (
|
# Mapping of core-memory language values to natural-language names for prompts.
|
||||||
|
_LANGUAGE_NAMES: dict[str, str] = {
|
||||||
|
"en": "English", "it": "Italian", "es": "Spanish",
|
||||||
|
"fr": "French", "de": "German",
|
||||||
|
"english": "English", "italian": "Italian", "italiano": "Italian",
|
||||||
|
"spanish": "Spanish", "español": "Spanish",
|
||||||
|
"french": "French", "français": "French",
|
||||||
|
"german": "German", "deutsch": "German",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _language_instruction(context: dict[str, Any]) -> str:
|
||||||
|
"""Return a system-prompt suffix that tells the LLM to respond in the user's language.
|
||||||
|
|
||||||
|
Returns an empty string when the language is English or unknown — saves tokens.
|
||||||
|
"""
|
||||||
|
core = context.get("core_memory") or {}
|
||||||
|
raw = (core.get("language") or "").strip().lower()
|
||||||
|
if not raw:
|
||||||
|
return ""
|
||||||
|
lang = _LANGUAGE_NAMES.get(raw, raw.title()) # best-effort capitalisation
|
||||||
|
if lang.lower() == "english":
|
||||||
|
return ""
|
||||||
|
return (
|
||||||
|
f"\n\nIMPORTANT: Always respond in {lang}. "
|
||||||
|
f"All your output text must be written in {lang}."
|
||||||
|
)
|
||||||
|
|
||||||
|
_HOME_SYSTEM_PROMPT = (
|
||||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
"Always use tools for factual data retrieval before answering. "
|
"Always use tools for factual data retrieval before answering. "
|
||||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
||||||
@@ -39,7 +68,7 @@ _HOME_SINGLE_AGENT_SYSTEM = (
|
|||||||
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
||||||
)
|
)
|
||||||
|
|
||||||
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
_FLOATING_SYSTEM_PROMPT = (
|
||||||
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
||||||
"Stay focused on the floating scope in context.scope and answer concisely. "
|
"Stay focused on the floating scope in context.scope and answer concisely. "
|
||||||
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
||||||
@@ -48,7 +77,7 @@ _FLOATING_SINGLE_AGENT_SYSTEM = (
|
|||||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
||||||
)
|
)
|
||||||
|
|
||||||
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
|
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
|
||||||
"You are a strict domain classifier for websocket floating requests. "
|
"You are a strict domain classifier for websocket floating requests. "
|
||||||
"Return ONLY a JSON object with keys: type, id, section. "
|
"Return ONLY a JSON object with keys: type, id, section. "
|
||||||
"Allowed type values: task, timeline, project, node. "
|
"Allowed type values: task, timeline, project, node. "
|
||||||
@@ -147,6 +176,15 @@ def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _session_id_from_context(context: dict[str, Any]) -> str | None:
|
||||||
|
debug = context.get("_debug")
|
||||||
|
if isinstance(debug, dict):
|
||||||
|
session_id = debug.get("session_id")
|
||||||
|
if isinstance(session_id, str) and session_id:
|
||||||
|
return session_id
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
||||||
sanitized = dict(context)
|
sanitized = dict(context)
|
||||||
sanitized.pop("_debug", None)
|
sanitized.pop("_debug", None)
|
||||||
@@ -535,10 +573,9 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
|
|||||||
}
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
llm = get_llm()
|
llm = get_agent_llm("classifier")
|
||||||
response = await llm.ainvoke(
|
classifier_messages = [
|
||||||
[
|
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
|
||||||
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM),
|
|
||||||
HumanMessage(
|
HumanMessage(
|
||||||
content=(
|
content=(
|
||||||
f"Message:\n{message}\n\n"
|
f"Message:\n{message}\n\n"
|
||||||
@@ -546,7 +583,29 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
|
|||||||
)
|
)
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
lf = get_langfuse()
|
||||||
|
_, classifier_prompt_obj = get_prompt_or_fallback(
|
||||||
|
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Extract user/session from context for Langfuse attribution
|
||||||
|
_debug = context.get("_debug") if isinstance(context, dict) else None
|
||||||
|
_lf_user = (_debug or {}).get("user_id") if isinstance(_debug, dict) else None
|
||||||
|
_lf_session = (_debug or {}).get("session_id") if isinstance(_debug, dict) else None
|
||||||
|
|
||||||
|
with langfuse_context(user_id=_lf_user, session_id=_lf_session):
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="floating-classifier",
|
||||||
|
model=model_for_agent("classifier"),
|
||||||
|
prompt=classifier_prompt_obj,
|
||||||
|
input=classifier_messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm.ainvoke(classifier_messages)
|
||||||
|
gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
|
else:
|
||||||
|
response = await llm.ainvoke(classifier_messages)
|
||||||
parsed = _parse_json_object(_as_text(response.content))
|
parsed = _parse_json_object(_as_text(response.content))
|
||||||
if parsed is not None:
|
if parsed is not None:
|
||||||
domain = _normalize_domain_payload(parsed, project_id)
|
domain = _normalize_domain_payload(parsed, project_id)
|
||||||
@@ -571,9 +630,13 @@ async def _run_single_agent(
|
|||||||
message: str,
|
message: str,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
max_steps: int = 6,
|
max_steps: int = 6,
|
||||||
|
langfuse_prompt: Any = None,
|
||||||
|
agent_name: str = "agent",
|
||||||
) -> str:
|
) -> str:
|
||||||
trace_id = _trace_id_from_context(context)
|
trace_id = _trace_id_from_context(context)
|
||||||
llm = get_llm()
|
session_id = _session_id_from_context(context)
|
||||||
|
lf = get_langfuse()
|
||||||
|
llm = get_agent_llm(agent_name)
|
||||||
tools = _all_tools_for_user(user_id, trace_id)
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
model_context = _context_for_model(context)
|
model_context = _context_for_model(context)
|
||||||
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
@@ -591,9 +654,39 @@ async def _run_single_agent(
|
|||||||
tool_calls_count = 0
|
tool_calls_count = 0
|
||||||
collected: list[dict[str, Any]] = []
|
collected: list[dict[str, Any]] = []
|
||||||
set_tool_result_collector(collected)
|
set_tool_result_collector(collected)
|
||||||
|
|
||||||
|
_lf_ctx = langfuse_context(user_id=user_id, session_id=session_id)
|
||||||
|
_lf_ctx.__enter__()
|
||||||
|
|
||||||
|
_span_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name=agent_name,
|
||||||
|
metadata={"user_id": user_id, "session_id": trace_id},
|
||||||
|
input=message,
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
|
_gen_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name=f"{agent_name}-llm",
|
||||||
|
model=model_for_agent(agent_name),
|
||||||
|
prompt=langfuse_prompt,
|
||||||
|
input=messages,
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
if _gen_ctx:
|
||||||
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
messages.append(response)
|
messages.append(response)
|
||||||
|
|
||||||
if not response.tool_calls:
|
if not response.tool_calls:
|
||||||
@@ -605,6 +698,8 @@ async def _run_single_agent(
|
|||||||
tool_calls_count,
|
tool_calls_count,
|
||||||
len(final_text),
|
len(final_text),
|
||||||
)
|
)
|
||||||
|
if _span:
|
||||||
|
_span.update(output=final_text)
|
||||||
return final_text
|
return final_text
|
||||||
|
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
@@ -644,9 +739,16 @@ async def _run_single_agent(
|
|||||||
tool_calls_count,
|
tool_calls_count,
|
||||||
len(final_text),
|
len(final_text),
|
||||||
)
|
)
|
||||||
|
if _span:
|
||||||
|
_span.update(output=final_text)
|
||||||
return final_text
|
return final_text
|
||||||
finally:
|
finally:
|
||||||
clear_tool_result_collector()
|
clear_tool_result_collector()
|
||||||
|
if _span_ctx:
|
||||||
|
_span_ctx.__exit__(None, None, None)
|
||||||
|
_lf_ctx.__exit__(None, None, None)
|
||||||
|
if lf:
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
|
||||||
async def _run_single_agent_stream(
|
async def _run_single_agent_stream(
|
||||||
@@ -656,9 +758,13 @@ async def _run_single_agent_stream(
|
|||||||
message: str,
|
message: str,
|
||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
max_steps: int = 6,
|
max_steps: int = 6,
|
||||||
|
langfuse_prompt: Any = None,
|
||||||
|
agent_name: str = "agent",
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
trace_id = _trace_id_from_context(context)
|
trace_id = _trace_id_from_context(context)
|
||||||
llm = get_llm()
|
session_id = _session_id_from_context(context)
|
||||||
|
lf = get_langfuse()
|
||||||
|
llm = get_agent_llm(agent_name)
|
||||||
tools = _all_tools_for_user(user_id, trace_id)
|
tools = _all_tools_for_user(user_id, trace_id)
|
||||||
model_context = _context_for_model(context)
|
model_context = _context_for_model(context)
|
||||||
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
||||||
@@ -677,9 +783,40 @@ async def _run_single_agent_stream(
|
|||||||
streamed_chars = 0
|
streamed_chars = 0
|
||||||
collected: list[dict[str, Any]] = []
|
collected: list[dict[str, Any]] = []
|
||||||
set_tool_result_collector(collected)
|
set_tool_result_collector(collected)
|
||||||
|
|
||||||
|
_lf_ctx = langfuse_context(user_id=user_id, session_id=session_id)
|
||||||
|
_lf_ctx.__enter__()
|
||||||
|
|
||||||
|
_span_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="span",
|
||||||
|
name=f"{agent_name}-stream",
|
||||||
|
metadata={"user_id": user_id, "session_id": trace_id},
|
||||||
|
input=message,
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_span = _span_ctx.__enter__() if _span_ctx else None
|
||||||
|
streamed_text: list[str] = []
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for _ in range(max_steps):
|
for _ in range(max_steps):
|
||||||
|
_gen_ctx = (
|
||||||
|
lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name=f"{agent_name}-llm",
|
||||||
|
model=model_for_agent(agent_name),
|
||||||
|
prompt=langfuse_prompt,
|
||||||
|
input=messages,
|
||||||
|
)
|
||||||
|
if lf else None
|
||||||
|
)
|
||||||
|
_gen = _gen_ctx.__enter__() if _gen_ctx else None
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
||||||
|
if _gen_ctx:
|
||||||
|
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
|
||||||
|
_gen_ctx.__exit__(None, None, None)
|
||||||
|
|
||||||
messages.append(response)
|
messages.append(response)
|
||||||
|
|
||||||
if not response.tool_calls:
|
if not response.tool_calls:
|
||||||
@@ -688,6 +825,7 @@ async def _run_single_agent_stream(
|
|||||||
token = _as_text(getattr(chunk, "content", ""))
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
if token:
|
if token:
|
||||||
streamed_chars += len(token)
|
streamed_chars += len(token)
|
||||||
|
streamed_text.append(token)
|
||||||
emitted_any = True
|
emitted_any = True
|
||||||
yield "token", token
|
yield "token", token
|
||||||
|
|
||||||
@@ -696,6 +834,7 @@ async def _run_single_agent_stream(
|
|||||||
fallback_text = _as_text(response.content)
|
fallback_text = _as_text(response.content)
|
||||||
if fallback_text:
|
if fallback_text:
|
||||||
streamed_chars += len(fallback_text)
|
streamed_chars += len(fallback_text)
|
||||||
|
streamed_text.append(fallback_text)
|
||||||
yield "token", fallback_text
|
yield "token", fallback_text
|
||||||
logger.info(
|
logger.info(
|
||||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
||||||
@@ -704,6 +843,8 @@ async def _run_single_agent_stream(
|
|||||||
tool_calls_count,
|
tool_calls_count,
|
||||||
streamed_chars,
|
streamed_chars,
|
||||||
)
|
)
|
||||||
|
if _span:
|
||||||
|
_span.update(output="".join(streamed_text))
|
||||||
return
|
return
|
||||||
|
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
||||||
@@ -738,6 +879,7 @@ async def _run_single_agent_stream(
|
|||||||
token = _as_text(getattr(chunk, "content", ""))
|
token = _as_text(getattr(chunk, "content", ""))
|
||||||
if token:
|
if token:
|
||||||
streamed_chars += len(token)
|
streamed_chars += len(token)
|
||||||
|
streamed_text.append(token)
|
||||||
yield "token", token
|
yield "token", token
|
||||||
logger.info(
|
logger.info(
|
||||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
||||||
@@ -746,17 +888,30 @@ async def _run_single_agent_stream(
|
|||||||
tool_calls_count,
|
tool_calls_count,
|
||||||
streamed_chars,
|
streamed_chars,
|
||||||
)
|
)
|
||||||
|
if _span:
|
||||||
|
_span.update(output="".join(streamed_text))
|
||||||
finally:
|
finally:
|
||||||
clear_tool_result_collector()
|
clear_tool_result_collector()
|
||||||
|
if _span_ctx:
|
||||||
|
_span_ctx.__exit__(None, None, None)
|
||||||
|
_lf_ctx.__exit__(None, None, None)
|
||||||
|
if lf:
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
|
||||||
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
|
"home_system", _HOME_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
system_prompt=system_prompt,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="home-agent",
|
||||||
)
|
)
|
||||||
return _normalize_tagged_list_lines(response, message)
|
return _normalize_tagged_list_lines(response, message)
|
||||||
|
|
||||||
@@ -764,11 +919,17 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
|
|||||||
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
|
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
domain = await _infer_floating_domain(message, prepared_context)
|
domain = await _infer_floating_domain(message, prepared_context)
|
||||||
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
|
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
response = await _run_single_agent(
|
response = await _run_single_agent(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
system_prompt=system_prompt,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="floating-agent",
|
||||||
)
|
)
|
||||||
sanitized = _strip_floating_markup(response)
|
sanitized = _strip_floating_markup(response)
|
||||||
if not sanitized and response:
|
if not sanitized and response:
|
||||||
@@ -782,12 +943,18 @@ async def run_home_stream(
|
|||||||
context: dict[str, Any],
|
context: dict[str, Any],
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
prepared_context = await _prepare_context(message, context)
|
prepared_context = await _prepare_context(message, context)
|
||||||
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
|
"home_system", _HOME_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
text_chunks: list[str] = []
|
text_chunks: list[str] = []
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
|
system_prompt=system_prompt,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="home-agent",
|
||||||
):
|
):
|
||||||
event_type, data = event
|
event_type, data = event
|
||||||
if event_type != "token":
|
if event_type != "token":
|
||||||
@@ -809,14 +976,20 @@ async def run_floating_stream(
|
|||||||
domain = await _infer_floating_domain(message, prepared_context)
|
domain = await _infer_floating_domain(message, prepared_context)
|
||||||
yield "floating_domain", domain
|
yield "floating_domain", domain
|
||||||
|
|
||||||
|
system_prompt, langfuse_prompt = get_prompt_or_fallback(
|
||||||
|
"floating_system", _FLOATING_SYSTEM_PROMPT
|
||||||
|
)
|
||||||
|
system_prompt += _language_instruction(context)
|
||||||
sanitizer = _FloatingStreamSanitizer()
|
sanitizer = _FloatingStreamSanitizer()
|
||||||
emitted_sanitized = False
|
emitted_sanitized = False
|
||||||
raw_chunks: list[str] = []
|
raw_chunks: list[str] = []
|
||||||
async for event in _run_single_agent_stream(
|
async for event in _run_single_agent_stream(
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
|
system_prompt=system_prompt,
|
||||||
message=message,
|
message=message,
|
||||||
context=prepared_context,
|
context=prepared_context,
|
||||||
|
langfuse_prompt=langfuse_prompt,
|
||||||
|
agent_name="floating-agent",
|
||||||
):
|
):
|
||||||
event_type, data = event
|
event_type, data = event
|
||||||
if event_type != "token":
|
if event_type != "token":
|
||||||
|
|||||||
190
app/core/langfuse_client.py
Normal file
190
app/core/langfuse_client.py
Normal file
@@ -0,0 +1,190 @@
|
|||||||
|
"""Langfuse observability — singleton client and prompt helpers.
|
||||||
|
|
||||||
|
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
|
||||||
|
all helpers are no-ops so the app works without Langfuse configured.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
Tracing::
|
||||||
|
|
||||||
|
from app.core.langfuse_client import get_langfuse
|
||||||
|
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf:
|
||||||
|
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
|
||||||
|
span.update(input=user_message)
|
||||||
|
# ... do work ...
|
||||||
|
span.update(output=result)
|
||||||
|
lf.flush()
|
||||||
|
|
||||||
|
Prompt management::
|
||||||
|
|
||||||
|
from app.core.langfuse_client import get_prompt_or_fallback
|
||||||
|
|
||||||
|
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
|
||||||
|
# Use text as the system prompt; pass prompt_obj to generations for linking.
|
||||||
|
|
||||||
|
Linking a prompt to a generation::
|
||||||
|
|
||||||
|
with lf.start_as_current_observation(
|
||||||
|
as_type="generation",
|
||||||
|
name="llm-call",
|
||||||
|
model="gpt-4o",
|
||||||
|
prompt=prompt_obj, # links generation → prompt version in the UI
|
||||||
|
input=messages,
|
||||||
|
) as gen:
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
gen.update(output=response.content, usage=_usage(response))
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import logging
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import Any, Generator
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
_client: Any = None
|
||||||
|
_initialized: bool = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_langfuse() -> Any | None:
|
||||||
|
"""Return the Langfuse singleton, or ``None`` when not configured."""
|
||||||
|
global _client, _initialized
|
||||||
|
if _initialized:
|
||||||
|
return _client
|
||||||
|
_initialized = True
|
||||||
|
|
||||||
|
from app.config.settings import settings # local import to avoid circular deps
|
||||||
|
|
||||||
|
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
||||||
|
logger.debug("langfuse: not configured — observability disabled")
|
||||||
|
return None
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import Langfuse
|
||||||
|
|
||||||
|
_client = Langfuse(
|
||||||
|
secret_key=settings.LANGFUSE_SECRET_KEY,
|
||||||
|
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
||||||
|
host=settings.LANGFUSE_BASE_URL,
|
||||||
|
)
|
||||||
|
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_BASE_URL)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse: failed to initialize: %s", exc)
|
||||||
|
_client = None
|
||||||
|
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
|
||||||
|
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
|
||||||
|
|
||||||
|
Returns ``(raw_template, prompt_obj_or_None)``.
|
||||||
|
|
||||||
|
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
|
||||||
|
on it directly; use :func:`compile_prompt` instead so the correct variable
|
||||||
|
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
|
||||||
|
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
|
||||||
|
unavailable / the fetch failed. Pass this to generation observations so
|
||||||
|
Langfuse links the generation to the exact prompt version in the UI.
|
||||||
|
"""
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf is None:
|
||||||
|
return fallback, None
|
||||||
|
|
||||||
|
try:
|
||||||
|
prompt = lf.get_prompt(name, label="production", fallback=fallback)
|
||||||
|
# For text-type prompts .prompt holds the raw template string.
|
||||||
|
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
|
||||||
|
return raw, prompt
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
|
||||||
|
return fallback, None
|
||||||
|
|
||||||
|
|
||||||
|
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
|
||||||
|
"""Compile *template* with *variables*, choosing the right syntax.
|
||||||
|
|
||||||
|
* When *prompt_obj* is a real Langfuse prompt object, calls
|
||||||
|
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
|
||||||
|
substitution as defined in the Langfuse UI.
|
||||||
|
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
|
||||||
|
falls back to ``template.format(**variables)`` which handles the
|
||||||
|
``{variable}`` syntax used in the hardcoded fallback strings.
|
||||||
|
|
||||||
|
This keeps callers oblivious to which syntax is in use.
|
||||||
|
"""
|
||||||
|
if prompt_obj is not None:
|
||||||
|
try:
|
||||||
|
compiled = prompt_obj.compile(**variables)
|
||||||
|
# compile() returns a string for text prompts.
|
||||||
|
if isinstance(compiled, str):
|
||||||
|
return compiled
|
||||||
|
# Chat prompts return a list of dicts — join text parts.
|
||||||
|
if isinstance(compiled, list):
|
||||||
|
return "\n".join(
|
||||||
|
m.get("content", "") for m in compiled if isinstance(m, dict)
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
|
||||||
|
getattr(prompt_obj, "name", "?"),
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
return template.format(**variables)
|
||||||
|
|
||||||
|
|
||||||
|
def extract_usage(response: Any) -> dict[str, int]:
|
||||||
|
"""Extract token usage from a LangChain AI message into Langfuse format."""
|
||||||
|
meta = getattr(response, "usage_metadata", None)
|
||||||
|
if not meta:
|
||||||
|
return {}
|
||||||
|
return {
|
||||||
|
"input": int(meta.get("input_tokens", 0)),
|
||||||
|
"output": int(meta.get("output_tokens", 0)),
|
||||||
|
"total": int(meta.get("total_tokens", 0)),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def hash_user_id(user_id: str) -> str:
|
||||||
|
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
|
||||||
|
|
||||||
|
This avoids sending raw database UUIDs to external observability services
|
||||||
|
while still providing a stable, deterministic identifier for per-user
|
||||||
|
metrics in the Langfuse dashboard.
|
||||||
|
"""
|
||||||
|
return hashlib.sha256(user_id.encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def langfuse_context(
|
||||||
|
user_id: str | None = None,
|
||||||
|
session_id: str | None = None,
|
||||||
|
) -> Generator[None, None, None]:
|
||||||
|
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
|
||||||
|
|
||||||
|
No-op when Langfuse is not configured or parameters are empty.
|
||||||
|
"""
|
||||||
|
lf = get_langfuse()
|
||||||
|
if lf is None or (not user_id and not session_id):
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
from langfuse import propagate_attributes
|
||||||
|
except ImportError:
|
||||||
|
logger.debug("langfuse: propagate_attributes not available — skipping context")
|
||||||
|
yield
|
||||||
|
return
|
||||||
|
|
||||||
|
attrs: dict[str, str] = {}
|
||||||
|
if user_id:
|
||||||
|
attrs["user_id"] = hash_user_id(user_id)
|
||||||
|
if session_id:
|
||||||
|
attrs["session_id"] = session_id
|
||||||
|
|
||||||
|
with propagate_attributes(**attrs):
|
||||||
|
yield
|
||||||
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections.abc import Callable
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import litellm
|
import litellm
|
||||||
@@ -95,6 +96,35 @@ def get_llm(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
|
||||||
|
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
|
||||||
|
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
|
||||||
|
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
|
||||||
|
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
|
||||||
|
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
|
||||||
|
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def model_for_agent(agent_name: str) -> str:
|
||||||
|
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
|
||||||
|
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
|
||||||
|
|
||||||
|
|
||||||
|
def get_agent_llm(
|
||||||
|
agent_name: str,
|
||||||
|
*,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
|
||||||
|
|
||||||
|
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
|
||||||
|
per-agent override is left empty in ``.env``.
|
||||||
|
"""
|
||||||
|
model = model_for_agent(agent_name)
|
||||||
|
return get_llm(model=model, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
async def embed(text: str) -> list[float]:
|
||||||
"""Return an embedding vector for *text*.
|
"""Return an embedding vector for *text*.
|
||||||
|
|
||||||
|
|||||||
104
app/core/preprocessors/__init__.py
Normal file
104
app/core/preprocessors/__init__.py
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
"""Preprocessor registry: detect content type and dispatch to handlers.
|
||||||
|
|
||||||
|
Public API
|
||||||
|
----------
|
||||||
|
detect_content_type(filename, raw_content) -> str
|
||||||
|
Heuristic detection based on file extension and content patterns.
|
||||||
|
|
||||||
|
preprocess(content_type, raw_content) -> PreprocessResult
|
||||||
|
Dispatch to the appropriate handler.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
|
||||||
|
from app.core.preprocessors.base import PreprocessResult
|
||||||
|
|
||||||
|
# ── Heuristics ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Patterns that strongly suggest an email HTML file
|
||||||
|
_EMAIL_SIGNALS = re.compile(
|
||||||
|
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Patterns that suggest a generic HTML page (not an email)
|
||||||
|
_GENERIC_HTML_SIGNALS = re.compile(
|
||||||
|
r"<(nav|main|header|footer|article|section)\b",
|
||||||
|
re.IGNORECASE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def detect_content_type(filename: str, raw_content: str) -> str:
|
||||||
|
"""Return a content-type string for the given file.
|
||||||
|
|
||||||
|
Supported types: ``"email_html"``, ``"generic_html"``,
|
||||||
|
``"plain_text"``, ``"unknown"``.
|
||||||
|
"""
|
||||||
|
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
|
||||||
|
|
||||||
|
if ext == "txt":
|
||||||
|
return "plain_text"
|
||||||
|
|
||||||
|
if ext in ("html", "htm", "eml", "mhtml", "mht"):
|
||||||
|
# Prefer email detection over generic HTML
|
||||||
|
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||||
|
return "email_html"
|
||||||
|
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
|
||||||
|
return "generic_html"
|
||||||
|
# .html without clear signals — check for any email header
|
||||||
|
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
|
||||||
|
return "email_html"
|
||||||
|
return "generic_html"
|
||||||
|
|
||||||
|
# Plain text files with email headers
|
||||||
|
if ext in ("", "txt") or not ext:
|
||||||
|
if _EMAIL_SIGNALS.search(raw_content[:4096]):
|
||||||
|
return "email_html"
|
||||||
|
|
||||||
|
# Detect binary content
|
||||||
|
try:
|
||||||
|
raw_content.encode("utf-8")
|
||||||
|
except (UnicodeEncodeError, AttributeError):
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
# Non-text bytes heuristic: high ratio of non-printable chars
|
||||||
|
sample = raw_content[:512]
|
||||||
|
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
|
||||||
|
if len(sample) > 0 and non_printable / len(sample) > 0.1:
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
return "unknown"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Generic fallback handler ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
|
||||||
|
"""Strip HTML tags if present, return text as-is."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup
|
||||||
|
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
|
||||||
|
except ImportError:
|
||||||
|
# No BeautifulSoup — strip tags with a simple regex
|
||||||
|
text = re.sub(r"<[^>]+>", "", raw_content)
|
||||||
|
|
||||||
|
text = re.sub(r"\n{3,}", "\n\n", text).strip()
|
||||||
|
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
|
||||||
|
|
||||||
|
|
||||||
|
# ── Dispatch ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
|
||||||
|
"""Dispatch *raw_content* to the handler registered for *content_type*.
|
||||||
|
|
||||||
|
Falls back to the generic handler for unknown types.
|
||||||
|
"""
|
||||||
|
if content_type == "email_html":
|
||||||
|
from app.core.preprocessors.email_html import preprocess_email_html
|
||||||
|
return preprocess_email_html(raw_content)
|
||||||
|
|
||||||
|
return _preprocess_generic(raw_content, content_type)
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]
|
||||||
25
app/core/preprocessors/base.py
Normal file
25
app/core/preprocessors/base.py
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
"""Base types for the preprocessor system."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class PreprocessResult:
|
||||||
|
"""Output of a preprocessor handler.
|
||||||
|
|
||||||
|
Attributes
|
||||||
|
----------
|
||||||
|
content_type:
|
||||||
|
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
|
||||||
|
clean_text:
|
||||||
|
Human-readable text stripped of markup/binary noise.
|
||||||
|
metadata:
|
||||||
|
Dict of extracted metadata (keys vary by handler).
|
||||||
|
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
content_type: str
|
||||||
|
clean_text: str
|
||||||
|
metadata: dict = field(default_factory=dict)
|
||||||
111
app/core/preprocessors/email_html.py
Normal file
111
app/core/preprocessors/email_html.py
Normal file
@@ -0,0 +1,111 @@
|
|||||||
|
"""Preprocessor for email HTML files.
|
||||||
|
|
||||||
|
Handles:
|
||||||
|
- HTML stripping via BeautifulSoup
|
||||||
|
- Metadata extraction (Subject, From, To, Date)
|
||||||
|
- Thread splitting — isolates the latest reply
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
|
from app.core.preprocessors.base import PreprocessResult
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ── Thread split markers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Matches patterns like:
|
||||||
|
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
|
||||||
|
# "-----Original Message-----"
|
||||||
|
# "> " (plain-text quote prefix)
|
||||||
|
_THREAD_PATTERNS = [
|
||||||
|
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
|
||||||
|
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
|
||||||
|
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
|
||||||
|
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
|
||||||
|
]
|
||||||
|
|
||||||
|
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
|
||||||
|
|
||||||
|
_META_PATTERNS: dict[str, list[re.Pattern]] = {
|
||||||
|
"subject": [
|
||||||
|
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
|
||||||
|
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
"from": [
|
||||||
|
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||||
|
re.compile(r"From:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
"to": [
|
||||||
|
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||||
|
re.compile(r"To:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
"date": [
|
||||||
|
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
|
||||||
|
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
|
||||||
|
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_metadata(raw_html: str, text: str) -> dict:
|
||||||
|
"""Extract Subject/From/To/Date from raw HTML or plain text."""
|
||||||
|
metadata: dict[str, str] = {}
|
||||||
|
for field, patterns in _META_PATTERNS.items():
|
||||||
|
for pat in patterns:
|
||||||
|
m = pat.search(raw_html) or pat.search(text)
|
||||||
|
if m:
|
||||||
|
metadata[field] = m.group(1).strip()
|
||||||
|
break
|
||||||
|
return metadata
|
||||||
|
|
||||||
|
|
||||||
|
def _split_thread(text: str) -> str:
|
||||||
|
"""Return only the latest message in a threaded email."""
|
||||||
|
earliest_pos: int | None = None
|
||||||
|
for pat in _THREAD_PATTERNS:
|
||||||
|
m = pat.search(text)
|
||||||
|
if m and (earliest_pos is None or m.start() < earliest_pos):
|
||||||
|
earliest_pos = m.start()
|
||||||
|
|
||||||
|
if earliest_pos is not None and earliest_pos > 0:
|
||||||
|
return text[:earliest_pos].strip()
|
||||||
|
return text.strip()
|
||||||
|
|
||||||
|
|
||||||
|
def preprocess_email_html(raw_content: str) -> PreprocessResult:
|
||||||
|
"""Strip HTML, extract metadata, split thread from an email HTML file."""
|
||||||
|
try:
|
||||||
|
from bs4 import BeautifulSoup # lazy import — optional dep
|
||||||
|
except ImportError as exc:
|
||||||
|
raise ImportError(
|
||||||
|
"beautifulsoup4 is required for email_html preprocessing. "
|
||||||
|
"Install it with: pip install beautifulsoup4"
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
# Parse with lxml if available, fall back to html.parser
|
||||||
|
try:
|
||||||
|
soup = BeautifulSoup(raw_content, "lxml")
|
||||||
|
except Exception:
|
||||||
|
soup = BeautifulSoup(raw_content, "html.parser")
|
||||||
|
|
||||||
|
# Remove noise tags
|
||||||
|
for tag in soup(["style", "script", "head", "noscript"]):
|
||||||
|
tag.decompose()
|
||||||
|
|
||||||
|
clean_text = soup.get_text(separator="\n")
|
||||||
|
# Collapse excessive blank lines
|
||||||
|
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
|
||||||
|
|
||||||
|
metadata = _extract_metadata(raw_content, clean_text)
|
||||||
|
latest_message = _split_thread(clean_text)
|
||||||
|
|
||||||
|
return PreprocessResult(
|
||||||
|
content_type="email_html",
|
||||||
|
clean_text=latest_message,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
@@ -25,7 +25,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timedelta, timezone
|
from datetime import datetime, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|||||||
@@ -30,7 +30,7 @@ async def lifespan(app: FastAPI):
|
|||||||
|
|
||||||
def create_app() -> FastAPI:
|
def create_app() -> FastAPI:
|
||||||
app = FastAPI(
|
app = FastAPI(
|
||||||
title="Adiuva Cloud API",
|
title="AdiuvAI Cloud API",
|
||||||
version="0.1.0",
|
version="0.1.0",
|
||||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
redoc_url=None,
|
redoc_url=None,
|
||||||
@@ -50,14 +50,10 @@ def create_app() -> FastAPI:
|
|||||||
app.add_middleware(SanitizerMiddleware)
|
app.add_middleware(SanitizerMiddleware)
|
||||||
app.add_middleware(TierRateLimitMiddleware)
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
from app.api.routes import agents, auth, billing, chat, device_ws
|
||||||
|
|
||||||
app.include_router(auth.router, prefix="/api/v1")
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
app.include_router(chat.router, prefix="/api/v1")
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
app.include_router(storage.router, prefix="/api/v1")
|
|
||||||
app.include_router(vectors.router, prefix="/api/v1")
|
|
||||||
app.include_router(backup.router, prefix="/api/v1")
|
|
||||||
app.include_router(plugins.router, prefix="/api/v1")
|
|
||||||
app.include_router(billing.router, prefix="/api/v1")
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
app.include_router(agents.router, prefix="/api/v1")
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
app.include_router(device_ws.router, prefix="/api/v1")
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|||||||
@@ -1,7 +0,0 @@
|
|||||||
"""Plugin marketplace package.
|
|
||||||
|
|
||||||
Three service classes introduced in Step 10:
|
|
||||||
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
|
||||||
- ``ReviewQueue`` — approval workflow + security checklist
|
|
||||||
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
|
||||||
"""
|
|
||||||
@@ -1,212 +0,0 @@
|
|||||||
"""Plugin catalog registry backed by PostgreSQL.
|
|
||||||
|
|
||||||
Maintains the authoritative list of plugins, their review status, and
|
|
||||||
aggregate install counts. All data is persisted in the ``plugins`` table.
|
|
||||||
|
|
||||||
Module-level singleton::
|
|
||||||
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from sqlalchemy import select, func
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.models import Plugin
|
|
||||||
from app.schemas import PluginListResponse, PluginManifest
|
|
||||||
|
|
||||||
_PAGE_SIZE = 20
|
|
||||||
|
|
||||||
|
|
||||||
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
|
|
||||||
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
|
|
||||||
try:
|
|
||||||
permissions = json.loads(p.permissions) if p.permissions else []
|
|
||||||
except (json.JSONDecodeError, TypeError):
|
|
||||||
permissions = []
|
|
||||||
return PluginManifest(
|
|
||||||
id=p.id,
|
|
||||||
name=p.name,
|
|
||||||
description=p.description,
|
|
||||||
version=p.version,
|
|
||||||
author=p.author_name,
|
|
||||||
permissions=permissions,
|
|
||||||
category=p.category,
|
|
||||||
price_cents=p.price_cents,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginRegistry:
|
|
||||||
"""PostgreSQL-backed plugin catalog.
|
|
||||||
|
|
||||||
All methods accept an ``AsyncSession`` parameter so the calling route
|
|
||||||
controls the session lifecycle.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ── Queries ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def list_plugins(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
category: str | None = None,
|
|
||||||
query: str | None = None,
|
|
||||||
page: int = 1,
|
|
||||||
sort: Literal["rating", "installs", "newest"] = "newest",
|
|
||||||
) -> PluginListResponse:
|
|
||||||
"""Return a page of approved plugins, optionally filtered and sorted."""
|
|
||||||
base = select(Plugin).where(Plugin.status == "approved")
|
|
||||||
|
|
||||||
if category:
|
|
||||||
base = base.where(Plugin.category == category)
|
|
||||||
if query:
|
|
||||||
pattern = f"%{query}%"
|
|
||||||
base = base.where(
|
|
||||||
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
|
|
||||||
)
|
|
||||||
|
|
||||||
# Count
|
|
||||||
count_q = select(func.count()).select_from(base.subquery())
|
|
||||||
total = (await db.execute(count_q)).scalar_one()
|
|
||||||
|
|
||||||
# Sort
|
|
||||||
if sort == "installs":
|
|
||||||
base = base.order_by(Plugin.install_count.desc())
|
|
||||||
elif sort == "rating":
|
|
||||||
base = base.order_by(Plugin.avg_rating.desc())
|
|
||||||
else: # newest
|
|
||||||
base = base.order_by(Plugin.created_at.desc())
|
|
||||||
|
|
||||||
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
|
|
||||||
rows = (await db.execute(base)).scalars().all()
|
|
||||||
|
|
||||||
return PluginListResponse(
|
|
||||||
plugins=[_plugin_to_manifest(r) for r in rows],
|
|
||||||
total=total,
|
|
||||||
page=page,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
|
|
||||||
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
p = result.scalar_one_or_none()
|
|
||||||
if p is None:
|
|
||||||
return None
|
|
||||||
return {
|
|
||||||
"manifest": _plugin_to_manifest(p),
|
|
||||||
"status": p.status,
|
|
||||||
"install_count": p.install_count,
|
|
||||||
"avg_rating": p.avg_rating,
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Mutations ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def submit_plugin(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
manifest: PluginManifest,
|
|
||||||
package_s3_key: str,
|
|
||||||
) -> str:
|
|
||||||
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
|
||||||
|
|
||||||
Returns the plugin_id. If a plugin with the same id already exists
|
|
||||||
it is overwritten (re-submission after rejection).
|
|
||||||
"""
|
|
||||||
plugin_id = manifest.id
|
|
||||||
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = existing.scalar_one_or_none()
|
|
||||||
|
|
||||||
if row is not None:
|
|
||||||
row.name = manifest.name
|
|
||||||
row.description = manifest.description
|
|
||||||
row.version = manifest.version
|
|
||||||
row.author_name = manifest.author
|
|
||||||
row.category = manifest.category
|
|
||||||
row.price_cents = manifest.price_cents
|
|
||||||
row.permissions = json.dumps(manifest.permissions)
|
|
||||||
row.status = "pending_review"
|
|
||||||
row.s3_package_key = package_s3_key
|
|
||||||
row.rejection_reason = None
|
|
||||||
else:
|
|
||||||
row = Plugin(
|
|
||||||
id=plugin_id,
|
|
||||||
name=manifest.name,
|
|
||||||
description=manifest.description,
|
|
||||||
version=manifest.version,
|
|
||||||
author_name=manifest.author,
|
|
||||||
category=manifest.category,
|
|
||||||
price_cents=manifest.price_cents,
|
|
||||||
permissions=json.dumps(manifest.permissions),
|
|
||||||
status="pending_review",
|
|
||||||
s3_package_key=package_s3_key,
|
|
||||||
install_count=0,
|
|
||||||
avg_rating=0.0,
|
|
||||||
)
|
|
||||||
db.add(row)
|
|
||||||
await db.commit()
|
|
||||||
return plugin_id
|
|
||||||
|
|
||||||
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
|
|
||||||
"""Set *plugin_id* status to ``'approved'``.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the plugin is not found.
|
|
||||||
"""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
|
||||||
row.status = "approved"
|
|
||||||
row.rejection_reason = None
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
|
|
||||||
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
|
||||||
|
|
||||||
Raises ``KeyError`` if the plugin is not found.
|
|
||||||
"""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
raise KeyError(f"Plugin not found: {plugin_id}")
|
|
||||||
row.status = "rejected"
|
|
||||||
row.rejection_reason = reason
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
|
|
||||||
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is not None:
|
|
||||||
row.install_count = row.install_count + 1
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
|
|
||||||
"""Decrement the install count for *plugin_id*, floored at 0."""
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is not None:
|
|
||||||
row.install_count = max(0, row.install_count - 1)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
|
||||||
|
|
||||||
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
|
|
||||||
"""Return all entries with status='pending_review'."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(Plugin).where(Plugin.status == "pending_review")
|
|
||||||
)
|
|
||||||
rows = result.scalars().all()
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"manifest": _plugin_to_manifest(r),
|
|
||||||
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
|
|
||||||
}
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
registry = PluginRegistry()
|
|
||||||
@@ -1,125 +0,0 @@
|
|||||||
"""Plugin review workflow backed by PostgreSQL.
|
|
||||||
|
|
||||||
Manages the approval queue for newly submitted plugins and enforces a
|
|
||||||
security checklist before any plugin is made visible in the marketplace.
|
|
||||||
|
|
||||||
Module-level singleton::
|
|
||||||
|
|
||||||
from app.marketplace.plugin_review import review_queue
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
from app.models import PluginReview as PluginReviewModel
|
|
||||||
from app.schemas import PluginManifest
|
|
||||||
|
|
||||||
# ── Security policy ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
|
||||||
{
|
|
||||||
"read:tasks",
|
|
||||||
"write:tasks",
|
|
||||||
"read:projects",
|
|
||||||
"write:projects",
|
|
||||||
"read:notes",
|
|
||||||
"write:notes",
|
|
||||||
"read:timelines",
|
|
||||||
"write:timelines",
|
|
||||||
"read:calendar",
|
|
||||||
"write:calendar",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
|
|
||||||
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
|
||||||
|
|
||||||
|
|
||||||
def validate_manifest(manifest: PluginManifest) -> None:
|
|
||||||
"""Enforce the plugin security checklist.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
``ValueError`` on the first violation found. Callers should catch
|
|
||||||
this and return HTTP 422 / reject the submission.
|
|
||||||
|
|
||||||
Checks:
|
|
||||||
1. Plugin id matches ``^[a-z0-9-]+$``
|
|
||||||
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
|
||||||
3. No manifest field contains raw binary data
|
|
||||||
"""
|
|
||||||
if not _PLUGIN_ID_RE.match(manifest.id):
|
|
||||||
raise ValueError(
|
|
||||||
f"Invalid plugin id format: '{manifest.id}'. "
|
|
||||||
"Only lowercase letters, digits, and hyphens are allowed."
|
|
||||||
)
|
|
||||||
|
|
||||||
for perm in manifest.permissions:
|
|
||||||
if perm not in ALLOWED_PERMISSIONS:
|
|
||||||
raise ValueError(
|
|
||||||
f"Unknown permission: '{perm}'. "
|
|
||||||
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
|
||||||
)
|
|
||||||
|
|
||||||
for field_name, value in manifest.model_dump().items():
|
|
||||||
if isinstance(value, (bytes, bytearray)):
|
|
||||||
raise ValueError(
|
|
||||||
f"Binary content is not allowed in manifest field '{field_name}'."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ReviewQueue:
|
|
||||||
"""Approval queue for pending plugin submissions.
|
|
||||||
|
|
||||||
Delegates status changes to the shared ``PluginRegistry`` singleton.
|
|
||||||
Review records are persisted in the ``plugin_reviews`` table.
|
|
||||||
"""
|
|
||||||
|
|
||||||
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
|
|
||||||
"""Return all plugins currently awaiting review.
|
|
||||||
|
|
||||||
Each item is ``{plugin_id, manifest, submitted_at}``.
|
|
||||||
"""
|
|
||||||
entries = await registry.get_pending_entries(db)
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"plugin_id": e["manifest"].id,
|
|
||||||
"manifest": e["manifest"],
|
|
||||||
"submitted_at": e["submitted_at"],
|
|
||||||
}
|
|
||||||
for e in entries
|
|
||||||
]
|
|
||||||
|
|
||||||
async def submit_review(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
plugin_id: str,
|
|
||||||
reviewer_id: str,
|
|
||||||
decision: Literal["approved", "rejected"],
|
|
||||||
notes: str = "",
|
|
||||||
) -> None:
|
|
||||||
"""Record a review decision and update the plugin's status.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
``KeyError`` if *plugin_id* is not found in the registry.
|
|
||||||
"""
|
|
||||||
if decision == "approved":
|
|
||||||
await registry.approve_plugin(db, plugin_id)
|
|
||||||
else:
|
|
||||||
await registry.reject_plugin(db, plugin_id, reason=notes)
|
|
||||||
|
|
||||||
review = PluginReviewModel(
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
reviewer_id=reviewer_id,
|
|
||||||
decision=decision,
|
|
||||||
notes=notes,
|
|
||||||
)
|
|
||||||
db.add(review)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
review_queue = ReviewQueue()
|
|
||||||
@@ -1,233 +0,0 @@
|
|||||||
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
|
|
||||||
|
|
||||||
Records every plugin installation as a revenue event and facilitates
|
|
||||||
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
|
|
||||||
in the ``revenue_events`` table.
|
|
||||||
|
|
||||||
Module-level singleton::
|
|
||||||
|
|
||||||
from app.marketplace.revenue_share import revenue_share
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import stripe as stripe_lib
|
|
||||||
from sqlalchemy import extract, func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.marketplace.plugin_registry import registry
|
|
||||||
from app.models import Plugin, RevenueEvent
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ── Revenue split constants ───────────────────────────────────────────
|
|
||||||
|
|
||||||
DEVELOPER_SHARE: float = 0.70
|
|
||||||
PLATFORM_SHARE: float = 0.30
|
|
||||||
|
|
||||||
|
|
||||||
class RevenueShare:
|
|
||||||
"""Records installation revenue events and coordinates developer payouts.
|
|
||||||
|
|
||||||
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
|
||||||
is not configured, consistent with the rest of the billing layer.
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _stripe_configured() -> bool:
|
|
||||||
return bool(settings.STRIPE_SECRET_KEY)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _stripe() -> Any:
|
|
||||||
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
|
||||||
return stripe_lib
|
|
||||||
|
|
||||||
# ── Core operations ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def record_install(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
plugin_id: str,
|
|
||||||
user_id: str,
|
|
||||||
amount_cents: int,
|
|
||||||
) -> None:
|
|
||||||
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
|
||||||
|
|
||||||
For free plugins (``amount_cents == 0``) no payment is initiated but
|
|
||||||
the event is still recorded for analytics.
|
|
||||||
|
|
||||||
For paid plugins the developer receives 70 % via a Stripe Connect
|
|
||||||
destination charge. If Stripe is not configured or the charge fails
|
|
||||||
the installation still succeeds (the event is recorded and the install
|
|
||||||
count is incremented) — a warning is logged for monitoring.
|
|
||||||
"""
|
|
||||||
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
|
||||||
stripe_transfer_id: str | None = None
|
|
||||||
|
|
||||||
if amount_cents > 0 and self._stripe_configured():
|
|
||||||
# Look up the plugin's author Stripe account from the DB
|
|
||||||
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
plugin_row = result.scalar_one_or_none()
|
|
||||||
developer_stripe_account: str | None = None
|
|
||||||
if plugin_row and plugin_row.author_id:
|
|
||||||
# Future: look up user.stripe_connect_account_id
|
|
||||||
developer_stripe_account = None # no real account yet
|
|
||||||
|
|
||||||
if developer_stripe_account:
|
|
||||||
try:
|
|
||||||
s = self._stripe()
|
|
||||||
transfer = s.Transfer.create(
|
|
||||||
amount=developer_share_cents,
|
|
||||||
currency="eur",
|
|
||||||
destination=developer_stripe_account,
|
|
||||||
description=f"Revenue share for plugin {plugin_id}",
|
|
||||||
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
|
||||||
)
|
|
||||||
stripe_transfer_id = transfer["id"]
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"Stripe Connect transfer failed for plugin %s: %s",
|
|
||||||
plugin_id,
|
|
||||||
exc,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logger.debug(
|
|
||||||
"No Stripe account on file for plugin %s developer; "
|
|
||||||
"skipping transfer.",
|
|
||||||
plugin_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
event = RevenueEvent(
|
|
||||||
plugin_id=plugin_id,
|
|
||||||
user_id=user_id,
|
|
||||||
amount_cents=amount_cents,
|
|
||||||
developer_share_cents=developer_share_cents,
|
|
||||||
stripe_transfer_id=stripe_transfer_id,
|
|
||||||
)
|
|
||||||
db.add(event)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
await registry.record_install(db, plugin_id)
|
|
||||||
|
|
||||||
async def get_earnings(
|
|
||||||
self,
|
|
||||||
db: AsyncSession,
|
|
||||||
developer_id: str,
|
|
||||||
period: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Return aggregated earnings for *developer_id*.
|
|
||||||
|
|
||||||
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
|
||||||
|
|
||||||
Returns::
|
|
||||||
|
|
||||||
{
|
|
||||||
"developer_id": str,
|
|
||||||
"period": str | None,
|
|
||||||
"total_installs": int,
|
|
||||||
"total_revenue_cents": int,
|
|
||||||
"developer_share_cents": int,
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
# Find plugin ids belonging to this developer (by author_name match)
|
|
||||||
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
|
|
||||||
plugin_result = await db.execute(plugin_q)
|
|
||||||
developer_plugin_ids = [row[0] for row in plugin_result.all()]
|
|
||||||
|
|
||||||
if not developer_plugin_ids:
|
|
||||||
return {
|
|
||||||
"developer_id": developer_id,
|
|
||||||
"period": period,
|
|
||||||
"total_installs": 0,
|
|
||||||
"total_revenue_cents": 0,
|
|
||||||
"developer_share_cents": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
query = select(
|
|
||||||
func.count().label("total_installs"),
|
|
||||||
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
|
|
||||||
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
|
|
||||||
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
|
|
||||||
|
|
||||||
if period:
|
|
||||||
# Filter by YYYY-MM: extract year and month from created_at
|
|
||||||
try:
|
|
||||||
year, month = period.split("-")
|
|
||||||
query = query.where(
|
|
||||||
extract("year", RevenueEvent.created_at) == int(year),
|
|
||||||
extract("month", RevenueEvent.created_at) == int(month),
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
pass # invalid period format — return all
|
|
||||||
|
|
||||||
result = await db.execute(query)
|
|
||||||
row = result.one()
|
|
||||||
|
|
||||||
return {
|
|
||||||
"developer_id": developer_id,
|
|
||||||
"period": period,
|
|
||||||
"total_installs": row.total_installs,
|
|
||||||
"total_revenue_cents": row.total_revenue,
|
|
||||||
"developer_share_cents": row.dev_share,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
|
|
||||||
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
|
||||||
|
|
||||||
Marks processed events with ``paid_at`` timestamp.
|
|
||||||
Stubs gracefully when Stripe is not configured.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
year, month = period.split("-")
|
|
||||||
year_int, month_int = int(year), int(month)
|
|
||||||
except ValueError:
|
|
||||||
logger.warning("Invalid period format: %s", period)
|
|
||||||
return
|
|
||||||
|
|
||||||
result = await db.execute(
|
|
||||||
select(RevenueEvent).where(
|
|
||||||
RevenueEvent.plugin_id == plugin_id,
|
|
||||||
RevenueEvent.paid_at.is_(None),
|
|
||||||
extract("year", RevenueEvent.created_at) == year_int,
|
|
||||||
extract("month", RevenueEvent.created_at) == month_int,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
unpaid = list(result.scalars().all())
|
|
||||||
|
|
||||||
total_dev_share = sum(e.developer_share_cents for e in unpaid)
|
|
||||||
if total_dev_share <= 0 or not unpaid:
|
|
||||||
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
|
||||||
return
|
|
||||||
|
|
||||||
if self._stripe_configured():
|
|
||||||
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
|
||||||
plugin_row = plugin_result.scalar_one_or_none()
|
|
||||||
developer_stripe_account: str | None = None # Future: fetch from DB
|
|
||||||
if plugin_row and developer_stripe_account:
|
|
||||||
try:
|
|
||||||
s = self._stripe()
|
|
||||||
s.Transfer.create(
|
|
||||||
amount=total_dev_share,
|
|
||||||
currency="eur",
|
|
||||||
destination=developer_stripe_account,
|
|
||||||
description=f"Payout for plugin {plugin_id} period {period}",
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
|
||||||
return
|
|
||||||
|
|
||||||
paid_ts = datetime.now(timezone.utc)
|
|
||||||
for event in unpaid:
|
|
||||||
event.paid_at = paid_ts
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
|
||||||
revenue_share = RevenueShare()
|
|
||||||
192
app/models.py
192
app/models.py
@@ -1,19 +1,15 @@
|
|||||||
"""SQLAlchemy ORM models for all persistent tables.
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
Only auth, billing, storage metadata, and marketplace data live here.
|
Only auth, billing, agent config, and memory data live here.
|
||||||
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
User content (notes, tasks, etc.) lives exclusively on the client.
|
||||||
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
|
||||||
|
|
||||||
Table inventory:
|
Table inventory:
|
||||||
users — account credentials + tier
|
users — account credentials + tier
|
||||||
refresh_tokens — hashed refresh token store
|
refresh_tokens — hashed refresh token store
|
||||||
subscriptions — Stripe subscription records
|
subscriptions — Stripe subscription records
|
||||||
storage_records — S3 blob metadata (no plaintext)
|
local_agent_configs — per-device batch agent configs
|
||||||
backup_metadata — encrypted backup manifests
|
cloud_agent_configs — OAuth-backed cloud agent configs
|
||||||
plugins — marketplace plugin catalog
|
agent_run_logs — execution history for all agents
|
||||||
plugin_installations — per-user install records
|
|
||||||
plugin_reviews — admin review decisions
|
|
||||||
revenue_events — Stripe Connect 70/30 split ledger
|
|
||||||
memory_core — per-user persistent key/value preferences (encrypted)
|
memory_core — per-user persistent key/value preferences (encrypted)
|
||||||
memory_associative — per-user semantic memory with embeddings (encrypted)
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
memory_episodic — per-user session summaries (encrypted)
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
@@ -26,7 +22,6 @@ import uuid
|
|||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
from sqlalchemy import (
|
from sqlalchemy import (
|
||||||
BigInteger,
|
|
||||||
Boolean,
|
Boolean,
|
||||||
DateTime,
|
DateTime,
|
||||||
Enum,
|
Enum,
|
||||||
@@ -36,7 +31,6 @@ from sqlalchemy import (
|
|||||||
JSON,
|
JSON,
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
UniqueConstraint,
|
|
||||||
Uuid,
|
Uuid,
|
||||||
func,
|
func,
|
||||||
)
|
)
|
||||||
@@ -58,8 +52,6 @@ def _now() -> datetime:
|
|||||||
# ── Enum types ────────────────────────────────────────────────────────────
|
# ── Enum types ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
||||||
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
|
||||||
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
|
||||||
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
||||||
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
||||||
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
||||||
@@ -77,7 +69,8 @@ class User(Base):
|
|||||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
||||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||||
@@ -86,6 +79,9 @@ class User(Base):
|
|||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
)
|
)
|
||||||
|
onboarding_completed_at: Mapped[datetime | None] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=True, default=None
|
||||||
|
)
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
)
|
)
|
||||||
@@ -96,6 +92,9 @@ class User(Base):
|
|||||||
subscription: Mapped[Subscription | None] = relationship(
|
subscription: Mapped[Subscription | None] = relationship(
|
||||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
||||||
)
|
)
|
||||||
|
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
|
||||||
|
back_populates="user", cascade="all, delete-orphan"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(Base):
|
class RefreshToken(Base):
|
||||||
@@ -116,6 +115,25 @@ class RefreshToken(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccount(Base):
|
||||||
|
__tablename__ = "oauth_accounts"
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
|
)
|
||||||
|
user_id: Mapped[str] = mapped_column(
|
||||||
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
||||||
|
)
|
||||||
|
provider: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
provider_user_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
provider_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
|
)
|
||||||
|
|
||||||
|
user: Mapped[User] = relationship(back_populates="oauth_accounts")
|
||||||
|
|
||||||
|
|
||||||
class Subscription(Base):
|
class Subscription(Base):
|
||||||
__tablename__ = "subscriptions"
|
__tablename__ = "subscriptions"
|
||||||
|
|
||||||
@@ -137,151 +155,6 @@ class Subscription(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="subscription")
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
class StorageRecord(Base):
|
|
||||||
__tablename__ = "storage_records"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
|
||||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
|
||||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
||||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BackupMetadata(Base):
|
|
||||||
__tablename__ = "backup_metadata"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
|
||||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
|
||||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
||||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class Plugin(Base):
|
|
||||||
__tablename__ = "plugins"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
|
||||||
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
|
||||||
# nullable until developer account system is built
|
|
||||||
author_id: Mapped[str | None] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
|
||||||
)
|
|
||||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
|
||||||
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
|
||||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
|
||||||
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
|
||||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
|
||||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
|
||||||
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
submitted_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
installations: Mapped[list[PluginInstallation]] = relationship(
|
|
||||||
back_populates="plugin", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
reviews: Mapped[list[PluginReview]] = relationship(
|
|
||||||
back_populates="plugin", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
|
||||||
back_populates="plugin", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginInstallation(Base):
|
|
||||||
__tablename__ = "plugin_installations"
|
|
||||||
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
plugin_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
installed_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
|
||||||
|
|
||||||
|
|
||||||
class PluginReview(Base):
|
|
||||||
__tablename__ = "plugin_reviews"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
plugin_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
reviewer_id: Mapped[str | None] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
|
||||||
)
|
|
||||||
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
|
||||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
reviewed_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
|
||||||
|
|
||||||
|
|
||||||
class RevenueEvent(Base):
|
|
||||||
__tablename__ = "revenue_events"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
plugin_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
||||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfig(Base):
|
class LocalAgentConfig(Base):
|
||||||
__tablename__ = "local_agent_configs"
|
__tablename__ = "local_agent_configs"
|
||||||
|
|
||||||
@@ -296,6 +169,7 @@ class LocalAgentConfig(Base):
|
|||||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
|
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
||||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
||||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||||
|
|||||||
118
app/schemas.py
118
app/schemas.py
@@ -30,6 +30,16 @@ class UserProfile(BaseModel):
|
|||||||
name: str | None = None
|
name: str | None = None
|
||||||
surname: str | None = None
|
surname: str | None = None
|
||||||
tier: BillingTier
|
tier: BillingTier
|
||||||
|
avatar_url: str | None = None
|
||||||
|
has_password: bool = True
|
||||||
|
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
|
||||||
|
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
|
||||||
|
|
||||||
|
|
||||||
|
class OAuthAccountInfo(BaseModel):
|
||||||
|
provider: str
|
||||||
|
provider_email: str | None = None
|
||||||
|
created_at: int # epoch ms
|
||||||
|
|
||||||
|
|
||||||
# ── Chat ─────────────────────────────────────────────────────────────
|
# ── Chat ─────────────────────────────────────────────────────────────
|
||||||
@@ -50,88 +60,6 @@ class ChatResponse(BaseModel):
|
|||||||
response: str
|
response: str
|
||||||
|
|
||||||
|
|
||||||
# ── Backup ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class BackupMetadata(BaseModel):
|
|
||||||
version: int
|
|
||||||
timestamp: int
|
|
||||||
checksum: str
|
|
||||||
chunk_count: int
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
|
|
||||||
|
|
||||||
class StorageRecord(BaseModel):
|
|
||||||
id: str
|
|
||||||
user_id: str
|
|
||||||
table: str
|
|
||||||
blob: bytes
|
|
||||||
checksum: str
|
|
||||||
created_at: int
|
|
||||||
updated_at: int
|
|
||||||
|
|
||||||
|
|
||||||
class StorageRecordCreate(BaseModel):
|
|
||||||
table: str
|
|
||||||
blob: bytes
|
|
||||||
checksum: str
|
|
||||||
|
|
||||||
|
|
||||||
class StorageRecordUpdate(BaseModel):
|
|
||||||
blob: bytes
|
|
||||||
checksum: str
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
|
|
||||||
|
|
||||||
class VectorItem(BaseModel):
|
|
||||||
id: str
|
|
||||||
blob: bytes # encrypted vector + metadata — backend never decrypts
|
|
||||||
checksum: str
|
|
||||||
|
|
||||||
|
|
||||||
class VectorUpsertRequest(BaseModel):
|
|
||||||
vectors: list[VectorItem]
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchRequest(BaseModel):
|
|
||||||
query_blob: bytes # encrypted query — backend never decrypts
|
|
||||||
top_k: int = 10
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchResult(BaseModel):
|
|
||||||
id: str
|
|
||||||
score: float
|
|
||||||
blob: bytes
|
|
||||||
|
|
||||||
|
|
||||||
class VectorSearchResponse(BaseModel):
|
|
||||||
results: list[VectorSearchResult]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plugin Marketplace ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class PluginManifest(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
version: str
|
|
||||||
author: str
|
|
||||||
permissions: list[str]
|
|
||||||
category: str
|
|
||||||
price_cents: int = 0
|
|
||||||
|
|
||||||
|
|
||||||
class PluginListResponse(BaseModel):
|
|
||||||
plugins: list[PluginManifest]
|
|
||||||
total: int
|
|
||||||
page: int
|
|
||||||
|
|
||||||
|
|
||||||
class PluginInstallRequest(BaseModel):
|
|
||||||
plugin_id: str
|
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
# ── WebSocket Frame Protocol ──────────────────────────────────────────
|
||||||
|
|
||||||
class WsFrameType(str, Enum):
|
class WsFrameType(str, Enum):
|
||||||
@@ -273,6 +201,27 @@ class WsFloatingDomain(BaseModel):
|
|||||||
domain: WsDomain
|
domain: WsDomain
|
||||||
|
|
||||||
|
|
||||||
|
# ── Agent Config V2 ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
class ContentTypeConfig(BaseModel):
|
||||||
|
"""Per-type extraction config produced by the journey chatbot."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
label: str = ""
|
||||||
|
detection_hint: str = ""
|
||||||
|
preprocessing: str = "generic" # handler name: "email_html", "plain_text", ...
|
||||||
|
extraction_prompt: str
|
||||||
|
|
||||||
|
|
||||||
|
class AgentConfig(BaseModel):
|
||||||
|
"""Structured agent configuration (replaces freeform prompt_template)."""
|
||||||
|
|
||||||
|
content_types: list[ContentTypeConfig] = []
|
||||||
|
global_rules: list[str] = []
|
||||||
|
data_types: list[str] = []
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
class AgentCatalogItem(BaseModel):
|
class AgentCatalogItem(BaseModel):
|
||||||
@@ -297,10 +246,11 @@ class AgentTriggerRequest(BaseModel):
|
|||||||
device_id: str = Field(default="")
|
device_id: str = Field(default="")
|
||||||
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
|
||||||
what_to_extract: list[str] = Field(min_length=1)
|
what_to_extract: list[str] = Field(min_length=1)
|
||||||
actions_by_type: dict[str, list[str]] | None = None
|
|
||||||
batch_interval: str = Field(min_length=1)
|
batch_interval: str = Field(min_length=1)
|
||||||
custom_agent_prompt: str = Field(min_length=1)
|
custom_agent_prompt: str | None = None
|
||||||
|
agent_config: dict | None = None
|
||||||
active_agents: int = Field(ge=0, default=0)
|
active_agents: int = Field(ge=0, default=0)
|
||||||
|
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
|
||||||
@@ -1,106 +0,0 @@
|
|||||||
"""S3-backed store for E2E-encrypted blobs.
|
|
||||||
|
|
||||||
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
|
||||||
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
import boto3
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
|
|
||||||
|
|
||||||
class BlobStore:
|
|
||||||
"""Thin wrapper around boto3 S3.
|
|
||||||
|
|
||||||
All blobs must be E2E encrypted by the client before upload.
|
|
||||||
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
|
||||||
but cannot decrypt the inner client-side payload.
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _client(self) -> Any:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"region_name": settings.S3_REGION,
|
|
||||||
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
|
|
||||||
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
|
|
||||||
}
|
|
||||||
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
|
|
||||||
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
|
|
||||||
return boto3.client("s3", **kwargs)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _key(user_id: str, table: str, record_id: str) -> str:
|
|
||||||
return f"{user_id}/{table}/{record_id}"
|
|
||||||
|
|
||||||
async def upload(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
table: str,
|
|
||||||
record_id: str,
|
|
||||||
blob: bytes,
|
|
||||||
checksum: str,
|
|
||||||
) -> str:
|
|
||||||
"""Store *blob* in S3 and return the S3 key.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Owner of the blob (used as key prefix).
|
|
||||||
table: Logical table name (e.g. ``"tasks"``).
|
|
||||||
record_id: Record UUID.
|
|
||||||
blob: Raw bytes (pre-encrypted by client).
|
|
||||||
checksum: SHA-256 hex digest supplied by the client; stored as
|
|
||||||
object metadata for download-time verification.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
The S3 key under which the blob was stored.
|
|
||||||
"""
|
|
||||||
key = self._key(user_id, table, record_id)
|
|
||||||
self._client().put_object(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Key=key,
|
|
||||||
Body=blob,
|
|
||||||
ServerSideEncryption="AES256", # SSE-S3 at rest
|
|
||||||
Metadata={"checksum": checksum},
|
|
||||||
)
|
|
||||||
return key
|
|
||||||
|
|
||||||
async def download(self, user_id: str, s3_key: str) -> bytes:
|
|
||||||
"""Retrieve the blob stored at *s3_key*.
|
|
||||||
|
|
||||||
*user_id* is retained in the signature so higher-level code can
|
|
||||||
enforce ownership without re-parsing the key.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
|
||||||
object does not exist.
|
|
||||||
"""
|
|
||||||
response = self._client().get_object(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Key=s3_key,
|
|
||||||
)
|
|
||||||
return response["Body"].read()
|
|
||||||
|
|
||||||
async def delete(self, user_id: str, s3_key: str) -> None:
|
|
||||||
"""Delete the object at *s3_key*.
|
|
||||||
|
|
||||||
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
|
||||||
not exist.
|
|
||||||
"""
|
|
||||||
self._client().delete_object(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Key=s3_key,
|
|
||||||
)
|
|
||||||
|
|
||||||
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
|
||||||
"""Return all S3 keys for a given user + table combination.
|
|
||||||
|
|
||||||
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
|
||||||
"""
|
|
||||||
prefix = f"{user_id}/{table}/"
|
|
||||||
response = self._client().list_objects_v2(
|
|
||||||
Bucket=settings.S3_BUCKET,
|
|
||||||
Prefix=prefix,
|
|
||||||
)
|
|
||||||
return [obj["Key"] for obj in response.get("Contents", [])]
|
|
||||||
@@ -1,32 +0,0 @@
|
|||||||
"""Integrity verification only — the backend NEVER decrypts user data."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import hmac
|
|
||||||
|
|
||||||
from fastapi import HTTPException
|
|
||||||
|
|
||||||
|
|
||||||
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
|
||||||
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
|
||||||
|
|
||||||
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
|
||||||
timing-based side-channel attacks.
|
|
||||||
"""
|
|
||||||
computed = hashlib.sha256(blob).hexdigest()
|
|
||||||
return hmac.compare_digest(computed, checksum)
|
|
||||||
|
|
||||||
|
|
||||||
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
|
||||||
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
|
||||||
|
|
||||||
Call this before storing or forwarding any client-provided blob.
|
|
||||||
The backend never holds decryption keys — this check only verifies
|
|
||||||
that the opaque bytes arrived intact.
|
|
||||||
"""
|
|
||||||
if not verify_checksum(blob, checksum):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=400,
|
|
||||||
detail="Checksum mismatch: blob integrity check failed",
|
|
||||||
)
|
|
||||||
@@ -1,205 +0,0 @@
|
|||||||
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
|
||||||
|
|
||||||
Vectors are pre-encrypted blobs from the client. The backend stores them
|
|
||||||
alongside a deterministic 32-dim float representation derived from the blob's
|
|
||||||
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
|
||||||
is a known trade-off documented in the backend plan.
|
|
||||||
|
|
||||||
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
|
||||||
``user_id`` payload field on a shared collection.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import base64
|
|
||||||
import hashlib
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from pinecone import Pinecone
|
|
||||||
from qdrant_client import QdrantClient
|
|
||||||
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
|
||||||
|
|
||||||
from app.config.settings import settings
|
|
||||||
from app.schemas import VectorItem, VectorSearchResult
|
|
||||||
|
|
||||||
_QDRANT_COLLECTION = "adiuva_vectors"
|
|
||||||
|
|
||||||
|
|
||||||
def _blob_to_vector(blob: bytes) -> list[float]:
|
|
||||||
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
|
||||||
|
|
||||||
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
|
||||||
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
|
||||||
semantic meaning on encrypted data.
|
|
||||||
"""
|
|
||||||
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
|
||||||
|
|
||||||
|
|
||||||
class VectorStore:
|
|
||||||
"""Thin wrapper around Pinecone or Qdrant.
|
|
||||||
|
|
||||||
The backend to use is selected at runtime:
|
|
||||||
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
|
||||||
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def _use_pinecone(self) -> bool:
|
|
||||||
return bool(settings.PINECONE_API_KEY)
|
|
||||||
|
|
||||||
# ── Pinecone helpers ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _pinecone_index(self) -> Any:
|
|
||||||
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
|
||||||
return pc.Index(settings.PINECONE_INDEX)
|
|
||||||
|
|
||||||
# ── Qdrant helpers ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _qdrant_client(self) -> Any:
|
|
||||||
return QdrantClient(
|
|
||||||
url=settings.QDRANT_URL,
|
|
||||||
api_key=settings.QDRANT_API_KEY or None,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Public API ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
|
||||||
"""Store encrypted vectors in the backend.
|
|
||||||
|
|
||||||
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
|
||||||
so it can be returned verbatim during search.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Used as Pinecone namespace or Qdrant payload field.
|
|
||||||
vectors: List of encrypted vector items from the client.
|
|
||||||
"""
|
|
||||||
if self._use_pinecone():
|
|
||||||
await self._pinecone_upsert(user_id, vectors)
|
|
||||||
else:
|
|
||||||
await self._qdrant_upsert(user_id, vectors)
|
|
||||||
|
|
||||||
async def search(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
query_blob: bytes,
|
|
||||||
top_k: int,
|
|
||||||
) -> list[VectorSearchResult]:
|
|
||||||
"""Query the vector store and return encrypted result blobs.
|
|
||||||
|
|
||||||
The query vector is derived from *query_blob* using the same
|
|
||||||
deterministic mapping as upsert.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Scopes the search to this user's namespace.
|
|
||||||
query_blob: Encrypted query from the client.
|
|
||||||
top_k: Maximum number of results to return.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
|
||||||
"""
|
|
||||||
if self._use_pinecone():
|
|
||||||
return await self._pinecone_search(user_id, query_blob, top_k)
|
|
||||||
return await self._qdrant_search(user_id, query_blob, top_k)
|
|
||||||
|
|
||||||
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
|
||||||
"""Remove vectors by ID, scoped to *user_id*.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
user_id: Namespace / payload filter to prevent cross-user deletion.
|
|
||||||
vector_ids: List of vector IDs to remove.
|
|
||||||
"""
|
|
||||||
if self._use_pinecone():
|
|
||||||
await self._pinecone_delete(user_id, vector_ids)
|
|
||||||
else:
|
|
||||||
await self._qdrant_delete(user_id, vector_ids)
|
|
||||||
|
|
||||||
# ── Pinecone implementation ───────────────────────────────────────
|
|
||||||
|
|
||||||
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
|
||||||
index = self._pinecone_index()
|
|
||||||
records = [
|
|
||||||
{
|
|
||||||
"id": v.id,
|
|
||||||
"values": _blob_to_vector(v.blob),
|
|
||||||
"metadata": {
|
|
||||||
"blob": base64.b64encode(v.blob).decode(),
|
|
||||||
"checksum": v.checksum,
|
|
||||||
"user_id": user_id,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for v in vectors
|
|
||||||
]
|
|
||||||
index.upsert(vectors=records, namespace=user_id)
|
|
||||||
|
|
||||||
async def _pinecone_search(
|
|
||||||
self, user_id: str, query_blob: bytes, top_k: int
|
|
||||||
) -> list[VectorSearchResult]:
|
|
||||||
index = self._pinecone_index()
|
|
||||||
query_vector = _blob_to_vector(query_blob)
|
|
||||||
response = index.query(
|
|
||||||
vector=query_vector,
|
|
||||||
top_k=top_k,
|
|
||||||
namespace=user_id,
|
|
||||||
include_metadata=True,
|
|
||||||
)
|
|
||||||
results: list[VectorSearchResult] = []
|
|
||||||
for match in response.get("matches", []):
|
|
||||||
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
|
||||||
results.append(
|
|
||||||
VectorSearchResult(
|
|
||||||
id=match["id"],
|
|
||||||
score=match["score"],
|
|
||||||
blob=blob_bytes,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
|
||||||
index = self._pinecone_index()
|
|
||||||
index.delete(ids=vector_ids, namespace=user_id)
|
|
||||||
|
|
||||||
# ── Qdrant implementation ─────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
|
||||||
client = self._qdrant_client()
|
|
||||||
points = [
|
|
||||||
PointStruct(
|
|
||||||
id=v.id,
|
|
||||||
vector=_blob_to_vector(v.blob),
|
|
||||||
payload={
|
|
||||||
"blob": base64.b64encode(v.blob).decode(),
|
|
||||||
"checksum": v.checksum,
|
|
||||||
"user_id": user_id,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
for v in vectors
|
|
||||||
]
|
|
||||||
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
|
||||||
|
|
||||||
async def _qdrant_search(
|
|
||||||
self, user_id: str, query_blob: bytes, top_k: int
|
|
||||||
) -> list[VectorSearchResult]:
|
|
||||||
client = self._qdrant_client()
|
|
||||||
query_vector = _blob_to_vector(query_blob)
|
|
||||||
hits = client.search(
|
|
||||||
collection_name=_QDRANT_COLLECTION,
|
|
||||||
query_vector=query_vector,
|
|
||||||
query_filter=Filter(
|
|
||||||
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
|
||||||
),
|
|
||||||
limit=top_k,
|
|
||||||
)
|
|
||||||
return [
|
|
||||||
VectorSearchResult(
|
|
||||||
id=str(hit.id),
|
|
||||||
score=hit.score,
|
|
||||||
blob=base64.b64decode(hit.payload["blob"]),
|
|
||||||
)
|
|
||||||
for hit in hits
|
|
||||||
]
|
|
||||||
|
|
||||||
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
|
||||||
client = self._qdrant_client()
|
|
||||||
client.delete(
|
|
||||||
collection_name=_QDRANT_COLLECTION,
|
|
||||||
points_selector=PointIdsList(points=vector_ids),
|
|
||||||
)
|
|
||||||
@@ -7,7 +7,7 @@ services:
|
|||||||
- path: .env
|
- path: .env
|
||||||
required: false
|
required: false
|
||||||
environment:
|
environment:
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
|
||||||
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
||||||
volumes:
|
volumes:
|
||||||
- copilot_tokens:/root/.config/litellm/github_copilot
|
- copilot_tokens:/root/.config/litellm/github_copilot
|
||||||
@@ -21,7 +21,7 @@ services:
|
|||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: postgres
|
POSTGRES_USER: postgres
|
||||||
POSTGRES_PASSWORD: postgres
|
POSTGRES_PASSWORD: postgres
|
||||||
POSTGRES_DB: adiuva
|
POSTGRES_DB: adiuvai
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
@@ -36,37 +36,6 @@ services:
|
|||||||
# image: redis:7-alpine
|
# image: redis:7-alpine
|
||||||
# restart: unless-stopped
|
# restart: unless-stopped
|
||||||
|
|
||||||
# ── Local S3-compatible storage (MinIO) ──
|
|
||||||
minio:
|
|
||||||
image: minio/minio:latest
|
|
||||||
command: server /data --console-address ":9001"
|
|
||||||
ports:
|
|
||||||
- "9000:9000"
|
|
||||||
- "9001:9001"
|
|
||||||
environment:
|
|
||||||
MINIO_ROOT_USER: minioadmin
|
|
||||||
MINIO_ROOT_PASSWORD: minioadmin
|
|
||||||
volumes:
|
|
||||||
- minio_data:/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "mc", "ready", "local"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
# ── Local vector store (Qdrant) ──
|
|
||||||
qdrant:
|
|
||||||
image: qdrant/qdrant:latest
|
|
||||||
ports:
|
|
||||||
- "6333:6333"
|
|
||||||
- "6334:6334"
|
|
||||||
volumes:
|
|
||||||
- qdrant_data:/qdrant/storage
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
minio_data:
|
|
||||||
qdrant_data:
|
|
||||||
copilot_tokens:
|
copilot_tokens:
|
||||||
|
|||||||
@@ -1,941 +0,0 @@
|
|||||||
# Adiuva — Architettura Microservizi (MVP)
|
|
||||||
|
|
||||||
## Panoramica
|
|
||||||
|
|
||||||
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
|
|
||||||
|
|
||||||
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐
|
|
||||||
│ Cloudflare │
|
|
||||||
│ (DNS + CDN) │
|
|
||||||
└──────┬───────┘
|
|
||||||
│ HTTPS / WSS
|
|
||||||
┌──────▼───────┐
|
|
||||||
│ Traefik │
|
|
||||||
│ API Gateway │
|
|
||||||
│ (routing, │
|
|
||||||
│ TLS, rate │
|
|
||||||
│ limiting) │
|
|
||||||
└──────┬───────┘
|
|
||||||
│
|
|
||||||
┌──────────┬───────────┼───────────┐
|
|
||||||
│ │ │ │
|
|
||||||
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
|
|
||||||
│ Auth │ │ Chat │ │ Agent │ │Billing │
|
|
||||||
│ Service │ │Service│ │ Service │ │Service │
|
|
||||||
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
|
|
||||||
│ │ │ │
|
|
||||||
┌─────▼──────────▼──────────▼───────────▼────┐
|
|
||||||
│ Infrastruttura │
|
|
||||||
│ PostgreSQL │ Redis │ Qdrant │
|
|
||||||
└─────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. Suddivisione dei Servizi
|
|
||||||
|
|
||||||
### 1.1 Auth Service (`auth-service`)
|
|
||||||
|
|
||||||
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
|
|
||||||
|
|
||||||
| Endpoint originale | Metodo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/auth/register` | POST |
|
|
||||||
| `/api/v1/auth/login` | POST |
|
|
||||||
| `/api/v1/auth/refresh` | POST |
|
|
||||||
| `/api/v1/auth/me` | GET / PUT |
|
|
||||||
|
|
||||||
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
|
|
||||||
|
|
||||||
**Modifica chiave — JWT con RS256**:
|
|
||||||
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
|
|
||||||
- L'Auth Service firma i JWT con la **chiave privata**.
|
|
||||||
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
|
|
||||||
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# auth-service/app/auth/jwt.py
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PRIVATE_KEY = ... # Da env/secret
|
|
||||||
PUBLIC_KEY = ... # Derivata o da env
|
|
||||||
|
|
||||||
def create_access_token(user_id: str, tier: str) -> str:
|
|
||||||
return jwt.encode(
|
|
||||||
{"sub": user_id, "tier": tier, "exp": ...},
|
|
||||||
PRIVATE_KEY,
|
|
||||||
algorithm="RS256",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/auth.py (usato da tutti gli altri servizi)
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
|
|
||||||
|
|
||||||
def verify_token(token: str) -> dict:
|
|
||||||
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
|
||||||
```
|
|
||||||
|
|
||||||
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
|
|
||||||
|
|
||||||
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
|
|
||||||
|
|
||||||
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
|
|
||||||
|
|
||||||
| Endpoint | Tipo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
|
|
||||||
| `/api/v1/chat` | POST (REST fallback) |
|
|
||||||
|
|
||||||
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
|
|
||||||
|
|
||||||
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
|
|
||||||
|
|
||||||
**Scaling**: 2–N repliche. Sticky cookies per le WS + Redis per cross-instance.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.3 Agent Service (`agent-service`) ⭐ Batch
|
|
||||||
|
|
||||||
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
|
|
||||||
|
|
||||||
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
|
|
||||||
|
|
||||||
| Endpoint | Tipo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/agents/catalog` | GET |
|
|
||||||
| `/api/v1/agents/can-create` | POST |
|
|
||||||
| `/api/v1/agents/trigger` | POST |
|
|
||||||
| `/api/v1/agents/journey/start` | POST (o WS relay) |
|
|
||||||
| `/api/v1/agents/journey/message` | POST (o WS relay) |
|
|
||||||
|
|
||||||
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
|
|
||||||
|
|
||||||
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐ ┌──────────────┐ ┌──────────┐
|
|
||||||
│ Agent Service│ │ Redis │ │ Chat │
|
|
||||||
│ (batch run) │ │ │ │ Service │
|
|
||||||
│ │ │ │ │ (ha WS) │
|
|
||||||
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
|
|
||||||
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
|
|
||||||
│ from │ │ │ │ al │
|
|
||||||
│ device │ │ │ │ device│
|
|
||||||
│ │ │ │ │ via WS│
|
|
||||||
│ │ SUBSCRIBE │ │ PUBLISH │ │
|
|
||||||
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
|
|
||||||
│ risultato │ │ │ │ reply │
|
|
||||||
└──────────────┘ └──────────────┘ └──────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
**Scaling**: 1–N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
|
|
||||||
|
|
||||||
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.4 Billing Service (`billing-service`)
|
|
||||||
|
|
||||||
**Responsabilità**: Stripe checkout, webhook, subscription management.
|
|
||||||
|
|
||||||
| Endpoint originale | Metodo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/billing/checkout` | POST |
|
|
||||||
| `/api/v1/billing/webhook` | POST |
|
|
||||||
| `/api/v1/billing/subscription` | GET / DELETE |
|
|
||||||
|
|
||||||
**Database**: Tabelle `subscriptions` (schema `billing`).
|
|
||||||
|
|
||||||
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
|
|
||||||
|
|
||||||
**Scaling**: 1 replica sufficiente. Basso traffico.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.5 Servizi esclusi dall'MVP
|
|
||||||
|
|
||||||
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
|
|
||||||
|
|
||||||
| Servizio | Responsabilità | Note |
|
|
||||||
|---|---|---|
|
|
||||||
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
|
|
||||||
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. Tier Check — Dove e Come
|
|
||||||
|
|
||||||
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
|
|
||||||
|
|
||||||
### Strategia: Tier nel JWT
|
|
||||||
|
|
||||||
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"sub": "user_123",
|
|
||||||
"tier": "pro",
|
|
||||||
"exp": 1742515200,
|
|
||||||
"iat": 1742511600
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Ogni servizio:
|
|
||||||
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
|
|
||||||
2. Legge `payload["tier"]` — **zero chiamate extra**
|
|
||||||
3. Applica le sue regole di enforcement localmente
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/auth.py — dependency FastAPI condivisa
|
|
||||||
from fastapi import Depends, HTTPException, Request
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PUBLIC_KEY = ...
|
|
||||||
|
|
||||||
class CurrentUser:
|
|
||||||
def __init__(self, user_id: str, tier: str):
|
|
||||||
self.user_id = user_id
|
|
||||||
self.tier = tier
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> CurrentUser:
|
|
||||||
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
|
|
||||||
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
|
||||||
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
|
|
||||||
|
|
||||||
def require_tier(*allowed_tiers: str):
|
|
||||||
"""Dependency che blocca se il tier non è tra quelli ammessi."""
|
|
||||||
async def check(user: CurrentUser = Depends(get_current_user)):
|
|
||||||
if user.tier not in allowed_tiers:
|
|
||||||
raise HTTPException(403, "Tier insufficient")
|
|
||||||
return user
|
|
||||||
return check
|
|
||||||
```
|
|
||||||
|
|
||||||
### Cosa succede quando il tier cambia (upgrade/downgrade)?
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
|
|
||||||
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
|
|
||||||
│ │ │ Service │ (Redis pub/sub) │ Service │
|
|
||||||
└──────────┘ └──────────┘ └────┬─────┘
|
|
||||||
│
|
|
||||||
UPDATE users
|
|
||||||
SET tier = 'power'
|
|
||||||
│
|
|
||||||
Al prossimo /refresh
|
|
||||||
il JWT conterrà tier='power'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 15–30 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
|
|
||||||
|
|
||||||
### Dove si applica in ciascun servizio
|
|
||||||
|
|
||||||
| Servizio | Enforcement |
|
|
||||||
|---|---|
|
|
||||||
| **Auth Service** | Nessuno (è lui che scrive il tier) |
|
|
||||||
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
|
|
||||||
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
|
|
||||||
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
|
|
||||||
|
|
||||||
### Rate-limit distribuito via Redis
|
|
||||||
|
|
||||||
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/middleware/rate_limit.py
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
|
|
||||||
class DistributedRateLimiter:
|
|
||||||
def __init__(self, redis: aioredis.Redis):
|
|
||||||
self._redis = redis
|
|
||||||
|
|
||||||
async def check(self, user_id: str, tier: str, service: str) -> bool:
|
|
||||||
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
|
|
||||||
max_req = limits.get(tier, 20)
|
|
||||||
key = f"rate:{service}:{user_id}"
|
|
||||||
|
|
||||||
pipe = self._redis.pipeline()
|
|
||||||
pipe.incr(key)
|
|
||||||
pipe.expire(key, 60)
|
|
||||||
count, _ = await pipe.execute()
|
|
||||||
|
|
||||||
return count <= max_req
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
|
|
||||||
|
|
||||||
`DeviceConnectionManager` è un **singleton in-memory**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class DeviceConnectionManager:
|
|
||||||
def __init__(self):
|
|
||||||
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
|
|
||||||
```
|
|
||||||
|
|
||||||
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
|
|
||||||
|
|
||||||
### La soluzione: Redis Pub/Sub + Registry
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────────────────────────────────────────────────────┐
|
|
||||||
│ Redis │
|
|
||||||
│ │
|
|
||||||
│ Hash: ws:connections │
|
|
||||||
│ user_123 → instance_A │
|
|
||||||
│ user_456 → instance_B │
|
|
||||||
│ │
|
|
||||||
│ Pub/Sub channels: │
|
|
||||||
│ tool_call:{user_id} → tool call payloads │
|
|
||||||
│ tool_result:{call_id} → tool result payloads │
|
|
||||||
│ stream:{user_id} → text_chunk streaming │
|
|
||||||
└──────────────────────────────────────────────────────────────┘
|
|
||||||
|
|
||||||
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
|
|
||||||
┌───────────────────────┐ ┌───────────────────────┐
|
|
||||||
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
|
|
||||||
│ tool_call:user_123│ │ → user_123 è su A │
|
|
||||||
│ │ │ │
|
|
||||||
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
|
|
||||||
│ da Redis channel │ │ tool_call:user_123 │
|
|
||||||
│ │ │ {id, action, ...} │
|
|
||||||
│ 3. Invia al device │ │ │
|
|
||||||
│ via WS │ │ 4. SUBSCRIBE │
|
|
||||||
│ │ │ tool_result:{id} │
|
|
||||||
│ 4. Device risponde │ │ │
|
|
||||||
│ tool_result │──────────│► 5. Riceve risultato │
|
|
||||||
│ │ │ │
|
|
||||||
│ 5. PUBLISH │ │ │
|
|
||||||
│ tool_result:{id} │ │ │
|
|
||||||
└───────────────────────┘ └───────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### Implementazione: `RedisDeviceManager`
|
|
||||||
|
|
||||||
```python
|
|
||||||
# chat-service/app/core/device_manager.py
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from fastapi import WebSocket
|
|
||||||
|
|
||||||
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LocalConnection:
|
|
||||||
ws: WebSocket
|
|
||||||
device_id: str
|
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisDeviceManager:
|
|
||||||
"""Device manager backed by Redis for cross-instance communication."""
|
|
||||||
|
|
||||||
def __init__(self, redis_url: str = "redis://redis:6379"):
|
|
||||||
self._redis = aioredis.from_url(redis_url)
|
|
||||||
self._pubsub = self._redis.pubsub()
|
|
||||||
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
|
|
||||||
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""Avvia il listener Redis per tool_call in arrivo."""
|
|
||||||
asyncio.create_task(self._listen_tool_calls())
|
|
||||||
|
|
||||||
# ── Registrazione ──
|
|
||||||
|
|
||||||
async def register(self, user_id: str, device_id: str, ws: WebSocket):
|
|
||||||
# Registra localmente
|
|
||||||
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
|
|
||||||
# Registra in Redis quale istanza ha la connessione
|
|
||||||
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
|
|
||||||
# Sottoscrivi ai tool_call per questo utente
|
|
||||||
await self._pubsub.subscribe(f"tool_call:{user_id}")
|
|
||||||
|
|
||||||
async def unregister(self, user_id: str):
|
|
||||||
conn = self._local.pop(user_id, None)
|
|
||||||
if conn:
|
|
||||||
for fut in conn.pending_calls.values():
|
|
||||||
if not fut.done():
|
|
||||||
fut.cancel()
|
|
||||||
await self._redis.hdel("ws:connections", user_id)
|
|
||||||
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
|
|
||||||
|
|
||||||
# ── Presenza ──
|
|
||||||
|
|
||||||
async def is_online(self, user_id: str) -> bool:
|
|
||||||
return await self._redis.hexists("ws:connections", user_id)
|
|
||||||
|
|
||||||
# ── Tool-call round-trip (cross-instance) ──
|
|
||||||
|
|
||||||
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Invia un tool_call al device dell'utente.
|
|
||||||
Funziona sia che la WS sia locale che su un'altra istanza.
|
|
||||||
"""
|
|
||||||
call_id = payload["id"]
|
|
||||||
|
|
||||||
# Caso 1: connessione locale → invio diretto
|
|
||||||
if user_id in self._local:
|
|
||||||
conn = self._local[user_id]
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
fut: asyncio.Future[dict] = loop.create_future()
|
|
||||||
conn.pending_calls[call_id] = fut
|
|
||||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
|
|
||||||
# Caso 2: connessione remota → Redis pub/sub
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
fut = loop.create_future()
|
|
||||||
self._remote_futures[call_id] = fut
|
|
||||||
|
|
||||||
# Sottoscrivi al canale di risposta
|
|
||||||
result_channel = f"tool_result:{call_id}"
|
|
||||||
await self._pubsub.subscribe(result_channel)
|
|
||||||
|
|
||||||
# Pubblica il tool_call
|
|
||||||
await self._redis.publish(
|
|
||||||
f"tool_call:{user_id}",
|
|
||||||
json.dumps(payload),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
finally:
|
|
||||||
self._remote_futures.pop(call_id, None)
|
|
||||||
await self._pubsub.unsubscribe(result_channel)
|
|
||||||
|
|
||||||
# ── Risoluzione tool_result (da WS locale) ──
|
|
||||||
|
|
||||||
def resolve_local(self, user_id: str, call_id: str, result: dict):
|
|
||||||
conn = self._local.get(user_id)
|
|
||||||
if conn:
|
|
||||||
fut = conn.pending_calls.pop(call_id, None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(result)
|
|
||||||
|
|
||||||
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
|
|
||||||
"""Chiamato quando il device locale invia un tool_result."""
|
|
||||||
self.resolve_local(user_id, call_id, result)
|
|
||||||
# Pubblica anche su Redis per l'istanza remota che aspetta
|
|
||||||
await self._redis.publish(
|
|
||||||
f"tool_result:{call_id}",
|
|
||||||
json.dumps(result),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Listener Redis ──
|
|
||||||
|
|
||||||
async def _listen_tool_calls(self):
|
|
||||||
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
|
|
||||||
async for message in self._pubsub.listen():
|
|
||||||
if message["type"] != "message":
|
|
||||||
continue
|
|
||||||
channel = message["channel"]
|
|
||||||
if isinstance(channel, bytes):
|
|
||||||
channel = channel.decode()
|
|
||||||
|
|
||||||
data = json.loads(message["data"])
|
|
||||||
|
|
||||||
if channel.startswith("tool_call:"):
|
|
||||||
# Un'altra istanza vuole che inviamo un tool_call al nostro device
|
|
||||||
user_id = channel.split(":", 1)[1]
|
|
||||||
conn = self._local.get(user_id)
|
|
||||||
if conn:
|
|
||||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
|
|
||||||
|
|
||||||
elif channel.startswith("tool_result:"):
|
|
||||||
# Risposta a un tool_call che abbiamo inviato tramite Redis
|
|
||||||
call_id = channel.split(":", 1)[1]
|
|
||||||
fut = self._remote_futures.pop(call_id, None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(data)
|
|
||||||
|
|
||||||
# ── Stream cross-instance ──
|
|
||||||
|
|
||||||
async def publish_stream_chunk(self, user_id: str, chunk: dict):
|
|
||||||
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
|
|
||||||
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. Struttura Directory Proposta (MVP)
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── docker-compose.yml # Orchestrazione completa
|
|
||||||
├── docker-compose.dev.yml # Override per sviluppo locale
|
|
||||||
├── shared/ # Codice condiviso (montato come volume)
|
|
||||||
│ ├── auth.py # JWT verification (chiave pubblica)
|
|
||||||
│ ├── schemas.py # Pydantic schemas condivisi
|
|
||||||
│ ├── middleware/
|
|
||||||
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
|
|
||||||
│ │ └── sanitizer.py
|
|
||||||
│ └── models/
|
|
||||||
│ └── base.py # SQLAlchemy base condivisa
|
|
||||||
│
|
|
||||||
├── auth-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # users, refresh_tokens
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ └── auth.py
|
|
||||||
│ └── services/
|
|
||||||
│ ├── jwt_service.py # RS256 signing
|
|
||||||
│ └── user_service.py
|
|
||||||
│
|
|
||||||
├── chat-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # memory_*
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ ├── device_ws.py # WS connection owner
|
|
||||||
│ │ └── chat.py # REST fallback
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── device_manager.py # RedisDeviceManager
|
|
||||||
│ │ ├── deep_agent.py # Home + floating chat
|
|
||||||
│ │ ├── memory_middleware.py
|
|
||||||
│ │ ├── ws_context.py
|
|
||||||
│ │ ├── output_formatter.py
|
|
||||||
│ │ └── llm.py
|
|
||||||
│ └── agents/ # Tool definitions (used by deep_agent)
|
|
||||||
│ ├── task_agent.py
|
|
||||||
│ ├── project_agent.py
|
|
||||||
│ ├── note_agent.py
|
|
||||||
│ └── timeline_agent.py
|
|
||||||
│
|
|
||||||
├── agent-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ ├── agents.py # catalog, can-create, trigger
|
|
||||||
│ │ └── agent_setup.py # journey start/message
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── agent_runner.py # Batch classify → process
|
|
||||||
│ │ ├── agent_registry.py
|
|
||||||
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
|
|
||||||
│ │ └── llm.py
|
|
||||||
│ └── agents/
|
|
||||||
│ ├── task_agent.py # Tool definitions (batch context)
|
|
||||||
│ ├── project_agent.py
|
|
||||||
│ ├── note_agent.py
|
|
||||||
│ ├── timeline_agent.py
|
|
||||||
│ └── filesystem_agent.py
|
|
||||||
│
|
|
||||||
├── billing-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # subscriptions
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ └── billing.py
|
|
||||||
│ └── services/
|
|
||||||
│ ├── stripe_service.py
|
|
||||||
│ └── tier_manager.py
|
|
||||||
│
|
|
||||||
└── infra/
|
|
||||||
├── traefik/
|
|
||||||
│ └── traefik.yml
|
|
||||||
├── keys/
|
|
||||||
│ ├── jwt_private.pem # Solo auth-service
|
|
||||||
│ └── jwt_public.pem # Tutti i servizi
|
|
||||||
└── alembic/ # Migrazioni condivise o per-servizio
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. Docker Compose — Configurazione MVP
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# docker-compose.yml
|
|
||||||
|
|
||||||
services:
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# API Gateway
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
traefik:
|
|
||||||
image: traefik:v3.2
|
|
||||||
command:
|
|
||||||
- "--api.insecure=true"
|
|
||||||
- "--providers.docker=true"
|
|
||||||
- "--providers.docker.exposedbydefault=false"
|
|
||||||
- "--entrypoints.web.address=:80"
|
|
||||||
- "--entrypoints.websecure.address=:443"
|
|
||||||
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
|
||||||
ports:
|
|
||||||
- "80:80"
|
|
||||||
- "443:443"
|
|
||||||
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
|
|
||||||
volumes:
|
|
||||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
|
||||||
- ./infra/certs:/certs:ro
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Auth Service (2 repliche)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
auth-service:
|
|
||||||
build: ./auth-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
|
|
||||||
SERVICE_NAME: auth
|
|
||||||
secrets:
|
|
||||||
- jwt_private_key
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
|
|
||||||
- "traefik.http.services.auth.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Chat Service — Real-time WS + Chat (scalabile)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
chat-service:
|
|
||||||
build: ./chat-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: chat
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
# REST chat endpoint
|
|
||||||
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
|
|
||||||
- "traefik.http.services.chat.loadbalancer.server.port=8000"
|
|
||||||
# WebSocket route con sticky session
|
|
||||||
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
|
|
||||||
- "traefik.http.routers.ws.service=chat-ws"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Agent Service — Batch processing (scalabile indipendentemente)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
agent-service:
|
|
||||||
build: ./agent-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: agent
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
|
|
||||||
- "traefik.http.services.agents.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Billing Service (1 replica)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
billing-service:
|
|
||||||
build: ./billing-service
|
|
||||||
deploy:
|
|
||||||
replicas: 1
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: billing
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
|
|
||||||
- "traefik.http.services.billing.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Infrastruttura
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
db:
|
|
||||||
image: pgvector/pgvector:pg16
|
|
||||||
environment:
|
|
||||||
POSTGRES_USER: postgres
|
|
||||||
POSTGRES_PASSWORD: postgres
|
|
||||||
POSTGRES_DB: adiuva
|
|
||||||
volumes:
|
|
||||||
- postgres_data:/var/lib/postgresql/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
redis:
|
|
||||||
image: redis:7-alpine
|
|
||||||
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
|
||||||
volumes:
|
|
||||||
- redis_data:/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 3s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
qdrant:
|
|
||||||
image: qdrant/qdrant:latest
|
|
||||||
volumes:
|
|
||||||
- qdrant_data:/qdrant/storage
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
secrets:
|
|
||||||
jwt_private_key:
|
|
||||||
file: ./infra/keys/jwt_private.pem
|
|
||||||
jwt_public_key:
|
|
||||||
file: ./infra/keys/jwt_public.pem
|
|
||||||
|
|
||||||
volumes:
|
|
||||||
postgres_data:
|
|
||||||
redis_data:
|
|
||||||
qdrant_data:
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. Configurazione Cloudflare + VPS
|
|
||||||
|
|
||||||
### 6.1 DNS
|
|
||||||
|
|
||||||
```
|
|
||||||
api.tuodominio.com → A record → IP del VPS
|
|
||||||
→ Proxy: ON (orange cloud)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.2 Cloudflare Settings
|
|
||||||
|
|
||||||
| Setting | Valore | Motivo |
|
|
||||||
|---------|--------|--------|
|
|
||||||
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
|
|
||||||
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
|
|
||||||
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
|
|
||||||
| Under Attack Mode | Off (attivare se necessario) | |
|
|
||||||
|
|
||||||
### 6.3 TLS sul VPS
|
|
||||||
|
|
||||||
Due opzioni:
|
|
||||||
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
|
|
||||||
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# traefik.yml — con Cloudflare Origin Certificate
|
|
||||||
entryPoints:
|
|
||||||
websecure:
|
|
||||||
address: ":443"
|
|
||||||
|
|
||||||
tls:
|
|
||||||
certificates:
|
|
||||||
- certFile: /certs/origin.pem
|
|
||||||
keyFile: /certs/origin-key.pem
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.4 Rete VPS
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
|
|
||||||
# https://www.cloudflare.com/ips/
|
|
||||||
ufw default deny incoming
|
|
||||||
ufw allow from 173.245.48.0/20 to any port 443
|
|
||||||
ufw allow from 103.21.244.0/22 to any port 443
|
|
||||||
# ... (tutti gli IP range di Cloudflare)
|
|
||||||
ufw allow ssh
|
|
||||||
ufw enable
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. Comunicazione Inter-Servizio
|
|
||||||
|
|
||||||
### 7.1 Redis Pub/Sub — Event Bus
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ tier_changed:user_123 ┌──────────┐
|
|
||||||
│ Billing │ ────────────────────────► │ Auth │
|
|
||||||
│ Service │ │ Service │
|
|
||||||
└──────────┘ └──────────┘
|
|
||||||
|
|
||||||
┌──────────┐ tool_call:user_123 ┌──────────┐
|
|
||||||
│ Agent │ ────────────────────────► │ Chat │
|
|
||||||
│ Service │ │ Service │
|
|
||||||
│ (batch) │ ◄────────────────────────│ (ha WS) │
|
|
||||||
└──────────┘ tool_result:{call_id} └──────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 7.2 Health Checks e Service Discovery
|
|
||||||
|
|
||||||
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
|
|
||||||
- **Redis pub/sub** (tool-call cross-instance, tier events)
|
|
||||||
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
|
|
||||||
- **PostgreSQL** (dati persistenti condivisi)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. Piano di Migrazione Incrementale (MVP)
|
|
||||||
|
|
||||||
### Fase 1 — Preparazione (nel monolite attuale)
|
|
||||||
1. Aggiungere Redis al `docker-compose.yml` attuale
|
|
||||||
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
|
|
||||||
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
|
|
||||||
4. Estrarre `shared/` con auth verification, schemas, middleware
|
|
||||||
|
|
||||||
### Fase 2 — Auth Service (primo split)
|
|
||||||
1. Estrarre `auth.py` routes + models in `auth-service/`
|
|
||||||
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
|
|
||||||
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
|
|
||||||
4. Il monolite continua a servire tutto il resto
|
|
||||||
|
|
||||||
### Fase 3 — Billing Service
|
|
||||||
1. Estrarre billing routes, Stripe service, tier manager
|
|
||||||
2. Configurare Redis pub/sub per `tier_changed` events
|
|
||||||
3. Routare via Traefik
|
|
||||||
|
|
||||||
### Fase 4 — Split Chat + Agent (il più delicato)
|
|
||||||
1. Il monolite residuo contiene WS + chat + agents
|
|
||||||
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
|
|
||||||
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
|
|
||||||
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
|
|
||||||
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
|
|
||||||
|
|
||||||
### Fase 5 — Scaling test
|
|
||||||
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
|
|
||||||
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
|
|
||||||
3. Monitoring (Prometheus + Grafana) per ogni servizio
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 9. Monitoraggio e Logging
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Aggiungere al docker-compose.yml
|
|
||||||
|
|
||||||
prometheus:
|
|
||||||
image: prom/prometheus:latest
|
|
||||||
volumes:
|
|
||||||
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
grafana:
|
|
||||||
image: grafana/grafana:latest
|
|
||||||
ports:
|
|
||||||
- "3000:3000"
|
|
||||||
volumes:
|
|
||||||
- grafana_data:/var/lib/grafana
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
loki:
|
|
||||||
image: grafana/loki:latest
|
|
||||||
restart: unless-stopped
|
|
||||||
```
|
|
||||||
|
|
||||||
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 10. Sizing VPS Minimo Consigliato (MVP)
|
|
||||||
|
|
||||||
| Componente | CPU | RAM | Note |
|
|
||||||
|---|---|---|---|
|
|
||||||
| Traefik | 0.25 | 128MB | |
|
|
||||||
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
|
|
||||||
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
|
|
||||||
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
|
|
||||||
| Billing Service | 0.25 | 128MB | |
|
|
||||||
| PostgreSQL | 1.0 | 1GB | |
|
|
||||||
| Redis | 0.25 | 256MB | |
|
|
||||||
| Qdrant | 0.5 | 512MB | |
|
|
||||||
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
|
|
||||||
|
|
||||||
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Riepilogo Architettura MVP
|
|
||||||
|
|
||||||
| Servizio | Repliche | Proprietario di |
|
|
||||||
|---|---|---|
|
|
||||||
| **Traefik** | 1 | Routing, TLS, sticky sessions |
|
|
||||||
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
|
|
||||||
| **Chat Service** | 2–N | WebSocket, home/floating chat, streaming |
|
|
||||||
| **Agent Service** | 2–N | Batch processing, directory scan, agent setup |
|
|
||||||
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
|
|
||||||
|
|
||||||
| Decisione | Scelta | Motivazione |
|
|
||||||
|---|---|---|
|
|
||||||
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
|
|
||||||
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
|
|
||||||
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
|
|
||||||
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
|
|
||||||
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
|
|
||||||
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
|
|
||||||
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
|
|
||||||
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
|
|
||||||
| TLS | Cloudflare Origin Certificate | Zero maintenance |
|
|
||||||
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
|
|
||||||
| Storage / Plugin | Post-MVP | Non critici per il lancio |
|
|
||||||
@@ -32,6 +32,8 @@ google-auth-oauthlib>=1.2.0
|
|||||||
google-auth-httplib2>=0.2.0
|
google-auth-httplib2>=0.2.0
|
||||||
msal>=1.28.0
|
msal>=1.28.0
|
||||||
cryptography>=42.0.0
|
cryptography>=42.0.0
|
||||||
redis>=5.0.0
|
langfuse>=2.0.0
|
||||||
langfuse>=3.0.0
|
beautifulsoup4>=4.12.0
|
||||||
|
lxml>=5.0.0
|
||||||
|
PyYAML>=6.0.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
@@ -1,19 +0,0 @@
|
|||||||
# ── Auth Service ──────────────────────────────────────────────────────────────
|
|
||||||
# This file contains env vars specific to the Auth Service.
|
|
||||||
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
|
|
||||||
# or from docker-compose environment.
|
|
||||||
|
|
||||||
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
|
|
||||||
# Generate keypair:
|
|
||||||
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
|
||||||
# openssl rsa -in private.pem -pubout -out public.pem
|
|
||||||
#
|
|
||||||
# Paste PEM content with literal \n for newlines:
|
|
||||||
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
|
|
||||||
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
|
|
||||||
|
|
||||||
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
|
|
||||||
JWT_PRIVATE_KEY=
|
|
||||||
|
|
||||||
# PUBLIC KEY — used to VERIFY JWTs.
|
|
||||||
JWT_PUBLIC_KEY=
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
# ── builder ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS builder
|
|
||||||
|
|
||||||
WORKDIR /build
|
|
||||||
|
|
||||||
# Install shared + service deps in one layer
|
|
||||||
COPY services/auth/requirements.txt ./requirements.txt
|
|
||||||
RUN pip install --upgrade pip && \
|
|
||||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
|
||||||
|
|
||||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS runtime
|
|
||||||
|
|
||||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY --from=builder /install /usr/local
|
|
||||||
|
|
||||||
# Copy shared module (available to all services)
|
|
||||||
COPY shared/ shared/
|
|
||||||
|
|
||||||
# Copy service source
|
|
||||||
COPY services/auth/app/ app/
|
|
||||||
|
|
||||||
RUN chown -R appuser:appgroup /app
|
|
||||||
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
CMD ["gunicorn", "app.main:app", \
|
|
||||||
"-k", "uvicorn.workers.UvicornWorker", \
|
|
||||||
"--bind", "0.0.0.0:8000", \
|
|
||||||
"--workers", "2", \
|
|
||||||
"--timeout", "30"]
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
# Auth Service
|
|
||||||
|
|
||||||
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
|
|
||||||
|
|
||||||
## Tables owned
|
|
||||||
- `users`
|
|
||||||
- `refresh_tokens`
|
|
||||||
- `subscriptions` (read; Billing Service writes)
|
|
||||||
|
|
||||||
## Endpoints
|
|
||||||
- `POST /auth/register`
|
|
||||||
- `POST /auth/login`
|
|
||||||
- `POST /auth/refresh`
|
|
||||||
- `GET /auth/me`
|
|
||||||
- `PUT /auth/me`
|
|
||||||
- `GET /auth/verify` (ForwardAuth for Traefik)
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
"""Auth Service — local configuration.
|
|
||||||
|
|
||||||
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
|
|
||||||
These are NOT in shared/config.py to prevent other services from accessing them.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pydantic import field_validator
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
||||||
|
|
||||||
|
|
||||||
class AuthSettings(BaseSettings):
|
|
||||||
# RS256 private key (PEM format). Used to SIGN JWTs.
|
|
||||||
# Only the Auth Service has this. Generate with:
|
|
||||||
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
|
||||||
# Then set the env var (newlines as \n):
|
|
||||||
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
|
|
||||||
JWT_PRIVATE_KEY: str = ""
|
|
||||||
|
|
||||||
# RS256 public key (PEM format). Used to VERIFY JWTs.
|
|
||||||
# Derived from the private key:
|
|
||||||
# openssl rsa -in private.pem -pubout -out public.pem
|
|
||||||
JWT_PUBLIC_KEY: str = ""
|
|
||||||
|
|
||||||
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _expand_pem_newlines(cls, v: str) -> str:
|
|
||||||
if isinstance(v, str) and r"\n" in v:
|
|
||||||
return v.replace(r"\n", "\n")
|
|
||||||
return v
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
auth_settings = AuthSettings()
|
|
||||||
@@ -1,69 +0,0 @@
|
|||||||
"""Auth dependencies — JWT validation for the Auth Service.
|
|
||||||
|
|
||||||
This is the canonical get_current_user used by protected endpoints
|
|
||||||
within the Auth Service itself (/me, /me PUT).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import Depends, HTTPException, status
|
|
||||||
from fastapi.security import OAuth2PasswordBearer
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
from shared.db import get_session
|
|
||||||
from shared.models import Subscription, User
|
|
||||||
from shared.schemas import UserProfile
|
|
||||||
|
|
||||||
from app.config import auth_settings
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
|
||||||
|
|
||||||
|
|
||||||
async def get_current_user(
|
|
||||||
token: str = Depends(oauth2_scheme),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> UserProfile:
|
|
||||||
"""Validate a Bearer JWT and return the authenticated user.
|
|
||||||
|
|
||||||
The JWT is used for identity and expiry. Tier is fetched live from the
|
|
||||||
subscriptions table so upgrades/downgrades take effect immediately.
|
|
||||||
"""
|
|
||||||
credentials_exc = HTTPException(
|
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
|
||||||
detail="Could not validate credentials",
|
|
||||||
headers={"WWW-Authenticate": "Bearer"},
|
|
||||||
)
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
|
||||||
)
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
email: str | None = payload.get("email")
|
|
||||||
if not user_id or not email:
|
|
||||||
raise credentials_exc
|
|
||||||
except JWTError:
|
|
||||||
raise credentials_exc
|
|
||||||
|
|
||||||
# Live tier lookup
|
|
||||||
result = await db.execute(
|
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
||||||
)
|
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
# Fetch name/surname
|
|
||||||
user_result = await db.execute(
|
|
||||||
select(User.name, User.surname).where(User.id == user_id)
|
|
||||||
)
|
|
||||||
user_row = user_result.one_or_none()
|
|
||||||
|
|
||||||
return UserProfile(
|
|
||||||
id=user_id,
|
|
||||||
email=email,
|
|
||||||
name=user_row.name if user_row else None,
|
|
||||||
surname=user_row.surname if user_row else None,
|
|
||||||
tier=tier,
|
|
||||||
) # type: ignore[arg-type]
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
|
|
||||||
|
|
||||||
Standalone FastAPI service extracted from the adiuva-api monolith.
|
|
||||||
Owns: users, refresh_tokens, subscriptions (read).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Ensure the repo root is on sys.path so "shared" is importable.
|
|
||||||
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
|
|
||||||
# In local dev, we need to add the repo root (two levels up from this file).
|
|
||||||
_repo_root = str(Path(__file__).resolve().parents[3])
|
|
||||||
if _repo_root not in sys.path:
|
|
||||||
sys.path.insert(0, _repo_root)
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
yield
|
|
||||||
from shared.db import engine
|
|
||||||
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
|
||||||
app = FastAPI(
|
|
||||||
title="Adiuva Auth Service",
|
|
||||||
version="0.1.0",
|
|
||||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
|
||||||
redoc_url=None,
|
|
||||||
lifespan=lifespan,
|
|
||||||
)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=settings.CORS_ORIGINS,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
from app.routes import router
|
|
||||||
from app.verify import router as verify_router
|
|
||||||
|
|
||||||
app.include_router(router, prefix="/api/v1")
|
|
||||||
app.include_router(verify_router, prefix="/api/v1")
|
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
|
||||||
async def health() -> dict:
|
|
||||||
return {"status": "ok", "service": "auth", "version": app.version}
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
app = create_app()
|
|
||||||
@@ -1,249 +0,0 @@
|
|||||||
"""Auth routes: register, login, refresh, me.
|
|
||||||
|
|
||||||
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import hashlib
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
|
|
||||||
import bcrypt
|
|
||||||
from cryptography.fernet import Fernet
|
|
||||||
from fastapi import APIRouter, Depends, HTTPException, status
|
|
||||||
from jose import jwt
|
|
||||||
from pydantic import BaseModel
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
from shared.db import get_session
|
|
||||||
from shared.models import RefreshToken, Subscription, User
|
|
||||||
from shared.schemas import AuthTokens, UserProfile
|
|
||||||
|
|
||||||
from app.config import auth_settings
|
|
||||||
from app.deps import get_current_user
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helpers ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _hash_password(password: str) -> str:
|
|
||||||
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def _verify_password(password: str, hashed: str) -> bool:
|
|
||||||
return bcrypt.checkpw(password.encode(), hashed.encode())
|
|
||||||
|
|
||||||
|
|
||||||
def _hash_token(plain_token: str) -> str:
|
|
||||||
"""SHA-256 of the plain refresh token string."""
|
|
||||||
return hashlib.sha256(plain_token.encode()).hexdigest()
|
|
||||||
|
|
||||||
|
|
||||||
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
|
||||||
"""Return (RS256-signed JWT, expires_at_ms)."""
|
|
||||||
now = int(time.time())
|
|
||||||
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
|
||||||
payload = {
|
|
||||||
"sub": user_id,
|
|
||||||
"email": email,
|
|
||||||
"tier": tier,
|
|
||||||
"exp": exp,
|
|
||||||
"iat": now,
|
|
||||||
}
|
|
||||||
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
|
|
||||||
return token, exp * 1000 # ms for client
|
|
||||||
|
|
||||||
|
|
||||||
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
|
|
||||||
"""Fetch authoritative tier from subscriptions table."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
||||||
)
|
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
|
||||||
return result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _RegisterRequest(BaseModel):
|
|
||||||
email: str
|
|
||||||
password: str
|
|
||||||
name: str | None = None
|
|
||||||
surname: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class _LoginRequest(BaseModel):
|
|
||||||
email: str
|
|
||||||
password: str
|
|
||||||
|
|
||||||
|
|
||||||
class _RefreshRequest(BaseModel):
|
|
||||||
refresh_token: str
|
|
||||||
|
|
||||||
|
|
||||||
class _UpdateProfileRequest(BaseModel):
|
|
||||||
name: str | None = None
|
|
||||||
surname: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
|
|
||||||
async def register(
|
|
||||||
body: _RegisterRequest,
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> AuthTokens:
|
|
||||||
"""Create a new account and return JWT tokens."""
|
|
||||||
existing = await db.execute(select(User).where(User.email == body.email))
|
|
||||||
if existing.scalar_one_or_none() is not None:
|
|
||||||
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
|
|
||||||
|
|
||||||
user = User(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
email=body.email,
|
|
||||||
name=body.name,
|
|
||||||
surname=body.surname,
|
|
||||||
password_hash=_hash_password(body.password),
|
|
||||||
tier="free",
|
|
||||||
encryption_key=Fernet.generate_key().decode(),
|
|
||||||
)
|
|
||||||
db.add(user)
|
|
||||||
await db.flush()
|
|
||||||
|
|
||||||
plain_token = str(uuid.uuid4())
|
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
||||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
||||||
)
|
|
||||||
rt = RefreshToken(
|
|
||||||
user_id=user.id,
|
|
||||||
token_hash=_hash_token(plain_token),
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
db.add(rt)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
|
||||||
return AuthTokens(
|
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=plain_token,
|
|
||||||
expires_at=expires_at_ms,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/login", response_model=AuthTokens)
|
|
||||||
async def login(
|
|
||||||
body: _LoginRequest,
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> AuthTokens:
|
|
||||||
"""Validate credentials and return JWT tokens."""
|
|
||||||
result = await db.execute(select(User).where(User.email == body.email))
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if user is None or not _verify_password(body.password, user.password_hash):
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
|
||||||
|
|
||||||
# Fetch live tier for the JWT claim
|
|
||||||
tier = await _get_live_tier(db, user.id)
|
|
||||||
|
|
||||||
plain_token = str(uuid.uuid4())
|
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
|
||||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
|
||||||
)
|
|
||||||
rt = RefreshToken(
|
|
||||||
user_id=user.id,
|
|
||||||
token_hash=_hash_token(plain_token),
|
|
||||||
expires_at=expires_at,
|
|
||||||
)
|
|
||||||
db.add(rt)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
|
||||||
return AuthTokens(
|
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=plain_token,
|
|
||||||
expires_at=expires_at_ms,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/refresh", response_model=AuthTokens)
|
|
||||||
async def refresh(
|
|
||||||
body: _RefreshRequest,
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> AuthTokens:
|
|
||||||
"""Rotate a refresh token and return a new token pair."""
|
|
||||||
token_hash = _hash_token(body.refresh_token)
|
|
||||||
result = await db.execute(
|
|
||||||
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
|
|
||||||
)
|
|
||||||
rt = result.scalar_one_or_none()
|
|
||||||
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
|
||||||
|
|
||||||
await db.delete(rt)
|
|
||||||
|
|
||||||
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
|
||||||
user = user_result.scalar_one_or_none()
|
|
||||||
if user is None:
|
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
|
||||||
|
|
||||||
# Fetch live tier for the new JWT
|
|
||||||
tier = await _get_live_tier(db, user.id)
|
|
||||||
|
|
||||||
plain_token = str(uuid.uuid4())
|
|
||||||
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
|
||||||
new_rt = RefreshToken(
|
|
||||||
user_id=user.id,
|
|
||||||
token_hash=_hash_token(plain_token),
|
|
||||||
expires_at=new_expires,
|
|
||||||
)
|
|
||||||
db.add(new_rt)
|
|
||||||
await db.commit()
|
|
||||||
|
|
||||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
|
||||||
return AuthTokens(
|
|
||||||
access_token=access_token,
|
|
||||||
refresh_token=plain_token,
|
|
||||||
expires_at=expires_at_ms,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserProfile)
|
|
||||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
|
||||||
"""Return the profile for the authenticated user."""
|
|
||||||
return current_user
|
|
||||||
|
|
||||||
|
|
||||||
@router.put("/me", response_model=UserProfile)
|
|
||||||
async def update_profile(
|
|
||||||
body: _UpdateProfileRequest,
|
|
||||||
current_user: UserProfile = Depends(get_current_user),
|
|
||||||
db: AsyncSession = Depends(get_session),
|
|
||||||
) -> UserProfile:
|
|
||||||
"""Update the authenticated user's name and surname."""
|
|
||||||
result = await db.execute(select(User).where(User.id == current_user.id))
|
|
||||||
user = result.scalar_one()
|
|
||||||
|
|
||||||
if body.name is not None:
|
|
||||||
user.name = body.name
|
|
||||||
if body.surname is not None:
|
|
||||||
user.surname = body.surname
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(user)
|
|
||||||
|
|
||||||
return UserProfile(
|
|
||||||
id=user.id,
|
|
||||||
email=user.email,
|
|
||||||
name=user.name,
|
|
||||||
surname=user.surname,
|
|
||||||
tier=current_user.tier,
|
|
||||||
)
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
"""ForwardAuth verification endpoint for Traefik.
|
|
||||||
|
|
||||||
Traefik calls GET /api/v1/auth/verify on every request to a protected
|
|
||||||
service. This endpoint validates the JWT from the Authorization header
|
|
||||||
and returns identity headers that Traefik injects into downstream requests.
|
|
||||||
|
|
||||||
Downstream services NEVER validate JWTs themselves — they trust the
|
|
||||||
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Request, Response
|
|
||||||
from fastapi import status as http_status
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
from shared.db import async_session
|
|
||||||
from shared.models import Subscription
|
|
||||||
|
|
||||||
from app.config import auth_settings
|
|
||||||
|
|
||||||
router = APIRouter(tags=["auth"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/auth/verify")
|
|
||||||
async def verify(request: Request) -> Response:
|
|
||||||
"""Validate JWT and return identity headers for Traefik ForwardAuth.
|
|
||||||
|
|
||||||
Returns 200 with X-User-* headers on success, 401 on failure.
|
|
||||||
Traefik copies response headers to the downstream request.
|
|
||||||
"""
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
if not auth_header.startswith("Bearer "):
|
|
||||||
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
|
||||||
|
|
||||||
token = auth_header[7:] # strip "Bearer "
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
|
||||||
)
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
email: str | None = payload.get("email")
|
|
||||||
if not user_id or not email:
|
|
||||||
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
|
||||||
except JWTError:
|
|
||||||
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
|
||||||
|
|
||||||
# Live tier lookup from subscriptions table
|
|
||||||
async with async_session() as db:
|
|
||||||
result = await db.execute(
|
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
||||||
)
|
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
status_code=http_status.HTTP_200_OK,
|
|
||||||
headers={
|
|
||||||
"X-User-Id": user_id,
|
|
||||||
"X-User-Email": email,
|
|
||||||
"X-User-Tier": tier,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
gunicorn>=22.0.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
pydantic-settings>=2.7.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
bcrypt>=4.2.0
|
|
||||||
cryptography>=42.0.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
# Batch Agent Service
|
|
||||||
|
|
||||||
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
|
|
||||||
|
|
||||||
## Tables owned
|
|
||||||
- `local_agent_configs`
|
|
||||||
- `cloud_agent_configs`
|
|
||||||
- `agent_run_logs`
|
|
||||||
|
|
||||||
## Endpoints
|
|
||||||
- `GET /agents/catalog`
|
|
||||||
- `POST /agents/can-create`
|
|
||||||
- `POST /agents/trigger`
|
|
||||||
- `GET /agents/{id}/history`
|
|
||||||
|
|
||||||
## Redis channels
|
|
||||||
- Subscribe: `batch:request:{user_id}`
|
|
||||||
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
|
|
||||||
- BRPOP: `tool:result:{call_id}` (30s timeout)
|
|
||||||
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
|
|
||||||
|
|
||||||
## TODO
|
|
||||||
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
# Billing Service
|
|
||||||
|
|
||||||
Owns: Stripe integration, tier management, subscription CRUD.
|
|
||||||
|
|
||||||
## Tables owned (write)
|
|
||||||
- `subscriptions`
|
|
||||||
|
|
||||||
## Endpoints
|
|
||||||
- `POST /billing/checkout`
|
|
||||||
- `POST /billing/webhook` (Stripe, no JWT auth)
|
|
||||||
- `GET /billing/subscription`
|
|
||||||
- `DELETE /billing/subscription`
|
|
||||||
|
|
||||||
## Redis channels
|
|
||||||
- Publish: `tier:changed:{user_id}` on tier change
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
# ── builder ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS builder
|
|
||||||
|
|
||||||
WORKDIR /build
|
|
||||||
|
|
||||||
COPY services/chat/requirements.txt ./requirements.txt
|
|
||||||
RUN pip install --upgrade pip && \
|
|
||||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
|
||||||
|
|
||||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS runtime
|
|
||||||
|
|
||||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY --from=builder /install /usr/local
|
|
||||||
|
|
||||||
# Shared module
|
|
||||||
COPY shared/ shared/
|
|
||||||
|
|
||||||
# Service source
|
|
||||||
COPY services/chat/app/ app/
|
|
||||||
|
|
||||||
RUN chown -R appuser:appgroup /app
|
|
||||||
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
# Chat service is CPU-bound (LLM calls) — use multiple workers
|
|
||||||
CMD ["gunicorn", "app.main:app", \
|
|
||||||
"-k", "uvicorn.workers.UvicornWorker", \
|
|
||||||
"--bind", "0.0.0.0:8000", \
|
|
||||||
"--workers", "2", \
|
|
||||||
"--timeout", "120"]
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
# Chat Service
|
|
||||||
|
|
||||||
Owns: deep_agent (home + floating chat), memory middleware, domain agents
|
|
||||||
(task, note, project, timeline), LLM orchestration.
|
|
||||||
|
|
||||||
## Tables owned
|
|
||||||
- `memory_core`
|
|
||||||
- `memory_associative`
|
|
||||||
- `memory_episodic`
|
|
||||||
- `memory_proactive`
|
|
||||||
|
|
||||||
## Tables read (cross-service)
|
|
||||||
- `users` (for encryption_key — memory decryption)
|
|
||||||
|
|
||||||
## Endpoints
|
|
||||||
- `POST /chat` (REST fallback)
|
|
||||||
|
|
||||||
## Redis channels
|
|
||||||
- Subscribe: `chat:request:{user_id}`
|
|
||||||
- Publish: `ws:out:{user_id}` (stream frames + tool calls)
|
|
||||||
- BRPOP: `tool:result:{call_id}` (30s timeout)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Chat Service domain agents."""
|
|
||||||
@@ -1,142 +0,0 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, delete).
|
|
||||||
|
|
||||||
Adapted for Chat Service: import from app.ws_context and app.llm.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.llm import embed
|
|
||||||
from app.ws_context import execute_on_client
|
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
NOTE_SYSTEM_PROMPT = (
|
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_notes(project_id: str = "") -> str:
|
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="select",
|
|
||||||
table="notes",
|
|
||||||
filters={"projectId": normalized_project_id or None},
|
|
||||||
)
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No notes found."
|
|
||||||
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
|
|
||||||
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_note(note_id: str) -> str:
|
|
||||||
"""Fetch a single note by its UUID to read its full Markdown content."""
|
|
||||||
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
|
|
||||||
row = result.get("row")
|
|
||||||
if not row:
|
|
||||||
return f"Note {note_id} not found."
|
|
||||||
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def create_note(
|
|
||||||
title: str,
|
|
||||||
content: str,
|
|
||||||
project_id: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""Create a new note.
|
|
||||||
title: note heading (required)
|
|
||||||
content: Markdown body text (required)
|
|
||||||
project_id: optional UUID linking this note to a project
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="insert",
|
|
||||||
table="notes",
|
|
||||||
data={
|
|
||||||
"title": title,
|
|
||||||
"content": content,
|
|
||||||
"projectId": project_id or None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
# Index the note content in the vector store.
|
|
||||||
vector = await embed(content)
|
|
||||||
await execute_on_client(
|
|
||||||
action="vector_upsert",
|
|
||||||
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
|
|
||||||
vector=vector,
|
|
||||||
)
|
|
||||||
return f"Note created: '{row['title']}' (id: {row['id']})."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def update_note(
|
|
||||||
note_id: str,
|
|
||||||
title: str = "",
|
|
||||||
content: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""Update an existing note. Only pass fields that should change.
|
|
||||||
note_id: UUID of the note (required)
|
|
||||||
If you need to preserve existing content, call get_note first.
|
|
||||||
"""
|
|
||||||
updates: dict[str, Any] = {}
|
|
||||||
if title:
|
|
||||||
updates["title"] = title
|
|
||||||
if content:
|
|
||||||
updates["content"] = content
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="update",
|
|
||||||
table="notes",
|
|
||||||
data={"id": note_id, "updates": updates},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
# Re-index if content changed.
|
|
||||||
if content:
|
|
||||||
vector = await embed(content)
|
|
||||||
await execute_on_client(
|
|
||||||
action="vector_upsert",
|
|
||||||
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
|
|
||||||
vector=vector,
|
|
||||||
)
|
|
||||||
return f"Note updated: '{row['title']}' (id: {row['id']})."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def delete_note(note_id: str) -> str:
|
|
||||||
"""Delete a note permanently by its UUID."""
|
|
||||||
await execute_on_client(action="delete", table="notes", data={"id": note_id})
|
|
||||||
return f"Note {note_id} deleted."
|
|
||||||
|
|
||||||
|
|
||||||
NOTE_TOOLS: list[Any] = [
|
|
||||||
list_notes,
|
|
||||||
get_note,
|
|
||||||
create_note,
|
|
||||||
update_note,
|
|
||||||
delete_note,
|
|
||||||
]
|
|
||||||
@@ -1,146 +0,0 @@
|
|||||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete).
|
|
||||||
|
|
||||||
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.ws_context import execute_on_client
|
|
||||||
|
|
||||||
PROJECT_SYSTEM_PROMPT = (
|
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
|
||||||
"update, and archive projects in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: active, archived\n"
|
|
||||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
|
||||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
|
||||||
" derive it from context data — do not fabricate content\n"
|
|
||||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
|
||||||
" user wants a complete cross-client view including archived projects\n"
|
|
||||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
|
||||||
" list_projects if you only have a project name\n"
|
|
||||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
|
||||||
" only call delete_project when the user explicitly confirms deletion."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_projects(
|
|
||||||
client_id: str = "",
|
|
||||||
include_archived: int = 0,
|
|
||||||
) -> str:
|
|
||||||
"""List projects, optionally filtered by client_id.
|
|
||||||
include_archived: 1 to include archived projects, 0 for active only (default).
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="select",
|
|
||||||
table="projects",
|
|
||||||
filters={
|
|
||||||
"clientId": client_id or None,
|
|
||||||
"includeArchived": bool(include_archived),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No projects found."
|
|
||||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
|
||||||
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_all_projects() -> str:
|
|
||||||
"""List every project regardless of client or status.
|
|
||||||
Use only when the user wants a complete cross-client overview.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(action="select", table="projects")
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No projects found."
|
|
||||||
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
|
|
||||||
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_project(project_id: str) -> str:
|
|
||||||
"""Fetch a single project by its UUID."""
|
|
||||||
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
|
|
||||||
row = result.get("row")
|
|
||||||
if not row:
|
|
||||||
return f"Project {project_id} not found."
|
|
||||||
return (
|
|
||||||
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
|
|
||||||
f"clientId: {row.get('clientId', 'none')})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def create_project(
|
|
||||||
name: str,
|
|
||||||
client_id: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""Create a new project.
|
|
||||||
name: human-readable project name (required)
|
|
||||||
client_id: optional UUID of the owning client
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="insert",
|
|
||||||
table="projects",
|
|
||||||
data={"name": name, "clientId": client_id or None},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
return f"Project created: '{row['name']}' (id: {row['id']})"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def update_project(
|
|
||||||
project_id: str,
|
|
||||||
name: str = "",
|
|
||||||
client_id: str = "",
|
|
||||||
status: str = "",
|
|
||||||
ai_summary: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""Update a project. Only pass fields that should change.
|
|
||||||
project_id: UUID of the project (required)
|
|
||||||
status: active | archived
|
|
||||||
ai_summary: AI-generated summary text (populate only when explicitly requested)
|
|
||||||
"""
|
|
||||||
updates: dict[str, Any] = {}
|
|
||||||
if name:
|
|
||||||
updates["name"] = name
|
|
||||||
if client_id:
|
|
||||||
updates["clientId"] = client_id
|
|
||||||
if status:
|
|
||||||
updates["status"] = status
|
|
||||||
if ai_summary:
|
|
||||||
updates["aiSummary"] = ai_summary
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="update",
|
|
||||||
table="projects",
|
|
||||||
data={"id": project_id, "updates": updates},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def delete_project(project_id: str) -> str:
|
|
||||||
"""Permanently delete a project and orphan its tasks.
|
|
||||||
IMPORTANT: prefer update_project(status='archived') unless the user
|
|
||||||
has explicitly confirmed they want permanent deletion.
|
|
||||||
"""
|
|
||||||
await execute_on_client(action="delete", table="projects", data={"id": project_id})
|
|
||||||
return f"Project {project_id} permanently deleted."
|
|
||||||
|
|
||||||
|
|
||||||
PROJECT_TOOLS: list[Any] = [
|
|
||||||
list_projects,
|
|
||||||
list_all_projects,
|
|
||||||
get_project,
|
|
||||||
create_project,
|
|
||||||
update_project,
|
|
||||||
delete_project,
|
|
||||||
]
|
|
||||||
@@ -1,240 +0,0 @@
|
|||||||
"""Task agent — full CRUD for tasks and task comments.
|
|
||||||
|
|
||||||
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.ws_context import execute_on_client
|
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
TASK_SYSTEM_PROMPT = (
|
|
||||||
"You are a task management assistant for a project workspace.\n"
|
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: todo, in_progress, done\n"
|
|
||||||
" - priority must be one of: high, medium, low\n"
|
|
||||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
|
||||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
|
||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
|
||||||
" did not explicitly request; 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
|
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_tasks(
|
|
||||||
project_id: str = "",
|
|
||||||
status: str = "",
|
|
||||||
search: str = "",
|
|
||||||
order_by: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="select",
|
|
||||||
table="tasks",
|
|
||||||
filters={
|
|
||||||
"projectId": normalized_project_id or None,
|
|
||||||
"status": status or None,
|
|
||||||
"search": search or None,
|
|
||||||
"orderBy": order_by or None,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No tasks found matching the given filters."
|
|
||||||
lines = [
|
|
||||||
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def create_task(
|
|
||||||
title: str,
|
|
||||||
description: str = "",
|
|
||||||
status: str = "todo",
|
|
||||||
priority: str = "medium",
|
|
||||||
assignees: str = "[]",
|
|
||||||
due_date: int = 0,
|
|
||||||
project_id: str = "",
|
|
||||||
is_ai_suggested: int = 0,
|
|
||||||
) -> str:
|
|
||||||
"""Create a new task.
|
|
||||||
title: task title (required)
|
|
||||||
description: optional details
|
|
||||||
status: todo | in_progress | done (default: todo)
|
|
||||||
priority: high | medium | low (default: medium)
|
|
||||||
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
|
|
||||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
|
||||||
project_id: optional UUID of the parent project
|
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="insert",
|
|
||||||
table="tasks",
|
|
||||||
data={
|
|
||||||
"title": title,
|
|
||||||
"description": description or None,
|
|
||||||
"status": status,
|
|
||||||
"priority": priority,
|
|
||||||
"assignee": assignees,
|
|
||||||
"dueDate": due_date or None,
|
|
||||||
"projectId": project_id or None,
|
|
||||||
"isAiSuggested": is_ai_suggested,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
return (
|
|
||||||
f"Task created: '{row['title']}' "
|
|
||||||
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def update_task(
|
|
||||||
task_id: str,
|
|
||||||
title: str = "",
|
|
||||||
description: str = "",
|
|
||||||
status: str = "",
|
|
||||||
priority: str = "",
|
|
||||||
assignees: str = "",
|
|
||||||
due_date: int = -1,
|
|
||||||
project_id: str = "",
|
|
||||||
) -> str:
|
|
||||||
"""Update fields on an existing task. Only pass fields you want to change.
|
|
||||||
task_id: the task's UUID (required)
|
|
||||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
|
||||||
"""
|
|
||||||
updates: dict[str, Any] = {}
|
|
||||||
if title:
|
|
||||||
updates["title"] = title
|
|
||||||
if description:
|
|
||||||
updates["description"] = description
|
|
||||||
if status:
|
|
||||||
updates["status"] = status
|
|
||||||
if priority:
|
|
||||||
updates["priority"] = priority
|
|
||||||
if assignees:
|
|
||||||
updates["assignee"] = assignees
|
|
||||||
if due_date != -1:
|
|
||||||
updates["dueDate"] = due_date or None
|
|
||||||
if project_id:
|
|
||||||
updates["projectId"] = project_id
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="update",
|
|
||||||
table="tasks",
|
|
||||||
data={"id": task_id, "updates": updates},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def delete_task(task_id: str) -> str:
|
|
||||||
"""Delete a task permanently by its UUID."""
|
|
||||||
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
|
|
||||||
return f"Task {task_id} deleted."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_tasks_due_today() -> str:
|
|
||||||
"""List all tasks whose due date falls on today's date."""
|
|
||||||
now = datetime.now(tz=timezone.utc)
|
|
||||||
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
|
|
||||||
end_ms = start_ms + 86_400_000 - 1 # last ms of today
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="select",
|
|
||||||
table="tasks",
|
|
||||||
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
|
|
||||||
)
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No tasks are due today."
|
|
||||||
lines = [
|
|
||||||
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
|
|
||||||
for r in rows
|
|
||||||
]
|
|
||||||
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task comment tools ────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_task_comments(task_id: str) -> str:
|
|
||||||
"""List all comments on a task by its UUID."""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="select",
|
|
||||||
table="taskComments",
|
|
||||||
filters={"taskId": task_id},
|
|
||||||
)
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return f"No comments found for task {task_id}."
|
|
||||||
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
|
|
||||||
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|
||||||
"""Add a comment to a task.
|
|
||||||
task_id: UUID of the task to comment on
|
|
||||||
author: name or ID of the comment author
|
|
||||||
content: comment text
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="insert",
|
|
||||||
table="taskComments",
|
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
|
||||||
)
|
|
||||||
row = result.get("row", {})
|
|
||||||
row_author = row.get("author", author)
|
|
||||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
|
||||||
row_comment_id = row.get("id", "unknown")
|
|
||||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def delete_task_comment(comment_id: str) -> str:
|
|
||||||
"""Delete a task comment by its UUID."""
|
|
||||||
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
|
|
||||||
return f"Comment {comment_id} deleted."
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
TASK_TOOLS: list[Any] = [
|
|
||||||
list_tasks,
|
|
||||||
create_task,
|
|
||||||
update_task,
|
|
||||||
delete_task,
|
|
||||||
list_tasks_due_today,
|
|
||||||
list_task_comments,
|
|
||||||
add_task_comment,
|
|
||||||
delete_task_comment,
|
|
||||||
]
|
|
||||||
@@ -1,117 +0,0 @@
|
|||||||
"""Timeline agent — project milestone management (list, create, update, delete).
|
|
||||||
|
|
||||||
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.ws_context import execute_on_client
|
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
TIMELINE_SYSTEM_PROMPT = (
|
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all timelines across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="select",
|
|
||||||
table="timelines",
|
|
||||||
filters={"projectId": normalized_project_id or None},
|
|
||||||
)
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not rows:
|
|
||||||
return "No timelines found."
|
|
||||||
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
|
|
||||||
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def create_timeline(
|
|
||||||
project_id: str,
|
|
||||||
title: str,
|
|
||||||
date: int,
|
|
||||||
is_ai_suggested: int = 0,
|
|
||||||
) -> str:
|
|
||||||
"""Create a project timeline (milestone).
|
|
||||||
project_id: REQUIRED UUID of the parent project
|
|
||||||
title: descriptive name for the milestone
|
|
||||||
date: Unix timestamp in milliseconds
|
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="insert",
|
|
||||||
table="timelines",
|
|
||||||
data={
|
|
||||||
"projectId": project_id,
|
|
||||||
"title": title,
|
|
||||||
"date": date,
|
|
||||||
"isAiSuggested": is_ai_suggested,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def update_timeline(
|
|
||||||
timeline_id: str,
|
|
||||||
title: str = "",
|
|
||||||
date: int = -1,
|
|
||||||
) -> str:
|
|
||||||
"""Update a timeline. Only pass fields that should change.
|
|
||||||
timeline_id: UUID of the timeline (required)
|
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
|
||||||
"""
|
|
||||||
updates: dict[str, Any] = {}
|
|
||||||
if title:
|
|
||||||
updates["title"] = title
|
|
||||||
if date != -1:
|
|
||||||
updates["date"] = date
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="update",
|
|
||||||
table="timelines",
|
|
||||||
data={"id": timeline_id, "updates": updates},
|
|
||||||
)
|
|
||||||
row = result["row"]
|
|
||||||
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def delete_timeline(timeline_id: str) -> str:
|
|
||||||
"""Delete a timeline permanently by its UUID."""
|
|
||||||
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
|
|
||||||
return f"Timeline {timeline_id} deleted."
|
|
||||||
|
|
||||||
|
|
||||||
TIMELINE_TOOLS: list[Any] = [
|
|
||||||
list_timelines,
|
|
||||||
create_timeline,
|
|
||||||
update_timeline,
|
|
||||||
delete_timeline,
|
|
||||||
]
|
|
||||||
@@ -1,883 +0,0 @@
|
|||||||
"""Single-agent runners for home and floating chat contexts.
|
|
||||||
|
|
||||||
Adapted from app/core/deep_agent.py for the Chat Service.
|
|
||||||
Import paths changed to use local app modules and shared/.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import re
|
|
||||||
from datetime import date
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from app.agents.note_agent import NOTE_TOOLS
|
|
||||||
from app.agents.project_agent import PROJECT_TOOLS
|
|
||||||
from app.agents.task_agent import TASK_TOOLS
|
|
||||||
from app.agents.timeline_agent import TIMELINE_TOOLS
|
|
||||||
from app.llm import get_llm
|
|
||||||
from app.memory_middleware import MemoryMiddleware
|
|
||||||
from app.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
|
|
||||||
from app import tracing
|
|
||||||
from shared.db import async_session
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
FloatingDomainType = Literal["task", "timeline", "project", "node"]
|
|
||||||
FloatingDomainSection = Literal["task", "timeline", "note"]
|
|
||||||
|
|
||||||
_HOME_SINGLE_AGENT_SYSTEM = (
|
|
||||||
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
|
||||||
"Always use tools for factual data retrieval before answering. "
|
|
||||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
|
||||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
||||||
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
|
|
||||||
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
|
|
||||||
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
|
|
||||||
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
|
|
||||||
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
|
|
||||||
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
|
|
||||||
)
|
|
||||||
|
|
||||||
_FLOATING_SINGLE_AGENT_SYSTEM = (
|
|
||||||
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
|
|
||||||
"Stay focused on the floating scope in context.scope and answer concisely. "
|
|
||||||
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
|
|
||||||
"Always use tools for factual data retrieval before answering. "
|
|
||||||
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
|
|
||||||
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
|
|
||||||
)
|
|
||||||
|
|
||||||
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
|
|
||||||
"You are a strict domain classifier for websocket floating requests. "
|
|
||||||
"Return ONLY a JSON object with keys: type, id, section. "
|
|
||||||
"Allowed type values: task, timeline, project, node. "
|
|
||||||
"Allowed section values: task, timeline, note, or null. "
|
|
||||||
"Rules: infer from user message intent first; do not blindly trust scope.type. "
|
|
||||||
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
|
|
||||||
"If project id is unknown but context.resolved_project_id exists, use it as id. "
|
|
||||||
"If id is unknown, use null. "
|
|
||||||
"No markdown, no prose, JSON only."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _as_text(content: Any) -> str:
|
|
||||||
if content is None:
|
|
||||||
return ""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts: list[str] = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, str):
|
|
||||||
parts.append(item)
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
text = item.get("text")
|
|
||||||
if isinstance(text, str):
|
|
||||||
parts.append(text)
|
|
||||||
return "".join(parts)
|
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
def _candidate_tokens(message: str) -> list[str]:
|
|
||||||
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
|
|
||||||
return [token for token in tokens if len(token) >= 3]
|
|
||||||
|
|
||||||
|
|
||||||
async def _resolve_project_id_from_message(message: str) -> str | None:
|
|
||||||
"""Resolve likely project UUID from user message using client project list."""
|
|
||||||
try:
|
|
||||||
result = await execute_on_client(action="select", table="projects")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("deep_agent: project resolve select failed: %s", exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
rows = result.get("rows", [])
|
|
||||||
if not isinstance(rows, list) or not rows:
|
|
||||||
return None
|
|
||||||
|
|
||||||
tokens = _candidate_tokens(message)
|
|
||||||
scored: list[tuple[int, dict[str, Any]]] = []
|
|
||||||
for row in rows:
|
|
||||||
if not isinstance(row, dict):
|
|
||||||
continue
|
|
||||||
name = str(row.get("name", "")).lower()
|
|
||||||
score = sum(1 for token in tokens if token in name)
|
|
||||||
if score > 0:
|
|
||||||
scored.append((score, row))
|
|
||||||
|
|
||||||
if not scored:
|
|
||||||
return None
|
|
||||||
|
|
||||||
scored.sort(key=lambda item: item[0], reverse=True)
|
|
||||||
top_score = scored[0][0]
|
|
||||||
top_rows = [row for score, row in scored if score == top_score]
|
|
||||||
if len(top_rows) != 1:
|
|
||||||
return None
|
|
||||||
|
|
||||||
project_id = top_rows[0].get("id")
|
|
||||||
return project_id if isinstance(project_id, str) else None
|
|
||||||
|
|
||||||
|
|
||||||
def _needs_project_resolution(message: str) -> bool:
|
|
||||||
lowered = message.lower()
|
|
||||||
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
|
|
||||||
|
|
||||||
|
|
||||||
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
prepared = dict(context)
|
|
||||||
if _needs_project_resolution(message):
|
|
||||||
resolved_project_id = await _resolve_project_id_from_message(message)
|
|
||||||
if resolved_project_id:
|
|
||||||
prepared["resolved_project_id"] = resolved_project_id
|
|
||||||
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
|
|
||||||
return prepared
|
|
||||||
|
|
||||||
|
|
||||||
def _all_tools() -> list[Any]:
|
|
||||||
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
|
|
||||||
|
|
||||||
|
|
||||||
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
|
|
||||||
debug = context.get("_debug")
|
|
||||||
if isinstance(debug, dict):
|
|
||||||
request_id = debug.get("request_id")
|
|
||||||
if isinstance(request_id, str) and request_id:
|
|
||||||
return request_id
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
sanitized = dict(context)
|
|
||||||
sanitized.pop("_debug", None)
|
|
||||||
return sanitized
|
|
||||||
|
|
||||||
|
|
||||||
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
|
|
||||||
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
|
|
||||||
|
|
||||||
|
|
||||||
def _is_upcoming_timeline_query(message: str) -> bool:
|
|
||||||
lowered = message.lower()
|
|
||||||
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
|
|
||||||
has_timeline_topic = any(
|
|
||||||
token in lowered
|
|
||||||
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
|
|
||||||
)
|
|
||||||
return has_upcoming and has_timeline_topic
|
|
||||||
|
|
||||||
|
|
||||||
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
|
|
||||||
match = _TIMELINE_DMY_RE.search(dmy)
|
|
||||||
if not match:
|
|
||||||
return True
|
|
||||||
try:
|
|
||||||
parsed = date(
|
|
||||||
int(match.group("y")),
|
|
||||||
int(match.group("m")),
|
|
||||||
int(match.group("d")),
|
|
||||||
)
|
|
||||||
except ValueError:
|
|
||||||
return True
|
|
||||||
|
|
||||||
today = date.today()
|
|
||||||
return parsed >= today and parsed.year == today.year and parsed.month == today.month
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_tagged_list_lines(text: str, message: str) -> str:
|
|
||||||
if not text:
|
|
||||||
return text
|
|
||||||
|
|
||||||
upcoming_timeline_only = _is_upcoming_timeline_query(message)
|
|
||||||
output_lines: list[str] = []
|
|
||||||
|
|
||||||
for line in text.splitlines():
|
|
||||||
matches = list(_TAG_LINE_RE.finditer(line))
|
|
||||||
if not matches:
|
|
||||||
output_lines.append(line)
|
|
||||||
continue
|
|
||||||
|
|
||||||
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
|
|
||||||
if not had_non_tag_text and len(matches) == 1:
|
|
||||||
tag_text = matches[0].group(0)
|
|
||||||
if (
|
|
||||||
upcoming_timeline_only
|
|
||||||
and "<timeline>" in tag_text
|
|
||||||
and not _timeline_date_in_current_month_or_future(line)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
output_lines.append(tag_text)
|
|
||||||
continue
|
|
||||||
|
|
||||||
for match in matches:
|
|
||||||
tag_text = match.group(0)
|
|
||||||
if (
|
|
||||||
upcoming_timeline_only
|
|
||||||
and "<timeline>" in tag_text
|
|
||||||
and not _timeline_date_in_current_month_or_future(line)
|
|
||||||
):
|
|
||||||
continue
|
|
||||||
output_lines.append(tag_text)
|
|
||||||
|
|
||||||
return "\n".join(output_lines)
|
|
||||||
|
|
||||||
|
|
||||||
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
|
|
||||||
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
|
|
||||||
_FLOATING_EMPTY_FALLBACK = "No results found."
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_floating_markup_fragment(text: str) -> str:
|
|
||||||
if not text:
|
|
||||||
return text
|
|
||||||
cleaned = _GENERIC_TAG_RE.sub("", text)
|
|
||||||
return _BRACKETED_ID_RE.sub("", cleaned)
|
|
||||||
|
|
||||||
|
|
||||||
def _strip_floating_markup(text: str) -> str:
|
|
||||||
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
|
|
||||||
if not text:
|
|
||||||
return text
|
|
||||||
|
|
||||||
cleaned = _strip_floating_markup_fragment(text)
|
|
||||||
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
|
|
||||||
return "\n".join(line for line in lines if line)
|
|
||||||
|
|
||||||
|
|
||||||
def _fallback_from_raw_floating_text(raw_text: str) -> str:
|
|
||||||
fallback = _strip_floating_markup_fragment(raw_text or "")
|
|
||||||
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
|
|
||||||
return fallback or _FLOATING_EMPTY_FALLBACK
|
|
||||||
|
|
||||||
|
|
||||||
class _FloatingStreamSanitizer:
|
|
||||||
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
|
|
||||||
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._pending = ""
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _split_safe_boundary(text: str) -> tuple[str, str]:
|
|
||||||
boundary = len(text)
|
|
||||||
|
|
||||||
last_lt = text.rfind("<")
|
|
||||||
if last_lt != -1 and ">" not in text[last_lt:]:
|
|
||||||
boundary = min(boundary, last_lt)
|
|
||||||
|
|
||||||
last_lb = text.rfind("[")
|
|
||||||
if last_lb != -1 and "]" not in text[last_lb:]:
|
|
||||||
boundary = min(boundary, last_lb)
|
|
||||||
|
|
||||||
if boundary == len(text):
|
|
||||||
return text, ""
|
|
||||||
return text[:boundary], text[boundary:]
|
|
||||||
|
|
||||||
def feed(self, chunk: str) -> str:
|
|
||||||
combined = f"{self._pending}{chunk}"
|
|
||||||
safe_text, self._pending = self._split_safe_boundary(combined)
|
|
||||||
return _strip_floating_markup_fragment(safe_text)
|
|
||||||
|
|
||||||
def finalize(self) -> str:
|
|
||||||
tail = re.sub(r"<[^>\n]*$", "", self._pending)
|
|
||||||
tail = re.sub(r"\[[^\]\n]*$", "", tail)
|
|
||||||
self._pending = ""
|
|
||||||
return _strip_floating_markup_fragment(tail)
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_memory_label(path_or_label: str) -> str:
|
|
||||||
value = path_or_label.strip()
|
|
||||||
if value.startswith("/memories/"):
|
|
||||||
value = value[len("/memories/"):]
|
|
||||||
value = value.strip("/")
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
|
|
||||||
@tool
|
|
||||||
async def memory_list_blocks() -> str:
|
|
||||||
"""List all core memory blocks currently stored for the user."""
|
|
||||||
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
blocks = await memory.list_core_blocks(user_id)
|
|
||||||
if not blocks:
|
|
||||||
return "No memory blocks found."
|
|
||||||
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
|
|
||||||
return "Memory blocks:\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def memory_get(path_or_label: str) -> str:
|
|
||||||
"""Get one memory block by label or /memories/<label> path."""
|
|
||||||
label = _normalize_memory_label(path_or_label)
|
|
||||||
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
||||||
if not label:
|
|
||||||
return "Invalid memory label."
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
value = await memory.get_core_block(user_id, label)
|
|
||||||
if value is None:
|
|
||||||
return f"Memory block '{label}' not found."
|
|
||||||
return f"Memory block '{label}':\n{value}"
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def memory_create(path_or_label: str, value: str) -> str:
|
|
||||||
"""Create or overwrite a memory block value by label or /memories/<label> path."""
|
|
||||||
label = _normalize_memory_label(path_or_label)
|
|
||||||
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
||||||
if not label:
|
|
||||||
return "Invalid memory label."
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
await memory.update_core(user_id, label, value, trace_id=trace_id)
|
|
||||||
return f"Memory block '{label}' saved."
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def memory_append(path_or_label: str, content: str) -> str:
|
|
||||||
"""Append content to a memory block, creating it if missing."""
|
|
||||||
label = _normalize_memory_label(path_or_label)
|
|
||||||
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
||||||
if not label:
|
|
||||||
return "Invalid memory label."
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
await memory.append_core(user_id, label, content)
|
|
||||||
return f"Memory block '{label}' appended."
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
|
|
||||||
"""Replace one exact string in a memory block."""
|
|
||||||
label = _normalize_memory_label(path_or_label)
|
|
||||||
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
||||||
if not label:
|
|
||||||
return "Invalid memory label."
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
changed = await memory.replace_core(user_id, label, old_string, new_string)
|
|
||||||
if not changed:
|
|
||||||
return f"No replacement made in '{label}' (old string not found)."
|
|
||||||
return f"Memory block '{label}' updated."
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def memory_delete(path_or_label: str) -> str:
|
|
||||||
"""Delete a memory block by label or /memories/<label> path."""
|
|
||||||
label = _normalize_memory_label(path_or_label)
|
|
||||||
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
|
|
||||||
if not label:
|
|
||||||
return "Invalid memory label."
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
deleted = await memory.delete_core(user_id, label)
|
|
||||||
if not deleted:
|
|
||||||
return f"Memory block '{label}' not found."
|
|
||||||
return f"Memory block '{label}' deleted."
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def archival_memory_insert(content: str) -> str:
|
|
||||||
"""Insert a long-term archival memory entry."""
|
|
||||||
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
await memory.insert_archival(user_id, content, source="assistant")
|
|
||||||
return "Archival memory saved."
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def archival_memory_search(query: str, top_k: int = 5) -> str:
|
|
||||||
"""Search long-term archival memory by semantic fallback (keyword currently)."""
|
|
||||||
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
results = await memory.search_archival(user_id, query, top_k=top_k)
|
|
||||||
if not results:
|
|
||||||
return "No archival memory results found."
|
|
||||||
lines = [f"- {item}" for item in results]
|
|
||||||
return "Archival memory results:\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def conversation_search(query: str, top_k: int = 5) -> str:
|
|
||||||
"""Search recall memory from prior episodic conversation summaries."""
|
|
||||||
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
results = await memory.search_recall(user_id, query, top_k=top_k)
|
|
||||||
if not results:
|
|
||||||
return "No recall memory results found."
|
|
||||||
lines = [f"- {item}" for item in results]
|
|
||||||
return "Recall memory results:\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
return [
|
|
||||||
memory_list_blocks,
|
|
||||||
memory_get,
|
|
||||||
memory_create,
|
|
||||||
memory_append,
|
|
||||||
memory_replace,
|
|
||||||
memory_delete,
|
|
||||||
archival_memory_insert,
|
|
||||||
archival_memory_search,
|
|
||||||
conversation_search,
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
|
|
||||||
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
|
|
||||||
|
|
||||||
|
|
||||||
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
|
|
||||||
lowered = message.lower()
|
|
||||||
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
|
|
||||||
return "timeline"
|
|
||||||
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
|
|
||||||
return "task"
|
|
||||||
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
|
|
||||||
return "note"
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
|
|
||||||
type_raw = str(payload.get("type") or "").strip().lower()
|
|
||||||
domain_type: FloatingDomainType = "task"
|
|
||||||
if type_raw in {"task", "timeline", "project", "node"}:
|
|
||||||
domain_type = type_raw
|
|
||||||
|
|
||||||
id_value = payload.get("id")
|
|
||||||
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
|
|
||||||
if domain_type == "project" and not domain_id:
|
|
||||||
domain_id = fallback_id
|
|
||||||
|
|
||||||
section_raw = payload.get("section")
|
|
||||||
section: FloatingDomainSection | None = None
|
|
||||||
if isinstance(section_raw, str):
|
|
||||||
section_candidate = section_raw.strip().lower()
|
|
||||||
if section_candidate in {"task", "timeline", "note"}:
|
|
||||||
section = section_candidate
|
|
||||||
|
|
||||||
if domain_type != "project":
|
|
||||||
section = None
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": domain_type,
|
|
||||||
"id": domain_id,
|
|
||||||
"section": section,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_json_object(text: str) -> dict[str, Any] | None:
|
|
||||||
raw = text.strip()
|
|
||||||
if not raw:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
parsed = json.loads(raw)
|
|
||||||
return parsed if isinstance(parsed, dict) else None
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
match = re.search(r"\{.*\}", raw, re.DOTALL)
|
|
||||||
if not match:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
parsed = json.loads(match.group(0))
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
return None
|
|
||||||
return parsed if isinstance(parsed, dict) else None
|
|
||||||
|
|
||||||
|
|
||||||
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
|
|
||||||
section = _detect_domain_section(message)
|
|
||||||
scope = context.get("scope") if isinstance(context, dict) else None
|
|
||||||
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
|
||||||
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
|
||||||
|
|
||||||
if isinstance(scope, dict):
|
|
||||||
scope_type = str(scope.get("type") or "").strip().lower()
|
|
||||||
scope_id = scope.get("id")
|
|
||||||
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
|
|
||||||
|
|
||||||
if scope_type in {"task", "tasks"}:
|
|
||||||
return {"type": "task", "id": scope_id_value, "section": None}
|
|
||||||
if scope_type in {"project", "projects"}:
|
|
||||||
project_scope_id = scope_id_value or project_id
|
|
||||||
return {
|
|
||||||
"type": "project",
|
|
||||||
"id": project_scope_id,
|
|
||||||
"section": section,
|
|
||||||
}
|
|
||||||
if scope_type in {"note", "notes"}:
|
|
||||||
return {
|
|
||||||
"type": "node",
|
|
||||||
"id": scope_id_value,
|
|
||||||
"section": None,
|
|
||||||
}
|
|
||||||
if scope_type in {"timeline", "timelines"}:
|
|
||||||
return {"type": "timeline", "id": scope_id_value, "section": None}
|
|
||||||
|
|
||||||
lowered = message.lower()
|
|
||||||
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
|
|
||||||
return {
|
|
||||||
"type": "project",
|
|
||||||
"id": project_id,
|
|
||||||
"section": section,
|
|
||||||
}
|
|
||||||
if section == "timeline":
|
|
||||||
return {"type": "timeline", "id": None, "section": None}
|
|
||||||
if section == "note":
|
|
||||||
return {"type": "node", "id": None, "section": None}
|
|
||||||
return {"type": "task", "id": None, "section": None}
|
|
||||||
|
|
||||||
|
|
||||||
async def _infer_floating_domain(
|
|
||||||
message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None,
|
|
||||||
) -> dict[str, str | None]:
|
|
||||||
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
|
|
||||||
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
|
|
||||||
|
|
||||||
classifier_context = {
|
|
||||||
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
|
|
||||||
"resolved_project_id": project_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
try:
|
|
||||||
classifier_prompt = _get_system_prompt(
|
|
||||||
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_SYSTEM,
|
|
||||||
)
|
|
||||||
callbacks = _build_callbacks(langfuse_handler)
|
|
||||||
llm = get_llm(callbacks=callbacks)
|
|
||||||
response = await llm.ainvoke(
|
|
||||||
[
|
|
||||||
SystemMessage(content=classifier_prompt),
|
|
||||||
HumanMessage(
|
|
||||||
content=(
|
|
||||||
f"Message:\n{message}\n\n"
|
|
||||||
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
|
|
||||||
)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
)
|
|
||||||
parsed = _parse_json_object(_as_text(response.content))
|
|
||||||
if parsed is not None:
|
|
||||||
domain = _normalize_domain_payload(parsed, project_id)
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
|
|
||||||
domain.get("type"),
|
|
||||||
domain.get("id"),
|
|
||||||
domain.get("section"),
|
|
||||||
)
|
|
||||||
return domain
|
|
||||||
logger.warning("deep_agent: floating_domain classifier returned non-json output")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
|
|
||||||
|
|
||||||
return _infer_floating_domain_rule_based(message, context)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_system_prompt(langfuse_name: str, fallback: str) -> str:
|
|
||||||
"""Fetch a managed prompt from Langfuse, falling back to the hardcoded string."""
|
|
||||||
managed = tracing.get_prompt(langfuse_name, fallback=None)
|
|
||||||
return managed if managed is not None else fallback
|
|
||||||
|
|
||||||
|
|
||||||
def _build_callbacks(langfuse_handler: Any | None) -> list[Any] | None:
|
|
||||||
"""Return a callbacks list if a Langfuse handler is available."""
|
|
||||||
if langfuse_handler is None:
|
|
||||||
return None
|
|
||||||
return [langfuse_handler]
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_single_agent(
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
system_prompt: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
max_steps: int = 6,
|
|
||||||
langfuse_handler: Any | None = None,
|
|
||||||
) -> str:
|
|
||||||
trace_id = _trace_id_from_context(context)
|
|
||||||
callbacks = _build_callbacks(langfuse_handler)
|
|
||||||
llm = get_llm(callbacks=callbacks)
|
|
||||||
tools = _all_tools_for_user(user_id, trace_id)
|
|
||||||
model_context = _context_for_model(context)
|
|
||||||
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
|
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
|
||||||
messages: list[Any] = [
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
HumanMessage(
|
|
||||||
content=(
|
|
||||||
f"User message:\n{message}\n\n"
|
|
||||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
|
||||||
)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
tool_calls_count = 0
|
|
||||||
collected: list[dict[str, Any]] = []
|
|
||||||
set_tool_result_collector(collected)
|
|
||||||
try:
|
|
||||||
for _ in range(max_steps):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
final_text = _as_text(response.content)
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
tool_calls_count,
|
|
||||||
len(final_text),
|
|
||||||
)
|
|
||||||
return final_text
|
|
||||||
|
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_calls_count += 1
|
|
||||||
call_id = str(call.get("id", ""))
|
|
||||||
call_name = str(call.get("name", ""))
|
|
||||||
call_args = call.get("args", {})
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
|
||||||
call_id,
|
|
||||||
call_name,
|
|
||||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_fn = tool_map.get(call_name)
|
|
||||||
if tool_fn is None:
|
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
|
||||||
else:
|
|
||||||
tool_output = await tool_fn.ainvoke(call_args)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
|
||||||
call_id,
|
|
||||||
call_name,
|
|
||||||
str(tool_output)[:1200],
|
|
||||||
)
|
|
||||||
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
||||||
|
|
||||||
final = await llm.ainvoke(messages)
|
|
||||||
final_text = _as_text(final.content)
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
tool_calls_count,
|
|
||||||
len(final_text),
|
|
||||||
)
|
|
||||||
return final_text
|
|
||||||
finally:
|
|
||||||
clear_tool_result_collector()
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_single_agent_stream(
|
|
||||||
*,
|
|
||||||
user_id: str,
|
|
||||||
system_prompt: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
max_steps: int = 6,
|
|
||||||
langfuse_handler: Any | None = None,
|
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
||||||
trace_id = _trace_id_from_context(context)
|
|
||||||
callbacks = _build_callbacks(langfuse_handler)
|
|
||||||
llm = get_llm(callbacks=callbacks)
|
|
||||||
tools = _all_tools_for_user(user_id, trace_id)
|
|
||||||
model_context = _context_for_model(context)
|
|
||||||
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
|
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
|
||||||
messages: list[Any] = [
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
HumanMessage(
|
|
||||||
content=(
|
|
||||||
f"User message:\n{message}\n\n"
|
|
||||||
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
|
|
||||||
)
|
|
||||||
),
|
|
||||||
]
|
|
||||||
|
|
||||||
tool_calls_count = 0
|
|
||||||
streamed_chars = 0
|
|
||||||
collected: list[dict[str, Any]] = []
|
|
||||||
set_tool_result_collector(collected)
|
|
||||||
try:
|
|
||||||
for _ in range(max_steps):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
emitted_any = False
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
token = _as_text(getattr(chunk, "content", ""))
|
|
||||||
if token:
|
|
||||||
streamed_chars += len(token)
|
|
||||||
emitted_any = True
|
|
||||||
yield "token", token
|
|
||||||
|
|
||||||
if not emitted_any:
|
|
||||||
fallback_text = _as_text(response.content)
|
|
||||||
if fallback_text:
|
|
||||||
streamed_chars += len(fallback_text)
|
|
||||||
yield "token", fallback_text
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
tool_calls_count,
|
|
||||||
streamed_chars,
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
||||||
for call in response.tool_calls:
|
|
||||||
tool_calls_count += 1
|
|
||||||
call_id = str(call.get("id", ""))
|
|
||||||
call_name = str(call.get("name", ""))
|
|
||||||
call_args = call.get("args", {})
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
|
|
||||||
call_id,
|
|
||||||
call_name,
|
|
||||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_fn = tool_map.get(call_name)
|
|
||||||
if tool_fn is None:
|
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
|
||||||
else:
|
|
||||||
tool_output = await tool_fn.ainvoke(call_args)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
|
|
||||||
call_id,
|
|
||||||
call_name,
|
|
||||||
str(tool_output)[:1200],
|
|
||||||
)
|
|
||||||
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
||||||
|
|
||||||
async for chunk in llm.astream(messages):
|
|
||||||
token = _as_text(getattr(chunk, "content", ""))
|
|
||||||
if token:
|
|
||||||
streamed_chars += len(token)
|
|
||||||
yield "token", token
|
|
||||||
logger.info(
|
|
||||||
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
|
|
||||||
trace_id or "-",
|
|
||||||
user_id,
|
|
||||||
tool_calls_count,
|
|
||||||
streamed_chars,
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
clear_tool_result_collector()
|
|
||||||
|
|
||||||
|
|
||||||
async def run_home(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> str:
|
|
||||||
prepared_context = await _prepare_context(message, context)
|
|
||||||
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
|
|
||||||
response = await _run_single_agent(
|
|
||||||
user_id=user_id,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
message=message,
|
|
||||||
context=prepared_context,
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
)
|
|
||||||
return _normalize_tagged_list_lines(response, message)
|
|
||||||
|
|
||||||
|
|
||||||
async def run_floating(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> tuple[str, dict[str, str | None]]:
|
|
||||||
prepared_context = await _prepare_context(message, context)
|
|
||||||
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
|
|
||||||
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
|
|
||||||
response = await _run_single_agent(
|
|
||||||
user_id=user_id,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
message=message,
|
|
||||||
context=prepared_context,
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
)
|
|
||||||
sanitized = _strip_floating_markup(response)
|
|
||||||
if not sanitized and response:
|
|
||||||
sanitized = _fallback_from_raw_floating_text(response)
|
|
||||||
return sanitized, domain
|
|
||||||
|
|
||||||
|
|
||||||
async def run_home_stream(
|
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
*,
|
|
||||||
langfuse_handler: Any | None = None,
|
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
||||||
prepared_context = await _prepare_context(message, context)
|
|
||||||
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
|
|
||||||
text_chunks: list[str] = []
|
|
||||||
async for event in _run_single_agent_stream(
|
|
||||||
user_id=user_id,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
message=message,
|
|
||||||
context=prepared_context,
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
):
|
|
||||||
event_type, data = event
|
|
||||||
if event_type != "token":
|
|
||||||
yield event
|
|
||||||
continue
|
|
||||||
text_chunks.append(str(data or ""))
|
|
||||||
|
|
||||||
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
|
|
||||||
if normalized:
|
|
||||||
yield "token", normalized
|
|
||||||
|
|
||||||
|
|
||||||
async def run_floating_stream(
|
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
context: dict[str, Any],
|
|
||||||
*,
|
|
||||||
langfuse_handler: Any | None = None,
|
|
||||||
) -> AsyncGenerator[tuple[str, Any], None]:
|
|
||||||
prepared_context = await _prepare_context(message, context)
|
|
||||||
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
|
|
||||||
yield "floating_domain", domain
|
|
||||||
|
|
||||||
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
|
|
||||||
sanitizer = _FloatingStreamSanitizer()
|
|
||||||
emitted_sanitized = False
|
|
||||||
raw_chunks: list[str] = []
|
|
||||||
async for event in _run_single_agent_stream(
|
|
||||||
user_id=user_id,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
message=message,
|
|
||||||
context=prepared_context,
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
):
|
|
||||||
event_type, data = event
|
|
||||||
if event_type != "token":
|
|
||||||
yield event
|
|
||||||
continue
|
|
||||||
|
|
||||||
raw_chunk = str(data or "")
|
|
||||||
raw_chunks.append(raw_chunk)
|
|
||||||
sanitized_chunk = sanitizer.feed(raw_chunk)
|
|
||||||
if sanitized_chunk:
|
|
||||||
emitted_sanitized = True
|
|
||||||
yield "token", sanitized_chunk
|
|
||||||
|
|
||||||
tail = sanitizer.finalize()
|
|
||||||
if tail:
|
|
||||||
emitted_sanitized = True
|
|
||||||
yield "token", tail
|
|
||||||
|
|
||||||
if not emitted_sanitized and raw_chunks:
|
|
||||||
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
|
|
||||||
|
|
||||||
|
|
||||||
async def update_core_memory(user_id: str, key: str, value: str) -> None:
|
|
||||||
"""Compatibility helper kept for callers that expect explicit memory update API."""
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
await memory.update_core(user_id, key, value)
|
|
||||||
@@ -1,72 +0,0 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
|
||||||
|
|
||||||
Adapted from app/core/llm.py for the Chat Service.
|
|
||||||
Uses shared.config.settings instead of app.config.settings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
import litellm
|
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain_litellm import ChatLiteLLM
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
litellm.drop_params = True
|
|
||||||
|
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore",
|
|
||||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
|
||||||
category=UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
|
||||||
if model.startswith("anthropic/"):
|
|
||||||
return settings.ANTHROPIC_API_KEY or None
|
|
||||||
if model.startswith("gemini/") or model.startswith("google/"):
|
|
||||||
return settings.GOOGLE_API_KEY or None
|
|
||||||
if model.startswith("cerebras/"):
|
|
||||||
return settings.CEREBRAS_API_KEY or None
|
|
||||||
if model.startswith("github_copilot/"):
|
|
||||||
return None
|
|
||||||
return settings.OPENAI_API_KEY or None
|
|
||||||
|
|
||||||
|
|
||||||
def get_llm(
|
|
||||||
*,
|
|
||||||
model: str | None = None,
|
|
||||||
temperature: float = 0,
|
|
||||||
callbacks: list | None = None,
|
|
||||||
) -> ChatOpenAI | ChatLiteLLM:
|
|
||||||
model = model or settings.LLM_MODEL
|
|
||||||
|
|
||||||
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
|
||||||
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
|
||||||
|
|
||||||
if "/" in model:
|
|
||||||
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
|
|
||||||
|
|
||||||
return ChatOpenAI(
|
|
||||||
model=model,
|
|
||||||
temperature=temperature,
|
|
||||||
api_key=_api_key_for_model(model),
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
|
||||||
model = settings.LLM_EMBED_MODEL
|
|
||||||
|
|
||||||
if model.startswith("github_copilot/") or "/" in model:
|
|
||||||
response = await litellm.aembedding(model=model, input=[text])
|
|
||||||
return response.data[0]["embedding"]
|
|
||||||
|
|
||||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
|
||||||
response = await client.embeddings.create(model=model, input=text)
|
|
||||||
return response.data[0].embedding
|
|
||||||
@@ -1,87 +0,0 @@
|
|||||||
"""Chat Service — LLM orchestration, domain agents, memory.
|
|
||||||
|
|
||||||
Consumes chat requests from Redis, executes deep_agent (home/floating),
|
|
||||||
streams responses back via Redis pub/sub to WS Gateway.
|
|
||||||
|
|
||||||
Owns: memory_core, memory_associative, memory_episodic, memory_proactive tables.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
|
|
||||||
_repo_root = str(Path(__file__).resolve().parents[3])
|
|
||||||
if _repo_root not in sys.path:
|
|
||||||
sys.path.insert(0, _repo_root)
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
||||||
)
|
|
||||||
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
|
||||||
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
# Initialise Langfuse tracing (no-op if keys are missing)
|
|
||||||
from app.tracing import init_langfuse
|
|
||||||
|
|
||||||
init_langfuse()
|
|
||||||
|
|
||||||
# Start Redis consumer in background
|
|
||||||
from app.redis_consumer import start_consumer
|
|
||||||
|
|
||||||
consumer_task = start_consumer()
|
|
||||||
yield
|
|
||||||
consumer_task.cancel()
|
|
||||||
|
|
||||||
from app.tracing import shutdown as shutdown_langfuse
|
|
||||||
|
|
||||||
shutdown_langfuse()
|
|
||||||
|
|
||||||
from shared.db import engine
|
|
||||||
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
from shared.redis import redis_client
|
|
||||||
|
|
||||||
await redis_client.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
|
||||||
app = FastAPI(
|
|
||||||
title="Adiuva Chat Service",
|
|
||||||
version="0.1.0",
|
|
||||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
|
||||||
redoc_url=None,
|
|
||||||
lifespan=lifespan,
|
|
||||||
)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=settings.CORS_ORIGINS,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
from app.routes import router
|
|
||||||
|
|
||||||
app.include_router(router, prefix="/api/v1")
|
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
|
||||||
async def health() -> dict:
|
|
||||||
return {"status": "ok", "service": "chat", "version": app.version}
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
app = create_app()
|
|
||||||
@@ -1,295 +0,0 @@
|
|||||||
"""Memory Middleware — adapted for Chat Service.
|
|
||||||
|
|
||||||
Uses shared.models instead of app.models. Otherwise identical to the
|
|
||||||
monolith's app/core/memory_middleware.py.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from cryptography.fernet import Fernet, InvalidToken
|
|
||||||
from sqlalchemy import select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from shared.models import (
|
|
||||||
MemoryAssociative,
|
|
||||||
MemoryCore,
|
|
||||||
MemoryEpisodic,
|
|
||||||
MemoryProactive,
|
|
||||||
User,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_ASSOCIATIVE_TOP_K = 5
|
|
||||||
_EPISODIC_RECENT_N = 10
|
|
||||||
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryMiddleware:
|
|
||||||
|
|
||||||
def __init__(self, db: AsyncSession) -> None:
|
|
||||||
self._db = db
|
|
||||||
|
|
||||||
async def enrich_context(
|
|
||||||
self,
|
|
||||||
user_id: str,
|
|
||||||
message: str,
|
|
||||||
trace_id: str | None = None,
|
|
||||||
session_id: str | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return {}
|
|
||||||
|
|
||||||
core = await self._load_core(user_id, fernet)
|
|
||||||
associative = await self._load_associative(user_id, message, fernet)
|
|
||||||
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
|
|
||||||
proactive = await self._load_proactive(user_id, fernet)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"memory: enrich_context trace=%s user=%s core=%d assoc=%d episodic=%d proactive=%d",
|
|
||||||
trace_id or "-", user_id, len(core), len(associative), len(episodic), len(proactive),
|
|
||||||
)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"core_memory": core,
|
|
||||||
"associative_memory": associative,
|
|
||||||
"episodic_memory": episodic,
|
|
||||||
"proactive_hints": proactive,
|
|
||||||
}
|
|
||||||
|
|
||||||
async def store_episode(
|
|
||||||
self, user_id: str, session_id: str, message: str, response: str,
|
|
||||||
trace_id: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
|
||||||
encrypted = _encrypt(fernet, summary)
|
|
||||||
|
|
||||||
row = MemoryEpisodic(
|
|
||||||
id=str(uuid.uuid4()),
|
|
||||||
user_id=user_id,
|
|
||||||
summary_encrypted=encrypted,
|
|
||||||
session_id=session_id,
|
|
||||||
)
|
|
||||||
self._db.add(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
|
|
||||||
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
encrypted = _encrypt(fernet, value)
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == key)
|
|
||||||
)
|
|
||||||
existing = result.scalar_one_or_none()
|
|
||||||
if existing is not None:
|
|
||||||
existing.value_encrypted = encrypted
|
|
||||||
else:
|
|
||||||
self._db.add(MemoryCore(
|
|
||||||
id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted,
|
|
||||||
))
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
|
|
||||||
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id).order_by(MemoryCore.key.asc())
|
|
||||||
)
|
|
||||||
out: list[dict[str, str]] = []
|
|
||||||
for row in result.scalars().all():
|
|
||||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
||||||
if plaintext is not None:
|
|
||||||
out.append({"label": row.key, "value": plaintext})
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def get_core_block(self, user_id: str, label: str) -> str | None:
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return None
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
return None
|
|
||||||
return _safe_decrypt(fernet, row.value_encrypted)
|
|
||||||
|
|
||||||
async def delete_core(self, user_id: str, label: str) -> bool:
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
|
|
||||||
)
|
|
||||||
row = result.scalar_one_or_none()
|
|
||||||
if row is None:
|
|
||||||
return False
|
|
||||||
await self._db.delete(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
return True
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def append_core(self, user_id: str, label: str, content: str) -> None:
|
|
||||||
current = await self.get_core_block(user_id, label)
|
|
||||||
if current is None:
|
|
||||||
await self.update_core(user_id, label, content)
|
|
||||||
return
|
|
||||||
await self.update_core(user_id, label, f"{current}\n{content}")
|
|
||||||
|
|
||||||
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
|
|
||||||
current = await self.get_core_block(user_id, label)
|
|
||||||
if current is None or old not in current:
|
|
||||||
return False
|
|
||||||
await self.update_core(user_id, label, current.replace(old, new, 1))
|
|
||||||
return True
|
|
||||||
|
|
||||||
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return
|
|
||||||
encrypted = _encrypt(fernet, content)
|
|
||||||
row = MemoryAssociative(
|
|
||||||
id=str(uuid.uuid4()), user_id=user_id,
|
|
||||||
content_encrypted=encrypted, embedding=None,
|
|
||||||
entity_type=source, entity_id=None,
|
|
||||||
)
|
|
||||||
self._db.add(row)
|
|
||||||
try:
|
|
||||||
await self._db.commit()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
|
|
||||||
await self._db.rollback()
|
|
||||||
|
|
||||||
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
|
||||||
.order_by(MemoryAssociative.updated_at.desc()).limit(100)
|
|
||||||
)
|
|
||||||
needle = query.strip().lower()
|
|
||||||
out: list[str] = []
|
|
||||||
for row in result.scalars().all():
|
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
||||||
if plaintext is None:
|
|
||||||
continue
|
|
||||||
if not needle or needle in plaintext.lower():
|
|
||||||
out.append(plaintext)
|
|
||||||
if len(out) >= max(top_k, 1):
|
|
||||||
break
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
|
|
||||||
fernet = await self._get_fernet(user_id)
|
|
||||||
if fernet is None:
|
|
||||||
return []
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
|
||||||
.order_by(MemoryEpisodic.created_at.desc()).limit(100)
|
|
||||||
)
|
|
||||||
needle = query.strip().lower()
|
|
||||||
out: list[str] = []
|
|
||||||
for row in result.scalars().all():
|
|
||||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
||||||
if plaintext is None:
|
|
||||||
continue
|
|
||||||
if not needle or needle in plaintext.lower():
|
|
||||||
out.append(plaintext)
|
|
||||||
if len(out) >= max(top_k, 1):
|
|
||||||
break
|
|
||||||
return out
|
|
||||||
|
|
||||||
# ── Private ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
|
||||||
result = await self._db.execute(select(User).where(User.id == user_id))
|
|
||||||
user = result.scalar_one_or_none()
|
|
||||||
if user is None or not user.encryption_key:
|
|
||||||
logger.warning("memory: no encryption_key for user=%s", user_id)
|
|
||||||
return None
|
|
||||||
return Fernet(user.encryption_key.encode())
|
|
||||||
|
|
||||||
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
|
||||||
)
|
|
||||||
out: dict[str, str] = {}
|
|
||||||
for row in result.scalars().all():
|
|
||||||
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
|
||||||
if plaintext is not None:
|
|
||||||
out[row.key] = plaintext
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def _load_associative(self, user_id: str, message: str, fernet: Fernet) -> list[str]:
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
|
|
||||||
.order_by(MemoryAssociative.updated_at.desc()).limit(_ASSOCIATIVE_TOP_K)
|
|
||||||
)
|
|
||||||
out: list[str] = []
|
|
||||||
for row in result.scalars().all():
|
|
||||||
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
|
||||||
if plaintext is not None:
|
|
||||||
out.append(plaintext)
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def _load_episodic(self, user_id: str, fernet: Fernet, session_id: str | None = None) -> list[str]:
|
|
||||||
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
|
|
||||||
if session_id:
|
|
||||||
query = query.where(MemoryEpisodic.session_id == session_id)
|
|
||||||
result = await self._db.execute(
|
|
||||||
query.order_by(MemoryEpisodic.created_at.desc()).limit(_EPISODIC_RECENT_N)
|
|
||||||
)
|
|
||||||
out: list[str] = []
|
|
||||||
for row in result.scalars().all():
|
|
||||||
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
|
||||||
if plaintext is not None:
|
|
||||||
out.append(plaintext)
|
|
||||||
return out
|
|
||||||
|
|
||||||
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
|
||||||
result = await self._db.execute(
|
|
||||||
select(MemoryProactive).where(
|
|
||||||
MemoryProactive.user_id == user_id,
|
|
||||||
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
|
||||||
).order_by(MemoryProactive.confidence.desc())
|
|
||||||
)
|
|
||||||
out: list[str] = []
|
|
||||||
for row in result.scalars().all():
|
|
||||||
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
|
||||||
if plaintext is not None:
|
|
||||||
out.append(plaintext)
|
|
||||||
return out
|
|
||||||
|
|
||||||
|
|
||||||
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
|
||||||
return fernet.encrypt(plaintext.encode()).decode()
|
|
||||||
|
|
||||||
|
|
||||||
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
|
||||||
try:
|
|
||||||
return fernet.decrypt(ciphertext.encode()).decode()
|
|
||||||
except (InvalidToken, Exception) as exc:
|
|
||||||
logger.warning("memory: decrypt failed: %s", exc)
|
|
||||||
return None
|
|
||||||
@@ -1,50 +0,0 @@
|
|||||||
"""Output formatter for deep-agent stream events — Chat Service copy.
|
|
||||||
|
|
||||||
Converts (event_type, data) tuples into WebSocket frame Pydantic models.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from shared.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
|
|
||||||
|
|
||||||
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
|
||||||
|
|
||||||
|
|
||||||
class StreamFormatter:
|
|
||||||
"""Convert `(event_type, data)` stream events into websocket frame models."""
|
|
||||||
|
|
||||||
def __init__(self, request_id: str) -> None:
|
|
||||||
self.request_id = request_id
|
|
||||||
|
|
||||||
async def format(
|
|
||||||
self,
|
|
||||||
event_stream: AsyncGenerator[tuple[str, Any], None],
|
|
||||||
) -> AsyncGenerator[WsFrame, None]:
|
|
||||||
started = False
|
|
||||||
|
|
||||||
async for event_type, data in event_stream:
|
|
||||||
if event_type == "floating_domain":
|
|
||||||
if isinstance(data, dict):
|
|
||||||
yield WsFloatingDomain(
|
|
||||||
request_id=self.request_id,
|
|
||||||
domain=data,
|
|
||||||
)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if event_type != "token":
|
|
||||||
continue
|
|
||||||
|
|
||||||
if not started:
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
|
||||||
started = True
|
|
||||||
|
|
||||||
text = str(data or "")
|
|
||||||
if text:
|
|
||||||
yield WsStreamText(request_id=self.request_id, chunk=text)
|
|
||||||
|
|
||||||
if not started:
|
|
||||||
yield WsStreamStart(request_id=self.request_id)
|
|
||||||
yield WsStreamEnd(request_id=self.request_id)
|
|
||||||
@@ -1,209 +0,0 @@
|
|||||||
"""Redis consumer — listens for chat requests and dispatches to deep_agent.
|
|
||||||
|
|
||||||
Subscribes to a Redis pattern channel chat:request:* so it receives
|
|
||||||
requests for ALL users. Each request is processed in a separate asyncio task.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from shared.db import async_session
|
|
||||||
from shared.redis import redis_client, ws_out_channel
|
|
||||||
|
|
||||||
from app.deep_agent import run_floating_stream, run_home_stream
|
|
||||||
from app.memory_middleware import MemoryMiddleware
|
|
||||||
from app.output_formatter import StreamFormatter
|
|
||||||
from app.ws_context import clear_current_user, set_current_user
|
|
||||||
from app import tracing
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def start_consumer() -> asyncio.Task:
|
|
||||||
"""Start the Redis consumer as a background asyncio task."""
|
|
||||||
return asyncio.create_task(_consumer_loop())
|
|
||||||
|
|
||||||
|
|
||||||
async def _consumer_loop() -> None:
|
|
||||||
"""Subscribe to chat:request:* and dispatch incoming frames."""
|
|
||||||
pubsub = redis_client.pubsub()
|
|
||||||
await pubsub.psubscribe("chat:request:*")
|
|
||||||
logger.info("redis_consumer: subscribed to chat:request:*")
|
|
||||||
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
message = await pubsub.get_message(
|
|
||||||
ignore_subscribe_messages=True, timeout=1.0
|
|
||||||
)
|
|
||||||
if message is not None and message["type"] == "pmessage":
|
|
||||||
frame = json.loads(message["data"])
|
|
||||||
asyncio.create_task(_dispatch(frame))
|
|
||||||
else:
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("redis_consumer: shutting down")
|
|
||||||
finally:
|
|
||||||
await pubsub.punsubscribe()
|
|
||||||
await pubsub.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch(frame: dict) -> None:
|
|
||||||
"""Route a chat request frame to the appropriate handler."""
|
|
||||||
frame_type = frame.get("type")
|
|
||||||
user_id = frame.get("user_id")
|
|
||||||
|
|
||||||
if not user_id:
|
|
||||||
logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type"))
|
|
||||||
return
|
|
||||||
|
|
||||||
if frame_type == "home_request":
|
|
||||||
await _handle_home_request(user_id, frame)
|
|
||||||
elif frame_type == "floating_request":
|
|
||||||
await _handle_floating_request(user_id, frame)
|
|
||||||
else:
|
|
||||||
logger.debug("redis_consumer: unknown frame type %r", frame_type)
|
|
||||||
|
|
||||||
|
|
||||||
async def _publish_frame(user_id: str, frame_data: str) -> None:
|
|
||||||
"""Publish a frame to ws:out:{user_id} for the WS Gateway to forward."""
|
|
||||||
channel = ws_out_channel(user_id)
|
|
||||||
await redis_client.publish(channel, frame_data)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_home_request(user_id: str, frame: dict) -> None:
|
|
||||||
"""Process a home_request — enrich with memory, run deep_agent, stream results."""
|
|
||||||
request_id = frame.get("request_id") or str(uuid4())
|
|
||||||
message: str = frame.get("message", "")
|
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"redis_consumer: home_request user=%s req=%s msg=%s",
|
|
||||||
user_id, request_id, message[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
response_chunks: list[str] = []
|
|
||||||
|
|
||||||
with tracing.trace_span(
|
|
||||||
name="home_request",
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
trace_id=request_id,
|
|
||||||
input=message,
|
|
||||||
metadata={"message_preview": message[:200]},
|
|
||||||
tags=["home"],
|
|
||||||
) as span:
|
|
||||||
langfuse_handler = tracing.get_langfuse_callback()
|
|
||||||
|
|
||||||
# Enrich with memory context
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
memory_context = await memory.enrich_context(
|
|
||||||
user_id, message,
|
|
||||||
trace_id=request_id, session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
|
||||||
"conversation_history": frame.get("conversation_history", []),
|
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
**memory_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
set_current_user(user_id)
|
|
||||||
try:
|
|
||||||
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
|
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(event_stream):
|
|
||||||
await _publish_frame(user_id, ws_frame.model_dump_json())
|
|
||||||
if hasattr(ws_frame, "chunk"):
|
|
||||||
response_chunks.append(ws_frame.chunk)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
|
||||||
finally:
|
|
||||||
clear_current_user()
|
|
||||||
|
|
||||||
# Link prompt and attach output preview
|
|
||||||
tracing.link_prompt_to_trace(span, "home_system")
|
|
||||||
response_text = "".join(response_chunks)
|
|
||||||
span.update(output=response_text[:500] if response_text else None)
|
|
||||||
|
|
||||||
tracing.flush()
|
|
||||||
|
|
||||||
# Store episode
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
await memory.store_episode(
|
|
||||||
user_id, session_id, message, "".join(response_chunks),
|
|
||||||
trace_id=request_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_floating_request(user_id: str, frame: dict) -> None:
|
|
||||||
"""Process a floating_request — enrich with memory, run deep_agent, stream results."""
|
|
||||||
request_id = frame.get("request_id") or str(uuid4())
|
|
||||||
message: str = frame.get("message", "")
|
|
||||||
session_id: str = frame.get("session_id") or str(uuid4())
|
|
||||||
scope: dict = frame.get("scope", {})
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"redis_consumer: floating_request user=%s req=%s scope=%s msg=%s",
|
|
||||||
user_id, request_id, json.dumps(scope)[:200], message[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
response_chunks: list[str] = []
|
|
||||||
|
|
||||||
with tracing.trace_span(
|
|
||||||
name="floating_request",
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
trace_id=request_id,
|
|
||||||
input=message,
|
|
||||||
metadata={"message_preview": message[:200], "scope": scope},
|
|
||||||
tags=["floating"],
|
|
||||||
) as span:
|
|
||||||
langfuse_handler = tracing.get_langfuse_callback()
|
|
||||||
|
|
||||||
# Enrich with memory context
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
memory_context = await memory.enrich_context(
|
|
||||||
user_id, message,
|
|
||||||
trace_id=request_id, session_id=session_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
context: dict = {
|
|
||||||
"scope": scope,
|
|
||||||
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
|
|
||||||
**memory_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
set_current_user(user_id)
|
|
||||||
try:
|
|
||||||
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
|
|
||||||
formatter = StreamFormatter(request_id=request_id)
|
|
||||||
async for ws_frame in formatter.format(event_stream):
|
|
||||||
await _publish_frame(user_id, ws_frame.model_dump_json())
|
|
||||||
if hasattr(ws_frame, "chunk"):
|
|
||||||
response_chunks.append(ws_frame.chunk)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
|
|
||||||
finally:
|
|
||||||
clear_current_user()
|
|
||||||
|
|
||||||
# Link prompt and attach output preview
|
|
||||||
tracing.link_prompt_to_trace(span, "floating_system")
|
|
||||||
response_text = "".join(response_chunks)
|
|
||||||
span.update(output=response_text[:500] if response_text else None)
|
|
||||||
|
|
||||||
tracing.flush()
|
|
||||||
|
|
||||||
# Store episode
|
|
||||||
async with async_session() as db:
|
|
||||||
memory = MemoryMiddleware(db)
|
|
||||||
await memory.store_episode(
|
|
||||||
user_id, session_id, message, "".join(response_chunks),
|
|
||||||
trace_id=request_id,
|
|
||||||
)
|
|
||||||
@@ -1,37 +0,0 @@
|
|||||||
"""Chat REST route — POST /chat fallback when WS is unavailable."""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Request
|
|
||||||
from fastapi.responses import JSONResponse
|
|
||||||
|
|
||||||
from shared.schemas import ChatRequest
|
|
||||||
|
|
||||||
from app.deep_agent import run_home
|
|
||||||
from app.ws_context import clear_current_user, set_current_user
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/chat", tags=["chat"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("")
|
|
||||||
async def chat(body: ChatRequest, request: Request) -> JSONResponse:
|
|
||||||
"""REST fallback for home chat.
|
|
||||||
|
|
||||||
In the microservices setup, Traefik ForwardAuth has already validated
|
|
||||||
the JWT and injected X-User-Id / X-User-Email / X-User-Tier headers.
|
|
||||||
"""
|
|
||||||
user_id = request.headers.get("X-User-Id", "")
|
|
||||||
if not user_id:
|
|
||||||
return JSONResponse(status_code=401, content={"detail": "Missing X-User-Id header"})
|
|
||||||
|
|
||||||
set_current_user(user_id)
|
|
||||||
try:
|
|
||||||
response = await run_home(
|
|
||||||
user_id=user_id,
|
|
||||||
message=body.message,
|
|
||||||
context=body.context.model_dump(),
|
|
||||||
)
|
|
||||||
finally:
|
|
||||||
clear_current_user()
|
|
||||||
|
|
||||||
return JSONResponse(content={"response": response})
|
|
||||||
@@ -1,264 +0,0 @@
|
|||||||
"""Langfuse tracing & prompt management for the Chat Service (v4 SDK).
|
|
||||||
|
|
||||||
Provides:
|
|
||||||
- ``init_langfuse()`` — initialise the singleton client at startup
|
|
||||||
- ``trace_span()`` — context manager that creates a trace + span
|
|
||||||
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
|
|
||||||
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
|
|
||||||
- ``flush()`` / ``shutdown()`` — lifecycle management
|
|
||||||
|
|
||||||
All functions gracefully degrade to no-ops when Langfuse is not configured,
|
|
||||||
so the service works identically with or without observability keys.
|
|
||||||
|
|
||||||
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ── State ────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_initialised: bool = False
|
|
||||||
_disabled: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_configured() -> bool:
|
|
||||||
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
|
|
||||||
|
|
||||||
|
|
||||||
def init_langfuse() -> None:
|
|
||||||
"""Initialise the Langfuse singleton. Call once at startup."""
|
|
||||||
global _initialised, _disabled
|
|
||||||
|
|
||||||
if _initialised or _disabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not _is_configured():
|
|
||||||
_disabled = True
|
|
||||||
logger.info("tracing: Langfuse keys not set — tracing disabled")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import Langfuse
|
|
||||||
|
|
||||||
Langfuse(
|
|
||||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
|
||||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
|
||||||
host=settings.LANGFUSE_HOST,
|
|
||||||
)
|
|
||||||
_initialised = True
|
|
||||||
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
|
|
||||||
except Exception as exc:
|
|
||||||
_disabled = True
|
|
||||||
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client() -> Any | None:
|
|
||||||
"""Return the singleton Langfuse client, or *None* if disabled."""
|
|
||||||
if _disabled:
|
|
||||||
return None
|
|
||||||
if not _initialised:
|
|
||||||
init_langfuse()
|
|
||||||
if _disabled:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
from langfuse import get_client
|
|
||||||
return get_client()
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _NullSpan:
|
|
||||||
"""Drop-in replacement when Langfuse is disabled."""
|
|
||||||
|
|
||||||
def update(self, **_: Any) -> None: ...
|
|
||||||
def set_trace_io(self, **_: Any) -> None: ...
|
|
||||||
def score_trace(self, **_: Any) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
# ── Trace context manager ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def trace_span(
|
|
||||||
*,
|
|
||||||
name: str,
|
|
||||||
user_id: str,
|
|
||||||
session_id: str | None = None,
|
|
||||||
trace_id: str | None = None,
|
|
||||||
input: Any = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
tags: list[str] | None = None,
|
|
||||||
):
|
|
||||||
"""Context manager that creates a Langfuse trace/span.
|
|
||||||
|
|
||||||
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
|
|
||||||
A ``CallbackHandler`` created inside this block auto-inherits the trace
|
|
||||||
context, so there is no need to pass trace IDs manually.
|
|
||||||
"""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None:
|
|
||||||
yield _NullSpan()
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import Langfuse, propagate_attributes
|
|
||||||
|
|
||||||
trace_ctx: dict[str, str] = {}
|
|
||||||
if trace_id is not None:
|
|
||||||
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
|
|
||||||
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="span",
|
|
||||||
name=name,
|
|
||||||
input=input,
|
|
||||||
metadata=metadata or {},
|
|
||||||
**({"trace_context": trace_ctx} if trace_ctx else {}),
|
|
||||||
) as span:
|
|
||||||
with propagate_attributes(
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
tags=tags or [],
|
|
||||||
):
|
|
||||||
yield span
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
|
|
||||||
yield _NullSpan()
|
|
||||||
|
|
||||||
|
|
||||||
# ── LangChain callback handler ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def get_langfuse_callback() -> Any | None:
|
|
||||||
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
|
|
||||||
|
|
||||||
Must be called inside a ``trace_span()`` block for proper linking.
|
|
||||||
Returns *None* when Langfuse is disabled.
|
|
||||||
"""
|
|
||||||
if _disabled and not _initialised:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse.langchain import CallbackHandler
|
|
||||||
return CallbackHandler()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Prompt management ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt(
|
|
||||||
name: str,
|
|
||||||
*,
|
|
||||||
version: int | None = None,
|
|
||||||
label: str | None = None,
|
|
||||||
fallback: str | None = None,
|
|
||||||
cache_ttl_seconds: int = 300,
|
|
||||||
) -> str | None:
|
|
||||||
"""Fetch a managed prompt from Langfuse by name.
|
|
||||||
|
|
||||||
Returns the compiled prompt string, or *fallback* if the prompt is not
|
|
||||||
found or Langfuse is disabled.
|
|
||||||
"""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None:
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
try:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"name": name,
|
|
||||||
"cache_ttl_seconds": cache_ttl_seconds,
|
|
||||||
}
|
|
||||||
if version is not None:
|
|
||||||
kwargs["version"] = version
|
|
||||||
if label is not None:
|
|
||||||
kwargs["label"] = label
|
|
||||||
prompt = lf.get_prompt(**kwargs)
|
|
||||||
return prompt.prompt
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
|
|
||||||
def link_prompt_to_trace(
|
|
||||||
span: Any,
|
|
||||||
prompt_name: str,
|
|
||||||
*,
|
|
||||||
version: int | None = None,
|
|
||||||
label: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Attach prompt metadata to a span/trace."""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None or isinstance(span, _NullSpan):
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
kwargs: dict[str, Any] = {"name": prompt_name}
|
|
||||||
if version is not None:
|
|
||||||
kwargs["version"] = version
|
|
||||||
if label is not None:
|
|
||||||
kwargs["label"] = label
|
|
||||||
prompt = lf.get_prompt(**kwargs)
|
|
||||||
span.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}})
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Scoring helper ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def score_trace(
|
|
||||||
trace_id: str,
|
|
||||||
name: str,
|
|
||||||
value: float,
|
|
||||||
*,
|
|
||||||
comment: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: score_trace failed: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Shutdown ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def flush() -> None:
|
|
||||||
"""Flush pending Langfuse events."""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is not None:
|
|
||||||
try:
|
|
||||||
lf.flush()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: flush failed: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
def shutdown() -> None:
|
|
||||||
"""Flush and close the Langfuse client."""
|
|
||||||
global _initialised, _disabled
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is not None:
|
|
||||||
try:
|
|
||||||
lf.flush()
|
|
||||||
lf.shutdown()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: shutdown failed: %s", exc)
|
|
||||||
_initialised = False
|
|
||||||
_disabled = False
|
|
||||||
@@ -1,115 +0,0 @@
|
|||||||
"""WebSocket context for Chat Service — Redis-based tool call round-trip.
|
|
||||||
|
|
||||||
Replaces the monolith's ws_context.py. Instead of calling Electron directly
|
|
||||||
via WebSocket, this publishes tool_call frames to Redis (ws:out:{user_id})
|
|
||||||
and awaits the result via BRPOP on tool:result:{call_id}.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from contextvars import ContextVar
|
|
||||||
from typing import Any
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from shared.redis import redis_client, tool_result_key, ws_out_channel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
|
|
||||||
|
|
||||||
# Per-request user_id context var (set before agent runs)
|
|
||||||
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
|
|
||||||
|
|
||||||
# Optional collector for debug
|
|
||||||
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
|
||||||
"_tool_result_collector", default=None
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def set_current_user(user_id: str) -> None:
|
|
||||||
_current_user_id.set(user_id)
|
|
||||||
|
|
||||||
|
|
||||||
def clear_current_user() -> None:
|
|
||||||
_current_user_id.set(None)
|
|
||||||
|
|
||||||
|
|
||||||
def set_tool_result_collector(lst: list[dict]) -> None:
|
|
||||||
_tool_result_collector.set(lst)
|
|
||||||
|
|
||||||
|
|
||||||
def clear_tool_result_collector() -> None:
|
|
||||||
_tool_result_collector.set(None)
|
|
||||||
|
|
||||||
|
|
||||||
async def execute_on_client(
|
|
||||||
action: str,
|
|
||||||
table: str | None = None,
|
|
||||||
data: dict[str, Any] | None = None,
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
vector: list[float] | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Send a tool_call to Electron via Redis and await the result.
|
|
||||||
|
|
||||||
1. Build tool_call payload
|
|
||||||
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
|
|
||||||
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
|
|
||||||
4. Return result dict
|
|
||||||
|
|
||||||
Raises RuntimeError if no user_id is set or if the call times out.
|
|
||||||
"""
|
|
||||||
user_id = _current_user_id.get()
|
|
||||||
if not user_id:
|
|
||||||
raise RuntimeError(
|
|
||||||
"execute_on_client() called without a user_id — "
|
|
||||||
"set_current_user() must be called first."
|
|
||||||
)
|
|
||||||
|
|
||||||
call_id = str(uuid4())
|
|
||||||
payload: dict[str, Any] = {
|
|
||||||
"type": "tool_call",
|
|
||||||
"id": call_id,
|
|
||||||
"action": action,
|
|
||||||
}
|
|
||||||
if table is not None:
|
|
||||||
payload["table"] = table
|
|
||||||
if data is not None:
|
|
||||||
payload["data"] = data
|
|
||||||
if filters is not None:
|
|
||||||
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
|
||||||
if vector is not None:
|
|
||||||
payload["vector"] = vector
|
|
||||||
if limit is not None:
|
|
||||||
payload["limit"] = limit
|
|
||||||
|
|
||||||
# Publish tool_call to WS Gateway → Electron
|
|
||||||
channel = ws_out_channel(user_id)
|
|
||||||
await redis_client.publish(channel, json.dumps(payload))
|
|
||||||
|
|
||||||
# Wait for Electron's tool_result
|
|
||||||
result_key = tool_result_key(call_id)
|
|
||||||
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
|
|
||||||
|
|
||||||
if response is None:
|
|
||||||
raise RuntimeError(
|
|
||||||
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
|
|
||||||
f"device may be offline or unresponsive."
|
|
||||||
)
|
|
||||||
|
|
||||||
# response is (key, value) tuple
|
|
||||||
_, raw = response
|
|
||||||
result = json.loads(raw)
|
|
||||||
|
|
||||||
# Collect for debug if requested
|
|
||||||
collector = _tool_result_collector.get(None)
|
|
||||||
if collector is not None:
|
|
||||||
collector.append({
|
|
||||||
"action": action,
|
|
||||||
"table": table,
|
|
||||||
"data": result,
|
|
||||||
})
|
|
||||||
|
|
||||||
return result
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
gunicorn>=22.0.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
pydantic-settings>=2.7.0
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
redis>=5.0.0
|
|
||||||
cryptography>=42.0.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
langchain-core>=0.3.0
|
|
||||||
langchain-openai>=0.3.0
|
|
||||||
langchain-litellm>=0.3.0
|
|
||||||
litellm>=1.50.0
|
|
||||||
openai>=1.50.0
|
|
||||||
httpx>=0.27.0
|
|
||||||
langfuse>=3.0.0
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
# ── builder ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS builder
|
|
||||||
|
|
||||||
WORKDIR /build
|
|
||||||
|
|
||||||
COPY services/ws-gateway/requirements.txt ./requirements.txt
|
|
||||||
RUN pip install --upgrade pip && \
|
|
||||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
|
||||||
|
|
||||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS runtime
|
|
||||||
|
|
||||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY --from=builder /install /usr/local
|
|
||||||
|
|
||||||
# Shared module
|
|
||||||
COPY shared/ shared/
|
|
||||||
|
|
||||||
# Service source
|
|
||||||
COPY services/ws-gateway/app/ app/
|
|
||||||
|
|
||||||
RUN chown -R appuser:appgroup /app
|
|
||||||
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
# Single worker — each instance handles many WS connections via asyncio
|
|
||||||
CMD ["gunicorn", "app.main:app", \
|
|
||||||
"-k", "uvicorn.workers.UvicornWorker", \
|
|
||||||
"--bind", "0.0.0.0:8000", \
|
|
||||||
"--workers", "1", \
|
|
||||||
"--timeout", "0"]
|
|
||||||
@@ -1,17 +0,0 @@
|
|||||||
# WS Gateway
|
|
||||||
|
|
||||||
Stateless WebSocket proxy. Accepts Electron connections, authenticates JWT,
|
|
||||||
routes frames to Chat/Batch services via Redis pub/sub.
|
|
||||||
|
|
||||||
## No business logic
|
|
||||||
This service does NOT know what tasks, notes, or agents are.
|
|
||||||
It only routes JSON frames between Electron and downstream services.
|
|
||||||
|
|
||||||
## Scaling
|
|
||||||
Sticky sessions on `user_id` (Traefik consistent hashing).
|
|
||||||
|
|
||||||
## Redis channels used
|
|
||||||
- Subscribe: `ws:out:{user_id}` (frames to send to client)
|
|
||||||
- Publish: `chat:request:{user_id}`, `batch:request:{user_id}`
|
|
||||||
- LPUSH: `tool:result:{call_id}` (from client tool_result frames)
|
|
||||||
- HSET/HDEL: `ws:devices:{user_id}` (device registry)
|
|
||||||
@@ -1,173 +0,0 @@
|
|||||||
"""WebSocket handler — device connection lifecycle.
|
|
||||||
|
|
||||||
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
|
|
||||||
and runs two concurrent loops:
|
|
||||||
1. Message loop: receive frames from Electron, route to Redis
|
|
||||||
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
|
|
||||||
3. Heartbeat loop: ping every 30s
|
|
||||||
|
|
||||||
No business logic lives here — the handler is a JSON frame router.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from uuid import uuid4
|
|
||||||
|
|
||||||
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
from shared.schemas import WsFrameType
|
|
||||||
|
|
||||||
from app.redis_bridge import (
|
|
||||||
publish_batch_request,
|
|
||||||
publish_chat_request,
|
|
||||||
push_tool_result,
|
|
||||||
register_device,
|
|
||||||
set_gateway_id,
|
|
||||||
subscribe_outbound,
|
|
||||||
unregister_device,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
|
|
||||||
|
|
||||||
_HEARTBEAT_INTERVAL = 30 # seconds
|
|
||||||
|
|
||||||
# Set a unique gateway instance ID on module load
|
|
||||||
set_gateway_id(str(uuid4()))
|
|
||||||
|
|
||||||
|
|
||||||
@router.websocket("/device")
|
|
||||||
async def device_ws(websocket: WebSocket) -> None:
|
|
||||||
"""Persistent WebSocket endpoint for Electron device connections."""
|
|
||||||
|
|
||||||
# ── 1. Authenticate via ?token= query parameter ──────────────────
|
|
||||||
token = websocket.query_params.get("token", "")
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token,
|
|
||||||
settings.JWT_PUBLIC_KEY,
|
|
||||||
algorithms=["RS256"],
|
|
||||||
)
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
email: str | None = payload.get("email")
|
|
||||||
if not user_id:
|
|
||||||
raise JWTError("missing sub")
|
|
||||||
except JWTError:
|
|
||||||
await websocket.close(code=1008)
|
|
||||||
return
|
|
||||||
|
|
||||||
await websocket.accept()
|
|
||||||
|
|
||||||
# ── 2. Await device_hello frame ──────────────────────────────────
|
|
||||||
try:
|
|
||||||
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
|
||||||
except (asyncio.TimeoutError, WebSocketDisconnect):
|
|
||||||
await websocket.close(code=1008)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
hello = json.loads(raw)
|
|
||||||
if hello.get("type") != WsFrameType.device_hello:
|
|
||||||
raise ValueError("expected device_hello as first frame")
|
|
||||||
device_id: str = hello["device_id"]
|
|
||||||
agent_ids: list[str] = hello.get("agent_ids", [])
|
|
||||||
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
|
||||||
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
|
|
||||||
await websocket.close(code=1008)
|
|
||||||
return
|
|
||||||
|
|
||||||
# ── 3. Register device in Redis ──────────────────────────────────
|
|
||||||
await register_device(user_id, device_id)
|
|
||||||
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
|
|
||||||
|
|
||||||
# Notify downstream services that device is online (for agent trigger)
|
|
||||||
await publish_batch_request(user_id, {
|
|
||||||
"type": "device_online",
|
|
||||||
"user_id": user_id,
|
|
||||||
"device_id": device_id,
|
|
||||||
"agent_ids": agent_ids,
|
|
||||||
})
|
|
||||||
|
|
||||||
# ── 4. Subscribe to outbound Redis channel ───────────────────────
|
|
||||||
pubsub = await subscribe_outbound(user_id)
|
|
||||||
|
|
||||||
# ── 5. Run concurrent loops ──────────────────────────────────────
|
|
||||||
try:
|
|
||||||
await asyncio.gather(
|
|
||||||
_inbound_loop(websocket, user_id),
|
|
||||||
_outbound_loop(websocket, pubsub),
|
|
||||||
_heartbeat_loop(websocket),
|
|
||||||
)
|
|
||||||
except WebSocketDisconnect:
|
|
||||||
pass
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
|
|
||||||
finally:
|
|
||||||
await pubsub.unsubscribe()
|
|
||||||
await pubsub.aclose()
|
|
||||||
await unregister_device(user_id)
|
|
||||||
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Inbound: Electron → Redis ────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
|
|
||||||
"""Receive frames from Electron and route to the appropriate Redis channel."""
|
|
||||||
async for raw in websocket.iter_text():
|
|
||||||
try:
|
|
||||||
frame: dict = json.loads(raw)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning("handler: invalid JSON from user=%s", user_id)
|
|
||||||
continue
|
|
||||||
|
|
||||||
frame_type = frame.get("type")
|
|
||||||
|
|
||||||
# Inject user_id so downstream services know who sent it
|
|
||||||
frame["user_id"] = user_id
|
|
||||||
|
|
||||||
if frame_type == WsFrameType.tool_result:
|
|
||||||
call_id = frame.get("id")
|
|
||||||
if call_id:
|
|
||||||
await push_tool_result(call_id, frame)
|
|
||||||
else:
|
|
||||||
logger.warning("handler: tool_result missing id user=%s", user_id)
|
|
||||||
|
|
||||||
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
|
|
||||||
await publish_chat_request(user_id, frame)
|
|
||||||
|
|
||||||
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
|
|
||||||
await publish_batch_request(user_id, frame)
|
|
||||||
|
|
||||||
elif frame_type == "pong":
|
|
||||||
pass # heartbeat ack
|
|
||||||
|
|
||||||
else:
|
|
||||||
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Outbound: Redis → Electron ───────────────────────────────────────
|
|
||||||
|
|
||||||
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
|
|
||||||
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
|
|
||||||
while True:
|
|
||||||
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
|
|
||||||
if message is not None and message["type"] == "message":
|
|
||||||
await websocket.send_text(message["data"])
|
|
||||||
else:
|
|
||||||
# Brief sleep to avoid busy-wait when no messages
|
|
||||||
await asyncio.sleep(0.01)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Heartbeat ────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
|
||||||
"""Send ping frames every 30s to keep the connection alive."""
|
|
||||||
while True:
|
|
||||||
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
|
||||||
await websocket.send_text(json.dumps({"type": "ping"}))
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
"""WS Gateway — stateless WebSocket proxy.
|
|
||||||
|
|
||||||
Accepts Electron device connections, authenticates JWT (RS256 public key),
|
|
||||||
and routes frames between Electron and downstream services via Redis pub/sub.
|
|
||||||
|
|
||||||
This service has NO business logic — it only routes JSON frames.
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
import logging
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
|
|
||||||
_repo_root = str(Path(__file__).resolve().parents[3])
|
|
||||||
if _repo_root not in sys.path:
|
|
||||||
sys.path.insert(0, _repo_root)
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
logging.basicConfig(
|
|
||||||
level=logging.INFO,
|
|
||||||
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
yield
|
|
||||||
from shared.redis import redis_client
|
|
||||||
|
|
||||||
await redis_client.aclose()
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
|
||||||
app = FastAPI(
|
|
||||||
title="Adiuva WS Gateway",
|
|
||||||
version="0.1.0",
|
|
||||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
|
||||||
redoc_url=None,
|
|
||||||
lifespan=lifespan,
|
|
||||||
)
|
|
||||||
|
|
||||||
from app.handler import router
|
|
||||||
|
|
||||||
app.include_router(router, prefix="/api/v1")
|
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
|
||||||
async def health() -> dict:
|
|
||||||
return {"status": "ok", "service": "ws-gateway", "version": app.version}
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
app = create_app()
|
|
||||||
@@ -1,104 +0,0 @@
|
|||||||
"""Redis bridge — device registry + pub/sub routing.
|
|
||||||
|
|
||||||
All inter-service communication passes through Redis:
|
|
||||||
- Device registry: HSET/HDEL ws:devices:{user_id}
|
|
||||||
- Outbound frames: Subscribe ws:out:{user_id}
|
|
||||||
- Chat requests: Publish chat:request:{user_id}
|
|
||||||
- Batch requests: Publish batch:request:{user_id}
|
|
||||||
- Tool results: LPUSH tool:result:{call_id}
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
|
|
||||||
from shared.redis import (
|
|
||||||
batch_request_channel,
|
|
||||||
chat_request_channel,
|
|
||||||
device_key,
|
|
||||||
redis_client,
|
|
||||||
tool_result_key,
|
|
||||||
ws_out_channel,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# Instance ID for this gateway replica (set on startup)
|
|
||||||
_GATEWAY_ID: str = ""
|
|
||||||
|
|
||||||
|
|
||||||
def set_gateway_id(gid: str) -> None:
|
|
||||||
global _GATEWAY_ID
|
|
||||||
_GATEWAY_ID = gid
|
|
||||||
|
|
||||||
|
|
||||||
def get_gateway_id() -> str:
|
|
||||||
return _GATEWAY_ID
|
|
||||||
|
|
||||||
|
|
||||||
# ── Device Registry ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def register_device(user_id: str, device_id: str) -> None:
|
|
||||||
"""Register a connected device in Redis."""
|
|
||||||
key = device_key(user_id)
|
|
||||||
await redis_client.hset(key, mapping={
|
|
||||||
"device_id": device_id,
|
|
||||||
"gateway_id": _GATEWAY_ID,
|
|
||||||
})
|
|
||||||
logger.info("redis_bridge: registered user=%s device=%s gateway=%s", user_id, device_id, _GATEWAY_ID)
|
|
||||||
|
|
||||||
|
|
||||||
async def unregister_device(user_id: str) -> None:
|
|
||||||
"""Remove device registration from Redis."""
|
|
||||||
key = device_key(user_id)
|
|
||||||
await redis_client.delete(key)
|
|
||||||
logger.info("redis_bridge: unregistered user=%s", user_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def is_device_online(user_id: str) -> bool:
|
|
||||||
"""Check if a device is registered."""
|
|
||||||
key = device_key(user_id)
|
|
||||||
return await redis_client.exists(key) > 0
|
|
||||||
|
|
||||||
|
|
||||||
# ── Frame Routing ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_chat_request(user_id: str, frame: dict) -> None:
|
|
||||||
"""Forward a chat request frame to the Chat Service via Redis."""
|
|
||||||
channel = chat_request_channel(user_id)
|
|
||||||
await redis_client.publish(channel, json.dumps(frame))
|
|
||||||
logger.debug("redis_bridge: published chat_request user=%s", user_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def publish_batch_request(user_id: str, frame: dict) -> None:
|
|
||||||
"""Forward a batch request frame to the Batch Agent Service via Redis."""
|
|
||||||
channel = batch_request_channel(user_id)
|
|
||||||
await redis_client.publish(channel, json.dumps(frame))
|
|
||||||
logger.debug("redis_bridge: published batch_request user=%s", user_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def push_tool_result(call_id: str, result: dict) -> None:
|
|
||||||
"""Push a tool_result to the Redis list for the waiting service.
|
|
||||||
|
|
||||||
Chat/Batch services do BRPOP on this key with a 30s timeout.
|
|
||||||
"""
|
|
||||||
key = tool_result_key(call_id)
|
|
||||||
await redis_client.lpush(key, json.dumps(result))
|
|
||||||
# Auto-expire after 60s to prevent stale keys
|
|
||||||
await redis_client.expire(key, 60)
|
|
||||||
logger.debug("redis_bridge: pushed tool_result call_id=%s", call_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def subscribe_outbound(user_id: str):
|
|
||||||
"""Return an async pubsub subscription for frames to send to Electron.
|
|
||||||
|
|
||||||
Chat/Batch services publish to ws:out:{user_id} and this gateway
|
|
||||||
forwards them to the connected WebSocket.
|
|
||||||
"""
|
|
||||||
channel = ws_out_channel(user_id)
|
|
||||||
pubsub = redis_client.pubsub()
|
|
||||||
await pubsub.subscribe(channel)
|
|
||||||
return pubsub
|
|
||||||
@@ -1,8 +0,0 @@
|
|||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
gunicorn>=22.0.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
pydantic-settings>=2.7.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
redis>=5.0.0
|
|
||||||
websockets>=14.0
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
"""Shared module — imported by all microservices.
|
|
||||||
|
|
||||||
Contains DB engine/session, ORM models, Pydantic schemas, config,
|
|
||||||
and Redis utilities. Changes here affect ALL services.
|
|
||||||
"""
|
|
||||||
@@ -1,98 +0,0 @@
|
|||||||
"""Shared configuration — Pydantic Settings loaded from environment.
|
|
||||||
|
|
||||||
All services import ``settings`` from here. Each service only uses a subset
|
|
||||||
of the vars, but keeping one Settings class avoids fragmentation.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Literal
|
|
||||||
|
|
||||||
from pydantic import field_validator
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
||||||
|
|
||||||
# Locate the repo root (adiuva-api/) so we can load its .env as a fallback.
|
|
||||||
# Works whether cwd is adiuva-api/ (monolith) or adiuva-api/services/xyz/ (microservice).
|
|
||||||
_this_dir = Path(__file__).resolve().parent # shared/
|
|
||||||
_repo_root = _this_dir.parent # adiuva-api/
|
|
||||||
_root_env = _repo_root / ".env"
|
|
||||||
|
|
||||||
|
|
||||||
class Settings(BaseSettings):
|
|
||||||
# ── Database ─────────────────────────────────────────────────────
|
|
||||||
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
|
||||||
|
|
||||||
# ── JWT ────────────────────────────────────────────────────────
|
|
||||||
# RS256 public key (PEM). Used by any service that needs to verify
|
|
||||||
# JWTs locally (optional — Traefik ForwardAuth handles this in prod).
|
|
||||||
# The private key lives ONLY in the Auth Service config.
|
|
||||||
JWT_PUBLIC_KEY: str = ""
|
|
||||||
|
|
||||||
@field_validator("JWT_PUBLIC_KEY", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _expand_pem_newlines(cls, v: str) -> str:
|
|
||||||
if isinstance(v, str) and r"\n" in v:
|
|
||||||
return v.replace(r"\n", "\n")
|
|
||||||
return v
|
|
||||||
|
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
|
||||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
|
||||||
|
|
||||||
# ── Redis ────────────────────────────────────────────────────────
|
|
||||||
REDIS_URL: str = "redis://localhost:6379/0"
|
|
||||||
|
|
||||||
# ── Stripe ───────────────────────────────────────────────────────
|
|
||||||
STRIPE_SECRET_KEY: str = ""
|
|
||||||
STRIPE_WEBHOOK_SECRET: str = ""
|
|
||||||
|
|
||||||
# ── S3 ───────────────────────────────────────────────────────────
|
|
||||||
S3_BUCKET: str = ""
|
|
||||||
S3_REGION: str = "us-east-1"
|
|
||||||
S3_ENDPOINT_URL: str = ""
|
|
||||||
AWS_ACCESS_KEY_ID: str = ""
|
|
||||||
AWS_SECRET_ACCESS_KEY: str = ""
|
|
||||||
|
|
||||||
# ── Vector stores ────────────────────────────────────────────────
|
|
||||||
PINECONE_API_KEY: str = ""
|
|
||||||
PINECONE_INDEX: str = "adiuva"
|
|
||||||
QDRANT_URL: str = ""
|
|
||||||
QDRANT_API_KEY: str = ""
|
|
||||||
|
|
||||||
# ── LLM providers ────────────────────────────────────────────────
|
|
||||||
OPENAI_API_KEY: str = ""
|
|
||||||
ANTHROPIC_API_KEY: str = ""
|
|
||||||
GOOGLE_API_KEY: str = ""
|
|
||||||
CEREBRAS_API_KEY: str = ""
|
|
||||||
|
|
||||||
LLM_MODEL: str = "gpt-4o"
|
|
||||||
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
|
||||||
|
|
||||||
GITHUB_COPILOT_TOKEN_DIR: str = ""
|
|
||||||
|
|
||||||
# ── OAuth (integrations) ─────────────────────────────────────────
|
|
||||||
GMAIL_CLIENT_ID: str = ""
|
|
||||||
GMAIL_CLIENT_SECRET: str = ""
|
|
||||||
MS_CLIENT_ID: str = ""
|
|
||||||
MS_CLIENT_SECRET: str = ""
|
|
||||||
MS_TENANT_ID: str = "common"
|
|
||||||
OAUTH_ENCRYPTION_KEY: str = ""
|
|
||||||
|
|
||||||
# ── Langfuse (observability) ─────────────────────────────────────
|
|
||||||
LANGFUSE_SECRET_KEY: str = ""
|
|
||||||
LANGFUSE_PUBLIC_KEY: str = ""
|
|
||||||
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
|
|
||||||
|
|
||||||
# ── CORS ─────────────────────────────────────────────────────────
|
|
||||||
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
|
||||||
|
|
||||||
# ── Environment ──────────────────────────────────────────────────
|
|
||||||
ENV: Literal["dev", "prod"] = "dev"
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(
|
|
||||||
# Local .env (cwd) takes priority; root .env is fallback.
|
|
||||||
env_file=(".env", str(_root_env)),
|
|
||||||
env_file_encoding="utf-8",
|
|
||||||
extra="ignore",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
settings = Settings()
|
|
||||||
32
shared/db.py
32
shared/db.py
@@ -1,32 +0,0 @@
|
|||||||
"""Database engine, session factory, and declarative base.
|
|
||||||
|
|
||||||
All services use the async SQLAlchemy API via ``get_session()``.
|
|
||||||
Alembic migrations use the synchronous psycopg2 URL (see alembic/env.py).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from collections.abc import AsyncGenerator
|
|
||||||
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
engine = create_async_engine(
|
|
||||||
settings.DATABASE_URL,
|
|
||||||
pool_pre_ping=True,
|
|
||||||
echo=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
async_session = async_sessionmaker(engine, expire_on_commit=False)
|
|
||||||
|
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
|
||||||
"""Shared declarative base for all ORM models."""
|
|
||||||
|
|
||||||
|
|
||||||
async def get_session() -> AsyncGenerator[AsyncSession, None]:
|
|
||||||
"""FastAPI dependency that yields an async DB session per request."""
|
|
||||||
async with async_session() as session:
|
|
||||||
yield session
|
|
||||||
455
shared/models.py
455
shared/models.py
@@ -1,455 +0,0 @@
|
|||||||
"""SQLAlchemy ORM models for all persistent tables.
|
|
||||||
|
|
||||||
Centralized here so that Alembic migrations and all services share
|
|
||||||
the same model definitions. Each service only queries the tables it owns.
|
|
||||||
|
|
||||||
Ownership:
|
|
||||||
Auth Service → users, refresh_tokens, subscriptions
|
|
||||||
Chat Service → memory_core, memory_associative, memory_episodic, memory_proactive
|
|
||||||
Batch Agent → local_agent_configs, cloud_agent_configs, agent_run_logs
|
|
||||||
Billing Service → subscriptions (shared write with Auth)
|
|
||||||
(excluded MVP) → storage_records, backup_metadata, plugins, plugin_*, revenue_events
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from sqlalchemy import (
|
|
||||||
BigInteger,
|
|
||||||
Boolean,
|
|
||||||
DateTime,
|
|
||||||
Enum,
|
|
||||||
Float,
|
|
||||||
ForeignKey,
|
|
||||||
Integer,
|
|
||||||
JSON,
|
|
||||||
String,
|
|
||||||
Text,
|
|
||||||
UniqueConstraint,
|
|
||||||
Uuid,
|
|
||||||
func,
|
|
||||||
)
|
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
|
||||||
|
|
||||||
from shared.db import Base
|
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _uuid() -> str:
|
|
||||||
return str(uuid.uuid4())
|
|
||||||
|
|
||||||
|
|
||||||
def _now() -> datetime:
|
|
||||||
return datetime.now(timezone.utc)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Enum types ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
|
|
||||||
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
|
|
||||||
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
|
|
||||||
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
|
|
||||||
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
|
|
||||||
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
|
||||||
|
|
||||||
|
|
||||||
# ── Auth models ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
|
||||||
__tablename__ = "users"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
|
|
||||||
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
|
||||||
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
|
||||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
|
||||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
||||||
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
|
|
||||||
back_populates="user", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
subscription: Mapped[Subscription | None] = relationship(
|
|
||||||
back_populates="user", uselist=False, cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RefreshToken(Base):
|
|
||||||
__tablename__ = "refresh_tokens"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
|
|
||||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
user: Mapped[User] = relationship(back_populates="refresh_tokens")
|
|
||||||
|
|
||||||
|
|
||||||
class Subscription(Base):
|
|
||||||
__tablename__ = "subscriptions"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
|
||||||
nullable=False, unique=True, index=True
|
|
||||||
)
|
|
||||||
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
|
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
|
||||||
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
|
|
||||||
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
user: Mapped[User] = relationship(back_populates="subscription")
|
|
||||||
|
|
||||||
|
|
||||||
# ── Storage models (excluded from MVP, kept for Alembic) ──────────────
|
|
||||||
|
|
||||||
|
|
||||||
class StorageRecord(Base):
|
|
||||||
__tablename__ = "storage_records"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
|
|
||||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
|
||||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
||||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class BackupMetadata(Base):
|
|
||||||
__tablename__ = "backup_metadata"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
|
|
||||||
version: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
|
|
||||||
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
|
|
||||||
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Plugin models (excluded from MVP, kept for Alembic) ───────────────
|
|
||||||
|
|
||||||
|
|
||||||
class Plugin(Base):
|
|
||||||
__tablename__ = "plugins"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(String(255), primary_key=True)
|
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
|
||||||
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
|
||||||
author_id: Mapped[str | None] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
|
||||||
)
|
|
||||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
|
||||||
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
|
||||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
|
||||||
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
|
||||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
|
||||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
|
|
||||||
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
submitted_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
installations: Mapped[list[PluginInstallation]] = relationship(
|
|
||||||
back_populates="plugin", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
reviews: Mapped[list[PluginReview]] = relationship(
|
|
||||||
back_populates="plugin", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
revenue_events: Mapped[list[RevenueEvent]] = relationship(
|
|
||||||
back_populates="plugin", cascade="all, delete-orphan"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class PluginInstallation(Base):
|
|
||||||
__tablename__ = "plugin_installations"
|
|
||||||
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
plugin_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
installed_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin: Mapped[Plugin] = relationship(back_populates="installations")
|
|
||||||
|
|
||||||
|
|
||||||
class PluginReview(Base):
|
|
||||||
__tablename__ = "plugin_reviews"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
plugin_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
reviewer_id: Mapped[str | None] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
|
||||||
)
|
|
||||||
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
|
|
||||||
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
reviewed_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
|
|
||||||
|
|
||||||
|
|
||||||
class RevenueEvent(Base):
|
|
||||||
__tablename__ = "revenue_events"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
plugin_id: Mapped[str] = mapped_column(
|
|
||||||
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
||||||
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
|
||||||
|
|
||||||
|
|
||||||
# ── Agent models ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfig(Base):
|
|
||||||
__tablename__ = "local_agent_configs"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
device_id: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
|
||||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
|
||||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
|
||||||
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
|
||||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
|
||||||
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
|
||||||
back_populates="local_agent",
|
|
||||||
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
|
||||||
foreign_keys="AgentRunLog.agent_id",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
overlaps="run_logs,cloud_agent",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class CloudAgentConfig(Base):
|
|
||||||
__tablename__ = "cloud_agent_configs"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
provider: Mapped[str] = mapped_column(CloudProviderEnum, nullable=False)
|
|
||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
|
|
||||||
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
|
||||||
oauth_token_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
|
|
||||||
filter_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
|
|
||||||
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
|
|
||||||
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
|
||||||
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
run_logs: Mapped[list[AgentRunLog]] = relationship(
|
|
||||||
back_populates="cloud_agent",
|
|
||||||
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
|
||||||
foreign_keys="AgentRunLog.agent_id",
|
|
||||||
cascade="all, delete-orphan",
|
|
||||||
overlaps="run_logs,local_agent",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class AgentRunLog(Base):
|
|
||||||
__tablename__ = "agent_run_logs"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
|
||||||
)
|
|
||||||
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
|
||||||
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
|
|
||||||
)
|
|
||||||
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
|
|
||||||
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
|
||||||
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
|
||||||
started_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
|
||||||
|
|
||||||
local_agent: Mapped[LocalAgentConfig | None] = relationship(
|
|
||||||
back_populates="run_logs",
|
|
||||||
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
|
|
||||||
foreign_keys="AgentRunLog.agent_id",
|
|
||||||
overlaps="run_logs,cloud_agent",
|
|
||||||
)
|
|
||||||
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
|
|
||||||
back_populates="run_logs",
|
|
||||||
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
|
|
||||||
foreign_keys="AgentRunLog.agent_id",
|
|
||||||
overlaps="run_logs,local_agent",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Memory models ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryCore(Base):
|
|
||||||
"""Per-user persistent key/value preferences, encrypted at rest."""
|
|
||||||
|
|
||||||
__tablename__ = "memory_core"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
|
||||||
nullable=False, index=True,
|
|
||||||
)
|
|
||||||
key: Mapped[str] = mapped_column(String(255), nullable=False)
|
|
||||||
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryAssociative(Base):
|
|
||||||
"""Per-user semantic memory: encrypted content + pgvector embedding."""
|
|
||||||
|
|
||||||
__tablename__ = "memory_associative"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
|
||||||
nullable=False, index=True,
|
|
||||||
)
|
|
||||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
|
||||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
|
||||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
|
||||||
updated_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryEpisodic(Base):
|
|
||||||
"""Per-user session summaries, encrypted at rest."""
|
|
||||||
|
|
||||||
__tablename__ = "memory_episodic"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
|
||||||
nullable=False, index=True,
|
|
||||||
)
|
|
||||||
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class MemoryProactive(Base):
|
|
||||||
"""Per-user inferred behavioral patterns, encrypted at rest."""
|
|
||||||
|
|
||||||
__tablename__ = "memory_proactive"
|
|
||||||
|
|
||||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
|
||||||
user_id: Mapped[str] = mapped_column(
|
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
|
|
||||||
nullable=False, index=True,
|
|
||||||
)
|
|
||||||
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
|
||||||
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
|
|
||||||
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
|
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
|
||||||
)
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
"""Redis client and pub/sub utilities for inter-service communication.
|
|
||||||
|
|
||||||
All services that need Redis import from here.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
redis_client: aioredis.Redis = aioredis.from_url(
|
|
||||||
settings.REDIS_URL,
|
|
||||||
decode_responses=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Channel naming conventions ────────────────────────────────────────
|
|
||||||
# See /memories/repo/microservices-architecture.md for full list.
|
|
||||||
|
|
||||||
def ws_out_channel(user_id: str) -> str:
|
|
||||||
"""Frames to forward to Electron via WS Gateway."""
|
|
||||||
return f"ws:out:{user_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def chat_request_channel(user_id: str) -> str:
|
|
||||||
"""Chat requests (home + floating) from WS Gateway → Chat Service."""
|
|
||||||
return f"chat:request:{user_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def batch_request_channel(user_id: str) -> str:
|
|
||||||
"""Batch requests (journey + triggers) from WS Gateway → Batch Agent."""
|
|
||||||
return f"batch:request:{user_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def tool_result_key(call_id: str) -> str:
|
|
||||||
"""Tool result list: LPUSH by WS Gateway, BRPOP by Chat/Batch."""
|
|
||||||
return f"tool:result:{call_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def device_key(user_id: str) -> str:
|
|
||||||
"""Device registry hash."""
|
|
||||||
return f"ws:devices:{user_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def tier_changed_channel(user_id: str) -> str:
|
|
||||||
"""Billing tier change notifications."""
|
|
||||||
return f"tier:changed:{user_id}"
|
|
||||||
|
|
||||||
|
|
||||||
def journey_session_key(user_id: str) -> str:
|
|
||||||
"""Journey builder session (String + TTL 1800s)."""
|
|
||||||
return f"journey:{user_id}"
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user