24 Commits

Author SHA1 Message Date
Roberto Musso
2b7d302ef2 refactor: remove monolith app/, Dockerfile, requirements.txt
All business logic has been extracted into microservices:
  - services/auth/       (Step 1)
  - services/ws-gateway/ (Step 2)
  - services/chat/       (Step 2)
  - services/batch-agent/ (Step 3)
  - services/billing/    (Step 4)

Shared code lives in shared/.
Migrations remain in alembic/.
Tests in tests/ will need updating to target individual services.
2026-04-06 23:44:12 +02:00
Roberto Musso
7f6ea29525 feat(infra): Docker Compose orchestration + env updates (Step 5)
- Replace monolith docker-compose with full microservices stack
- Services: traefik, db, redis, migrate, auth, ws-gateway, chat, batch-agent, billing
- Traefik API gateway with ForwardAuth, ACME/Cloudflare DNS-01 (from Step 2)
- Centralized migrations via 'migrate' service (run-once)
- All services share .env via env_file + override DATABASE_URL/REDIS_URL
- Health checks on db and redis; service dependency ordering
- MinIO and Qdrant kept as optional (commented out)
- .env.example: add JWT_PRIVATE_KEY, CF_DNS_API_TOKEN, ACME_EMAIL, POSTGRES_ vars
2026-04-06 23:40:14 +02:00
Roberto Musso
48036397f1 fix(billing): auto-detect repo root for shared module import in local dev 2026-04-06 23:32:17 +02:00
Roberto Musso
57b5648915 feat(billing): extract Billing Service (Step 4)
- stripe_service: checkout sessions, webhook handling, subscription CRUD
- tier_manager: feature matrix (4 tiers), quota enforcement, rate limits
- routes: checkout, webhook (no auth), subscription, tier query, features
- Traefik header auth (X-User-Id) replaces get_current_user dependency
- /tier/{user_id} endpoint for internal service-to-service lookups
- /features and /features/{tier} for feature matrix queries
- Dockerfile: single worker, 30s timeout (lightweight service)
2026-04-06 23:07:46 +02:00
Roberto Musso
7e4374c69b feat(eval): add custom system prompt support for step-1 classification 2026-04-06 22:56:30 +02:00
Roberto Musso
fe0dd038ee fix: Langfuse SDK v4 migration, tracing improvements, and LLM config
- Langfuse SDK v4: fix prompt-to-trace linking (as_type=generation)
- tracing: compile_prompt with Langfuse managed prompt fallback
- journey: remove journey CLI subcommand (keep only interactive)
- LLM: add service-specific llm modules for batch-agent and chat
- gitignore: exclude eval private test data
- config: add LANGFUSE settings to shared config
2026-03-24 16:25:51 +01:00
Roberto Musso
d3f7099d93 refactor(eval): 3-mode eval harness (step1/step2/full) with Langfuse fixes
- Rewrite eval config with EvalMode (step1, step2, full) replacing prompt_variants
- Rewrite runner with _run_step1, _run_step2, _run_full dispatch
- CLI: replace --variants with --mode flag
- Add 3 fixture YAMLs: classify_invoices (step1), process_invoices (step2), full_invoices (full)
- Remove old freelance_invoices fixture
- Langfuse: mode-aware dataset items (classifications for step1, extraction for step2, both for full)
- Langfuse: link both prompts (batch_file_classifier + batch_processing) in full mode
- Langfuse: post separate classification_precision/recall/f1 scores for full mode
- Langfuse: skip misleading field_accuracy=0 when field_scores is empty (step1)
- Langfuse: include step1_results in trace output
- MockExecutor: mock async_session to bypass DB in full mode
- Journey fixture: remove user_messages (only interactive test kept)
2026-03-24 16:18:51 +01:00
Roberto Musso
63fa119543 feat(batch-agent): add journey eval to E2E harness
- journey_runner.py: orchestrates journey start → simulated user
  messages → template extraction → LLM judge scoring
- config.py: JourneyFixture dataclass with user_messages and
  expected_template_criteria, discover_journey_fixtures()
- langfuse_eval.py: sync_journey_fixture_to_dataset()
- cli.py: new 'journey' subcommand (python -m eval journey)
  with --fixture, --models, --judge-model flags
- fixtures/journey_invoice_setup.yaml: example journey fixture
  with 4 user messages and 8 quality criteria
2026-03-23 23:16:41 +01:00
Roberto Musso
d856dfd28c refactor: deduplicate shared code into shared/ module
Move duplicated files from chat + batch-agent into shared/:
- shared/ws_context.py — Redis-based tool call round-trip
- shared/llm.py — LiteLLM factory (get_llm, embed)
- shared/agents/ — 4 domain agents (task, note, project, timeline)

Update all service imports to use shared.* instead of app.*.
Delete 12 duplicated files across both services.
2026-03-23 23:01:45 +01:00
Roberto Musso
ccba54ac24 fix(tracing): use Langfuse compile_prompt with {{variable}} syntax
- tracing.py: add compile_prompt() that uses Langfuse .compile(**vars)
  for {{variable}} substitution, falls back to Python .format() for
  hardcoded {variable} templates
- agent_runner.py: replace _get_system_prompt().format() with
  tracing.compile_prompt() for batch_file_classifier, batch_processing,
  batch_cloud_processing prompts
- journey.py: replace get_prompt + .format() with compile_prompt()
  for journey_system prompt
- chat tracing.py: add compile_prompt() for parity (chat prompts
  currently have no variables, but ready for future use)
- Remove unused _get_system_prompt helper
2026-03-23 22:39:27 +01:00
Roberto Musso
55500cc818 feat(batch-agent): add Langfuse prompt management
- _get_system_prompt helper: fetches managed prompts from Langfuse
  with hardcoded fallback (same pattern as chat service)
- journey.py: journey_system prompt manageable via Langfuse
- agent_runner.py: batch_file_classifier, batch_processing,
  batch_cloud_processing prompts all manageable via Langfuse
- redis_consumer.py: link_prompt_to_trace for all three handlers
2026-03-23 22:30:36 +01:00
Roberto Musso
75a826c9d8 feat(batch-agent): add E2E evaluation harness with Langfuse integration
- eval/mock_executor.py: intercepts execute_on_client, serves fixture
  files from disk, records all mutations (insert/update/delete)
- eval/config.py: YAML fixture loader with prompt variants, expected
  results, seed records, model overrides
- eval/scorer.py: FieldMatchScorer (fuzzy title match, per-field
  accuracy, precision/recall/F1) + LLMJudgeScorer (semantic eval)
- eval/langfuse_eval.py: sync fixtures to Langfuse datasets, create
  dataset runs, post scores, link traces to runs
- eval/runner.py: orchestrates fixture → mock → agent pipeline →
  scoring → Langfuse reporting
- eval/cli.py: CLI (python -m eval run/list/sync) with --models,
  --variants, --fixture, --no-judge flags
- eval/fixtures/: example Italian freelance scenario with 3 prompt
  variants (baseline, detailed_italian, minimal)
2026-03-23 08:54:19 +01:00
Roberto Musso
971f1dd84f feat(batch-agent): integrate Langfuse tracing
- tracing.py: init/shutdown, trace_span, get_langfuse_callback, prompt mgmt
- main.py: init_langfuse at startup, shutdown on teardown
- redis_consumer.py: trace_span around journey_start/message/agent_trigger
- agent_runner.py: thread langfuse_handler through classify + processing LLM
- journey.py: thread langfuse_handler through _call_llm_with_tools
- llm.py: accept callbacks param, forward to LLM constructors
- requirements.txt: add langfuse>=3.0.0
2026-03-23 08:43:15 +01:00
Roberto Musso
333bba6fdd feat(batch-agent): extract Batch Agent Service (Step 3)
- agent_runner: local directory + cloud agent orchestration via Redis
- 5 domain agents: filesystem, task, note, project, timeline
- integrations: Gmail, MS Graph (Outlook + Teams)
- journey: guided chatbot conversation to build prompt_template
- routes: REST endpoints (catalog, can-create, trigger)
- redis_consumer: subscribes to batch:request:* pattern
- ws_context: Redis-based execute_on_client for tool round-trip
- Dockerfile with 300s timeout for long-running batch jobs
2026-03-23 07:19:02 +01:00
Roberto Musso
229e20d073 docs: add Langfuse integration TODO for batch-agent service 2026-03-23 00:25:42 +01:00
Roberto Musso
0b491b3643 fix: langfuse v4 SDK compatibility and pass user message as trace input 2026-03-23 00:23:59 +01:00
Roberto Musso
0d5fa3e569 feat(chat): integrate Langfuse tracing, prompt management & generation tracking
- shared/config.py: add LANGFUSE_SECRET_KEY, LANGFUSE_PUBLIC_KEY, LANGFUSE_HOST
- services/chat/app/tracing.py: new module — Langfuse client singleton,
  create_trace(), get_langfuse_callback(), get_prompt(), link_prompt_to_trace(),
  score_trace(), flush/shutdown helpers. Gracefully no-ops when keys are missing.
- services/chat/app/llm.py: add callbacks param to get_llm() for LangChain
  callback handler injection
- services/chat/app/deep_agent.py: accept langfuse_handler in all run_* and
  _run_single_agent* functions, pipe callbacks to LLM calls, fetch managed
  prompts from Langfuse with fallback to hardcoded system prompts
- services/chat/app/redis_consumer.py: create Langfuse trace per request
  (home_request/floating_request), pass callback handler to deep_agent,
  link prompt name to trace, attach output preview, flush after each request
- services/chat/app/main.py: shutdown Langfuse client in lifespan teardown
- services/chat/requirements.txt: add langfuse>=2.0.0

Langfuse prompt names: 'home_system', 'floating_system' — create these in
the Langfuse dashboard to manage prompts. Without them, hardcoded defaults
are used transparently.
2026-03-22 23:15:04 +01:00
Roberto Musso
aff68a9051 fix: shared config loads root .env as fallback for microservices 2026-03-22 22:42:54 +01:00
Roberto Musso
5e9ef2809e fix: add extra=ignore to monolith Settings for strangler fig compat 2026-03-22 22:28:50 +01:00
Roberto Musso
90018af311 feat: add WS Gateway and Chat Service (Step 2)
WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)

Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
2026-03-22 01:20:11 +01:00
Roberto Musso
1e2e395676 fix: PEM newline parsing + shared config extra=ignore
- Add field_validator to expand literal \n in PEM keys (auth config + shared config)
- Set extra='ignore' on shared Settings so service-specific .env vars don't cause ValidationError
- Add *.pem to .gitignore
2026-03-22 01:03:28 +01:00
Roberto Musso
59d3a53980 chore: update .env.example files for RS256 + Redis
- Root .env.example: replace JWT_SECRET/JWT_ALGORITHM with JWT_PUBLIC_KEY, add REDIS_URL
- Auth Service .env.example: JWT_PRIVATE_KEY + JWT_PUBLIC_KEY with generation instructions
2026-03-22 00:51:54 +01:00
Roberto Musso
9feeaa79c8 feat(auth): migrate JWT from HS256 to RS256
- Add services/auth/app/config.py with JWT_PRIVATE_KEY and JWT_PUBLIC_KEY
  (Auth Service local config - private key never leaves this service)
- Update routes.py: sign tokens with RS256 private key
- Update deps.py + verify.py: verify tokens with RS256 public key
- Update shared/config.py: replace JWT_SECRET/JWT_ALGORITHM with
  JWT_PUBLIC_KEY (for optional local verification by other services)
- Add sys.path fix in main.py for local dev without PYTHONPATH
2026-03-22 00:50:36 +01:00
Roberto Musso
aa219a4d08 feat: microservices scaffold + Auth Service (Step 1)
- Add shared/ module: config, db, models, schemas, redis utilities
- Add Auth Service (services/auth/): register, login, refresh, me,
  ForwardAuth /verify endpoint for Traefik
- Add Traefik config: ACME/Cloudflare DNS-01, dynamic routing,
  ForwardAuth middleware, sticky sessions for WS Gateway
- Add service scaffolds: ws-gateway, chat, batch-agent, billing (READMEs)
- Add redis>=5.0.0 to requirements.txt
- Monolith app/ is untouched — strangler fig migration
2026-03-22 00:29:51 +01:00
202 changed files with 15432 additions and 15899 deletions

View File

@@ -2,94 +2,66 @@
ENV=dev
# ── Database ──────────────────────────────────────────────────────────────────
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
# ── Auth ──────────────────────────────────────────────────────────────────────
JWT_SECRET=replace-with-a-long-random-secret
JWT_ALGORITHM=HS256
# ── Redis ─────────────────────────────────────────────────────────────────────
REDIS_URL=redis://localhost:6379/0
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
# Generate keypair:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# openssl rsa -in private.pem -pubout -out public.pem
# Paste PEM content with literal \n for newlines.
#
# Private key — ONLY used by the Auth Service (JWT signing).
JWT_PRIVATE_KEY=
# Public key — used by all services / Traefik ForwardAuth (JWT verification).
JWT_PUBLIC_KEY=
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
# ── LLM ───────────────────────────────────────────────────────────────────────
# LiteLLM model identifiers — change to swap providers without code changes.
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
#
# API keys — only the key(s) matching your chosen provider(s) are required.
# The correct key is picked automatically from the model prefix (e.g.
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
GOOGLE_API_KEY=
CEREBRAS_API_KEY=
GROQ_API_KEY=
DEEPSEEK_API_KEY=
# Default model used by any agent that does not have a specific override below.
LLM_MODEL=gpt-5-mini
LLM_EMBED_MODEL=text-embedding-3-small
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
# In Docker, point this to a named-volume path so tokens survive restarts.
# GITHUB_COPILOT_TOKEN_DIR=
# ── Per-agent model overrides ─────────────────────────────────────────────────
# Leave a value empty to fall back to LLM_MODEL.
# Each agent resolves its API key from the model prefix automatically.
#
# Intent classifier — routes user messages to the right domain agent.
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
LLM_MODEL_CLASSIFIER=
# Home-agent — handles chat from the home screen (all tools available).
LLM_MODEL_HOME_AGENT=
# Floating-agent — handles contextual chat triggered from a task/project/note.
LLM_MODEL_FLOATING_AGENT=
# Unified-processor — processes local directory files (local agent runner).
LLM_MODEL_UNIFIED_PROCESSOR=
# Cloud-processor — fetches and processes data from cloud connectors.
LLM_MODEL_CLOUD_PROCESSOR=
# Brief-agent — produces home and project text briefs.
# A small model (e.g. gpt-4o-mini) is sufficient.
# LLM_MODEL_BRIEF_AGENT=
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
# LLM_MODEL_TASK_BRIEF_AGENT=
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
LLM_MODEL_SETUP_AGENT=
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
LLM_MODEL_MEMORY_EXTRACTOR=
# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only).
# Defaults to gpt-4o-mini when empty.
LLM_MODEL_MEMORY_MINER=
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
LLM_MODEL_MEMORY_AUDITOR=
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
SCHEDULER_ENABLED=true
LLM_MODEL=gpt-4o
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
# ── Langfuse (leave empty to disable observability) ───────────────────────────
LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY=
# LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default)
# LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US
# LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted
# ── Vector Store ──────────────────────────────────────────────────────────────
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
PINECONE_API_KEY=
PINECONE_INDEX=adiuva
QDRANT_URL=
QDRANT_API_KEY=
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
# ── CORS ──────────────────────────────────────────────────────────────────────
# Comma-separated list parsed by Settings (override default if needed)
# CORS_ORIGINS=["app://.","http://localhost:3000"]
# ── Langfuse (observability) ─────────────────────────────────────────────────
LANGFUSE_SECRET_KEY=sk-lf-...
LANGFUSE_PUBLIC_KEY=pk-lf-...
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL
# ── Cloudflare (Traefik ACME DNS-01 challenge) ───────────────────────────────
CF_DNS_API_TOKEN=
ACME_EMAIL=
# ── PostgreSQL (used by docker-compose) ──────────────────────────────────────
POSTGRES_USER=postgres
POSTGRES_PASSWORD=postgres
POSTGRES_DB=adiuva

View File

@@ -48,23 +48,23 @@ jobs:
key: ${{ secrets.SSH_KEY }}
script: |
set -e
DEPLOY_DIR="/opt/adiuvai-api"
DEPLOY_DIR="/opt/adiuva-api"
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
TAG="${{ gitea.ref_name }}"
# ── Pull latest code ──
cd /tmp && rm -rf adiuvai-api-deploy
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
cd /tmp && rm -rf adiuva-api-deploy
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
# ── Sync source (preserve .env) ──
cp -rf /tmp/adiuvai-api-deploy/app/ \
/tmp/adiuvai-api-deploy/alembic/ \
/tmp/adiuvai-api-deploy/alembic.ini \
/tmp/adiuvai-api-deploy/Dockerfile \
/tmp/adiuvai-api-deploy/docker-compose.yml \
/tmp/adiuvai-api-deploy/requirements.txt \
cp -rf /tmp/adiuva-api-deploy/app/ \
/tmp/adiuva-api-deploy/alembic/ \
/tmp/adiuva-api-deploy/alembic.ini \
/tmp/adiuva-api-deploy/Dockerfile \
/tmp/adiuva-api-deploy/docker-compose.yml \
/tmp/adiuva-api-deploy/requirements.txt \
"$DEPLOY_DIR/"
rm -rf /tmp/adiuvai-api-deploy
rm -rf /tmp/adiuva-api-deploy
# ── Verify .env ──
if [ ! -f "$DEPLOY_DIR/.env" ]; then

View File

@@ -58,7 +58,7 @@ jobs:
- uses: actions/checkout@v4
- name: Build image
run: docker build -t adiuvai-api:ci .
run: docker build -t adiuva-api:ci .
- name: Verify gunicorn installed
run: docker run --rm adiuvai-api:ci gunicorn --version
run: docker run --rm adiuva-api:ci gunicorn --version

10
.gitignore vendored
View File

@@ -13,6 +13,9 @@ env/
# Environment variables
.env
# Cryptographic keys
*.pem
# IDE
.vscode/
.idea/
@@ -21,18 +24,17 @@ env/
.pytest_cache/
htmlcov/
.coverage
tests/fixtures/private*/
# Docker
*.log
# OS
.DS_Store
# Smoke scripts (dev-only, not for CI)
scripts/smoke_*.py
Thumbs.db
# Claude Code
.claude/
logs/
# Eval private test data
services/batch-agent/eval/fixtures/private_data/

796
README.md
View File

@@ -1,5 +1,793 @@
## DEV
Run in DEV with command:
# Adiuva Cloud API
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
---
## Table of Contents
- [Overview](#overview)
- [Architecture](#architecture)
- [Key Features](#key-features)
- [Tech Stack](#tech-stack)
- [Getting Started](#getting-started)
- [Docker Deployment](#docker-deployment)
- [Environment Variables](#environment-variables)
- [API Reference](#api-reference)
- [Data Model](#data-model)
- [AI Agent System](#ai-agent-system)
- [Orchestration & Execution Plans](#orchestration--execution-plans)
- [Middleware](#middleware)
- [Storage Layer](#storage-layer)
- [Billing & Tiers](#billing--tiers)
- [Plugin Marketplace](#plugin-marketplace)
- [Testing](#testing)
- [Project Structure](#project-structure)
- [License](#license)
---
## Overview
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
### Design Principles
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
---
## Architecture
```
uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-config logging.conf
```
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
│ Desktop App │────▶│ │
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
└──────────────┘ │ │
│ ┌──────────────────┐ ┌────────────────────────────┐ │
│ │ Auth Routes │ │ Chat Routes │ │
│ │ Billing Routes │ │ ↓ │ │
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
│ │ Backup Routes │ │ ↓ classify intent │ │
│ │ Plugin Routes │ │ Agent Registry │ │
│ │ Vector Routes │ │ ↓ │ │
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
│ │ (GPT-4o + LangChain) │ │
│ └────────────────────────────┘ │
└────────────────────────────────────────────────────────┘
│ │ │
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
│ Billing, │ │ backups) │ │ (Vectors) │
│ Metadata) │ └───────────────┘ └────────────────┘
└────────────┘
┌────────▼───┐
│ Stripe │
│ (Billing, │
│ Connect) │
└────────────┘
```
---
## Key Features
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
---
## Tech Stack
| Package | Version | Purpose |
|---|---|---|
| `fastapi` | ≥ 0.115.0 | Web framework |
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
| `gunicorn` | ≥ 22.0.0 | Production process manager |
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
| `boto3` | ≥ 1.35.0 | AWS S3 client |
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
| `alembic` | ≥ 1.14.0 | Database migration management |
| `bcrypt` | ≥ 4.2.0 | Password hashing |
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
| `websockets` | ≥ 14.0 | WebSocket protocol support |
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
| `pytest` | ≥ 8.0.0 | Test framework |
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
| `ruff` | ≥ 0.8.0 | Linter and formatter |
---
## Getting Started
### Prerequisites
- Python 3.12+
- PostgreSQL 16+
- An OpenAI API key (for LLM features)
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
- AWS credentials (optional — needed for S3 storage in production)
### Installation
```bash
# Clone the repository
git clone <repo-url> && cd adiuva-api
# Create a virtual environment
python -m venv .venv && source .venv/bin/activate
# Install dependencies
pip install -r requirements.txt
# Configure environment
cp .env.example .env
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
```
### Database Setup
```bash
# Start PostgreSQL (or use the Docker Compose database)
docker compose up db -d
# Run migrations
alembic upgrade head
```
### Run the Development Server
```bash
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
---
## Docker Deployment
### Quick Start
```bash
docker compose up --build
```
This starts two services:
- **app** — FastAPI server on port `8000`
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
The compose file also includes optional services for fully local deployments:
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
### Dockerfile Details
The Dockerfile uses a multi-stage build:
1. **Builder stage** — Installs Python dependencies into a virtual environment.
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
```bash
# Production command (run by the container)
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
```
---
## Homelab / Self-Hosted Deployment
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
### 1. Start all services
```bash
docker compose up -d
```
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
### 2. Create the MinIO bucket
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
```bash
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
docker compose exec minio mc mb local/adiuva
```
### 3. Configure your `.env`
```bash
# Database (uses the compose PostgreSQL)
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
# S3 → MinIO
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=http://minio:9000
AWS_ACCESS_KEY_ID=minioadmin
AWS_SECRET_ACCESS_KEY=minioadmin
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
QDRANT_URL=http://qdrant:6333
QDRANT_API_KEY=
PINECONE_API_KEY=
# Billing — leave empty to stub (no Stripe needed)
STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=
# LLM — the only external service
OPENAI_API_KEY=sk-...
LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# Auth
JWT_SECRET=your-secret-here
ENV=dev
```
### 4. Run migrations
```bash
docker compose exec app alembic upgrade head
```
### What runs where
| Service | Runs on | Port | Notes |
|---|---|---|---|
| FastAPI app | Docker | 8000 | API server |
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
| Stripe | — | — | Stubbed when keys are empty |
| OpenAI / LLM | Cloud | — | Only external dependency |
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
---
## Environment Variables
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
| Variable | Type | Default | Description |
|---|---|---|---|
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
| `S3_REGION` | `str` | `us-east-1` | AWS region |
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
---
## API Reference
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
### Health
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
### Auth
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
### Chat
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
### Plans
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
### Storage (Cloud Records)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
### Vectors (Cloud Vector Store)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
### Backup
| Method | Path | Auth | Description |
|---|---|---|---|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
### Plugins (Marketplace)
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
### Billing
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
---
## Data Model
9 tables managed by Alembic migrations. Source: `app/models.py`
### Tables
| Table | Primary Key | Key Columns | Purpose |
|---|---|---|---|
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
### Enum Types
| Enum | Values |
|---|---|
| `billing_tier` | `free`, `pro`, `power`, `team` |
| `plugin_status` | `pending_review`, `approved`, `rejected` |
| `review_decision` | `approved`, `rejected` |
### Migrations
| Version | Description |
|---|---|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
---
## AI Agent System
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
### Architecture
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
### Registered Agents
| Agent | Registry Name | Tools | Description |
|---|---|---|---|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
### Switching LLM Providers
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
```bash
# OpenAI (default)
LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# Anthropic
LLM_MODEL=anthropic/claude-3.5-sonnet
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
# Google Gemini
LLM_MODEL=gemini/gemini-pro
LLM_ROUTER_MODEL=gemini/gemini-flash
# Local Ollama
LLM_MODEL=ollama/llama3
LLM_ROUTER_MODEL=ollama/llama3
# AWS Bedrock
LLM_MODEL=bedrock/anthropic.claude-v2
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
```
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
---
## Orchestration & Execution Plans
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
### Orchestrator
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
### Execution Plans
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
### Built-in Templates (6)
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
### Built-in Playbooks (2)
| Playbook | Description |
|---|---|
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
---
## Middleware
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
### JWT Authentication
Source: `app/api/middleware/auth.py`
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
- Falls back to `free` when no subscription row exists.
- Raises `401 Unauthorized` on invalid or expired tokens.
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
### Tier-Based Rate Limiter
Source: `app/api/middleware/rate_limit.py`
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
- Per-user 60-second window sized by subscription tier:
| Tier | Requests / Minute |
|---|---|
| Free | 20 |
| Pro | 60 |
| Power | 120 |
| Team | 200 |
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
- **Exempt paths:** register, login, webhook, health
### Response Sanitizer
Source: `app/api/middleware/sanitizer.py`
- Runs only on `/api/v1/chat` endpoints.
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
- Logs sanitization events as `WARNING`.
- Binary responses (storage, backup) are never touched.
---
## Storage Layer
### Blob Store
Source: `app/storage/blob_store.py`
- S3-backed storage for E2E encrypted blobs.
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
- The backend **never inspects or decrypts blob content**.
### Vector Store
Source: `app/storage/vector_store.py`
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
- Methods: `upsert()`, `search()`, `delete()`
### Encryption Utilities
Source: `app/storage/encryption.py`
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
- **No decryption key ever reaches the backend.**
---
## Billing & Tiers
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
### Feature Matrix
| Feature | Free | Pro | Power | Team |
|---|---|---|---|---|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
| Batch Active | 2 | 10 | Unlimited | Unlimited |
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
| Batch Builder | — | — | ✓ | ✓ |
| Plugin Marketplace | — | — | ✓ | ✓ |
| SSO | — | — | — | ✓ |
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
### Stripe Integration
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
### Tier Manager
- `get_tier(user_id)` — Returns the user's current billing tier.
- `check_feature(tier, feature)` — Boolean feature gate check.
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
---
## Plugin Marketplace
Source: `app/marketplace/`
### Plugin Registry
- PostgreSQL-backed catalog of submitted and approved plugins.
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
### Review Queue
- Automated security checklist before human review:
- Plugin ID must match `^[a-z0-9-]+$`
- Permissions must be from the allowed set only
- No binary blobs in the manifest
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
- `get_pending(db)` — Lists plugins awaiting review.
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
### Revenue Sharing
- **70% developer / 30% platform** split on all paid plugin sales.
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
- Gracefully stubs transfers when Stripe is not configured.
### Seed Plugins
| Plugin | Category | Price |
|---|---|---|
| GitHub Sync | Productivity | Free |
| Slack Notifier | Communication | €4.99 |
| Time Tracker | Productivity | €9.99 |
---
## Testing
### Running Tests
```bash
# Run all tests
pytest
# Run a specific test file
pytest tests/test_auth.py
# Run with verbose output
pytest -v
```
### Test Infrastructure
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
- **No external dependencies** — all tests run fully offline.
### Test Coverage
| File | Coverage |
|---|---|
| `test_auth.py` | Register, login, token access, refresh, expiration |
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
---
## Project Structure
```
adiuva-api/
├── alembic.ini # Alembic configuration
├── BACKEND_PLAN.md # Architecture & design decisions
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
├── Dockerfile # Multi-stage production build
├── requirements.txt # Python dependencies
├── alembic/ # Database migrations
│ ├── env.py # Alembic environment config
│ ├── script.py.mako # Migration template
│ └── versions/
│ ├── 001_initial_schema.py # Tables, indexes, FKs
│ └── 002_seed_plugins.py # Seed marketplace plugins
├── app/ # Application source
│ ├── main.py # FastAPI app factory, middleware, routes
│ ├── db.py # Async SQLAlchemy engine & session
│ ├── models.py # SQLAlchemy ORM models (9 tables)
│ ├── schemas.py # Pydantic request/response schemas
│ │
│ ├── config/
│ │ └── settings.py # Pydantic Settings (env vars)
│ │
│ ├── agents/ # LLM-powered domain agents
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
│ │ ├── project_agent.py # Project lifecycle (6 tools)
│ │ ├── timeline_agent.py # Milestones (4 tools)
│ │ └── note_agent.py # Markdown notes (5 tools)
│ │
│ ├── core/ # Orchestration engine
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
│ │ ├── llm.py # LiteLLM factory (get_llm)
│ │ ├── orchestrator.py # Intent classification & routing
│ │ └── execution_plan.py # Plan builder, templates, cache
│ │
│ ├── api/ # HTTP layer
│ │ ├── deps.py # Shared FastAPI dependencies
│ │ ├── middleware/
│ │ │ ├── auth.py # JWT validation, live tier lookup
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
│ │ │ └── sanitizer.py # Prompt IP leak protection
│ │ └── routes/
│ │ ├── auth.py # Register, login, refresh, me
│ │ ├── chat.py # Chat + WebSocket streaming
│ │ ├── plans.py # Execution plan playbooks
│ │ ├── storage.py # E2E encrypted record CRUD
│ │ ├── vectors.py # Vector upsert, search, delete
│ │ ├── backup.py # Encrypted backup management
│ │ ├── plugins.py # Marketplace browse & install
│ │ └── billing.py # Stripe checkout & webhooks
│ │
│ ├── storage/ # Storage backends
│ │ ├── blob_store.py # S3 blob storage
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
│ │ └── encryption.py # Checksum verification utilities
│ │
│ ├── billing/ # Subscription management
│ │ ├── stripe_service.py # Stripe API integration
│ │ └── tier_manager.py # Feature matrix & quota enforcement
│ │
│ └── marketplace/ # Plugin ecosystem
│ ├── plugin_registry.py # Catalog CRUD & search
│ ├── plugin_review.py # Security checklist & review queue
│ └── revenue_share.py # 70/30 split & Stripe Connect
└── tests/ # Test suite
├── conftest.py # Fixtures: DB, S3, auth, seeds
├── test_auth.py
├── test_orchestrator.py
├── test_agents.py
├── test_storage.py
├── test_backup.py
├── test_plugins.py
├── test_agent_registry.py
├── test_execution_plan.py
└── test_middleware.py
```
---
## License
*To be determined.*

View File

@@ -16,7 +16,7 @@ import re
from logging.config import fileConfig
from alembic import context
from sqlalchemy import pool
from sqlalchemy import engine_from_config, pool
from sqlalchemy.ext.asyncio import create_async_engine
# Alembic Config object (gives access to alembic.ini values).

View File

@@ -1,4 +1,5 @@
"""Initial schema: users, refresh_tokens, subscriptions.
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
Revision ID: 001
Revises:
@@ -27,6 +28,18 @@ def upgrade() -> None:
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
# ── users ─────────────────────────────────────────────────────────────
op.create_table(
@@ -75,10 +88,122 @@ def upgrade() -> None:
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
# ── storage_records ───────────────────────────────────────────────────
op.create_table(
"storage_records",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("table_name", sa.String(100), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
# ── backup_metadata ───────────────────────────────────────────────────
op.create_table(
"backup_metadata",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column("timestamp", sa.BigInteger, nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
# ── plugins ───────────────────────────────────────────────────────────
op.create_table(
"plugins",
sa.Column("id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("description", sa.Text, nullable=False, server_default=""),
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
sa.Column("category", sa.String(100), nullable=False, server_default=""),
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
sa.Column("s3_package_key", sa.String(500), nullable=True),
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
sa.Column("rejection_reason", sa.Text, nullable=True),
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
)
# ── plugin_installations ──────────────────────────────────────────────
op.create_table(
"plugin_installations",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
)
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
# ── plugin_reviews ────────────────────────────────────────────────────
op.create_table(
"plugin_reviews",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
sa.Column("notes", sa.Text, nullable=True),
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
)
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
# ── revenue_events ────────────────────────────────────────────────────
op.create_table(
"revenue_events",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
def downgrade() -> None:
op.drop_table("revenue_events")
op.drop_table("plugin_reviews")
op.drop_table("plugin_installations")
op.drop_table("plugins")
op.drop_table("backup_metadata")
op.drop_table("storage_records")
op.drop_table("subscriptions")
op.drop_table("refresh_tokens")
op.drop_table("users")
op.execute("DROP TYPE IF EXISTS review_decision")
op.execute("DROP TYPE IF EXISTS plugin_status")
op.execute("DROP TYPE IF EXISTS billing_tier")

View File

@@ -0,0 +1,92 @@
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
Revision ID: 002
Revises: 001
Create Date: 2026-03-03
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_SEED_PLUGINS = [
{
"id": "plugin-github-sync",
"name": "GitHub Sync",
"description": "Sync tasks with GitHub Issues and pull requests.",
"version": "1.0.0",
"author_name": "Adiuva",
"category": "productivity",
"price_cents": 0,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-slack-notify",
"name": "Slack Notifier",
"description": "Post task and timeline updates to Slack channels.",
"version": "1.2.0",
"author_name": "Adiuva",
"category": "communication",
"price_cents": 499,
"permissions": json.dumps(["read:tasks", "read:timelines"]),
"status": "approved",
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-time-tracker",
"name": "Time Tracker",
"description": "Track time spent on tasks with automatic reporting.",
"version": "0.9.1",
"author_name": "Third Party",
"category": "productivity",
"price_cents": 999,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
]
def upgrade() -> None:
plugins = sa.table(
"plugins",
sa.column("id", sa.String),
sa.column("name", sa.String),
sa.column("description", sa.Text),
sa.column("version", sa.String),
sa.column("author_name", sa.String),
sa.column("category", sa.String),
sa.column("price_cents", sa.Integer),
sa.column("permissions", sa.Text),
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
sa.column("s3_package_key", sa.String),
sa.column("install_count", sa.Integer),
sa.column("avg_rating", sa.Float),
)
op.bulk_insert(plugins, _SEED_PLUGINS)
def downgrade() -> None:
op.execute(
"DELETE FROM plugins WHERE id IN ("
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
")"
)

View File

@@ -14,7 +14,7 @@ from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "003"
down_revision: Union[str, None] = "001"
down_revision: Union[str, None] = "002"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

View File

@@ -1,54 +0,0 @@
"""Phase 1 — confirm pgvector activation on memory_associative.
Migration 004 created the embedding column as vector(1536) and added the
IVFFlat index. This migration is the Phase-1 checkpoint:
1. Ensures the pgvector extension is enabled (idempotent).
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
memory_associative_embedding_idx (creates it only if absent).
Revision ID: 005
Revises: 9a1f2d0b6c7e
Create Date: 2026-04-15
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
revision: str = "005"
down_revision: Union[str, None] = "e04100e88ace"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Ensure pgvector extension is enabled (also done in 004, idempotent).
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# Ensure the canonical Phase-1 IVFFlat index exists.
# 004 may have created ix_memory_associative_embedding; this adds the
# Phase-1 name memory_associative_embedding_idx if it is missing.
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_indexes
WHERE tablename = 'memory_associative'
AND indexname = 'memory_associative_embedding_idx'
) THEN
CREATE INDEX memory_associative_embedding_idx
ON memory_associative
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
END IF;
END $$;
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")

View File

@@ -1,74 +0,0 @@
"""Add memory_relations table (Phase 3 — relational tier).
Revision ID: 006
Revises: 1f5975a4f3f4
Create Date: 2026-04-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "006"
down_revision: Union[str, None] = "1f5975a4f3f4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"memory_relations",
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("subject_label", sa.String(128), nullable=False),
sa.Column("subject_type", sa.String(32), nullable=False),
sa.Column("predicate", sa.String(64), nullable=False),
sa.Column("object_label", sa.String(128), nullable=False),
sa.Column("object_type", sa.String(32), nullable=False),
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
sa.Column(
"source_episode_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
)
op.create_index(
"memory_relations_user_subject_idx",
"memory_relations",
["user_id", "subject_label"],
)
op.create_index(
"memory_relations_user_predicate_idx",
"memory_relations",
["user_id", "predicate"],
)
def downgrade() -> None:
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
op.drop_table("memory_relations")

View File

@@ -1,38 +0,0 @@
"""add extraction_queue
Revision ID: 1f5975a4f3f4
Revises: 005
Create Date: 2026-04-16 17:26:25.790870
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '1f5975a4f3f4'
down_revision: Union[str, None] = '005'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
'extraction_queue',
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
)
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
op.drop_table('extraction_queue')

View File

@@ -1,107 +0,0 @@
"""Restore agent config tables and add agent_config column.
9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both
ORM models are still active. This migration recreates them with agent_config
added to local_agent_configs.
Revision ID: a3b9c0d1e2f3
Revises: 9a1f2d0b6c7e
Create Date: 2026-04-07 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "a3b9c0d1e2f3"
down_revision: Union[str, None] = "9a1f2d0b6c7e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Recreate enum types (idempotent — they may already exist from migration 003)
op.execute("""
DO $$ BEGIN
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
bind = op.get_bind()
inspector = sa.inspect(bind)
existing = set(inspector.get_table_names())
# ── local_agent_configs (with agent_config column) ────────────────────
if "local_agent_configs" not in existing:
op.create_table(
"local_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("device_id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("agent_config", sa.JSON, nullable=True),
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
# ── cloud_agent_configs ───────────────────────────────────────────────
if "cloud_agent_configs" not in existing:
op.create_table(
"cloud_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column(
"provider",
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
nullable=False,
),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
sa.Column("filter_config", sa.JSON, nullable=True),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
def downgrade() -> None:
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
op.drop_table("cloud_agent_configs")
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
op.drop_table("local_agent_configs")

View File

@@ -1,56 +0,0 @@
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
Revision ID: b4c0d1e2f3a4
Revises: a3b9c0d1e2f3
Create Date: 2026-04-10 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "b4c0d1e2f3a4"
down_revision: Union[str, None] = "a3b9c0d1e2f3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── users: make password_hash nullable (social users have no password) ──
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
# ── users: add avatar_url ─────────────────────────────────────────────
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
# ── oauth_accounts ────────────────────────────────────────────────────
op.create_table(
"oauth_accounts",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("provider", sa.String(50), nullable=False),
sa.Column("provider_user_id", sa.String(255), nullable=False),
sa.Column("provider_email", sa.String(255), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
)
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
def downgrade() -> None:
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
op.drop_table("oauth_accounts")
op.drop_column("users", "avatar_url")
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)

View File

@@ -1,31 +0,0 @@
"""Add onboarding_completed_at column to users table.
Revision ID: c5d1e2f3a4b5
Revises: b4c0d1e2f3a4
Create Date: 2026-04-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c5d1e2f3a4b5"
down_revision: Union[str, None] = "b4c0d1e2f3a4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"users",
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("users", "onboarding_completed_at")

View File

@@ -1,46 +0,0 @@
"""Add token tracking columns for folder integration.
Revision ID: d6e3f4a5b6c7
Revises: 006
Create Date: 2026-05-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = "d6e3f4a5b6c7"
down_revision: Union[str, None] = "006"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"agent_run_logs",
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
)
op.create_table(
"monthly_token_usage",
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("year_month", sa.String(7), nullable=False),
sa.Column("feature", sa.String(64), nullable=False),
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
)
op.create_index(
"ix_monthly_token_usage_user_month",
"monthly_token_usage",
["user_id", "year_month"],
)
def downgrade() -> None:
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
op.drop_table("monthly_token_usage")
op.drop_column("agent_run_logs", "tokens_used")

View File

@@ -1,34 +0,0 @@
"""avatar_url_varchar_to_text
Revision ID: e04100e88ace
Revises: c5d1e2f3a4b5
Create Date: 2026-04-13 09:13:06.733674
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'e04100e88ace'
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.alter_column('users', 'avatar_url',
existing_type=sa.VARCHAR(length=2048),
type_=sa.Text(),
existing_nullable=True)
def downgrade() -> None:
op.alter_column('users', 'avatar_url',
existing_type=sa.Text(),
type_=sa.VARCHAR(length=2048),
existing_nullable=True)

View File

@@ -1,5 +0,0 @@
"""Expose tool modules used by deep orchestrator-worker graphs."""
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]

View File

@@ -1,52 +0,0 @@
"""Client agent — read-only tools for the clients table."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
@tool
async def list_clients(search: str = "", limit: int = 20) -> str:
"""List clients, optionally filtered by a name/email substring search.
search: optional substring to match against client name or email.
limit: max rows to return (default 20).
"""
filters: dict[str, Any] = {"limit": limit}
if search:
filters["search"] = search
result = await execute_on_client(action="select", table="clients", filters=filters)
rows = result.get("rows", [])
if not rows:
return "No clients found."
lines = [
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
f"company: {r.get('company', '')})"
for r in rows
]
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
@tool
async def get_client(id: str) -> str:
"""Get full details for one client by UUID.
id: the client's UUID.
"""
if not id:
return "Client id is required."
result = await execute_on_client(action="get", table="clients", data={"id": id})
row = result.get("row") or result.get("rows", [None])[0] if result else None
if not row:
return f"Client '{id}' not found."
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
CLIENT_TOOLS: list[Any] = [list_clients, get_client]

View File

@@ -1,194 +0,0 @@
"""Filesystem agent — tools for reading local directories and files on Electron.
These tools delegate to the Electron client via ``execute_on_client()`` using
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
handles actual disk I/O and responds with ``tool_result`` frames.
"""
from __future__ import annotations
import os
import re
from pathlib import Path
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
# Max characters returned by read_file_content in journey (exploration) tools.
# The journey only needs to understand file structure, not full content.
_JOURNEY_READ_MAX_CHARS: int = 4000
def _resolve_path(path: str, base: str) -> str:
"""Resolve *path* against *base* when *path* is relative.
The LLM often passes ``"."`` meaning "the configured directory".
Without this, Electron resolves ``"."`` relative to its own CWD instead
of the user's chosen directory.
"""
if os.path.isabs(path):
return path
return str(Path(base) / path)
@tool
async def list_directory(path: str) -> str:
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
result = await execute_on_client(
action="list_directory",
data={"path": path},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{path}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str:
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
result = await execute_on_client(
action="read_file_content",
data={"path": path},
)
content: str = result.get("content", "")
if not content:
return f"File '{path}' is empty or could not be read."
return content
@tool
async def get_file_metadata(path: str) -> str:
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
result = await execute_on_client(
action="get_file_metadata",
data={"path": path},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", path)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
FILESYSTEM_TOOLS: list[Any] = [
list_directory,
read_file_content,
get_file_metadata,
]
def make_directory_tools(base_directory: str) -> list[Any]:
"""Return filesystem tools that resolve relative paths against *base_directory*.
Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target
directory upfront (e.g., journey setup sessions). Relative paths like ``"."``
from the LLM are resolved to the correct absolute path before being sent to
the Electron client, preventing it from falling back to its own CWD.
"""
def _compact_for_journey(raw: str) -> str:
"""Strip HTML noise and truncate for journey exploration.
The journey LLM only needs to understand file structure (headers,
first paragraphs). Full CSS/style blocks are pure noise that eat
up context window budget.
"""
text = re.sub(r"<style[^>]*>.*?</style>", "", raw, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<script[^>]*>.*?</script>", "", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)
if len(text) > _JOURNEY_READ_MAX_CHARS:
text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]"
return text
@tool
async def list_directory(path: str) -> str: # noqa: F811
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
resolved = _resolve_path(path, base_directory)
result = await execute_on_client(
action="list_directory",
data={"path": resolved},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{resolved}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str: # noqa: F811
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
resolved = _resolve_path(path, base_directory)
result = await execute_on_client(
action="read_file_content",
data={"path": resolved},
)
content: str = result.get("content", "")
if not content:
return f"File '{resolved}' is empty or could not be read."
return _compact_for_journey(content)
@tool
async def get_file_metadata(path: str) -> str: # noqa: F811
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
resolved = _resolve_path(path, base_directory)
result = await execute_on_client(
action="get_file_metadata",
data={"path": resolved},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", resolved)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
return [list_directory, read_file_content, get_file_metadata]

View File

@@ -1,168 +0,0 @@
"""Scoped file-read and search tools for the project folder feature."""
from __future__ import annotations
from langchain_core.tools import tool
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
from app.core.ws_context import execute_on_client
# Cap returned slice size to keep tool output under control.
_MAX_RETURN_CHARS = 50_000
_MAX_SEARCH_MATCHES = 20
def _is_unsafe_path(rel: str) -> bool:
if not rel:
return True
norm = rel.replace("\\", "/")
if norm.startswith("/"):
return True
# Windows drive letter
if len(rel) >= 2 and rel[1] == ":":
return True
parts = norm.split("/")
return ".." in parts
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
"""Return the raw Electron tool_result dict for a file read."""
return await execute_on_client(
action="read_project_folder_file",
data={
"projectId": project_id,
"relativePath": relative_path,
"offset": offset,
"length": length,
},
)
def _decode(result: dict) -> tuple[str, str, int]:
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
extracts text from base64. For images, returns a placeholder string.
For text, content is already a sliced utf-8 string.
"""
kind = result.get("kind", "text")
content = result.get("content", "") or ""
total = int(result.get("totalSize", 0) or 0)
if kind == "image":
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
if kind == "pdf":
return (_extract_pdf_text(content), kind, total)
if kind == "docx":
return (_extract_docx_text(content), kind, total)
return (content, kind, total)
@tool
async def read_project_folder_file(
project_id: str,
relative_path: str,
offset: int = 0,
length: int = _MAX_RETURN_CHARS,
) -> str:
"""Read a slice of a file inside the project's linked folder.
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
offset: char offset to start reading from (0 = beginning).
length: max chars to return. Default 50000. Use smaller values to save tokens.
Returns text content slice with a header showing position. Header tells you
when more content is available; call again with the suggested next offset.
For PDF / DOCX files the backend extracts text first, then applies offset/length
on the extracted text. For images returns a placeholder; navigate with the
manifest summary instead.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
result = await _fetch_file(project_id, relative_path, offset, length)
text, kind, total_size = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind in ("pdf", "docx"):
# Backend extracted full text — apply offset/length on chars.
sliced = text[offset:offset + length]
slice_end = min(offset + length, len(text))
header = (
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
f"totalChars={len(text)}]"
)
if slice_end < len(text):
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + sliced
if kind == "text":
slice_end = offset + len(text)
header = (
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
f"totalBytes={total_size}]"
)
if slice_end < total_size:
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + text
# image or unknown
return text
@tool
async def search_project_folder_file(
project_id: str,
relative_path: str,
query: str,
context_lines: int = 3,
) -> str:
"""Search a project folder file for a query string (case-insensitive substring).
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
query: text to search for.
context_lines: number of lines of context around each match (default 3).
Returns matching line ranges with surrounding context and 1-based line numbers.
Capped at 20 matches; if more exist the header shows the total.
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
Images and binary files are not searchable.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
if not query:
return "Empty query."
# For text we still need full file; pass length=very large.
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
text, kind, _ = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind == "image":
return "Cannot search inside images."
lines = text.splitlines()
q = query.lower()
matches = [i for i, line in enumerate(lines) if q in line.lower()]
if not matches:
return f"No matches for '{query}' in {relative_path}."
shown = matches[:_MAX_SEARCH_MATCHES]
snippets: list[str] = []
for i in shown:
start = max(0, i - context_lines)
end = min(len(lines), i + context_lines + 1)
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
snippets.append(block)
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
body = "\n---\n".join(snippets)
return header + "\n" + body
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]

View File

@@ -1,206 +0,0 @@
"""Note agent — Markdown note management (list, get, create, update, propose edit)."""
from __future__ import annotations
import asyncio
import re
from typing import Any
from langchain_core.tools import tool
from app.core.note_summarizer import generate_note_summary
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
def _fmt_summary(row: dict) -> str:
summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
if summary:
return f"{summary}"
snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
return f"{snippet}" if snippet else ""
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes with AI summaries, optionally scoped to a project by project_id.
Returns id, title, and ai_summary for each note so you can decide which
note to read in full with get_note before creating or updating.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="notes",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No notes found."
lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@tool
async def get_note(note_id: str) -> str:
"""Fetch a single note by its UUID to read its full Markdown content."""
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
row = result.get("row")
if not row:
return f"Note {note_id} not found."
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
@tool
async def create_note(
title: str,
content: str,
project_id: str = "",
) -> str:
"""Create a new note.
title: note heading (required)
content: Markdown body text (required)
project_id: optional UUID linking this note to a project
"""
result = await execute_on_client(
action="insert",
table="notes",
data={
"title": title,
"content": content,
"projectId": project_id or None,
},
)
row = result["row"]
note_id: str = row["id"]
# Generate summary asynchronously — fire-and-forget.
asyncio.create_task(_refresh_summary(note_id, title, content))
return f"Note created: '{row['title']}' (id: {note_id})."
@tool
async def update_note(
note_id: str,
title: str = "",
content: str = "",
) -> str:
"""Update an existing note directly (no approval required).
Use propose_note_edit instead when human review is needed.
note_id: UUID of the note (required)
If you need to preserve existing content, call get_note first.
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if content:
updates["content"] = content
result = await execute_on_client(
action="update",
table="notes",
data={"id": note_id, "updates": updates},
)
row = result["row"]
if content:
new_title = title or row.get("title", "")
asyncio.create_task(_refresh_summary(note_id, new_title, content))
return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def propose_note_edit(
note_id: str,
edit_type: str,
proposed_content: str,
reasoning: str = "",
anchor_before: str = "",
anchor_text: str = "",
agent_id: str = "",
run_id: str = "",
) -> str:
"""Propose an AI edit to an existing note, pending human approval.
Use this instead of update_note when review_required is true.
The user will see the proposal highlighted before it is merged.
note_id: UUID of the target note (required)
edit_type: 'append' | 'insert' | 'replace'
- append: adds proposed_content at the end of the note
- insert: inserts proposed_content immediately after anchor_before text
- replace: replaces the first occurrence of anchor_text with proposed_content
proposed_content: the new Markdown text to add or substitute (required)
reasoning: brief explanation shown to the user (recommended)
anchor_before: for 'insert' — the text snippet that precedes the insertion point
anchor_text: for 'replace' — the exact text to be replaced
agent_id: agent identifier (for traceability)
run_id: run identifier (for traceability)
"""
if edit_type not in ("append", "insert", "replace"):
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
result = await execute_on_client(
action="propose_note_edit",
data={
"noteId": note_id,
"type": edit_type,
"proposedContent": proposed_content,
"reasoning": reasoning or None,
"anchorBefore": anchor_before or None,
"anchorText": anchor_text or None,
"agentId": agent_id or None,
"runId": run_id or None,
},
)
edit_id = result.get("id", "?")
return (
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
f"Status: pending user approval."
)
@tool
async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID."""
await execute_on_client(action="delete", table="notes", data={"id": note_id})
return f"Note {note_id} deleted."
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
"""Generate and persist the AI summary for a note. Fire-and-forget."""
try:
summary = await generate_note_summary(title, content)
if summary:
await execute_on_client(
action="update",
table="notes",
data={
"id": note_id,
"updates": {
"aiSummary": summary,
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
},
},
)
except Exception:
pass # fire-and-forget; errors logged by generate_note_summary
NOTE_TOOLS: list[Any] = [
list_notes,
get_note,
create_note,
update_note,
propose_note_edit,
delete_note,
]
NOTE_READ_TOOLS: list[Any] = [
list_notes,
get_note,
]

View File

@@ -1,63 +0,0 @@
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
# Each tool closure captures the user_id bound at factory time.
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
"""Return a query_relations tool bound to *user_id*."""
@tool
async def query_relations(
subject_label: str = "",
predicate: str = "",
object_label: str = "",
limit: int = 10,
) -> str:
"""Query the relational memory graph for entity relationships.
Returns rows where subject ↔ predicate ↔ object match the given filters.
All parameters are optional — omit to retrieve all relations up to limit.
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
limit: max rows to return (default 10).
"""
import logging
logger = logging.getLogger(__name__)
logger.info(
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
trace_id or "-", user_id, subject_label, predicate, object_label,
)
async with async_session() as db:
memory = MemoryMiddleware(db)
rows = await memory.query_relations(
user_id=user_id,
subject=subject_label or None,
predicate=predicate or None,
object_=object_label or None,
limit=limit,
)
if not rows:
return "No relational memory entries found for the given filters."
lines = [
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
for r in rows
]
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
return query_relations

View File

@@ -1,358 +0,0 @@
"""Task agent — full CRUD for tasks and task comments."""
from __future__ import annotations
from datetime import datetime, timezone
import re
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
# ── Task tools ────────────────────────────────────────────────────────
@tool
async def list_tasks(
project_id: str = "",
status: str = "",
priority: str = "",
assignee: str = "",
search: str = "",
order_by: str = "",
order_dir: str = "",
due_date_from: int = -1,
due_date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
is_ai_suggested: int = -1,
limit: int = 50,
offset: int = 0,
) -> str:
"""List tasks with optional filters. Returns up to `limit` results (default 50).
project_id: UUID of the project to scope results to.
status: filter by status — todo | in_progress | done.
priority: filter by priority — high | medium | low.
assignee: substring to match against assignee names. OMIT unless the user explicitly
names a person or refers to themselves ("my tasks", "assigned to me", "mine").
Do NOT default to the current user.
search: substring search across title and description.
order_by: sort field — dueDate | priority | createdAt | completedAt.
order_dir: asc (default) | desc.
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
limit: max rows to return (default 50). Use with offset to paginate.
offset: skip first N rows (default 0).
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
Tip — prefer count_tasks for "how many" questions to avoid listing rows.
Tip — for natural-language windows ("today", "tomorrow", "this week", "last month", etc.)
take due_date_from / due_date_to verbatim from the DATE CONTEXT block in the system prompt;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
filters: dict[str, Any] = {
"projectId": normalized_project_id or None,
"status": status or None,
"priority": priority or None,
"search": search or None,
"orderBy": order_by or None,
"orderDir": order_dir or None,
"limit": limit,
"offset": offset,
}
if assignee:
filters["assignee"] = assignee
if due_date_from != -1:
filters["dueDateFrom"] = due_date_from
if due_date_to != -1:
filters["dueDateTo"] = due_date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
result = await execute_on_client(action="select", table="tasks", filters=filters)
rows = result.get("rows", [])
if not rows:
return "No tasks found matching the given filters."
lines = [
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, "
f"dueDate: {r.get('dueDate')}, completedAt: {r.get('completedAt')}, "
f"projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
@tool
async def count_tasks(
project_id: str = "",
status: str = "",
priority: str = "",
assignee: str = "",
search: str = "",
due_date_from: int = -1,
due_date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
is_ai_suggested: int = -1,
) -> str:
"""Count tasks matching the given filters without returning rows.
Use this instead of list_tasks for "how many" questions — it is much cheaper.
Same filter parameters as list_tasks (no limit/offset/order_by needed).
assignee: OMIT unless the user explicitly names a person or refers to themselves
("my tasks"). Do NOT default to the current user.
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
Tip — for natural-language windows take due_date_from / due_date_to from the DATE CONTEXT block;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
filters: dict[str, Any] = {
"projectId": normalized_project_id or None,
"status": status or None,
"priority": priority or None,
"search": search or None,
}
if assignee:
filters["assignee"] = assignee
if due_date_from != -1:
filters["dueDateFrom"] = due_date_from
if due_date_to != -1:
filters["dueDateTo"] = due_date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
result = await execute_on_client(action="count", table="tasks", filters=filters)
return f"Task count: {result.get('count', 0)}"
@tool
async def create_task(
title: str,
description: str = "",
status: str = "todo",
priority: str = "medium",
assignees: str = "[]",
due_date: int = 0,
project_id: str = "",
is_ai_suggested: int = 0,
) -> str:
"""Create a new task.
title: task title (required)
description: optional details
status: todo | in_progress | done (default: todo)
priority: high | medium | low (default: medium)
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
completedAt is set automatically when status is 'done'.
"""
result = await execute_on_client(
action="insert",
table="tasks",
data={
"title": title,
"description": description or None,
"status": status,
"priority": priority,
"assignee": assignees,
"dueDate": due_date or None,
"projectId": project_id or None,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return (
f"Task created: '{row['title']}' "
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']}, projectId: {row.get('projectId')})"
)
@tool
async def update_task(
task_id: str,
title: str = "",
description: str = "",
status: str = "",
priority: str = "",
assignees: str = "",
due_date: int = -1,
project_id: str = "",
) -> str:
"""Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
completedAt is managed automatically:
- setting status to 'done' records the current timestamp
- changing status away from 'done' clears completedAt
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if description:
updates["description"] = description
if status:
updates["status"] = status
if priority:
updates["priority"] = priority
if assignees:
updates["assignee"] = assignees
if due_date != -1:
updates["dueDate"] = due_date or None
if project_id:
updates["projectId"] = project_id
result = await execute_on_client(
action="update",
table="tasks",
data={"id": task_id, "updates": updates},
)
row = result["row"]
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']}, projectId: {row.get('projectId')})"
@tool
async def delete_task(task_id: str) -> str:
"""Delete a task permanently by its UUID."""
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
return f"Task {task_id} deleted."
@tool
async def list_tasks_due_today(user_timezone: str = "UTC", include_done: bool = False) -> str:
"""List all tasks whose due date falls on today's date.
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
Always pass the user's timezone so 'today' is computed in their local time.
include_done: set True to also include already-completed tasks due today (default False).
"""
try:
from zoneinfo import ZoneInfo
tz = ZoneInfo(user_timezone or "UTC")
except Exception:
tz = timezone.utc
now_local = datetime.now(tz=tz)
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
start_ms = int(start_dt.timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1
filters: dict[str, Any] = {"dueDateFrom": start_ms, "dueDateTo": end_ms}
if not include_done:
filters["status"] = "todo"
result = await execute_on_client(
action="select",
table="tasks",
filters=filters,
)
rows = result.get("rows", [])
if not rows:
return "No tasks are due today."
lines = [
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, "
f"projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows
]
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
# ── Task comment tools ────────────────────────────────────────────────
@tool
async def list_task_comments(task_id: str) -> str:
"""List all comments on a task by its UUID."""
result = await execute_on_client(
action="select",
table="taskComments",
filters={"taskId": task_id},
)
rows = result.get("rows", [])
if not rows:
return f"No comments found for task {task_id}."
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
@tool
async def add_task_comment(task_id: str, author: str, content: str) -> str:
"""Add a comment to a task.
task_id: UUID of the task to comment on
author: name or ID of the comment author
content: comment text
"""
result = await execute_on_client(
action="insert",
table="taskComments",
data={"taskId": task_id, "author": author, "content": content},
)
row = result.get("row", {})
row_author = row.get("author", author)
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool
async def delete_task_comment(comment_id: str) -> str:
"""Delete a task comment by its UUID."""
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
return f"Comment {comment_id} deleted."
# ── Agent ─────────────────────────────────────────────────────────────
TASK_TOOLS: list[Any] = [
list_tasks,
count_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]
TASK_READ_TOOLS: list[Any] = [
list_tasks,
count_tasks,
list_tasks_due_today,
list_task_comments,
]

View File

@@ -1,270 +0,0 @@
"""Timeline agent — project milestone management (list, create, update, delete)."""
from __future__ import annotations
import re
from datetime import datetime, timezone
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
@tool
async def list_timelines(
project_id: str = "",
type: str = "",
is_completed: int = -1,
is_ai_suggested: int = -1,
order_by: str = "",
order_dir: str = "",
date_from: int = -1,
date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
limit: int = 50,
offset: int = 0,
) -> str:
"""List timeline events (milestones, checkpoints, activities) with optional filters.
project_id: UUID to scope results to a specific project.
type: filter by event type — milestone | checkpoint | activity.
is_completed: 0 = incomplete only, 1 = completed only, -1 = any (default).
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
order_by: sort field — date (default) | createdAt | completedAt.
order_dir: asc (default) | desc.
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
limit: max rows to return (default 50). Use with offset to paginate.
offset: skip first N rows (default 0).
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
Tip — prefer count_timelines for "how many" questions to avoid listing rows.
Tip — for natural-language windows ("today", "this week", "last month", etc.)
take date_from / date_to verbatim from the DATE CONTEXT block in the system prompt;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
filters: dict[str, Any] = {
"projectId": normalized_project_id or None,
"orderBy": order_by or None,
"orderDir": order_dir or None,
"limit": limit,
"offset": offset,
}
if type:
filters["type"] = type
if is_completed != -1:
filters["isCompleted"] = is_completed
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
if date_from != -1:
filters["dateFrom"] = date_from
if date_to != -1:
filters["dateTo"] = date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
result = await execute_on_client(action="select", table="timelines", filters=filters)
rows = result.get("rows", [])
if not rows:
return "No timeline events found."
lines = [
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
f"completed: {bool(r.get('isCompleted'))}, completedAt: {r.get('completedAt')}, "
f"projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} timeline event(s):\n" + "\n".join(lines)
@tool
async def count_timelines(
project_id: str = "",
type: str = "",
is_completed: int = -1,
is_ai_suggested: int = -1,
date_from: int = -1,
date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
) -> str:
"""Count timeline events matching the given filters without returning rows.
Use this instead of list_timelines for "how many" questions — it is much cheaper.
Same filter parameters as list_timelines (no limit/offset/order_by needed).
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
Tip — for natural-language windows take date_from / date_to from the DATE CONTEXT block;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
filters: dict[str, Any] = {"projectId": normalized_project_id or None}
if type:
filters["type"] = type
if is_completed != -1:
filters["isCompleted"] = is_completed
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
if date_from != -1:
filters["dateFrom"] = date_from
if date_to != -1:
filters["dateTo"] = date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
result = await execute_on_client(action="count", table="timelines", filters=filters)
return f"Timeline event count: {result.get('count', 0)}"
@tool
async def create_timeline(
project_id: str,
title: str,
date: int,
type: str = "milestone",
is_completed: int = 0,
is_ai_suggested: int = 0,
) -> str:
"""Create a project timeline event.
project_id: REQUIRED UUID of the parent project
title: descriptive name for the event
date: Unix timestamp in milliseconds for the event date
type: milestone (default) | checkpoint | activity
is_completed: 1 if already completed, 0 if not (default 0)
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
completedAt is set automatically when is_completed is 1.
"""
result = await execute_on_client(
action="insert",
table="timelines",
data={
"projectId": project_id,
"title": title,
"date": date,
"type": type,
"isCompleted": is_completed,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return f"Timeline event created: '{row['title']}' (id: {row['id']}, date: {row['date']}, type: {row.get('type')})"
@tool
async def update_timeline(
timeline_id: str,
title: str = "",
date: int = -1,
is_completed: int = -1,
) -> str:
"""Update a timeline event. Only pass fields that should change.
timeline_id: UUID of the event (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp)
is_completed: 0 = mark incomplete, 1 = mark complete, -1 = unchanged
completedAt is managed automatically:
- setting is_completed to 1 records the current timestamp
- setting is_completed to 0 clears completedAt
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
if is_completed != -1:
updates["isCompleted"] = is_completed
result = await execute_on_client(
action="update",
table="timelines",
data={"id": timeline_id, "updates": updates},
)
row = result["row"]
return f"Timeline event updated: '{row['title']}' (id: {row['id']})"
@tool
async def delete_timeline(timeline_id: str) -> str:
"""Delete a timeline event permanently by its UUID."""
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
return f"Timeline event {timeline_id} deleted."
@tool
async def list_timelines_today(user_timezone: str = "UTC", include_completed: bool = True) -> str:
"""List all timeline events whose date falls on today.
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
Always pass the user's timezone so 'today' is computed in their local time.
include_completed: set False to exclude already-completed events (default True).
"""
try:
from zoneinfo import ZoneInfo
tz = ZoneInfo(user_timezone or "UTC")
except Exception:
tz = timezone.utc
now_local = datetime.now(tz=tz)
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
start_ms = int(start_dt.timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1
filters: dict[str, Any] = {"dateFrom": start_ms, "dateTo": end_ms}
if not include_completed:
filters["isCompleted"] = 0
result = await execute_on_client(
action="select",
table="timelines",
filters=filters,
)
rows = result.get("rows", [])
if not rows:
return "No timeline events today."
lines = [
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
f"completed: {bool(r.get('isCompleted'))}, projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows
]
return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines)
TIMELINE_TOOLS: list[Any] = [
list_timelines,
count_timelines,
list_timelines_today,
create_timeline,
update_timeline,
delete_timeline,
]
TIMELINE_READ_TOOLS: list[Any] = [
list_timelines,
count_timelines,
list_timelines_today,
]

View File

@@ -1,14 +0,0 @@
"""Shared FastAPI dependencies.
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
(the canonical location per Step 9). This module re-exports them so that all
existing route imports (``from app.api.deps import get_current_user``) continue
to work without modification.
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
instead of reading it from the JWT payload.
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
__all__ = ["get_current_user", "oauth2_scheme"]

View File

@@ -1,19 +0,0 @@
"""API middleware package.
Exports the three middleware components introduced in Step 9:
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
- Sanitizer: ``SanitizerMiddleware``
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
from app.api.middleware.sanitizer import SanitizerMiddleware
__all__ = [
"get_current_user",
"oauth2_scheme",
"TierRateLimitMiddleware",
"limiter",
"SanitizerMiddleware",
]

View File

@@ -1,103 +0,0 @@
"""Auth middleware — JWT validation dependency.
``get_current_user`` is the FastAPI dependency used by all protected routes.
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
from the ``subscriptions`` table so that tier changes take effect immediately
without requiring token re-issue.
Exempt routes (no JWT required):
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.db import get_session
from app.schemas import UserProfile
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
The JWT is used for identity and expiry only. The tier is fetched live
from the ``subscriptions`` table so that upgrades/downgrades take effect
immediately. Falls back to ``'free'`` when no subscription row exists.
Raises HTTP 401 on any invalid or expired token.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
# Live tier lookup — subscription row is the authoritative source.
# In dev, fall back to 'power' (unlimited) so quota limits don't
# block local development when no Stripe subscription exists.
from app.models import Subscription, User # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
user_result = await db.execute(
select(
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
User.password_hash,
).where(User.id == user_id)
)
user_row = user_result.one_or_none()
# Convert onboarding_completed_at to epoch ms (int) or None.
onboarding_ms: int | None = None
if user_row and user_row.onboarding_completed_at is not None:
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
# Load decrypted core memory.
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
memory_dict: dict[str, str] = {}
try:
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(user_id)
memory_dict = {b["label"]: b["value"] for b in blocks}
except Exception:
pass # Non-critical — return empty memory on failure
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
avatar_url=user_row.avatar_url if user_row else None,
has_password=bool(user_row.password_hash) if user_row else False,
tier=tier,
onboarding_completed_at=onboarding_ms,
memory=memory_dict,
) # type: ignore[arg-type]

View File

@@ -1,129 +0,0 @@
"""Tier-aware rate limiting middleware.
Uses a per-user sliding-window counter (in-process, no Redis required).
The ``slowapi`` Limiter is also exported for optional route-level decoration.
Limits (requests per minute):
- free: 20
- pro: 60
- power: 120
- team: 200
Exempt paths bypass the limiter entirely:
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
- GET /api/v1/health
"""
from __future__ import annotations
import json
import time
from collections import defaultdict
from fastapi import Request, Response
from jose import JWTError, jwt
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.config.settings import settings
_TIER_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/billing/webhook",
"/api/v1/health",
}
)
def _get_user_id_from_jwt(request: Request) -> str:
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return get_remote_address(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
return payload.get("sub") or get_remote_address(request)
except JWTError:
return get_remote_address(request)
# Exported Limiter instance — available for optional route-level decoration.
limiter = Limiter(key_func=_get_user_id_from_jwt)
class TierRateLimitMiddleware(BaseHTTPMiddleware):
"""Sliding-window rate limiter applied globally across all non-exempt routes.
Each authenticated user gets their own 60-second window sized by tier.
Unauthenticated requests pass through (the auth dependency will reject them
with 401 before the route handler runs).
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
# user_id → list of request timestamps (float, seconds since epoch)
self._window: dict[str, list[float]] = defaultdict(list)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
if request.url.path in _EXEMPT_PATHS:
return await call_next(request)
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return await call_next(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str = payload.get("sub") or get_remote_address(request)
tier: str = payload.get("tier", "free")
except JWTError:
return await call_next(request)
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
now = time.monotonic()
window_start = now - 60.0
# Slide the window: discard timestamps older than 60 seconds.
timestamps = [t for t in self._window[user_id] if t > window_start]
if len(timestamps) >= limit:
retry_after = max(1, int(60 - (now - min(timestamps))))
return Response(
content=json.dumps(
{
"detail": (
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
f"Retry in {retry_after}s."
)
}
),
status_code=429,
headers={
"Retry-After": str(retry_after),
"Content-Type": "application/json",
},
)
timestamps.append(now)
self._window[user_id] = timestamps
return await call_next(request)

View File

@@ -1,138 +0,0 @@
"""Response sanitizer middleware.
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
that could reveal server-side prompt IP:
- System prompt openers ("You are a/an/the …")
- Agent routing metadata ("Available agents:", "intent classifier", …)
- LangChain tool schema fragments (``"type": "function"``)
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints
The middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified.
"""
from __future__ import annotations
import json
import logging
import re
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Detection patterns — order matters: fingerprints checked first (exact),
# then compiled regexes.
# ---------------------------------------------------------------------------
_FINGERPRINTS: tuple[str, ...] = (
"You are an intent classifier",
"Respond with just the agent name",
"Summarize these agent results",
"Available agents:",
"route to:",
)
_PATTERNS: tuple[re.Pattern[str], ...] = (
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
re.compile(r"Available agents\s*:", re.IGNORECASE),
re.compile(r"\bintent classifier\b", re.IGNORECASE),
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
re.compile(r"route\s+to\s*:", re.IGNORECASE),
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
)
def _sanitize_text(text: str) -> tuple[str, bool]:
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
Returns ``(cleaned_text, was_changed)``.
"""
# Fingerprint check — if any exact phrase is present, redact the whole string.
for fp in _FINGERPRINTS:
if fp in text:
return "[REDACTED]", True
changed = False
for pattern in _PATTERNS:
new_text, n = pattern.subn("[REDACTED]", text)
if n:
text = new_text
changed = True
return text, changed
class SanitizerMiddleware(BaseHTTPMiddleware):
"""Strip prompt IP from /api/v1/chat JSON responses."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
response: Response = await call_next(request)
# Only process chat endpoint responses.
if not request.url.path.startswith("/api/v1/chat"):
return response
# Read body — collect streaming chunks.
body_bytes = b""
async for chunk in response.body_iterator:
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
try:
body = json.loads(body_bytes.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
if not isinstance(body, dict):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Walk top-level string fields and sanitise.
sanitised_fields: list[str] = []
for key, value in body.items():
if isinstance(value, str):
cleaned, changed = _sanitize_text(value)
if changed:
body[key] = cleaned
sanitised_fields.append(key)
if sanitised_fields:
logger.warning(
"Sanitizer redacted prompt fragments",
extra={
"path": request.url.path,
"fields": sanitised_fields,
},
)
new_body = json.dumps(body).encode("utf-8")
headers = dict(response.headers)
headers["content-length"] = str(len(new_body))
return Response(
content=new_body,
status_code=response.status_code,
headers=headers,
media_type="application/json",
)

View File

@@ -1,513 +0,0 @@
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
The journey is driven entirely through WebSocket frames (no REST endpoints).
The device WS handler dispatches ``journey_start`` and ``journey_message``
frames to the functions exported here.
Journey flow:
1. FE sends ``journey_start`` frame with basic agent info (directory,
data_types, schedule).
2. Server creates an in-memory session, sets up a WS executor so the
setup LLM can use file-system tools, does a first directory scrape,
and sends back a ``journey_reply`` with the first question.
3. FE sends ``journey_message`` frames for each user reply.
4. Server appends the user message, calls the LLM (which may read files
via tools), and sends back a ``journey_reply``.
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
6. Server parses and validates the JSON with Pydantic, sends
``journey_reply`` with ``done=True`` and the serialised config.
FE stores it locally.
"""
from __future__ import annotations
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.filesystem_agent import make_directory_tools
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent
from app.schemas import AgentConfig
logger = logging.getLogger(__name__)
# ── Session TTL ───────────────────────────────────────────────────────────
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
_CONFIG_START = "AGENT_CONFIG_START"
_CONFIG_END = "AGENT_CONFIG_END"
# Minimum turns before we consider nudging the LLM to wrap up.
_MIN_TURNS_BEFORE_NUDGE: int = 3
# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion).
_MAX_TURNS: int = 15
# Max tool-calling steps per LLM invocation.
_MAX_TOOL_STEPS: int = 6
# ── In-memory session store ───────────────────────────────────────────────
@dataclass
class JourneySession:
session_id: str
user_id: str
agent_type: str # "local" | "cloud"
directory: str
data_types: list[str]
history: list[dict[str, Any]] = field(default_factory=list)
system_prompt: str = ""
langfuse_prompt: Any = None
created_at: float = field(default_factory=time.monotonic)
def is_expired(self) -> bool:
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
# session_id → session
_sessions: dict[str, JourneySession] = {}
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
"""Retrieve session; return None on missing, expired, or wrong owner."""
s = _sessions.get(session_id)
if s is None or s.is_expired():
_sessions.pop(session_id, None)
return None
if s.user_id != user_id:
return None
return s
# ── System prompt ─────────────────────────────────────────────────────────
_JOURNEY_SYSTEM_PROMPT = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand what files the user has in their directory and produce a
structured AgentConfig JSON that the extraction agent will use as its instruction set.
You have access to file-system tools to explore the user's directory:
- list_directory: see folder structure and file names
- read_file_content: peek at a file's content
- get_file_metadata: check file size, extension, dates
The user's configured directory is: {directory}
Target data types: {data_types}
## Your process
### Step 1 — Explore the directory
Use list_directory and read_file_content to understand what types of files are present
(HTML emails, plain-text documents, CSVs, etc.).
### Step 2 — Identify content types
For each distinct file type found, decide:
- A short id (e.g. "email_html", "plain_text", "csv")
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
- A human-readable label and optional detection_hint
### Step 3 — Ask focused questions (one at a time)
Cover these topics based on what you discovered:
1. How to map content to entity types (task / note / timeline entry)
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
3. Priority or status rules (e.g. "urgent" in subject → high priority)
4. Date extraction (e.g. "by Friday" → dueDate)
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
### Step 4 — Produce the AgentConfig JSON
Once you are ≥ 90% confident, output the final config between these exact markers
(each on its own line):
{config_start}
{{
"content_types": [
{{
"id": "email_html",
"label": "Email HTML",
"detection_hint": "HTML file with From/To/Subject headers",
"preprocessing": "email_html",
"extraction_prompt": "Detailed extraction instructions for this content type..."
}}
],
"global_rules": [
"If the file cannot be matched to any project, do not create any entity."
],
"data_types": {data_types_json}
}}
{config_end}
## Rules for the extraction_prompt field
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
- Include field mapping rules based on what you found in the directory
- Include priority/status/date rules if applicable
- Do NOT include projectId logic — the runner handles project assignment automatically
- Do NOT mention isAiSuggested — the runner always sets it to 1
## Constraints
- Never ask about projects, projectId, or how to link records to projects
- Never include projectId or project creation logic in the generated config
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
{existing_section}\
Begin by exploring the directory, then ask your first question.\
"""
def _build_system_prompt(
directory: str,
data_types: list[str],
existing_config: str | None = None,
) -> tuple[str, Any]:
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
existing_section = (
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
f"```json\n{existing_config}\n```\n"
if existing_config
else ""
)
template, prompt_obj = get_prompt_or_fallback(
"journey_system", _JOURNEY_SYSTEM_PROMPT
)
compiled = compile_prompt(
template,
prompt_obj,
directory=directory,
data_types=", ".join(data_types),
data_types_json=json.dumps(data_types),
config_start=_CONFIG_START,
config_end=_CONFIG_END,
existing_section=existing_section,
)
return compiled, prompt_obj
# ── AgentConfig extraction ────────────────────────────────────────────────
def _extract_agent_config(text: str) -> str | None:
"""Return validated AgentConfig JSON string from between markers, or None.
Parses the JSON with Pydantic to ensure it conforms to the schema before
returning. Returns None if markers are absent or JSON is invalid.
"""
if _CONFIG_START not in text or _CONFIG_END not in text:
return None
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
end_idx = text.index(_CONFIG_END)
raw = text[start_idx:end_idx].strip()
if not raw:
return None
try:
parsed = AgentConfig.model_validate_json(raw)
return parsed.model_dump_json()
except Exception as exc:
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
return None
# ── LLM call with tool support ───────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _call_llm_with_tools(
system_prompt: str,
history: list[dict[str, Any]],
tools: list[Any],
*,
user_id: str = "",
session_id: str = "",
langfuse_prompt: Any = None,
) -> str:
"""Build LangChain messages from history and invoke the LLM with tools.
Handles tool-calling loops: if the LLM calls tools, execute them and
continue until a final text response is produced.
"""
lf = get_langfuse()
messages: list[Any] = [SystemMessage(content=system_prompt)]
for turn in history:
if turn["role"] == "user":
messages.append(HumanMessage(content=turn["content"]))
else:
messages.append(AIMessage(content=turn["content"]))
llm = get_agent_llm("setup", temperature=0.4)
llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools}
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
_lf_ctx.__enter__()
_span_ctx = (
lf.start_as_current_observation(
as_type="span",
name="journey-setup",
input=history[-1]["content"] if history else "",
)
if lf else None
)
_span = _span_ctx.__enter__() if _span_ctx else None
try:
for step in range(_MAX_TOOL_STEPS):
_gen_ctx = (
lf.start_as_current_observation(
as_type="generation",
name="journey-setup-llm",
model=model_for_agent("setup"),
prompt=langfuse_prompt,
input=messages,
)
if lf else None
)
_gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx:
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
_gen_ctx.__exit__(None, None, None)
resp_text = _as_text(response.content)
# Guard against empty responses (e.g. model returned finish_reason
# 'error' which LiteLLM maps to 'stop' with empty content).
if not response.tool_calls and not resp_text.strip():
logger.warning(
"agent_setup: journey LLM returned empty response at step %d — retrying",
step,
)
# Drop the empty AIMessage so we don't pollute history, and retry.
continue
messages.append(response)
if not response.tool_calls:
if _span:
_span.update(output=resp_text)
return resp_text
for call in response.tool_calls:
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_setup: journey tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:500],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_setup: journey tool_result name=%s output=%s",
call_name,
str(tool_output)[:800],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max steps.
final = await llm.ainvoke(messages)
final_text = _as_text(final.content)
if _span:
_span.update(output=final_text)
return final_text or (
"Sorry, I had trouble processing the files. "
"Could you try again? If the issue persists, the files might be too large for me to analyse."
)
finally:
if _span_ctx:
_span_ctx.__exit__(None, None, None)
_lf_ctx.__exit__(None, None, None)
if lf:
lf.flush()
# ── Journey handlers (called from device_ws.py) ──────────────────────────
async def handle_journey_start(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_start`` WS frame.
Creates a session, runs the setup LLM with directory exploration,
and returns the ``journey_reply`` payload.
"""
agent_type = frame.get("agent_type", "local")
directory = frame.get("directory", "")
data_types = frame.get("data_types", [])
existing_config = frame.get("existing_config")
# Use the session_id provided by the FE so the reply matches the
# listener key; fall back to a generated one if absent.
session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
session = JourneySession(
session_id=session_id,
user_id=user_id,
agent_type=agent_type,
directory=directory,
data_types=data_types,
system_prompt=system_prompt,
langfuse_prompt=langfuse_prompt,
)
# Seed with an initial user message — some providers require at least one
# user/input message to be present.
seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
]
ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt,
history=seed_history,
tools=make_directory_tools(directory),
user_id=user_id,
session_id=session_id,
langfuse_prompt=langfuse_prompt,
)
session.history.extend(seed_history)
session.history.append({"role": "assistant", "content": ai_reply})
_sessions[session_id] = session
logger.info(
"agent_setup: journey session %s started for user %s (directory=%s)",
session_id,
user_id,
directory,
)
# Check if the LLM produced the config on the first turn (unlikely but possible).
agent_config = _extract_agent_config(ai_reply)
done = agent_config is not None
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
or "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"agent_config": agent_config,
}
async def handle_journey_message(
user_id: str,
frame: dict[str, Any],
) -> dict[str, Any]:
"""Handle a ``journey_message`` WS frame.
Appends the user message, calls the LLM, and returns the
``journey_reply`` payload.
"""
session_id = frame.get("session_id", "")
message = frame.get("message", "")
session = get_journey_session(session_id, user_id)
if session is None:
return {
"type": "journey_reply",
"session_id": session_id,
"message": "Journey session not found or expired. Please start a new setup.",
"done": True,
"agent_config": None,
}
# Append user turn.
session.history.append({"role": "user", "content": message})
# Call the LLM with tools.
session_tools = make_directory_tools(session.directory)
ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=session_tools,
user_id=session.user_id,
session_id=session_id,
langfuse_prompt=session.langfuse_prompt,
)
session.history.append({"role": "assistant", "content": ai_reply})
# Check if the LLM produced the final config.
agent_config = _extract_agent_config(ai_reply)
done = agent_config is not None
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
if not done:
turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
nudge_content = (
"[System: You have enough information. Please generate the final "
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})
nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=session_tools,
user_id=session.user_id,
session_id=session_id,
langfuse_prompt=session.langfuse_prompt,
)
session.history.append({"role": "assistant", "content": nudge_reply})
agent_config = _extract_agent_config(nudge_reply)
if agent_config is not None:
done = True
ai_reply = nudge_reply
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
if _CONFIG_START in ai_reply
else "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"agent_config": agent_config,
}

View File

@@ -1,257 +0,0 @@
"""Agent routes.
Backend responsibilities are intentionally minimal:
GET /agents/catalog — static catalog for UI display
POST /agents/can-create — billing eligibility check
POST /agents/trigger — trigger a local agent run
Agent configuration is owned by the Electron app and is not persisted
in backend agent-config tables.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES
from app.core.agent_runner import is_agent_running, run_local_agent
from app.core.device_manager import device_manager
from app.core.note_summarizer import generate_note_summary
from app.db import get_session
from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import (
AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse,
AgentTriggerRequest,
UserProfile,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Datetime helpers ──────────────────────────────────────────────────
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
return AgentRunLogResponse(
id=log.id,
agent_id=log.agent_id,
agent_type=log.agent_type, # type: ignore[arg-type]
status=log.status, # type: ignore[arg-type]
items_processed=log.items_processed,
items_created=log.items_created,
errors=log.errors or [],
started_at=_dt_ms(log.started_at),
completed_at=_dt_ms_opt(log.completed_at),
)
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog", response_model=list[AgentCatalogItem])
async def get_agent_catalog(
current_user: UserProfile = Depends(get_current_user),
) -> list[AgentCatalogItem]:
"""Return the static list of available agent types and their descriptions."""
return [
AgentCatalogItem(
type="local_directory",
name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI",
),
AgentCatalogItem(
type="gmail",
name="Gmail Connector",
description="Scans Gmail inbox, extracts tasks/notes from emails",
),
AgentCatalogItem(
type="teams",
name="Microsoft Teams Connector",
description="Monitors Teams messages, extracts action items",
),
AgentCatalogItem(
type="outlook",
name="Outlook Connector",
description="Scans Outlook inbox, extracts tasks/notes",
),
]
@router.post("/can-create", response_model=AgentCreationCheckResponse)
async def can_create_agent(
body: AgentCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
) -> AgentCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return AgentCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: AgentTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse:
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
last_run_dt = (
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
if body.last_run_at
else None
)
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id=body.device_id,
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt or "",
agent_config=body.agent_config,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
last_run_at=last_run_dt,
)
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running. Only one run per agent is allowed at a time.",
)
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=current_user.id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_context = {
"type": "agent_batch",
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
)
return _to_run_log_response(run_log)
# ── Note summary endpoint ──────────────────────────────────────────────────────
class NoteSummarizeRequest(BaseModel):
title: str
content: str
class NoteSummarizeResponse(BaseModel):
summary: str
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
async def summarize_note(
body: NoteSummarizeRequest,
current_user: UserProfile = Depends(get_current_user),
) -> NoteSummarizeResponse:
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
summary = await generate_note_summary(body.title, body.content)
return NoteSummarizeResponse(summary=summary)

View File

@@ -1,795 +0,0 @@
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
SHA-256 hashes so plaintext never reaches the DB.
OAuth (Google):
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
"""
from __future__ import annotations
import hashlib
import json
import time
import urllib.parse
import uuid
from datetime import datetime, timedelta, timezone
from typing import Literal
import bcrypt
from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import RedirectResponse
from jose import jwt
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
from app.config.settings import settings
from app.core.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.models import OAuthAccount, RefreshToken, User
from app.schemas import AuthTokens, UserProfile
router = APIRouter(prefix="/auth", tags=["auth"])
# ── OAuth provider registry ───────────────────────────────────────────
def _get_google_provider() -> GoogleOAuthProvider:
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
raise HTTPException(
status.HTTP_503_SERVICE_UNAVAILABLE,
"Google login is not configured on this server",
)
return GoogleOAuthProvider(
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
redirect_uri=settings.OAUTH_REDIRECT_URI,
)
_PROVIDERS = {"google": _get_google_provider}
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
# Production note: replace with Redis for multi-process deployments.
_pending_states: dict[str, tuple[str, float]] = {}
_STATE_TTL_SECONDS = 600 # 10 minutes
# ── Internal helpers ─────────────────────────────────────────────────
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def _verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def _hash_token(plain_token: str) -> str:
"""SHA-256 of the plain refresh token string."""
return hashlib.sha256(plain_token.encode()).hexdigest()
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"""Return (signed JWT, expires_at_ms)."""
now = int(time.time())
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
payload = {
"sub": user_id,
"email": email,
"tier": tier,
"exp": exp,
"iat": now,
}
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
return token, exp * 1000 # ms for client
# ── Request bodies ────────────────────────────────────────────────────
class _RegisterRequest(BaseModel):
email: str
password: str
name: str | None = None
surname: str | None = None
class _LoginRequest(BaseModel):
email: str
password: str
class _RefreshRequest(BaseModel):
refresh_token: str
# ── Routes ────────────────────────────────────────────────────────────
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
async def register(
body: _RegisterRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Create a new account and return JWT tokens."""
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
user = User(
id=str(uuid.uuid4()),
email=body.email,
name=body.name,
surname=body.surname,
password_hash=_hash_password(body.password),
tier="free",
encryption_key=Fernet.generate_key().decode(),
)
db.add(user)
await db.flush() # get user.id without committing
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/login", response_model=AuthTokens)
async def login(
body: _LoginRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate credentials and return JWT tokens."""
result = await db.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not _verify_password(body.password, user.password_hash):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/refresh", response_model=AuthTokens)
async def refresh(
body: _RefreshRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Rotate a refresh token and return a new token pair."""
token_hash = _hash_token(body.refresh_token)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
rt = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
# Rotate: delete old token, issue new one.
await db.delete(rt)
user_result = await db.execute(select(User).where(User.id == rt.user_id))
user = user_result.scalar_one_or_none()
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
plain_token = str(uuid.uuid4())
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=new_expires,
)
db.add(new_rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
class _UpdateProfileRequest(BaseModel):
name: str | None = None
surname: str | None = None
@router.get("/me", response_model=UserProfile)
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
"""Return the profile for the authenticated user."""
return current_user
@router.put("/me", response_model=UserProfile)
async def update_profile(
body: _UpdateProfileRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update the authenticated user's name and surname."""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
if body.name is not None:
user.name = body.name
if body.surname is not None:
user.surname = body.surname
await db.commit()
await db.refresh(user)
return UserProfile(
id=user.id,
email=user.email,
name=user.name,
surname=user.surname,
avatar_url=user.avatar_url,
tier=current_user.tier,
)
# ── OAuth helpers ─────────────────────────────────────────────────────
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
"""Create a refresh token row and return (plain_token, AuthTokens)."""
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return plain_token, AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
# ── OAuth request/response schemas ───────────────────────────────────
class _OAuthAuthorizeResponse(BaseModel):
url: str
state: str
class _OAuthCallbackRequest(BaseModel):
code: str
state: str
# ── OAuth routes ──────────────────────────────────────────────────────
@router.get(
"/oauth/{provider}/web-callback",
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
include_in_schema=False,
)
async def oauth_web_callback(
provider: Literal["google"],
code: str,
state: str,
) -> RedirectResponse:
"""Google redirects here after user consent.
This endpoint immediately redirects to the Electron deep-link URI so the
desktop app receives the authorization code. It is intentionally simple —
no state validation here (the Electron app + backend callback do that).
Registered in Google Cloud Console as:
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
"""
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
deep_link = f"adiuvai://oauth/callback?{params}"
return RedirectResponse(url=deep_link, status_code=302)
@router.get(
"/oauth/{provider}/authorize",
response_model=_OAuthAuthorizeResponse,
summary="Start OAuth flow — returns the provider consent-screen URL",
)
async def oauth_authorize(
provider: Literal["google"],
) -> _OAuthAuthorizeResponse:
"""Generate a PKCE state + code_challenge and return the authorization URL.
The client opens this URL in the system browser. After the user grants
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
with ``code`` and ``state`` query params. The client then calls
``POST /auth/oauth/{provider}/callback`` with those values.
"""
provider_factory = _PROVIDERS.get(provider)
if provider_factory is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
oauth_provider = provider_factory()
state = str(uuid.uuid4())
code_verifier, code_challenge = generate_pkce_pair()
# Purge expired states to prevent unbounded growth.
now = time.time()
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
for s in expired:
del _pending_states[s]
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
return _OAuthAuthorizeResponse(url=url, state=state)
@router.post(
"/oauth/{provider}/callback",
response_model=AuthTokens,
summary="Complete OAuth flow — exchange code and issue JWT tokens",
)
async def oauth_callback(
provider: Literal["google"],
body: _OAuthCallbackRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate state, exchange the authorization code, and sign in (or register) the user.
Resolution order:
1. ``oauth_accounts`` row match → existing user, log in.
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
3. No match → create new user (password_hash=None, avatar from provider).
"""
provider_factory = _PROVIDERS.get(provider)
if provider_factory is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
# Validate state (CSRF protection).
now = time.time()
entry = _pending_states.pop(body.state, None)
if entry is None or entry[1] < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
code_verifier, _ = entry
oauth_provider = provider_factory()
# Exchange code for tokens.
try:
token_data = await oauth_provider.exchange_code(
code=body.code,
code_verifier=code_verifier,
redirect_uri=settings.OAUTH_REDIRECT_URI,
)
except Exception:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
)
access_token_google = token_data.get("access_token")
if not access_token_google:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
# Fetch user identity.
try:
userinfo = await oauth_provider.get_userinfo(access_token_google)
except Exception:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
# ── Resolution order ──────────────────────────────────────────────
# 1. Existing OAuth link?
oauth_result = await db.execute(
select(OAuthAccount).where(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == userinfo.provider_user_id,
)
)
oauth_account = oauth_result.scalar_one_or_none()
if oauth_account is not None:
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
user = user_result.scalar_one()
# Backfill avatar if the user doesn't have one yet.
if user.avatar_url is None and userinfo.avatar_url:
user.avatar_url = userinfo.avatar_url
await db.commit()
plain_token, tokens = await _issue_refresh_token(user, db)
await db.commit()
return tokens
# 2. Email match with a verified Google email → link accounts.
if userinfo.email_verified:
email_result = await db.execute(select(User).where(User.email == userinfo.email))
existing_user = email_result.scalar_one_or_none()
if existing_user is not None:
new_link = OAuthAccount(
user_id=existing_user.id,
provider=provider,
provider_user_id=userinfo.provider_user_id,
provider_email=userinfo.email,
)
db.add(new_link)
if existing_user.avatar_url is None and userinfo.avatar_url:
existing_user.avatar_url = userinfo.avatar_url
plain_token, tokens = await _issue_refresh_token(existing_user, db)
await db.commit()
return tokens
# Guard: if the email is already taken but we couldn't auto-link (e.g.
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
if not userinfo.email_verified:
conflict = await db.execute(select(User).where(User.email == userinfo.email))
if conflict.scalar_one_or_none() is not None:
raise HTTPException(
status.HTTP_409_CONFLICT,
"An account with this email already exists. "
"Please sign in with your password.",
)
# 3. New user — social-only account (no password).
new_user = User(
id=str(uuid.uuid4()),
email=userinfo.email,
name=userinfo.name,
password_hash=None,
avatar_url=userinfo.avatar_url,
tier="free",
encryption_key=Fernet.generate_key().decode(),
)
db.add(new_user)
await db.flush() # populate new_user.id
new_oauth = OAuthAccount(
user_id=new_user.id,
provider=provider,
provider_user_id=userinfo.provider_user_id,
provider_email=userinfo.email,
)
db.add(new_oauth)
plain_token, tokens = await _issue_refresh_token(new_user, db)
await db.commit()
return tokens
# ── Onboarding helpers ────────────────────────────────────────────────
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
# We can't call the FastAPI dependency directly, but we can replicate
# the core logic inline. Instead, we just re-query the same way.
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
user_result = await db.execute(
select(
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
User.password_hash,
).where(User.id == user_id)
)
user_row = user_result.one_or_none()
onboarding_ms: int | None = None
if user_row and user_row.onboarding_completed_at is not None:
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
memory_dict: dict[str, str] = {}
try:
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(user_id)
memory_dict = {b["label"]: b["value"] for b in blocks}
except Exception:
pass
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
avatar_url=user_row.avatar_url if user_row else None,
has_password=bool(user_row.password_hash) if user_row else False,
tier=tier,
onboarding_completed_at=onboarding_ms,
memory=memory_dict,
)
# ── Onboarding routes ────────────────────────────────────────────────
class _UpdateMemoryRequest(BaseModel):
memory: dict[str, str] = Field(default_factory=dict)
mark_onboarded: bool = False
@router.put("/me/memory", response_model=UserProfile)
async def update_memory(
body: _UpdateMemoryRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update core memory key/value pairs and optionally mark onboarding complete."""
mw = MemoryMiddleware(db)
for key, value in body.memory.items():
await mw.update_core(current_user.id, key, value)
if body.mark_onboarded:
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
user.onboarding_completed_at = datetime.now(timezone.utc)
await db.commit()
return await _build_profile(current_user.id, current_user.email, db)
@router.post("/me/onboarding/reset")
async def reset_onboarding(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
):
"""Reset onboarding so the wizard runs again on next login."""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
user.onboarding_completed_at = None
await db.commit()
return {"status": "reset"}
class _NormalizeRequest(BaseModel):
inputs: dict[str, str]
class _NormalizeResponse(BaseModel):
normalized: dict[str, str]
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
async def normalize_onboarding(
body: _NormalizeRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _NormalizeResponse:
"""One-shot LLM normalization for free-text onboarding answers."""
if not body.inputs:
return _NormalizeResponse(normalized={})
try:
llm = get_llm(model="gpt-4o-mini", temperature=0)
prompt = (
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
"Return a JSON object with the same keys and normalized values.\n"
"Examples: 'i build websites''Web Developer', 'tech-ish stuff''Technology'\n"
f"Input: {json.dumps(body.inputs)}"
)
response = await llm.ainvoke(
[
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
{"role": "user", "content": prompt},
],
)
normalized = json.loads(response.content)
return _NormalizeResponse(normalized=normalized)
except Exception:
# LLM failure must never block onboarding — return inputs unchanged
return _NormalizeResponse(normalized=body.inputs)
# ── Password management ───────────────────────────────────────────────
class _ChangePasswordRequest(BaseModel):
current_password: str = Field(min_length=1)
new_password: str = Field(min_length=8)
@router.put("/me/password", status_code=status.HTTP_200_OK)
async def change_password(
body: _ChangePasswordRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Change the authenticated user's password.
Requires the current password for verification.
Returns 400 for social-only users (no password set).
"""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
if user.password_hash is None:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"This account uses social login and has no password to change",
)
if not _verify_password(body.current_password, user.password_hash):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
user.password_hash = _hash_password(body.new_password)
await db.commit()
return {"ok": True}
# ── OAuth account management ─────────────────────────────────────────
@router.get("/me/oauth-accounts", response_model=list[dict])
async def list_oauth_accounts(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[dict]:
"""List all OAuth providers linked to the authenticated user."""
result = await db.execute(
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
)
accounts = result.scalars().all()
return [
{
"provider": a.provider,
"provider_email": a.provider_email,
"created_at": int(a.created_at.timestamp() * 1000),
}
for a in accounts
]
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
async def unlink_oauth_account(
provider: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unlink an OAuth provider from the authenticated user.
Refuses if the user has no password and this is their only login method.
"""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
oauth_result = await db.execute(
select(OAuthAccount).where(
OAuthAccount.user_id == current_user.id,
OAuthAccount.provider == provider,
)
)
account = oauth_result.scalar_one_or_none()
if account is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
# Safety: don't let users lock themselves out.
all_oauth = await db.execute(
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
)
oauth_count = len(all_oauth.scalars().all())
if user.password_hash is None and oauth_count <= 1:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Cannot unlink the only login method. Set a password first.",
)
await db.delete(account)
await db.commit()
return {"ok": True}
# ── Avatar update ─────────────────────────────────────────────────────
class _UpdateAvatarRequest(BaseModel):
avatar_url: str = Field(min_length=1)
@router.put("/me/avatar", response_model=UserProfile)
async def update_avatar(
body: _UpdateAvatarRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update the authenticated user's avatar URL.
Accepts {"avatar_url": "https://..."} — the client uploads the image
to its own storage and passes the resulting URL here.
"""
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
user.avatar_url = body.avatar_url
await db.commit()
return await _build_profile(current_user.id, current_user.email, db)
# ── Account deletion ─────────────────────────────────────────────────
@router.delete("/me", status_code=status.HTTP_200_OK)
async def delete_account(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Permanently delete the authenticated user's account.
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
is cancelled if active.
"""
# Cancel Stripe subscription if present.
try:
from app.billing.stripe_service import stripe_service # noqa: PLC0415
await stripe_service.cancel_subscription(current_user.id, db)
except HTTPException:
pass # No subscription — that's fine
# Delete all memory rows (core, associative, episodic, proactive).
try:
from app.models import ( # noqa: PLC0415
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
)
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
await db.execute(
model.__table__.delete().where(model.user_id == current_user.id)
)
except Exception:
pass # Non-critical — cascade on User will handle most
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
await db.delete(user)
await db.commit()
return {"ok": True}

View File

@@ -1,132 +0,0 @@
"""Billing routes: Stripe checkout, webhook, subscription management.
Business logic lives in ``app.billing.stripe_service.StripeService``.
The route layer handles HTTP concerns (request parsing, response shaping)
and delegates everything else to the service singleton.
"""
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.stripe_service import stripe_service
from app.db import get_session
from app.schemas import BillingTier, UserProfile
router = APIRouter(prefix="/billing", tags=["billing"])
# ── Request bodies ─────────────────────────────────────────────────────
class _CheckoutRequest(BaseModel):
tier: BillingTier
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/checkout", response_model=dict)
async def create_checkout(
body: _CheckoutRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, str]:
"""Create a Stripe checkout session for a tier upgrade.
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
"""
url = stripe_service.create_checkout_session(current_user.id, body.tier)
return {"checkout_url": url}
@router.post("/webhook", response_model=dict)
async def stripe_webhook(
request: Request,
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Handle Stripe webhook events.
No JWT auth — authenticated via Stripe signature verification instead.
Returns 200 immediately when Stripe is not configured (local dev).
"""
payload = await request.body()
await stripe_service.handle_webhook(payload, stripe_signature, db)
return {"ok": True}
@router.get("/subscription", response_model=dict)
async def get_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Return the current subscription info for the authenticated user."""
sub = await stripe_service.get_subscription(current_user.id, db)
if sub is None:
return {
"tier": current_user.tier,
"status": "free",
"stripe_subscription_id": None,
"current_period_end": None,
}
return sub
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
async def cancel_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Cancel the active subscription."""
await stripe_service.cancel_subscription(current_user.id, db)
return {"ok": True}
@router.get("/invoices", response_model=list[dict])
async def list_invoices(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[dict[str, Any]]:
"""Return billing history (invoices) from Stripe.
Returns an empty list when Stripe is not configured.
"""
invoices = await stripe_service.list_invoices(current_user.id, db)
return invoices
# ── Quota check ────────────────────────────────────────────────────────
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
class QuotaCheckRequest(BaseModel):
feature: str
estimated_files: int
@router.post("/quota/check")
async def quota_check(
payload: QuotaCheckRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict:
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
if payload.feature != "folder_index":
raise HTTPException(status_code=400, detail="Unknown feature")
try:
await check_folder_quota(
user_id=current_user.id,
tier=current_user.tier,
estimated_files=payload.estimated_files,
db=db,
)
except QuotaExceeded as exc:
raise HTTPException(
status_code=402,
detail={"reason": exc.reason, "message": str(exc)},
)
return {"ok": True}

View File

@@ -1,116 +0,0 @@
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
"""
from __future__ import annotations
import uuid
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_home
from app.core.llm import embed
from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"])
# ── Embed helpers ─────────────────────────────────────────────────────────
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("")
async def chat(
body: ChatRequest,
current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse:
"""REST fallback for home chat when websocket streaming is unavailable."""
response = await run_home(
user_id=current_user.id,
message=body.message,
context=body.context.model_dump(),
)
return JSONResponse(content={"response": response})
class _BriefRequest(BaseModel):
mode: Literal["home", "project"]
project_id: str | None = None
class _BriefResponse(BaseModel):
response: str
@router.post("/brief", response_model=_BriefResponse)
async def brief(
body: _BriefRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _BriefResponse:
"""REST fallback for brief when the device WebSocket is not ready."""
if body.mode == "project":
if not body.project_id:
raise HTTPException(status_code=422, detail="project_id required for project mode")
try:
uuid.UUID(body.project_id)
except ValueError:
raise HTTPException(status_code=422, detail="project_id must be a valid UUID")
request_id = str(uuid.uuid4())
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
current_user.id,
"",
trace_id=request_id,
session_id=request_id,
)
context: dict = {
"_debug": {"request_id": request_id, "user_id": current_user.id},
**memory_context,
}
chunks: list[str] = []
if body.mode == "project":
stream = run_project_brief(current_user.id, body.project_id, context) # type: ignore[arg-type]
else:
stream = run_home_brief(current_user.id, context)
async for event_type, data in stream:
if event_type == "token" and data:
chunks.append(str(data))
return _BriefResponse(response="".join(chunks))
@router.post("/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by Electron (vectordb.ts) for local note search.
"""
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

View File

@@ -1,845 +0,0 @@
"""Device WebSocket endpoint.
Persistent connection from Electron devices to the backend.
WS /api/v1/ws/device?token=<jwt>
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
available during the WebSocket handshake).
Protocol:
1. Client connects → JWT validated → connection accepted.
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
3. Backend registers the connection in ``DeviceConnectionManager``.
4. Session enters message dispatch loop + heartbeat.
Incoming frame dispatch:
- ``tool_result`` → resolves a pending tool-call Future.
- ``journey_start`` → starts a guided setup journey session.
- ``journey_message`` → continues a journey conversation.
- ``pong`` → heartbeat acknowledgement (updates last-seen).
- unknown types → logged, ignored.
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
On disconnect:
- Unregisters from DeviceConnectionManager.
- Marks all in-progress AgentRunLog rows for this user as ``error``
with message "device disconnected".
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.core.agent_session_buffer import session_buffer
from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream
from app.core.output_formatter import extract_canvas_block
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog
from app.schemas import WsFrameType, WsStreamEnd
from app.schemas.contextual import ContextualScope, render_scope_block
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"])
# ── v7 folder index session state ─────────────────────────────────────
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
_index_sessions: dict[str, dict] = {}
_HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
@router.websocket("/device")
async def device_ws(websocket: WebSocket) -> None:
"""Persistent WebSocket endpoint for Electron device connections.
Authentication is via ``?token=<jwt>`` query parameter.
"""
# ── 1. Authenticate before accepting ─────────────────────────────
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008) # Policy Violation
return
await websocket.accept()
# ── 2. Await device_hello frame ───────────────────────────────────
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
try:
hello = json.loads(raw)
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
await websocket.close(code=1008)
return
# ── 3. Register connection ────────────────────────────────────────
device_manager.register(user_id, device_id, websocket)
logger.info(
"device_ws: connected user=%s device=%s agents=%s",
user_id,
device_id,
agent_ids,
)
# Trigger any overdue agent runs now that the device is connected.
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
# ── 4. Concurrent message loop + heartbeat ────────────────────────
try:
await asyncio.gather(
_message_loop(websocket, user_id),
_heartbeat_loop(websocket),
)
except WebSocketDisconnect:
pass
except Exception as exc:
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
finally:
device_manager.unregister(user_id)
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
await _mark_runs_disconnected(user_id)
# ── Message dispatch loop ─────────────────────────────────────────────
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
"""Receive frames from Electron and dispatch to the appropriate handler."""
async for raw in websocket.iter_text():
try:
frame: dict = json.loads(raw)
except json.JSONDecodeError:
logger.warning("device_ws: invalid JSON from user=%s", user_id)
continue
frame_type = frame.get("type")
if frame_type == WsFrameType.tool_result:
call_id = frame.get("id")
if call_id:
device_manager.resolve_pending_call(user_id, call_id, frame)
else:
logger.warning(
"device_ws: tool_result missing id from user=%s", user_id
)
elif frame_type == WsFrameType.home_request:
asyncio.create_task(
_handle_home_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.brief_request:
asyncio.create_task(
_handle_brief_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.task_brief_request:
asyncio.create_task(
_handle_task_brief_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_start:
asyncio.create_task(
_handle_journey_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_message:
asyncio.create_task(
_handle_journey_message(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_start:
asyncio.create_task(
_handle_index_session_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_file_batch:
asyncio.create_task(
_handle_index_file_batch(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_cancel:
await _handle_index_session_cancel(websocket, frame)
elif frame_type == WsFrameType.contextual_request:
asyncio.create_task(
_handle_contextual_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.contextual_scope_update:
asyncio.create_task(
_handle_contextual_scope_update(websocket, user_id, frame)
)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
else:
logger.debug(
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
)
# ── v3 Chat Handlers ──────────────────────────────────────────────────
async def _make_ws_executor(websocket: WebSocket, user_id: str):
"""Return a callback that sends tool_call frames and awaits tool_result."""
async def _executor(payload: dict) -> dict:
payload["type"] = WsFrameType.tool_call
await websocket.send_text(json.dumps(payload))
future = device_manager.create_pending_call(user_id, payload["id"])
return await future
return _executor
async def _handle_home_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info(
"device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
user_id,
request_id,
session_id,
project_id,
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_home_stream(user_id, message, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: home_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
# ── v8 Contextual Sidebar Handlers ───────────────────────────────────
def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"):
"""Return a session-scoped buffer proxy for the given user+session.
Returns a _ContextualBufferProxy that exposes append_system_message().
Defined at module level so tests can monkeypatch it.
The channel kwarg is accepted for forward-compatibility.
"""
from app.core.agent_session_buffer import ContextualBufferProxy # noqa: PLC0415
return ContextualBufferProxy(session_buffer, user_id, session_id)
async def _handle_contextual_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a contextual_request frame — runs the contextual agent and streams frames."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope_payload: dict = frame.get("scope", {})
logger.info(
"device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
user_id,
request_id,
session_id,
message[:200],
)
scope = ContextualScope.model_validate(scope_payload)
# Enrich context with memory before the LLM call.
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"format_prefs": frame.get("format_prefs"),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_contextual_stream(
user_id=user_id,
message=message,
context=context,
scope=scope,
)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: contextual_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# Store episode so the contextual agent can recall prior turns.
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
async def _handle_contextual_scope_update(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a contextual_scope_update frame.
Injects a synthetic system message into the session buffer so the next
agent turn knows the user navigated. No LLM call is made.
"""
session_id: str = frame.get("session_id") or str(uuid4())
scope = ContextualScope.model_validate(frame.get("scope", {}))
block = render_scope_block(scope)
buf = get_session_buffer(user_id, session_id, channel="contextual")
buf.append_system_message(
f"User navigated to a new view. {block} Treat this as the new active context."
)
await websocket.send_text(json.dumps({
"type": WsFrameType.contextual_scope_ack,
"session_id": session_id,
}))
logger.info(
"device_ws: contextual_scope_update user=%s session=%s page=%s",
user_id, session_id, scope.page,
)
async def _handle_brief_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a brief_request frame — streams plain-text brief back on the socket.
No episode storage — briefs are not conversations.
"""
import uuid as _uuid
request_id = frame.get("request_id") or str(uuid4())
session_id = frame.get("session_id") or str(uuid4())
mode: str = frame.get("mode", "home")
project_id: str | None = frame.get("project_id")
logger.info(
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
user_id, request_id, mode, project_id,
)
# Validate project_id for project mode before touching LLM.
if mode == "project":
try:
if not project_id:
raise ValueError("project_id required for project mode")
_uuid.UUID(project_id)
except (ValueError, AttributeError) as exc:
logger.warning(
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
user_id, request_id, exc,
)
await websocket.send_text(
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
)
return
# Enrich context with memory (no user message — use empty string as probe).
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
"",
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
if mode == "project":
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
else:
event_stream = run_home_brief(user_id, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
except Exception as exc:
logger.error(
"device_ws: brief_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
await websocket.send_text(
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
)
finally:
clear_client_executor()
logger.info(
"device_ws: brief_request_end user=%s req=%s mode=%s",
user_id, request_id, mode,
)
# ── v6 Task Brief Handler ────────────────────────────────────────────
async def _handle_task_brief_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
Streams the briefing markdown back to the client.
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
"""
request_id = frame.get("request_id") or str(uuid4())
session_id = frame.get("session_id") or str(uuid4())
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info(
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
user_id, request_id, task_id, project_id,
)
if not task_id:
await websocket.send_text(
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
)
return
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
f"task brief: {task_id}",
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
await websocket.send_text(ws_frame.model_dump_json())
elif ws_frame.type == "stream_start":
await websocket.send_text(ws_frame.model_dump_json())
# stream_end is emitted below with mutations — skip formatter's version
except Exception as exc:
logger.error(
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
user_id, request_id, task_id, exc,
)
await websocket.send_text(
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
)
return
finally:
clear_client_executor()
# Extract canvas block then emit stream_end with optional mutations.
full_response = "".join(response_chunks)
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
mutations: list[dict] = []
if canvas_content:
mutations.append({
"type": "canvas_draft",
"content": canvas_content,
"kind": canvas_kind,
})
await websocket.send_text(
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
)
logger.info(
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
)
# ── v4 Journey Handlers ─────────────────────────────────────────────
async def _handle_journey_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_start frame — explores directory and sends first question."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_start(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
logger.error(
"device_ws: journey_start failed user=%s: %s", user_id, exc
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": frame.get("session_id", ""),
"message": f"Failed to start journey: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
async def _handle_journey_message(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_message frame — continues the journey conversation."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_message(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
session_id = frame.get("session_id", "")
logger.error(
"device_ws: journey_message failed user=%s session=%s: %s",
user_id, session_id, exc,
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey error: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
# ── v7 Folder Index Handlers ──────────────────────────────────────────
async def _handle_index_session_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Register a new folder index session. No response sent — client is declaring intent."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
project_id: str | None = frame.get("projectId") or frame.get("project_id")
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
if not session_id:
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
return
_index_sessions[session_id] = {
"user_id": user_id,
"project_id": project_id,
"processed": 0,
"total": total,
"cancelled": False,
}
logger.info(
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
user_id, session_id, project_id, total,
)
async def _handle_index_session_cancel(
websocket: WebSocket,
frame: dict,
) -> None:
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
session = _index_sessions.get(session_id)
if session:
session["cancelled"] = True
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "cancelled",
}))
_index_sessions.pop(session_id, None)
logger.info("device_ws: index_session_cancel session=%s", session_id)
async def _handle_index_file_batch(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Process a batch of files for an index session, streaming results back."""
# Lazy imports to avoid heavy load at module startup.
from app.core.folder_indexer import ( # noqa: PLC0415
summarize_image,
summarize_pdf,
summarize_docx,
summarize_text,
)
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.billing.quota import add_token_usage # noqa: PLC0415
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
files: list[dict] = frame.get("files", [])
session = _index_sessions.get(session_id)
if not session or session.get("cancelled"):
return
async with async_session() as db:
tier = await tier_manager.get_tier(user_id, db)
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
cap: int | None = None if raw_cap == -1 else raw_cap
for file_info in files:
if session.get("cancelled"):
return
# Electron's toSnakeCase converts payload keys, so accept both forms.
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
kind: str = file_info.get("kind") or "text"
content: str = file_info.get("content") or ""
ext: str = file_info.get("ext") or ""
mime: str = file_info.get("mime") or "application/octet-stream"
name: str = rel_path.split("/")[-1] or rel_path
try:
if kind == "image":
res = await summarize_image(image_b64=content, mime=mime)
elif kind == "pdf":
res = await summarize_pdf(pdf_b64=content, name=name)
elif kind == "docx":
res = await summarize_docx(docx_b64=content, name=name)
else:
res = await summarize_text(content=content, ext=ext, name=name)
except Exception as exc:
logger.warning(
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
session_id, rel_path, exc,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": None,
"tokensUsed": 0,
"error": str(exc),
}))
session["processed"] += 1
continue
# Account for token usage and check cap.
usage = await add_token_usage(
user_id=user_id,
feature="folder_index",
tokens=res.tokens_used,
db=db,
cap=cap,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": res.summary,
"tokensUsed": res.tokens_used,
}))
session["processed"] += 1
if usage.exhausted:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "quota_exceeded",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session quota_exceeded user=%s session=%s",
user_id, session_id,
)
return
# After processing the batch, emit progress.
processed = session["processed"]
total = session["total"]
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_progress,
"sessionId": session_id,
"processed": processed,
"total": total,
}))
if processed >= total:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "completed",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session_done completed user=%s session=%s processed=%d",
user_id, session_id, processed,
)
# ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:
"""Send a ping frame every 30 s to keep the connection alive."""
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"type": "ping"}))
# ── Disconnect cleanup ────────────────────────────────────────────────
async def _mark_runs_disconnected(user_id: str) -> None:
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
try:
async with async_session() as db:
await db.execute(
update(AgentRunLog)
.where(
AgentRunLog.user_id == user_id,
AgentRunLog.status == "running",
)
.values(
status="error",
errors=["device disconnected"],
)
)
await db.commit()
except Exception as exc:
logger.error(
"device_ws: failed to mark runs as disconnected for user=%s: %s",
user_id,
exc,
)

View File

@@ -1,225 +0,0 @@
"""Memory management routes — view/edit/delete user memory tiers.
All routes require authentication. Data is always user-scoped.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.models import (
ExtractionQueue,
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
MemoryRelation,
)
from app.schemas import UserProfile
router = APIRouter(prefix="/memory", tags=["memory"])
logger = logging.getLogger(__name__)
_ALLOWED_PREDICATES = {
"works_at",
"reports_to",
"stakeholder_of",
"last_contacted_on",
"owes_followup",
"manages",
"collaborates_with",
"owns",
"member_of",
"custom",
}
# ── Response schemas ─────────────────────────────────────────────────────────
class RelationOut(BaseModel):
id: str
subject_label: str
subject_type: str
predicate: str
object_label: str
object_type: str
confidence: float
last_confirmed_at: int | None = None # epoch ms
class RelationPatch(BaseModel):
subject_label: str | None = None
object_label: str | None = None
predicate: str | None = None
confidence: float | None = Field(None, ge=0.0, le=1.0)
class CoreAddBody(BaseModel):
key: str = Field(..., min_length=1, max_length=255)
value: str = Field(..., min_length=1)
# ── Helpers ──────────────────────────────────────────────────────────────────
def _relation_to_out(row: MemoryRelation) -> RelationOut:
last_ms: int | None = None
if row.last_confirmed_at is not None:
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
return RelationOut(
id=row.id,
subject_label=row.subject_label,
subject_type=row.subject_type,
predicate=row.predicate,
object_label=row.object_label,
object_type=row.object_type,
confidence=row.confidence,
last_confirmed_at=last_ms,
)
# ── Routes ───────────────────────────────────────────────────────────────────
@router.get("/core", response_model=dict[str, str])
async def get_core_memory(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, str]:
"""Return all core memory k/v pairs (plaintext) for the current user."""
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(current_user.id)
return {b["label"]: b["value"] for b in blocks}
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_core_key(
key: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> None:
"""Delete a single core memory key (GDPR Art. 17)."""
mw = MemoryMiddleware(db)
deleted = await mw.delete_core(current_user.id, key)
if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
async def add_core_key(
body: CoreAddBody,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, str]:
"""Add or overwrite a core memory key/value pair."""
mw = MemoryMiddleware(db)
await mw.update_core(current_user.id, body.key, body.value)
return {body.key: body.value}
@router.get("/relational", response_model=list[RelationOut])
async def get_relational_memory(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[RelationOut]:
"""Return all relational memory rows for the current user."""
mw = MemoryMiddleware(db)
rows = await mw.query_relations(current_user.id, limit=200)
return [_relation_to_out(r) for r in rows]
@router.patch("/relational/{relation_id}", response_model=RelationOut)
async def patch_relation(
relation_id: str,
body: RelationPatch,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> RelationOut:
"""Edit a relation row's labels, predicate, or confidence."""
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
)
result = await db.execute(
select(MemoryRelation).where(
MemoryRelation.id == relation_id,
MemoryRelation.user_id == current_user.id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
if body.subject_label is not None:
row.subject_label = body.subject_label
if body.object_label is not None:
row.object_label = body.object_label
if body.predicate is not None:
row.predicate = body.predicate
if body.confidence is not None:
row.confidence = body.confidence
row.last_confirmed_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(row)
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
return _relation_to_out(row)
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_relation(
relation_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> None:
"""Hard-delete a relation row (GDPR Art. 17)."""
result = await db.execute(
select(MemoryRelation).where(
MemoryRelation.id == relation_id,
MemoryRelation.user_id == current_user.id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
await db.delete(row)
await db.commit()
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
async def forget_all(
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> None:
"""Wipe all memory tiers for the current user (GDPR Art. 17).
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
"""
if x_confirm != "true":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
)
uid = current_user.id
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
await db.commit()
logger.warning("memory: forget_all GDPR wipe user=%s", uid)

View File

@@ -1 +0,0 @@
"OAuth provider abstractions and utilities."

View File

@@ -1,135 +0,0 @@
"""OAuth 2.0 + PKCE provider abstractions.
Each provider implements a three-step flow designed for a desktop (public) client:
1. get_authorization_url(state, code_challenge) → str
Build the provider's consent-screen URL. State and code_challenge are
generated server-side; the client opens this URL in the system browser.
2. exchange_code(code, code_verifier, redirect_uri) → dict
Exchange the short-lived authorization code for an access token.
The code_verifier proves ownership of the PKCE challenge.
3. get_userinfo(access_token) → OAuthUserInfo
Fetch the canonical user identity from the provider.
Currently supported providers:
- GoogleOAuthProvider (scope: openid email profile)
Adding a new provider:
- Implement the three methods above.
- Register in _PROVIDERS inside routes/auth.py.
"""
from __future__ import annotations
import base64
import hashlib
import os
import urllib.parse
from dataclasses import dataclass
import httpx
# ── Data transfer objects ─────────────────────────────────────────────
@dataclass
class OAuthUserInfo:
"""Normalized user identity returned by any provider."""
provider_user_id: str
email: str
email_verified: bool
avatar_url: str | None
name: str | None
# ── PKCE helpers ──────────────────────────────────────────────────────
def generate_pkce_pair() -> tuple[str, str]:
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
The code_verifier is a random 32-byte URL-safe base64 string.
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
"""
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return code_verifier, code_challenge
# ── Google provider ───────────────────────────────────────────────────
class GoogleOAuthProvider:
"""Google OAuth 2.0 provider (openid email profile scope).
Uses Google's standard authorization endpoint with PKCE S256.
Does NOT use google-auth-oauthlib to keep the flow generic and async.
"""
name = "google"
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self, state: str, code_challenge: str) -> str:
"""Build the Google consent-screen URL."""
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"response_type": "code",
"scope": "openid email profile",
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "select_account",
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
async def exchange_code(
self, code: str, code_verifier: str, redirect_uri: str
) -> dict:
"""Exchange authorization code for an access token."""
async with httpx.AsyncClient() as client:
response = await client.post(
self._TOKEN_URL,
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"code_verifier": code_verifier,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
},
)
response.raise_for_status()
return response.json()
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
"""Fetch the authenticated user's identity from Google."""
async with httpx.AsyncClient() as client:
response = await client.get(
self._USERINFO_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
response.raise_for_status()
data = response.json()
return OAuthUserInfo(
provider_user_id=data["sub"],
email=data["email"],
email_verified=data.get("email_verified", False),
avatar_url=data.get("picture"),
name=data.get("name"),
)

View File

@@ -1,4 +0,0 @@
from app.billing.stripe_service import stripe_service
from app.billing.tier_manager import tier_manager
__all__ = ["stripe_service", "tier_manager"]

View File

@@ -1,139 +0,0 @@
"""Quota checks and atomic token-usage accounting for folder integration."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.billing.tier_manager import TierManager
from app.models import MonthlyTokenUsage
from app.schemas import BillingTier
class QuotaExceeded(Exception):
"""Raised when a folder operation cannot proceed under the user's tier."""
def __init__(self, reason: str, message: str) -> None:
super().__init__(message)
self.reason = reason # "max_files" | "monthly_tokens"
@dataclass
class TokenUsageResult:
tokens_used: int
exhausted: bool
def _current_year_month() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m")
_tier_manager = TierManager()
async def check_folder_quota(
*,
user_id: str,
tier: BillingTier,
estimated_files: int,
db: AsyncSession,
) -> None:
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
would be violated. -1 in either feature means unlimited."""
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
if max_files != -1 and estimated_files > max_files:
raise QuotaExceeded(
"max_files",
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
)
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
if cap == -1:
return
ym = _current_year_month()
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)
).scalar_one_or_none()
used = row.tokens_used if row else 0
if used >= cap:
raise QuotaExceeded(
"monthly_tokens",
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
)
async def add_token_usage(
*,
user_id: str,
feature: str,
tokens: int,
db: AsyncSession,
cap: int | None = None,
) -> TokenUsageResult:
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
back to a read-then-write on other engines (e.g. aiosqlite in tests).
Returns post-update total and whether cap is exhausted.
"""
ym = _current_year_month()
# Detect dialect to choose between native upsert and portable fallback.
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
if dialect_name == "postgresql":
# Native atomic upsert — production path.
stmt = (
pg_insert(MonthlyTokenUsage)
.values(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
.on_conflict_do_update(
index_elements=["user_id", "year_month", "feature"],
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
)
.returning(MonthlyTokenUsage.tokens_used)
)
used: int = (await db.execute(stmt)).scalar_one()
await db.commit()
else:
# Portable fallback — used in tests (SQLite) and any non-PG engine.
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == feature,
)
)
).scalar_one_or_none()
if row is None:
row = MonthlyTokenUsage(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
db.add(row)
else:
row.tokens_used += tokens
await db.commit()
await db.refresh(row)
used = row.tokens_used
exhausted = cap is not None and cap != -1 and used >= cap
return TokenUsageResult(tokens_used=used, exhausted=exhausted)

View File

@@ -1,149 +0,0 @@
"""Tier manager: feature matrix and quota enforcement.
``TierManager`` is the single source of truth for what each billing tier
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
Quota-enforcement helpers take ``tier`` directly — the caller already has it
from ``current_user.tier`` (provided by ``get_current_user``).
"""
from __future__ import annotations
from typing import Any
from fastapi import HTTPException, status
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas import BillingTier
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
FEATURES: dict[str, dict[str, Any]] = {
"free": {
"agents": 3,
"batch_active": 2,
"batch_runs_per_day": 5,
"providers": 1,
"batch_builder": False,
"sso": False,
"real_embeddings": False, # keyword fallback only
"realtime_extraction": False, # batch queue (Phase 2)
"relational_memory": False, # relational tier (Phase 3) — Pro+
"proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 200,
"folder_monthly_tokens": 100_000,
},
"pro": {
"agents": -1, # unlimited
"batch_active": 10,
"batch_runs_per_day": 50,
"providers": -1,
"batch_builder": False,
"sso": False,
"real_embeddings": True, # pgvector cosine search
"realtime_extraction": True, # fire-and-forget asyncio.create_task
"relational_memory": True, # person/project predicates
"proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 5000,
"folder_monthly_tokens": 2_000_000,
},
"power": {
"agents": -1,
"batch_active": -1, # unlimited
"batch_runs_per_day": -1, # unlimited
"providers": -1,
"batch_builder": True,
"sso": False,
"real_embeddings": True,
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
},
"team": {
"agents": -1,
"batch_active": -1,
"batch_runs_per_day": -1, # unlimited
"providers": -1,
"batch_builder": True,
"sso": True,
"real_embeddings": True,
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
},
}
# Requests-per-minute limit per tier.
RATE_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
class TierManager:
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
# ── Tier lookup ─────────────────────────────────────────────────────
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
"""Return the current billing tier for ``user_id`` from the DB.
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
when no subscription row exists.
"""
from app.models import Subscription # noqa: PLC0415
from app.config.settings import settings # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
tier: str | None = result.scalar_one_or_none()
if tier is None or tier not in FEATURES:
return "power" if settings.ENV == "dev" else "free"
return tier # type: ignore[return-value]
# ── Feature access ───────────────────────────────────────────────────
def check_feature(self, tier: BillingTier, feature: str) -> bool:
"""Return ``True`` if ``tier`` has ``feature`` enabled.
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
"""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if value is None:
return False
if isinstance(value, bool):
return value
return value != 0
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
if not self.check_feature(tier, feature):
detail = (
f"Feature '{feature}' requires {tier_name} tier or above."
if tier_name
else f"Feature '{feature}' is not available on your current tier."
)
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
"""Return integer feature value for tier. -1 means unlimited."""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if not isinstance(value, int):
return 0
return value
# ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int:
"""Return the requests-per-minute limit for ``tier``."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# Module-level singleton shared across the app.
tier_manager = TierManager()

View File

@@ -1,85 +0,0 @@
from typing import Literal
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
JWT_SECRET: str = "change-me-in-production"
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = ""
OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = ""
CEREBRAS_API_KEY: str = ""
GROQ_API_KEY: str = ""
DEEPSEEK_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o"
LLM_EMBED_MODEL: str = "text-embedding-3-small"
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use)
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
# GitHub Copilot OAuth token storage directory.
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
GITHUB_COPILOT_TOKEN_DIR: str = ""
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
GMAIL_CLIENT_ID: str = ""
GMAIL_CLIENT_SECRET: str = ""
MS_CLIENT_ID: str = ""
MS_CLIENT_SECRET: str = ""
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
MS_TENANT_ID: str = "common"
# Google Login OAuth credentials — scope: openid email profile.
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
GOOGLE_AUTH_CLIENT_ID: str = ""
GOOGLE_AUTH_CLIENT_SECRET: str = ""
# The redirect URI registered in Google Cloud Console.
# Google redirects here after consent; this backend route then bounces to
# the adiuvai:// deep link so the Electron app receives the code.
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
OAUTH_ENCRYPTION_KEY: str = ""
CORS_ORIGINS: list[str] = [
"app://.",
"http://localhost:3000",
"http://localhost:5173",
"http://localhost:4173", # Vite preview (web SPA)
"https://app.adiuvai.com", # Production web portal
]
LANGFUSE_SECRET_KEY: str = ""
LANGFUSE_PUBLIC_KEY: str = ""
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
SCHEDULER_ENABLED: bool = True
ENV: Literal["dev", "prod"] = "dev"
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
settings = Settings()

View File

@@ -1,30 +0,0 @@
"""Minimal agent base types retained for compatibility with batch runners."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
"""Common base for non-chat agents still using the old base contract."""
def __init__(
self,
user_id: str = "",
shared_memory: dict[str, Any] | None = None,
vector_store_context: list[str] | None = None,
) -> None:
self.user_id = user_id
self.shared_memory: dict[str, Any] = shared_memory or {}
self.vector_store_context: list[str] = vector_store_context or []
@abstractmethod
def get_name(self) -> str: ...
@abstractmethod
def get_description(self) -> str: ...
@property
def skills(self) -> list[str]:
return []

File diff suppressed because it is too large Load Diff

View File

@@ -1,96 +0,0 @@
"""In-process TTL buffer for per-session LangChain message history.
Stores the full message list (including AIMessage with tool_calls and ToolMessage)
keyed by (user_id, session_id), so agents can reconstruct tool-call context across
conversation turns without it being lossy through the wire.
Single-process only. For multi-worker deployments, replace the _SessionBuffer
implementation with one backed by Redis (serialize LangChain messages to dicts via
message_to_dict / messages_from_dict from langchain_core.messages).
"""
from __future__ import annotations
import time
from threading import Lock
from langchain_core.messages import BaseMessage
SESSION_TTL_SECONDS = 1800 # 30-minute idle expiry
MAX_MESSAGES_PER_SESSION = 80 # cap to avoid unbounded memory growth
class _SessionBuffer:
def __init__(self) -> None:
self._store: dict[tuple[str, str], tuple[float, list[BaseMessage]]] = {}
self._lock = Lock()
def _evict_stale(self) -> None:
now = time.monotonic()
stale = [k for k, (ts, _) in self._store.items() if now - ts > SESSION_TTL_SECONDS]
for k in stale:
del self._store[k]
def get(self, user_id: str, session_id: str) -> list[BaseMessage] | None:
key = (user_id, session_id)
with self._lock:
entry = self._store.get(key)
if entry is None:
return None
ts, msgs = entry
if time.monotonic() - ts > SESSION_TTL_SECONDS:
del self._store[key]
return None
self._store[key] = (time.monotonic(), msgs)
return list(msgs)
def set(self, user_id: str, session_id: str, messages: list[BaseMessage]) -> None:
key = (user_id, session_id)
capped = messages[-MAX_MESSAGES_PER_SESSION:]
with self._lock:
self._evict_stale()
self._store[key] = (time.monotonic(), capped)
def clear(self, user_id: str, session_id: str) -> None:
with self._lock:
self._store.pop((user_id, session_id), None)
def append_system_message(self, user_id: str, session_id: str, text: str) -> None:
"""Append a synthetic system message to the buffer for the given session.
Creates the session slot if it does not yet exist. Used by the
contextual_scope_update handler to inject navigation events without
making an LLM call.
"""
from langchain_core.messages import SystemMessage # noqa: PLC0415
key = (user_id, session_id)
with self._lock:
entry = self._store.get(key)
if entry is None:
msgs: list[BaseMessage] = [SystemMessage(content=text)]
else:
_, existing = entry
msgs = list(existing) + [SystemMessage(content=text)]
capped = msgs[-MAX_MESSAGES_PER_SESSION:]
self._store[key] = (time.monotonic(), capped)
class ContextualBufferProxy:
"""Thin wrapper around _SessionBuffer that closes over user_id + session_id.
Returned by get_session_buffer() so callers can call
``proxy.append_system_message(text)`` without threading user_id/session_id
through every call site.
"""
def __init__(self, buf: "_SessionBuffer", user_id: str, session_id: str) -> None:
self._buf = buf
self._user_id = user_id
self._session_id = session_id
def append_system_message(self, text: str) -> None:
self._buf.append_system_message(self._user_id, self._session_id, text)
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
session_buffer = _SessionBuffer()

View File

@@ -1,228 +0,0 @@
"""Brief agent — produces plain-text home and project status briefs.
Read-only tool subset only. Never calls _normalize_tagged_list_lines —
the brief prompt forbids XML tags, so skipping post-processing is intentional.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from datetime import date
from typing import Any
from app.agents.note_agent import NOTE_READ_TOOLS
from app.agents.project_agent import PROJECT_READ_TOOLS
from app.agents.task_agent import TASK_READ_TOOLS
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
from app.core.deep_agent import (
_language_instruction,
_proactive_hints_injection,
_read_only_memory_tools,
_relational_memory_injection,
_run_single_agent_stream,
_trace_id_from_context,
build_brief_multi_project_manifest,
)
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
_LANGUAGE_NAMES: dict[str, str] = {
"en": "English", "it": "Italian", "es": "Spanish",
"fr": "French", "de": "German",
"english": "English", "italian": "Italian", "italiano": "Italian",
"spanish": "Spanish", "español": "Spanish",
"french": "French", "français": "French",
"german": "German", "deutsch": "German",
}
_HOME_BRIEF_FALLBACK = """\
You are the user's personal assistant producing a short daily brief.
ROLE
Act like a calm, attentive secretary writing a stand-up note for your boss.
Warm and human, never breezy. Never cheerful filler, never emojis, never
"here is your brief" meta-text. The user is opening the app mid-workday and
is probably stressed — your job is to lower cognitive load, not add noise.
TOOLS — always call before writing
Pull fresh data every run. Do not invent counts or titles. Use at minimum:
- list_tasks_due_today — tasks the user owes today
- list_timelines_today — events starting or ending today
- list_all_projects — projects currently in progress or at risk
- memory_list_blocks / memory_get — personal context about people, clients,
payment habits, working preferences
If a tool returns nothing, simply omit that topic. Never report zeros.
WHAT TO INCLUDE
1. Tasks due today (title + priority; group the 1-2 most important).
2. Timeline events starting or ending today (and anything that starts/ends
tomorrow if the user has a very light day).
3. Active projects that need a nudge — stalled, blocked, or awaiting input.
4. Memory-aware colour where it sharpens the brief. Examples:
- "Client Rossi tends to pay late — the Acme invoice is 6 days out."
- "You usually dislike meetings before 10:00 — the call at 09:30 is unusual."
Only add a memory line when it changes what the user does. Do not pad.
WHAT TO OMIT
- Zero-counts ("no overdue items", "0 meetings today").
- Statistics ("2 active projects, 3 completed tasks").
- Headers, titles, greetings, sign-offs, dates, emojis, slang.
- Meta-phrases ("here is", "let me know if", "hope this helps").
- XML/HTML tags of any kind. Plain prose only.
LIGHT-DAY CLAUSE
If tasks + events + active-project-nudges together produce fewer than two
sentences of content, also list 1-2 projects in status on_hold or waiting
and ask a single, specific question about them — e.g. "Is the Bianchi
redesign still paused, or ready to pick back up?" One question max, grounded
in a real project name.
VOICE
- Calm. Concise. Human. Short sentences.
- Use **bold** sparingly for task titles, project names, and people's names.
- No bullet lists. Flow as 2-4 sentences of prose.
LENGTH
2-4 sentences total. Hard cap 4. If the day is truly empty, one sentence.
Respond in the user's language ({language}). Today is {today}.\
"""
_PROJECT_BRIEF_FALLBACK = """\
You are the project assistant producing a short status brief for ONE project.
ROLE
A senior project manager summarising state-of-play for the owner. Factual,
sharp, forward-looking. Never reassuring filler, never emojis.
SCOPE
Work only with project_id = {project_id}. Do not mention or pull data from
other projects. Use tools to fetch fresh data:
- get_project — current status, dates, description
- list_tasks(project_id) — open work, split by status
- list_timelines(project_id) — milestones hit, upcoming, overdue
- list_notes(project_id) — any recent decisions or blockers
- memory_get — relevant context about the client, collaborators, constraints
STRUCTURE — follow exactly, one short paragraph per section, no headers
1. **State.** One sentence: current phase, health (on track / at risk / blocked),
and why. Cite the concrete signal (overdue milestone, stalled tasks, recent
blocker note).
2. **What's moving.** What was completed or progressed recently. Name specific
tasks or milestones.
3. **Next steps.** The 1-3 most important things the user should do next, in
priority order. Be concrete — task name, who owns it, when due if known.
If waiting on someone else, name them and what the ask is.
4. **Risks / memory-flagged items.** One line max. Only include when there is
a real risk or a relevant memory (e.g. late-paying client, tight deadline,
scope change). Omit the section entirely if nothing to say.
WHAT TO OMIT
- Zero-counts ("no overdue tasks").
- Generic advice ("keep up the good work").
- Greetings, headers, bullet lists, emojis, sign-offs, meta-phrases.
- XML/HTML tags or bracketed id lists. Plain prose only.
VOICE
- Direct. Factual. No fluff.
- Use **bold** sparingly for task titles, milestone names, and the owner's name.
- Short sentences. Prefer verbs over nouns ("Client review is blocking release"
not "There is a blocker which is the client review").
LENGTH
4-8 sentences total across the 3-4 sections. Hard cap 8.
Respond in the user's language ({language}). Today is {today}.\
"""
def _resolve_language(context: dict[str, Any]) -> str:
core = context.get("core_memory") or {}
raw = (core.get("language") or "en").strip().lower()
return _LANGUAGE_NAMES.get(raw, raw.title()) or "English"
def _build_read_tools(user_id: str, trace_id: str | None) -> list[Any]:
return [
*TASK_READ_TOOLS,
*PROJECT_READ_TOOLS,
*TIMELINE_READ_TOOLS,
*NOTE_READ_TOOLS,
*_read_only_memory_tools(user_id, trace_id),
]
async def run_home_brief(
user_id: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
"""Stream a plain-text daily home brief.
Yields (event_type, data) tuples identical to _run_single_agent_stream.
Do NOT post-process output through _normalize_tagged_list_lines.
"""
from app.agents.folder_agent import FOLDER_TOOLS
trace_id = _trace_id_from_context(context)
today = date.today().isoformat()
language = _resolve_language(context)
raw_template, langfuse_prompt = get_prompt_or_fallback("home_brief", _HOME_BRIEF_FALLBACK)
system_prompt = compile_prompt(raw_template, langfuse_prompt, language=language, today=today)
system_prompt += _relational_memory_injection(context)
system_prompt += _proactive_hints_injection(context)
system_prompt += _language_instruction(context)
if today not in system_prompt:
system_prompt += f"\nToday is {today}."
brief_manifest = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message="Generate the daily brief.",
context=context,
langfuse_prompt=langfuse_prompt,
agent_name="brief-agent",
tools=tools,
):
yield event
async def run_project_brief(
user_id: str,
project_id: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
"""Stream a plain-text project status brief for project_id.
Yields (event_type, data) tuples identical to _run_single_agent_stream.
Do NOT post-process output through _normalize_tagged_list_lines.
"""
trace_id = _trace_id_from_context(context)
today = date.today().isoformat()
language = _resolve_language(context)
raw_template, langfuse_prompt = get_prompt_or_fallback("project_brief", _PROJECT_BRIEF_FALLBACK)
system_prompt = compile_prompt(
raw_template, langfuse_prompt,
language=language, today=today, project_id=project_id,
)
system_prompt += _relational_memory_injection(context)
system_prompt += _proactive_hints_injection(context)
system_prompt += _language_instruction(context)
if today not in system_prompt:
system_prompt += f"\nToday is {today}."
tools = _build_read_tools(user_id, trace_id)
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=f"Generate the project status brief for project {project_id}.",
context=context,
langfuse_prompt=langfuse_prompt,
agent_name="brief-agent",
tools=tools,
):
yield event

File diff suppressed because it is too large Load Diff

View File

@@ -1,151 +0,0 @@
"""Device connection manager.
Maintains in-memory state for all active Electron → backend WebSocket
connections. One connection per user (latest replaces previous).
The manager handles the **tool-call round-trip** pattern:
- Backend sends ``tool_call`` frame → Electron executes the action →
returns ``tool_result`` frame.
- ``create_pending_call`` registers a Future keyed by ``call_id``.
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
receive the result dict from Electron.
This pattern is used by all tools (CRUD, file-system, etc.) via
``execute_on_client()`` in ``ws_context.py``.
The ``device_manager`` module-level singleton is imported by both the
device WS route and the agent runner.
"""
from __future__ import annotations
import asyncio
import json
import logging
from dataclasses import dataclass, field
from fastapi import WebSocket
logger = logging.getLogger(__name__)
@dataclass
class DeviceConnection:
"""State for a single connected Electron device."""
ws: WebSocket
device_id: str
# Futures indexed by tool_call id — resolved when tool_result arrives.
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
class DeviceConnectionManager:
"""Singleton registry of active Electron WebSocket connections.
Thread/task safety note: asyncio is single-threaded by design. All
mutations happen inside await-points on the main event loop, so no
locking is required for the in-memory dicts.
"""
def __init__(self) -> None:
self._connections: dict[str, DeviceConnection] = {}
# ── Registration ──────────────────────────────────────────────────
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
"""Store the active connection for *user_id*, replacing any previous one."""
if user_id in self._connections:
old = self._connections[user_id]
logger.info(
"device_manager: replacing existing connection for user=%s device=%s",
user_id,
old.device_id,
)
# Cancel any futures that were waiting on the old connection.
for fut in old.pending_calls.values():
if not fut.done():
fut.cancel()
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
logger.info(
"device_manager: registered user=%s device=%s", user_id, device_id
)
def unregister(self, user_id: str) -> None:
"""Remove the connection for *user_id* and cancel any pending futures."""
conn = self._connections.pop(user_id, None)
if conn is None:
return
for fut in conn.pending_calls.values():
if not fut.done():
fut.cancel()
logger.info("device_manager: unregistered user=%s", user_id)
# ── Presence queries ──────────────────────────────────────────────
def get_ws(self, user_id: str) -> WebSocket | None:
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
conn = self._connections.get(user_id)
return conn.ws if conn else None
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
"""Return ``True`` if the user has an active connection.
If *device_id* is provided also checks that it matches the connected device.
"""
conn = self._connections.get(user_id)
if conn is None:
return False
if device_id is not None:
return conn.device_id == device_id
return True
# ── Frame sending ─────────────────────────────────────────────────
async def send_frame(self, user_id: str, frame: dict) -> None:
"""Send *frame* as a JSON text message to the device.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"send_frame: user {user_id!r} is not connected"
)
await conn.ws.send_text(json.dumps(frame))
# ── Tool-call round-trip ──────────────────────────────────────────
def create_pending_call(
self, user_id: str, call_id: str
) -> asyncio.Future[dict]:
"""Register a Future that will be resolved when the tool_result arrives.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"create_pending_call: user {user_id!r} is not connected"
)
loop = asyncio.get_event_loop()
fut: asyncio.Future[dict] = loop.create_future()
conn.pending_calls[call_id] = fut
return fut
def resolve_pending_call(
self, user_id: str, call_id: str, result: dict
) -> None:
"""Fulfil the Future registered under *call_id* with the Electron result.
No-ops if the call_id is unknown (already timed out or cancelled).
"""
conn = self._connections.get(user_id)
if conn is None:
return
fut = conn.pending_calls.pop(call_id, None)
if fut is not None and not fut.done():
fut.set_result(result)
# Module-level singleton — import this everywhere.
device_manager = DeviceConnectionManager()

View File

@@ -1,34 +0,0 @@
"""OpenAI embedding helper for associative memory tier.
Single public function: ``embed_text(text) -> list[float] | None``.
Returns None on any failure — callers must implement a keyword fallback.
Never raises; all exceptions are logged as warnings.
"""
from __future__ import annotations
import logging
from openai import AsyncOpenAI
logger = logging.getLogger(__name__)
_MAX_INPUT_CHARS = 8000
_EMBEDDING_MODEL = "text-embedding-3-small"
async def embed_text(text: str) -> list[float] | None:
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
try:
client = AsyncOpenAI()
truncated = text[:_MAX_INPUT_CHARS]
response = await client.embeddings.create(
input=truncated,
model=_EMBEDDING_MODEL,
)
result: list[float] = response.data[0].embedding
logger.debug("embeddings: embed_text dims=%d", len(result))
return result
except Exception as exc:
logger.warning("embeddings: embed_text failed: %s", exc)
return None

View File

@@ -1,183 +0,0 @@
"""Per-file summarisation for project folder integration."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from langchain_core.messages import HumanMessage, SystemMessage
from pypdf import PdfReader
from docx import Document as DocxDocument
from app.core.langfuse_client import (
compile_prompt,
extract_usage,
get_langfuse,
get_prompt_or_fallback,
)
from app.core.llm import get_llm
_TEXT_FALLBACK = (
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
)
_IMAGE_FALLBACK = (
"You are summarising an image attached to a project folder.\n"
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
)
_MAX_INPUT_CHARS = 6000
@dataclass
class IndexResult:
summary: str
tokens_used: int
async def _llm_text(messages: list) -> object:
"""Make the LLM call for text summarisation.
Defined as a standalone async function so tests can patch it cleanly
without needing to mock the LLM object itself.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def _llm_vision(messages: list) -> object:
"""Make the LLM call for vision (image) summarisation.
Accepts the message list and returns the response directly, mirroring
the ``_llm_text`` caller pattern so tests can patch it at the module level.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
"""Return a compact summary of an image file using vision.
Parameters
----------
image_b64:
Base64-encoded image bytes.
mime:
MIME type of the image, e.g. ``"image/png"``.
file_name:
Optional file name, attached to the Langfuse trace as input metadata.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
messages = [
SystemMessage(content=template),
HumanMessage(content=[
{"type": "text", "text": "Summarise this image."},
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
]),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-image",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": file_name, "mime": mime},
) as gen:
response = await _llm_vision(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_vision(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
"""Return a compact summary of a text file.
Parameters
----------
content:
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
ext:
File extension including the leading dot, e.g. ``".md"``.
name:
File name, e.g. ``"kickoff.md"``.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
truncated = content[:_MAX_INPUT_CHARS]
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
messages = [
SystemMessage(content=compiled),
HumanMessage(content="Summarise this file."),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-text",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
) as gen:
response = await _llm_text(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_text(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
def _extract_pdf_text(pdf_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(pdf_b64))
reader = PdfReader(buf)
parts: list[str] = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
continue
return "\n".join(parts).strip()
def _extract_docx_text(docx_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(docx_b64))
doc = DocxDocument(buf)
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a PDF file.
Parameters
----------
pdf_b64:
Base64-encoded PDF bytes.
name:
File name, e.g. ``"report.pdf"``.
"""
text = _extract_pdf_text(pdf_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".pdf", name=name)
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a DOCX file.
Parameters
----------
docx_b64:
Base64-encoded DOCX bytes.
name:
File name, e.g. ``"spec.docx"``.
"""
text = _extract_docx_text(docx_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".docx", name=name)

View File

@@ -1,190 +0,0 @@
"""Langfuse observability — singleton client and prompt helpers.
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
all helpers are no-ops so the app works without Langfuse configured.
Usage
-----
Tracing::
from app.core.langfuse_client import get_langfuse
lf = get_langfuse()
if lf:
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
span.update(input=user_message)
# ... do work ...
span.update(output=result)
lf.flush()
Prompt management::
from app.core.langfuse_client import get_prompt_or_fallback
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
# Use text as the system prompt; pass prompt_obj to generations for linking.
Linking a prompt to a generation::
with lf.start_as_current_observation(
as_type="generation",
name="llm-call",
model="gpt-4o",
prompt=prompt_obj, # links generation → prompt version in the UI
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=_usage(response))
"""
from __future__ import annotations
import hashlib
import logging
from contextlib import contextmanager
from typing import Any, Generator
logger = logging.getLogger(__name__)
_client: Any = None
_initialized: bool = False
def get_langfuse() -> Any | None:
"""Return the Langfuse singleton, or ``None`` when not configured."""
global _client, _initialized
if _initialized:
return _client
_initialized = True
from app.config.settings import settings # local import to avoid circular deps
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
logger.debug("langfuse: not configured — observability disabled")
return None
try:
from langfuse import Langfuse
_client = Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_BASE_URL,
)
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_BASE_URL)
except Exception as exc:
logger.warning("langfuse: failed to initialize: %s", exc)
_client = None
return _client
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
Returns ``(raw_template, prompt_obj_or_None)``.
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
on it directly; use :func:`compile_prompt` instead so the correct variable
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
unavailable / the fetch failed. Pass this to generation observations so
Langfuse links the generation to the exact prompt version in the UI.
"""
lf = get_langfuse()
if lf is None:
return fallback, None
try:
prompt = lf.get_prompt(name, label="production", fallback=fallback)
# For text-type prompts .prompt holds the raw template string.
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
return raw, prompt
except Exception as exc:
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
return fallback, None
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
"""Compile *template* with *variables*, choosing the right syntax.
* When *prompt_obj* is a real Langfuse prompt object, calls
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
substitution as defined in the Langfuse UI.
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
falls back to ``template.format(**variables)`` which handles the
``{variable}`` syntax used in the hardcoded fallback strings.
This keeps callers oblivious to which syntax is in use.
"""
if prompt_obj is not None:
try:
compiled = prompt_obj.compile(**variables)
# compile() returns a string for text prompts.
if isinstance(compiled, str):
return compiled
# Chat prompts return a list of dicts — join text parts.
if isinstance(compiled, list):
return "\n".join(
m.get("content", "") for m in compiled if isinstance(m, dict)
)
except Exception as exc:
logger.warning(
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
getattr(prompt_obj, "name", "?"),
exc,
)
return template.format(**variables)
def extract_usage(response: Any) -> dict[str, int]:
"""Extract token usage from a LangChain AI message into Langfuse format."""
meta = getattr(response, "usage_metadata", None)
if not meta:
return {}
return {
"input": int(meta.get("input_tokens", 0)),
"output": int(meta.get("output_tokens", 0)),
"total": int(meta.get("total_tokens", 0)),
}
def hash_user_id(user_id: str) -> str:
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
This avoids sending raw database UUIDs to external observability services
while still providing a stable, deterministic identifier for per-user
metrics in the Langfuse dashboard.
"""
return hashlib.sha256(user_id.encode()).hexdigest()
@contextmanager
def langfuse_context(
user_id: str | None = None,
session_id: str | None = None,
) -> Generator[None, None, None]:
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
No-op when Langfuse is not configured or parameters are empty.
"""
lf = get_langfuse()
if lf is None or (not user_id and not session_id):
yield
return
try:
from langfuse import propagate_attributes
except ImportError:
logger.debug("langfuse: propagate_attributes not available — skipping context")
yield
return
attrs: dict[str, str] = {}
if user_id:
attrs["user_id"] = hash_user_id(user_id)
if session_id:
attrs["session_id"] = session_id
with propagate_attributes(**attrs):
yield

View File

@@ -1,156 +0,0 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()``
instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_:
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
* Anthropic: ``anthropic/claude-3.5-sonnet``
* Google: ``gemini/gemini-pro``
* Ollama: ``ollama/llama3``
* Bedrock: ``bedrock/anthropic.claude-v2``
Switch providers by changing **LLM_MODEL** in ``.env``
— no code changes required.
"""
from __future__ import annotations
import os
import warnings
from collections.abc import Callable
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from litellm import get_supported_openai_params # noqa: F401 validates install
from app.config.settings import settings
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
# Drop them silently instead of raising UnsupportedParamsError.
litellm.drop_params = True
# Some provider responses include a plain dict in the `usage` field where a
# richer Pydantic model is expected. This warning is noisy but non-fatal.
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string."""
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("groq/"):
return settings.GROQ_API_KEY or None
if model.startswith("deepseek/"):
return settings.DEEPSEEK_API_KEY or None
if model.startswith("github_copilot/"):
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
# No API key is required; returning None lets LiteLLM handle auth.
return None
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
"""Return a LangChain chat model backed by LiteLLM.
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
``openai`` client transparently when the model string contains a provider
prefix (``anthropic/…``, ``gemini/…``, etc.).
Parameters
----------
model:
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
temperature:
Sampling temperature. ``0`` = deterministic.
"""
model = model or settings.LLM_MODEL
# Point LiteLLM to the custom token directory when configured.
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
)
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
"note-summarizer": lambda: "gpt-4o-mini",
}
def model_for_agent(agent_name: str) -> str:
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
def get_agent_llm(
agent_name: str,
*,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
per-agent override is left empty in ``.env``.
"""
model = model_for_agent(agent_name)
return get_llm(model=model, temperature=temperature)
async def embed(text: str) -> list[float]:
"""Return an embedding vector for *text*.
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
model names to preserve existing behaviour.
"""
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
# so the provider's auth mechanism is applied correctly.
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

View File

@@ -1,450 +0,0 @@
"""Mem0-style Extract/Update pipeline — Phase 2.
Runs after every ``store_episode`` call to distil durable facts, preferences,
routines, and relations from the latest conversation turn.
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
Design notes
------------
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
- Zero-trust: never logs decrypted user content; relation subject/object labels are
treated as identifiers (safe to log per spec).
- Must not raise into the request path — caller wraps in asyncio.create_task().
"""
from __future__ import annotations
import json
import logging
from typing import Any, Literal
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent
logger = logging.getLogger(__name__)
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
_EXTRACTION_FALLBACK = (
"You are a memory extractor for a personal AI secretary. Given the last conversation "
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
"preferences, routines, and person/project relations worth remembering.\n\n"
"Output JSON matching this schema exactly:\n"
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
'"content": "<short canonical statement>", '
'"target_tier": "<core|associative|relational|proactive>", '
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
"Rules:\n"
"- Skip small talk, greetings, one-off questions.\n"
"- Max 5 candidates per call.\n"
"- Only extract durable information (still true next week).\n"
"- For type=relation: subject/predicate/object required.\n"
"- Default confidence=0.7.\n\n"
"## Last turn\n{last_turn}\n\n"
"## Core memory (current)\n{core_memory}\n\n"
"## Recent episodes\n{recent_episodes}"
)
_DECIDE_FALLBACK = (
"You are a memory update decision engine. Given a new memory candidate and a list of "
"existing memories from the same tier, decide what action to take.\n\n"
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
"- ADD: new information not in existing memories.\n"
"- UPDATE: contradicts or supersedes an existing memory.\n"
"- DELETE: states something is no longer true.\n"
"- NOOP: already captured accurately.\n\n"
"## New candidate\n{candidate}\n\n"
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
)
# ── Pydantic schemas ───────────────────────────────────────────────────────────
class MemoryCandidate(BaseModel):
type: Literal["fact", "preference", "relation", "routine"]
content: str
target_tier: Literal["core", "associative", "relational", "proactive"]
subject: str | None = None
predicate: str | None = None
object: str | None = None
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
class ExtractionResult(BaseModel):
candidates: list[MemoryCandidate] = Field(default_factory=list)
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
async def extract_candidates(
last_turn: str,
core_memory: dict[str, str],
recent_episodes: list[str],
) -> ExtractionResult:
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
Returns an ExtractionResult (may be empty on failure — never raises).
"""
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
if prompt_obj is not None:
try:
system_text = prompt_obj.compile(
last_turn=last_turn,
core_memory=core_str,
recent_episodes=episodes_str,
)
if isinstance(system_text, list):
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
except Exception as exc:
logger.warning("memory_extraction: compile failed: %s", exc)
system_text = template.format(
last_turn=last_turn,
core_memory=core_str,
recent_episodes=episodes_str,
)
else:
system_text = template.format(
last_turn=last_turn,
core_memory=core_str,
recent_episodes=episodes_str,
)
llm = get_agent_llm("memory-extractor", temperature=0)
# Bind JSON mode so the model always returns parseable output.
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
lf = get_langfuse()
try:
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Extract memory candidates as JSON."),
]
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-extraction",
model=model_for_agent("memory-extractor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm_json.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm_json.ainvoke(messages)
raw = json.loads(response.content)
result = ExtractionResult.model_validate(raw)
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
return result
except Exception as exc:
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
return ExtractionResult(candidates=[])
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
async def decide_action(
candidate: MemoryCandidate,
existing: list[str],
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
"""Decide what to do with a candidate given existing memories in the same tier.
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
Never raises.
"""
if not existing:
return "ADD"
candidate_str = f"[{candidate.type}] {candidate.content}"
existing_str = "\n".join(f"- {m}" for m in existing)
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
if prompt_obj is not None:
try:
system_text = prompt_obj.compile(
candidate=candidate_str,
existing_memories=existing_str,
)
if isinstance(system_text, list):
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
except Exception as exc:
logger.warning("memory_extraction: decide compile failed: %s", exc)
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
else:
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
llm = get_agent_llm("memory-extractor", temperature=0)
lf = get_langfuse()
try:
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Decide action."),
]
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-decide-action",
model=model_for_agent("memory-extractor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm.ainvoke(messages)
verb = response.content.strip().upper()
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
return verb # type: ignore[return-value]
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
return "ADD"
except Exception as exc:
logger.warning("memory_extraction: decide_action failed: %s", exc)
return "ADD"
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
async def run_extraction(
db: AsyncSession,
user_id: str,
last_user_msg: str,
last_assistant_msg: str,
session_id: str | None,
) -> None:
"""Full Mem0-style extract/update pipeline for one conversation turn.
Steps:
1. Load core memory + last 5 episodes.
2. extract_candidates() → up to 5 MemoryCandidate objects.
3. For each candidate: find top-3 neighbours → decide_action() → apply.
4. Trace via Langfuse.
Never raises — wraps everything in try/except.
"""
try:
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
except Exception as exc:
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
async def _run_extraction_inner(
db: AsyncSession,
user_id: str,
last_user_msg: str,
last_assistant_msg: str,
session_id: str | None,
) -> None:
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
middleware = MemoryMiddleware(db)
fernet = await middleware._get_fernet(user_id)
if fernet is None:
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
return
# 1. Load context
core: dict[str, str] = await middleware._load_core(user_id, fernet)
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
lf = get_langfuse()
async def _run(trace_id: str | None) -> dict[str, Any]:
# 2. Extract candidates
result = await extract_candidates(last_turn, core, episodes)
if not result.candidates:
logger.info("memory_extraction: no candidates user=%s", user_id)
return {"candidates": 0, "applied": 0}
logger.info(
"memory_extraction: processing %d candidates user=%s trace=%s",
len(result.candidates),
user_id,
trace_id or "-",
)
# 3. Apply each candidate
applied = 0
actions: list[str] = []
for candidate in result.candidates:
try:
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
applied += 1
actions.append(f"{candidate.type}:{candidate.target_tier}")
except Exception as exc:
logger.warning(
"memory_extraction: apply failed candidate=%r user=%s: %s",
candidate.content[:80],
user_id,
exc,
)
logger.info(
"memory_extraction: applied %d/%d candidates user=%s",
applied,
len(result.candidates),
user_id,
)
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
with langfuse_context(user_id=user_id, session_id=session_id):
if lf:
with lf.start_as_current_observation(
as_type="span",
name="memory-extraction-pipeline",
input={"last_turn_preview": last_turn[:200]},
) as span:
summary = await _run(trace_id=span.id)
span.update(output=summary)
try:
lf.flush()
except Exception:
pass
else:
await _run(trace_id=None)
async def _apply_candidate(
middleware: Any,
db: AsyncSession,
user_id: str,
fernet: Any,
candidate: MemoryCandidate,
trace_id: str | None,
) -> None:
"""Fetch neighbours, decide action, apply to the appropriate tier."""
neighbours: list[str] = []
if candidate.target_tier == "core":
# For core tier: neighbours are existing core block values for similar keys.
blocks = await middleware.list_core_blocks(user_id)
neighbours = [b["value"] for b in blocks[:3]]
elif candidate.target_tier == "associative":
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
elif candidate.target_tier == "relational":
# Relation candidates handled specially — passed to upsert_relation directly.
# Neighbours: search by subject label if available.
neighbours = []
elif candidate.target_tier == "proactive":
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
action = await decide_action(candidate, neighbours)
logger.info(
"memory_extraction: candidate type=%s tier=%s action=%s",
candidate.type,
candidate.target_tier,
action,
)
if action == "NOOP":
return
if candidate.target_tier == "relational":
# Always upsert relations — decide_action skipped (no neighbour search).
if candidate.subject and candidate.predicate and candidate.object:
await _upsert_relation(
middleware, db, user_id, candidate, trace_id
)
return
if action in ("ADD", "UPDATE"):
if candidate.target_tier == "core":
# Derive a short key from the content (first 40 chars, snake_cased).
key = _content_to_key(candidate.content)
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
elif candidate.target_tier == "associative":
await middleware.store_associative(user_id, candidate.content)
elif candidate.target_tier == "proactive":
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
elif action == "DELETE":
if candidate.target_tier == "core":
key = _content_to_key(candidate.content)
await middleware.delete_core(user_id, key)
def _content_to_key(content: str) -> str:
"""Derive a short snake_case key from a content string (first 40 chars)."""
import re # noqa: PLC0415
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
return slug or "memory"
async def _upsert_relation(
middleware: Any,
db: AsyncSession,
user_id: str,
candidate: MemoryCandidate,
trace_id: str | None,
) -> None:
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
await middleware.upsert_relation(
user_id=user_id,
subject=candidate.subject or "unknown",
subject_type="unknown",
predicate=candidate.predicate or "related_to",
object_=candidate.object or "unknown",
object_type="unknown",
confidence=candidate.confidence,
)
logger.info(
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
candidate.subject,
candidate.predicate,
candidate.object,
)
async def _store_proactive_stub(
middleware: Any,
db: AsyncSession,
user_id: str,
candidate: MemoryCandidate,
fernet: Any,
) -> None:
"""Store a proactive pattern row directly (MemoryProactive model)."""
import uuid # noqa: PLC0415
from app.models import MemoryProactive # noqa: PLC0415
from app.core.memory_middleware import _encrypt # noqa: PLC0415
encrypted = _encrypt(fernet, candidate.content)
row = MemoryProactive(
id=str(uuid.uuid4()),
user_id=user_id,
pattern_encrypted=encrypted,
confidence=candidate.confidence,
source="inferred",
)
db.add(row)
try:
await db.commit()
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
except Exception as exc:
logger.warning("memory_extraction: store proactive failed: %s", exc)
await db.rollback()

View File

@@ -1,581 +0,0 @@
"""Memory maintenance jobs — Phase 3/5.
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
All are safe to call manually or from tests; they never raise.
"""
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from cryptography.fernet import Fernet
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
logger = logging.getLogger(__name__)
# Decay parameters for relations
_DECAY_FACTOR = 0.95
_DECAY_PERIOD_DAYS = 30
_PRUNE_THRESHOLD = 0.2
# Proactive pattern decay: 10 % per 7 days since last sighting
_PROACTIVE_DECAY_FACTOR = 0.9
_PROACTIVE_DECAY_PERIOD_DAYS = 7
_PROACTIVE_PRUNE_THRESHOLD = 0.2
# Mining: require at least this many episodes to attempt pattern extraction
_MIN_EPISODES_FOR_MINING = 3
_MINING_LOOKBACK_DAYS = 30
# Audit: caps to control token cost
_AUDIT_MAX_FACTS = 50
_AUDIT_MAX_LABELS = 100
async def decay_relations(db: AsyncSession, user_id: str) -> None:
"""Apply confidence decay to all relation rows for a user.
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
Rows whose confidence falls below 0.2 are deleted.
Never raises — wraps in try/except.
"""
try:
await _decay_relations_inner(db, user_id)
except Exception as exc:
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
result = await db.execute(
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
)
rows = result.scalars().all()
now = datetime.now(timezone.utc)
deleted = 0
decayed = 0
for row in rows:
reference = row.last_confirmed_at or row.created_at
if reference is None:
continue
if reference.tzinfo is None:
reference = reference.replace(tzinfo=timezone.utc)
days_elapsed = (now - reference).days
if days_elapsed < _DECAY_PERIOD_DAYS:
continue
periods = days_elapsed // _DECAY_PERIOD_DAYS
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
if new_confidence < _PRUNE_THRESHOLD:
await db.delete(row)
deleted += 1
logger.info(
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
"confidence=%.3f (below threshold)",
row.id, user_id, row.subject_label, row.predicate, new_confidence,
)
else:
row.confidence = new_confidence
decayed += 1
try:
await db.commit()
logger.info(
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
user_id, decayed, deleted,
)
except Exception as exc:
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
await db.rollback()
async def drain_extraction_queue(db: AsyncSession) -> None:
"""Process pending ExtractionQueue rows for Free-tier users.
Each row corresponds to a stored episode that should be fed through the
Mem0-style extraction pipeline. Rows are deleted after successful processing.
Never raises — wraps in try/except.
"""
try:
await _drain_extraction_queue_inner(db)
except Exception as exc:
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
from app.models import ExtractionQueue # noqa: PLC0415
result = await db.execute(select(ExtractionQueue))
rows = result.scalars().all()
if not rows:
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
return
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
from app.core.memory_extraction import run_extraction # noqa: PLC0415
processed = 0
for row in rows:
try:
await run_extraction(
db=db,
user_id=row.user_id,
last_user_msg="",
last_assistant_msg="",
session_id=None,
)
await db.delete(row)
await db.commit()
processed += 1
except Exception as exc:
logger.warning(
"memory_maintenance: drain failed row=%s user=%s: %s",
row.id, row.user_id, exc,
)
await db.rollback()
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
Steps:
1. Gate on proactive_mining tier feature.
2. Load + decrypt last 30 days of episodic summaries.
3. Call gpt-4o-mini to identify recurring patterns.
4. Encrypt and store each pattern in memory_proactive.
5. Apply decay to existing proactive rows.
Never raises — wraps in try/except.
"""
try:
await _mine_proactive_patterns_inner(db, user_id)
except Exception as exc:
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
from app.billing.tier_manager import tier_manager # noqa: PLC0415
tier = await tier_manager.get_tier(user_id, db)
if not tier_manager.check_feature(tier, "proactive_mining"):
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
return
# Load user Fernet key
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
return
fernet = Fernet(user.encryption_key.encode())
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
episodes_result = await db.execute(
select(MemoryEpisodic)
.where(
MemoryEpisodic.user_id == user_id,
MemoryEpisodic.created_at >= cutoff,
)
.order_by(MemoryEpisodic.created_at.asc())
)
episode_rows = episodes_result.scalars().all()
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
logger.info(
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
)
return
summaries: list[str] = []
for ep in episode_rows:
try:
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
summaries.append(plaintext)
except Exception:
pass
if not summaries:
return
patterns = await _extract_proactive_patterns(summaries)
if not patterns:
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
return
stored = 0
for pattern_text in patterns:
try:
encrypted = fernet.encrypt(pattern_text.encode()).decode()
row = MemoryProactive(
id=str(uuid.uuid4()),
user_id=user_id,
pattern_encrypted=encrypted,
confidence=0.7,
source="inferred",
)
db.add(row)
stored += 1
except Exception as exc:
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
try:
await db.commit()
logger.info(
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
user_id, stored,
)
except Exception as exc:
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
await db.rollback()
return
await _decay_proactive_patterns(db, user_id, fernet)
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
from app.core.llm import get_agent_llm # noqa: PLC0415
llm = get_agent_llm("memory-miner", temperature=0)
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
prompt = (
"You are analyzing conversation history for a personal AI secretary. "
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
"Return each pattern as a plain, short English sentence on its own line. "
"No numbering, no bullet points, no extra text.\n\n"
f"Conversation history:\n{combined}"
)
try:
response = await llm.ainvoke(prompt)
text = response.content if hasattr(response, "content") else str(response)
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
return lines[:5]
except Exception as exc:
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
return []
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
"""Decay confidence of existing proactive patterns; prune below threshold."""
result = await db.execute(
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
)
rows = result.scalars().all()
now = datetime.now(timezone.utc)
deleted = 0
decayed = 0
for row in rows:
reference = row.created_at
if reference is None:
continue
if reference.tzinfo is None:
reference = reference.replace(tzinfo=timezone.utc)
days_elapsed = (now - reference).days
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
continue
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
await db.delete(row)
deleted += 1
else:
row.confidence = new_confidence
decayed += 1
try:
await db.commit()
logger.info(
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
user_id, decayed, deleted,
)
except Exception as exc:
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
await db.rollback()
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
_AUDIT_CONTRADICTIONS_FALLBACK = (
"You are auditing a personal AI assistant's memory bank. "
"Each fact has an ID in brackets. "
"Find pairs that directly contradict each other "
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
'Return ONLY a valid JSON array, no markdown fences: '
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
"If no contradictions, return [].\n\n"
"Facts:\n{facts}"
)
_AUDIT_CANONICALIZE_FALLBACK = (
"You are auditing entity labels in a personal AI assistant's relational memory. "
"These are names of people, companies, projects, or topics. "
"Group labels that clearly refer to the same real-world entity "
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
"Return ONLY a valid JSON array, no markdown fences: "
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
"Only include groups with at least one variant. Singletons: omit.\n\n"
"Labels:\n{labels}"
)
async def audit_memory(db: AsyncSession, user_id: str) -> None:
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
Steps:
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
2. LLM flags rows to delete (direct contradictions); hard-delete them.
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
4. Rewrite variant labels to their canonical form in-place.
Never raises — wraps in try/except.
"""
try:
await _audit_memory_inner(db, user_id)
except Exception as exc:
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
return
fernet = Fernet(user.encryption_key.encode())
await _scan_associative_contradictions(db, user_id, fernet)
await _canonicalize_relation_labels(db, user_id)
async def _scan_associative_contradictions(
db: AsyncSession,
user_id: str,
fernet: Fernet,
) -> None:
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
result = await db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(_AUDIT_MAX_FACTS)
)
rows = result.scalars().all()
if len(rows) < 2:
return
id_to_text: dict[str, str] = {}
for row in rows:
try:
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
id_to_text[row.id] = plaintext
except Exception:
pass
if len(id_to_text) < 2:
return
id_list = list(id_to_text.keys())
numbered = "\n".join(
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
)
template, prompt_obj = get_prompt_or_fallback(
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
)
system_text = compile_prompt(template, prompt_obj, facts=numbered)
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
llm = get_agent_llm("memory-auditor", temperature=0)
lf = get_langfuse()
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Audit facts for contradictions."),
]
try:
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-audit-contradictions",
model=model_for_agent("memory-auditor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm.ainvoke(messages)
text = response.content if hasattr(response, "content") else str(response)
deletions = json.loads(text.strip())
if not isinstance(deletions, list):
return
except Exception as exc:
logger.warning(
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
user_id, exc,
)
return
deleted = 0
for item in deletions:
if not isinstance(item, dict):
continue
rid = item.get("delete")
if not rid or rid not in id_to_text:
continue
result2 = await db.execute(
select(MemoryAssociative).where(
MemoryAssociative.id == rid,
MemoryAssociative.user_id == user_id,
)
)
target = result2.scalar_one_or_none()
if target:
await db.delete(target)
deleted += 1
logger.info(
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
rid, user_id, item.get("reason", ""),
)
if deleted:
try:
await db.commit()
except Exception as exc:
logger.warning(
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
)
await db.rollback()
logger.info(
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
)
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
result = await db.execute(
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
)
rows = result.scalars().all()
if not rows:
return
all_labels: set[str] = set()
for row in rows:
all_labels.add(row.subject_label)
all_labels.add(row.object_label)
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
if len(labels_list) < 2:
return
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
template, prompt_obj = get_prompt_or_fallback(
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
)
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
llm = get_agent_llm("memory-auditor", temperature=0)
lf = get_langfuse()
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Canonicalize entity labels."),
]
try:
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-audit-canonicalize",
model=model_for_agent("memory-auditor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm.ainvoke(messages)
text = response.content if hasattr(response, "content") else str(response)
groups = json.loads(text.strip())
if not isinstance(groups, list):
return
except Exception as exc:
logger.warning(
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
user_id, exc,
)
return
# Build variant → canonical map
remap: dict[str, str] = {}
for group in groups:
if not isinstance(group, dict):
continue
canonical = group.get("canonical", "")
variants = group.get("variants") or []
if not canonical:
continue
for v in variants:
if isinstance(v, str) and v != canonical:
remap[v] = canonical
if not remap:
return
updated = 0
for row in rows:
changed = False
if row.subject_label in remap:
row.subject_label = remap[row.subject_label]
changed = True
if row.object_label in remap:
row.object_label = remap[row.object_label]
changed = True
if changed:
updated += 1
if updated:
try:
await db.commit()
logger.info(
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
user_id, updated,
)
except Exception as exc:
logger.warning(
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
)
await db.rollback()

View File

@@ -1,733 +0,0 @@
"""Memory Middleware — enrich requests with memory context and store interactions.
Four-tier memory model (MemGPT-style):
core — persistent key/value user preferences, always injected
associative — semantic similarity search via pgvector (top-k)
episodic — recent session summaries (last N)
proactive — behavioral patterns above confidence threshold
All memory content is encrypted at rest using the per-user Fernet key
stored in User.encryption_key. Decryption happens in-memory only.
Usage:
memory = MemoryMiddleware(db_session)
context = await memory.enrich_context(user_id, message)
# ... run agent ...
await memory.store_episode(user_id, session_id, message, response)
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import (
ExtractionQueue,
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
MemoryRelation,
User,
)
logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
# Tuning constants
_ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware:
"""Enrich orchestrator context with memory and persist interactions after."""
def __init__(self, db: AsyncSession) -> None:
self._db = db
# ── Public API ────────────────────────────────────────────────────────────
async def enrich_context(
self,
user_id: str,
message: str,
trace_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call.
Returns a dict with keys:
core_memory — {key: plaintext_value, ...}
associative_memory — [plaintext_content, ...] (top-k by keyword match)
episodic_memory — [plaintext_summary, ...] (most recent N)
proactive_hints — [plaintext_pattern, ...] (above threshold)
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
return {}
user_dbg = await self._get_user_debug(user_id)
user_tier: str = user_dbg.get("tier") or "free"
core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet)
relational = await self._load_relational(user_id, user_tier=user_tier)
logger.info(
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
trace_id or "-",
user_id,
user_tier,
len(core),
len(associative),
len(episodic),
len(proactive),
len(relational),
)
return {
"core_memory": core,
"associative_memory": associative,
"episodic_memory": episodic,
"proactive_hints": proactive,
"relational_memory": relational,
}
async def store_episode(
self,
user_id: str,
session_id: str,
message: str,
response: str,
trace_id: str | None = None,
) -> None:
"""Summarise and store a completed interaction in episodic memory.
The summary is a simple heuristic concatenation (no LLM call) to keep
latency low. After committing the episode row, dispatches the Mem0-style
extraction pipeline:
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
- Free → enqueue an ExtractionQueue row for the daily cron.
"""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
encrypted = _encrypt(fernet, summary)
episode = MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=user_id,
summary_encrypted=encrypted,
session_id=session_id,
)
self._db.add(episode)
episode_id: str = episode.id
try:
await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
tier = user_dbg.get("tier") or "free"
logger.info(
"memory: store_episode trace=%s user=%s tier=%s session=%s",
trace_id or "-",
user_id,
tier,
session_id,
)
except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback()
return
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
await self._dispatch_extraction(
user_id=user_id,
episode_id=episode_id,
last_user_msg=message,
last_assistant_msg=response,
session_id=session_id,
)
async def _dispatch_extraction(
self,
user_id: str,
episode_id: str,
last_user_msg: str,
last_assistant_msg: str,
session_id: str | None,
) -> None:
"""Route extraction to realtime task or batch queue based on user tier."""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
tier = await tier_manager.get_tier(user_id, self._db)
if tier_manager.check_feature(tier, "realtime_extraction"):
# Pro/Power/Team: fire-and-forget in the background.
# Must open a fresh session — request session closes after handler returns.
from app.core.memory_extraction import run_extraction # noqa: PLC0415
from app.db import async_session # noqa: PLC0415
async def _task() -> None:
try:
async with async_session() as fresh_db:
await run_extraction(
db=fresh_db,
user_id=user_id,
last_user_msg=last_user_msg,
last_assistant_msg=last_assistant_msg,
session_id=session_id,
)
except Exception as exc:
logger.warning(
"memory: extraction task failed user=%s: %s", user_id, exc
)
asyncio.create_task(_task())
logger.info("memory: realtime extraction dispatched user=%s", user_id)
else:
# Free tier: enqueue for daily batch cron.
queue_row = ExtractionQueue(
id=str(uuid.uuid4()),
user_id=user_id,
episode_id=episode_id,
)
self._db.add(queue_row)
try:
await self._db.commit()
logger.info(
"memory: extraction enqueued (batch) user=%s episode=%s",
user_id,
episode_id,
)
except Exception as exc:
logger.warning(
"memory: extraction queue insert failed user=%s: %s", user_id, exc
)
await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
"""Upsert a core memory key/value for a user."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, value)
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == key,
)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.value_encrypted = encrypted
else:
self._db.add(MemoryCore(
id=str(uuid.uuid4()),
user_id=user_id,
key=key,
value_encrypted=encrypted,
))
try:
await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: update_core trace=%s user=%s tier=%s key=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
key,
)
except Exception as exc:
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
await self._db.rollback()
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
"""Return core memory as editable blocks (label/value)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryCore)
.where(MemoryCore.user_id == user_id)
.order_by(MemoryCore.key.asc())
)
rows = result.scalars().all()
out: list[dict[str, str]] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out.append({"label": row.key, "value": plaintext})
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
return out
async def get_core_block(self, user_id: str, label: str) -> str | None:
"""Return a single core memory block value by label."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return None
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
return None
value = _safe_decrypt(fernet, row.value_encrypted)
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
return value
async def delete_core(self, user_id: str, label: str) -> bool:
"""Delete a core memory block by label. Returns True if deleted."""
result = await self._db.execute(
select(MemoryCore).where(
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
)
row = result.scalar_one_or_none()
if row is None:
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
return False
await self._db.delete(row)
try:
await self._db.commit()
logger.info("memory: delete_core user=%s label=%s", user_id, label)
return True
except Exception as exc:
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
await self._db.rollback()
return False
async def append_core(self, user_id: str, label: str, content: str) -> None:
"""Append content to a core block, creating it if missing."""
current = await self.get_core_block(user_id, label)
if current is None:
await self.update_core(user_id, label, content)
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
return
await self.update_core(user_id, label, f"{current}\n{content}")
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
"""Replace one exact string inside a core block. Returns False if not found."""
current = await self.get_core_block(user_id, label)
if current is None or old not in current:
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
return False
await self.update_core(user_id, label, current.replace(old, new, 1))
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
return True
async def store_associative(
self,
user_id: str,
content: str,
entity_type: str | None = None,
entity_id: str | None = None,
) -> None:
"""Store associative memory; embed if user tier has real_embeddings."""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.core.embeddings import embed_text # noqa: PLC0415
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, content)
user_dbg = await self._get_user_debug(user_id)
user_tier = user_dbg.get("tier") or "free"
embedding: list[float] | None = None
if tier_manager.check_feature(user_tier, "real_embeddings"):
embedding = await embed_text(content)
row = MemoryAssociative(
id=str(uuid.uuid4()),
user_id=user_id,
content_encrypted=encrypted,
embedding=embedding,
entity_type=entity_type,
entity_id=entity_id,
)
self._db.add(row)
try:
await self._db.commit()
logger.info(
"memory: store_associative user=%s embedded=%s",
user_id,
embedding is not None,
)
except Exception as exc:
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def upsert_relation(
self,
user_id: str,
subject: str,
subject_type: str,
predicate: str,
object_: str,
object_type: str,
*,
confidence: float = 0.7,
source_episode_id: str | None = None,
notes: str | None = None,
) -> None:
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
subject_label / object_label are plaintext entity identifiers — not encrypted.
notes is optional; encrypted with user Fernet if provided.
"""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
user_dbg = await self._get_user_debug(user_id)
user_tier = user_dbg.get("tier") or "free"
if not tier_manager.check_feature(user_tier, "relational_memory"):
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
return
notes_encrypted: bytes | None = None
if notes:
fernet = await self._get_fernet(user_id)
if fernet:
notes_encrypted = fernet.encrypt(notes.encode())
result = await self._db.execute(
select(MemoryRelation).where(
MemoryRelation.user_id == user_id,
MemoryRelation.subject_label == subject,
MemoryRelation.predicate == predicate,
MemoryRelation.object_label == object_,
)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.subject_type = subject_type
existing.object_type = object_type
existing.confidence = confidence
existing.last_confirmed_at = _now()
if notes_encrypted is not None:
existing.notes_encrypted = notes_encrypted
else:
self._db.add(MemoryRelation(
id=str(uuid.uuid4()),
user_id=user_id,
subject_label=subject,
subject_type=subject_type,
predicate=predicate,
object_label=object_,
object_type=object_type,
confidence=confidence,
source_episode_id=source_episode_id,
notes_encrypted=notes_encrypted,
))
try:
await self._db.commit()
logger.info(
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
user_id, subject, predicate, object_,
)
except Exception as exc:
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def query_relations(
self,
user_id: str,
subject: str | None = None,
predicate: str | None = None,
object_: str | None = None,
limit: int = 20,
) -> list[MemoryRelation]:
"""Query relation rows for a user with optional filters."""
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
if subject is not None:
q = q.where(MemoryRelation.subject_label == subject)
if predicate is not None:
q = q.where(MemoryRelation.predicate == predicate)
if object_ is not None:
q = q.where(MemoryRelation.object_label == object_)
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
result = await self._db.execute(q)
return list(result.scalars().all())
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
"""Insert a long-term archival memory entry."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, content)
row = MemoryAssociative(
id=str(uuid.uuid4()),
user_id=user_id,
content_encrypted=encrypted,
embedding=None,
entity_type=source,
entity_id=None,
)
self._db.add(row)
try:
await self._db.commit()
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
except Exception as exc:
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search recall memory (episodic summaries) by keyword."""
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryEpisodic)
.where(MemoryEpisodic.user_id == user_id)
.order_by(MemoryEpisodic.created_at.desc())
.limit(100)
)
rows = result.scalars().all()
needle = query.strip().lower()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out
# ── Private helpers ───────────────────────────────────────────────────────
async def _get_fernet(self, user_id: str) -> Fernet | None:
"""Load the user's Fernet key from DB. Returns None if missing."""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory: no encryption_key for user=%s", user_id)
return None
return Fernet(user.encryption_key.encode())
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
"""Load lightweight user debug fields for trace logs."""
from app.config.settings import settings # noqa: PLC0415
from app.models import Subscription # noqa: PLC0415
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return {"tier": None}
sub_result = await self._db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
sub_tier: str | None = sub_result.scalar_one_or_none()
if sub_tier:
tier = sub_tier
elif settings.ENV == "dev":
tier = "power"
else:
tier = user.tier or "free"
return {"tier": tier}
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id)
)
rows = result.scalars().all()
out: dict[str, str] = {}
for row in rows:
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out[row.key] = plaintext
return out
async def _load_associative(
self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
) -> list[str]:
"""Load top-k associative memories.
Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
Free / embedding failure: keyword-ordered fallback (most recent rows).
"""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.core.embeddings import embed_text # noqa: PLC0415
if tier_manager.check_feature(user_tier, "real_embeddings"):
vec = await embed_text(message)
if vec is not None:
try:
result = await self._db.execute(
select(MemoryAssociative)
.where(
MemoryAssociative.user_id == user_id,
MemoryAssociative.embedding.isnot(None),
)
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
.limit(_ASSOCIATIVE_TOP_K)
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None:
out.append(plaintext)
logger.info(
"memory: _load_associative user=%s mode=vector hits=%d",
user_id,
len(out),
)
return out
except Exception as exc:
logger.warning(
"memory: vector search failed user=%s, falling back to keyword: %s",
user_id,
exc,
)
# Keyword fallback: most recent rows
result = await self._db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(_ASSOCIATIVE_TOP_K)
)
rows = result.scalars().all()
out = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_episodic(
self,
user_id: str,
fernet: Fernet,
session_id: str | None = None,
) -> list[str]:
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
if session_id:
query = query.where(MemoryEpisodic.session_id == session_id)
result = await self._db.execute(
query
.order_by(MemoryEpisodic.created_at.desc())
.limit(_EPISODIC_RECENT_N)
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
if not tier_manager.check_feature(user_tier, "relational_memory"):
return []
result = await self._db.execute(
select(MemoryRelation)
.where(MemoryRelation.user_id == user_id)
.order_by(MemoryRelation.confidence.desc())
.limit(10)
)
rows = result.scalars().all()
out = [
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
for r in rows
]
return out
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryProactive)
.where(
MemoryProactive.user_id == user_id,
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
)
.order_by(MemoryProactive.confidence.desc())
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
# ── Encryption helpers ────────────────────────────────────────────────────────
def _encrypt(fernet: Fernet, plaintext: str) -> str:
return fernet.encrypt(plaintext.encode()).decode()
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
try:
return fernet.decrypt(ciphertext.encode()).decode()
except (InvalidToken, Exception) as exc:
logger.warning("memory: decrypt failed: %s", exc)
return None

View File

@@ -1,51 +0,0 @@
"""Note summarizer — generates a compact AI summary for a note.
Called fire-and-forget from create_note / update_note tools so the
``notes.ai_summary`` column stays current without blocking the agent loop.
"""
from __future__ import annotations
import logging
from langchain_core.messages import HumanMessage, SystemMessage
from app.core.langfuse_client import get_prompt_or_fallback
from app.core.llm import get_agent_llm
logger = logging.getLogger(__name__)
_FALLBACK_PROMPT = """\
Summarize this note in <=250 characters. Be terse and dense.
Keep proper nouns, dates, decisions, and action items.
Do not start with "This note".
Respond with the summary text only — no intro, no labels.
Title: {title}
Content: {content}"""
_MAX_CONTENT_CHARS = 4000
async def generate_note_summary(title: str, content: str) -> str:
"""Return a <=250-char summary of *title* + *content*.
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
fallback. Truncates *content* to 4000 chars before sending to avoid
token waste on large notes.
"""
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
trimmed = content[:_MAX_CONTENT_CHARS]
system_prompt = template.format(title=title, content=trimmed)
try:
llm = get_agent_llm("note-summarizer")
response = await llm.ainvoke([
SystemMessage(content=system_prompt),
HumanMessage(content="Generate the summary."),
])
text = response.content if isinstance(response.content, str) else ""
return text.strip()[:250]
except Exception as exc:
logger.warning("note_summarizer: failed to generate summary: %s", exc)
return ""

View File

@@ -1,63 +0,0 @@
"""Output formatter for deep-agent stream events."""
from __future__ import annotations
import re
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
_CANVAS_BLOCK_RE = re.compile(
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
re.DOTALL | re.IGNORECASE,
)
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
Returns ``(visible_text, canvas_content, canvas_kind)``.
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
"""
match = _CANVAS_BLOCK_RE.search(text)
if not match:
return text, None, None
canvas_kind = match.group(1).strip()
canvas_content = match.group(2).strip()
visible = text[: match.start()] + text[match.end() :]
visible = visible.strip()
return visible, canvas_content, canvas_kind
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd
class StreamFormatter:
"""Convert `(event_type, data)` stream events into websocket frame models."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
started = False
async for event_type, data in event_stream:
if event_type != "token":
continue
if not started:
yield WsStreamStart(request_id=self.request_id)
started = True
text = str(data or "")
if text:
yield WsStreamText(request_id=self.request_id, chunk=text)
if not started:
yield WsStreamStart(request_id=self.request_id)
yield WsStreamEnd(request_id=self.request_id)

View File

@@ -1,104 +0,0 @@
"""Preprocessor registry: detect content type and dispatch to handlers.
Public API
----------
detect_content_type(filename, raw_content) -> str
Heuristic detection based on file extension and content patterns.
preprocess(content_type, raw_content) -> PreprocessResult
Dispatch to the appropriate handler.
"""
from __future__ import annotations
import re
from app.core.preprocessors.base import PreprocessResult
# ── Heuristics ────────────────────────────────────────────────────────
# Patterns that strongly suggest an email HTML file
_EMAIL_SIGNALS = re.compile(
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
re.IGNORECASE,
)
# Patterns that suggest a generic HTML page (not an email)
_GENERIC_HTML_SIGNALS = re.compile(
r"<(nav|main|header|footer|article|section)\b",
re.IGNORECASE,
)
def detect_content_type(filename: str, raw_content: str) -> str:
"""Return a content-type string for the given file.
Supported types: ``"email_html"``, ``"generic_html"``,
``"plain_text"``, ``"unknown"``.
"""
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
if ext == "txt":
return "plain_text"
if ext in ("html", "htm", "eml", "mhtml", "mht"):
# Prefer email detection over generic HTML
if _EMAIL_SIGNALS.search(raw_content[:4096]):
return "email_html"
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
return "generic_html"
# .html without clear signals — check for any email header
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
return "email_html"
return "generic_html"
# Plain text files with email headers
if ext in ("", "txt") or not ext:
if _EMAIL_SIGNALS.search(raw_content[:4096]):
return "email_html"
# Detect binary content
try:
raw_content.encode("utf-8")
except (UnicodeEncodeError, AttributeError):
return "unknown"
# Non-text bytes heuristic: high ratio of non-printable chars
sample = raw_content[:512]
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
if len(sample) > 0 and non_printable / len(sample) > 0.1:
return "unknown"
return "unknown"
# ── Generic fallback handler ──────────────────────────────────────────
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
"""Strip HTML tags if present, return text as-is."""
try:
from bs4 import BeautifulSoup
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
except ImportError:
# No BeautifulSoup — strip tags with a simple regex
text = re.sub(r"<[^>]+>", "", raw_content)
text = re.sub(r"\n{3,}", "\n\n", text).strip()
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
# ── Dispatch ──────────────────────────────────────────────────────────
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
"""Dispatch *raw_content* to the handler registered for *content_type*.
Falls back to the generic handler for unknown types.
"""
if content_type == "email_html":
from app.core.preprocessors.email_html import preprocess_email_html
return preprocess_email_html(raw_content)
return _preprocess_generic(raw_content, content_type)
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]

View File

@@ -1,25 +0,0 @@
"""Base types for the preprocessor system."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class PreprocessResult:
"""Output of a preprocessor handler.
Attributes
----------
content_type:
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
clean_text:
Human-readable text stripped of markup/binary noise.
metadata:
Dict of extracted metadata (keys vary by handler).
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
"""
content_type: str
clean_text: str
metadata: dict = field(default_factory=dict)

View File

@@ -1,111 +0,0 @@
"""Preprocessor for email HTML files.
Handles:
- HTML stripping via BeautifulSoup
- Metadata extraction (Subject, From, To, Date)
- Thread splitting — isolates the latest reply
"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from app.core.preprocessors.base import PreprocessResult
if TYPE_CHECKING:
pass
# ── Thread split markers ──────────────────────────────────────────────
# Matches patterns like:
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
# "-----Original Message-----"
# "> " (plain-text quote prefix)
_THREAD_PATTERNS = [
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
]
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
_META_PATTERNS: dict[str, list[re.Pattern]] = {
"subject": [
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
],
"from": [
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
re.compile(r"From:\s*(.+)", re.IGNORECASE),
],
"to": [
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
re.compile(r"To:\s*(.+)", re.IGNORECASE),
],
"date": [
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
],
}
def _extract_metadata(raw_html: str, text: str) -> dict:
"""Extract Subject/From/To/Date from raw HTML or plain text."""
metadata: dict[str, str] = {}
for field, patterns in _META_PATTERNS.items():
for pat in patterns:
m = pat.search(raw_html) or pat.search(text)
if m:
metadata[field] = m.group(1).strip()
break
return metadata
def _split_thread(text: str) -> str:
"""Return only the latest message in a threaded email."""
earliest_pos: int | None = None
for pat in _THREAD_PATTERNS:
m = pat.search(text)
if m and (earliest_pos is None or m.start() < earliest_pos):
earliest_pos = m.start()
if earliest_pos is not None and earliest_pos > 0:
return text[:earliest_pos].strip()
return text.strip()
def preprocess_email_html(raw_content: str) -> PreprocessResult:
"""Strip HTML, extract metadata, split thread from an email HTML file."""
try:
from bs4 import BeautifulSoup # lazy import — optional dep
except ImportError as exc:
raise ImportError(
"beautifulsoup4 is required for email_html preprocessing. "
"Install it with: pip install beautifulsoup4"
) from exc
# Parse with lxml if available, fall back to html.parser
try:
soup = BeautifulSoup(raw_content, "lxml")
except Exception:
soup = BeautifulSoup(raw_content, "html.parser")
# Remove noise tags
for tag in soup(["style", "script", "head", "noscript"]):
tag.decompose()
clean_text = soup.get_text(separator="\n")
# Collapse excessive blank lines
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
metadata = _extract_metadata(raw_content, clean_text)
latest_message = _split_thread(clean_text)
return PreprocessResult(
content_type="email_html",
clean_text=latest_message,
metadata=metadata,
)

View File

@@ -1,115 +0,0 @@
"""WebSocket client executor context.
Holds a per-request async callback that tools call to execute CRUD
operations on the Electron client's local SQLite / LanceDB databases.
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
"""
from __future__ import annotations
import re
from contextvars import ContextVar
from typing import Any, Callable, Coroutine
from uuid import uuid4
_SNAKE_TO_CAMEL_RE = re.compile(r"_([a-z])")
def _key_to_camel(key: str) -> str:
return _SNAKE_TO_CAMEL_RE.sub(lambda m: m.group(1).upper(), key)
def _keys_to_camel(obj: Any) -> Any:
"""Recursively convert dict keys from snake_case to camelCase.
Mirrors the JS-side ``toCamelCase`` applied to incoming WS frames in
``adiuvAI/src/main/api/backend-client.ts``. The Electron executor wraps
tool_result payloads in ``toSnakeCase`` before sending; this restores the
camelCase schema property names that the tool code expects to read.
"""
if isinstance(obj, dict):
return {_key_to_camel(k): _keys_to_camel(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_keys_to_camel(v) for v in obj]
return obj
# Holds the execute callback for the current WS session.
# Set by the chat WS handler before the orchestrator runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
"_client_executor"
)
# Optional collector that captures raw execute_on_client results.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_tool_result_collector(lst: list[dict]) -> None:
"""Register *lst* as the collector for this async context."""
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
"""Clear the collector (best-effort)."""
_tool_result_collector.set(None)
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
_client_executor.set(fn)
def clear_client_executor() -> None:
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
try:
_client_executor.set(None) # type: ignore[arg-type]
except Exception:
pass
async def execute_on_client(
action: str,
table: str | None = None,
data: dict[str, Any] | None = None,
filters: dict[str, Any] | None = None,
vector: list[float] | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Send a CRUD/vector operation to the Electron client and return the result.
Builds a ``tool_call`` payload, invokes the per-session WS callback,
and returns the ``tool_result`` dict from Electron.
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
"""
callback = _client_executor.get(None)
if callback is None:
raise RuntimeError(
"execute_on_client() called outside a WebSocket session — "
"no client executor is set."
)
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
if table is not None:
payload["table"] = table
if data is not None:
payload["data"] = data
if filters is not None:
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
if vector is not None:
payload["vector"] = vector
if limit is not None:
payload["limit"] = limit
result = await callback(payload)
result = _keys_to_camel(result)
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append({
"action": action,
"table": table,
"data": result,
})
return result

View File

@@ -1,143 +0,0 @@
from contextlib import asynccontextmanager
import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.api.middleware.rate_limit import TierRateLimitMiddleware
from app.api.middleware.sanitizer import SanitizerMiddleware
from app.config.settings import settings
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
async def _memory_audit_cron_tick() -> None:
"""Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
import logging # noqa: PLC0415
_log = logging.getLogger(__name__)
_log.info("memory audit cron tick: starting")
try:
from app.db import async_session # noqa: PLC0415
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
from app.models import User # noqa: PLC0415
from sqlalchemy import select # noqa: PLC0415
async with async_session() as db:
result = await db.execute(select(User.id))
user_ids: list[str] = list(result.scalars().all())
for uid in user_ids:
try:
async with async_session() as db:
await audit_memory(db, uid)
except Exception as exc:
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
_log.info("memory audit cron tick: done users=%d", len(user_ids))
except Exception as exc:
_log.warning("memory audit cron tick: failed: %s", exc)
async def _memory_cron_tick() -> None:
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
import logging # noqa: PLC0415
_log = logging.getLogger(__name__)
_log.info("memory cron tick: starting")
try:
from app.db import async_session # noqa: PLC0415
from app.core.memory_maintenance import drain_extraction_queue, mine_proactive_patterns # noqa: PLC0415
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.models import User # noqa: PLC0415
from sqlalchemy import select # noqa: PLC0415
async with async_session() as db:
await drain_extraction_queue(db)
# mine proactive patterns for every Power+ user
async with async_session() as db:
result = await db.execute(select(User.id))
user_ids: list[str] = list(result.scalars().all())
for uid in user_ids:
try:
async with async_session() as db:
tier = await tier_manager.get_tier(uid, db)
if tier_manager.check_feature(tier, "proactive_mining"):
await mine_proactive_patterns(db, uid)
except Exception as exc:
_log.warning("memory cron tick: mine_proactive_patterns failed user=%s: %s", uid, exc)
_log.info("memory cron tick: done users=%d", len(user_ids))
except Exception as exc:
_log.warning("memory cron tick: failed: %s", exc)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: ensure agent tool modules are loaded.
import app.agents # noqa: F401
scheduler = None
if settings.SCHEDULER_ENABLED:
from apscheduler.schedulers.asyncio import AsyncIOScheduler # noqa: PLC0415
scheduler = AsyncIOScheduler()
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
scheduler.start()
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
yield
if scheduler is not None:
scheduler.shutdown(wait=False)
# Shutdown: dispose SQLAlchemy connection pool
from app.db import engine
await engine.dispose()
def create_app() -> FastAPI:
app = FastAPI(
title="AdiuvAI Cloud API",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
# Request flow: TierRateLimit → Sanitizer → CORS → Router
# Response flow: Router → CORS → Sanitizer → TierRateLimit
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, billing, chat, device_ws, memory
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")
app.include_router(memory.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "version": app.version}
return app
app = create_app()

View File

@@ -1,73 +0,0 @@
"""Contextual sidebar scope schema and prompt block renderer.
ContextualScope mirrors the TypeScript ContextualScope type sent by the
Electron renderer when the user opens the side chat anchored to a specific
view. The renderer ships camelCase keys; Pydantic's alias_generator maps
them to snake_case Python attributes automatically.
"""
from __future__ import annotations
from typing import Literal, Optional
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
PageType = Literal[
"timeline",
"tasks",
"projects-list",
"project",
"note",
]
EntityType = Literal["project", "note", "task", "timeline_event"]
class ContextualScope(BaseModel):
"""Scope payload sent by the Electron renderer for contextual chat.
The renderer ships camelCase keys (entityType, entityId, ...). Pydantic's
alias generator maps them to snake_case Python attrs.
"""
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
page: PageType
entity_type: Optional[EntityType] = None
entity_id: Optional[str] = None
entity_name: Optional[str] = None
project_id: Optional[str] = None
char_count: Optional[int] = None
counts: Optional[dict[str, int]] = None
filters: Optional[dict] = None
def render_scope_block(scope: ContextualScope) -> str:
"""Produce a single-paragraph human-readable summary of the current view
for injection into the contextual agent system prompt.
Never emits internal ids — only names. The LLM is told to use names in
prose; ids travel through tool calls.
"""
if scope.entity_type == "project":
c = scope.counts or {}
return (
f"User is viewing the project {scope.entity_name!r}. "
f"{c.get('tasks', 0)} tasks, "
f"{c.get('notes', 0)} notes, "
f"{c.get('milestones', 0)} milestones."
)
if scope.entity_type == "note":
return (
f"User is viewing the note {scope.entity_name!r} "
f"({scope.char_count or 0} characters)."
)
if scope.page == "tasks":
return "User is viewing the global Tasks list (all projects)."
if scope.page == "timeline":
return "User is viewing the global Timeline view."
if scope.page == "projects-list":
return "User is viewing the Projects list."
return f"User is on page {scope.page}."

View File

@@ -1,27 +1,34 @@
# ── Adiuva Microservices ─────────────────────────────────────────────
# docker compose up --build
# docker compose up --build auth ws-gateway chat # subset
services:
app:
build: .
# ═══════════════════════════════════════════════════════════════════
# Infrastructure
# ═══════════════════════════════════════════════════════════════════
traefik:
image: traefik:v3.1
ports:
- "8080:8000"
env_file:
- path: .env
required: false
- "80:80"
- "443:443"
- "8080:8080" # dashboard (dev only)
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
CF_DNS_API_TOKEN: ${CF_DNS_API_TOKEN:-}
volumes:
- copilot_tokens:/root/.config/litellm/github_copilot
depends_on:
db:
condition: service_healthy
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./traefik/traefik.yml:/etc/traefik/traefik.yml:ro
- ./traefik/dynamic:/etc/traefik/dynamic:ro
- traefik_acme:/etc/traefik/acme
restart: unless-stopped
db:
image: pgvector/pgvector:pg16
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: adiuvai
POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-postgres}
POSTGRES_DB: ${POSTGRES_DB:-adiuva}
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
@@ -31,11 +38,161 @@ services:
retries: 5
restart: unless-stopped
# Optional Redis for future rate-limit or caching needs
# redis:
# image: redis:7-alpine
redis:
image: redis:7-alpine
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 5
restart: unless-stopped
# ── Optional infrastructure (uncomment as needed) ────────────────
# minio:
# image: minio/minio:latest
# command: server /data --console-address ":9001"
# ports:
# - "9000:9000"
# - "9001:9001"
# environment:
# MINIO_ROOT_USER: minioadmin
# MINIO_ROOT_PASSWORD: minioadmin
# volumes:
# - minio_data:/data
# healthcheck:
# test: ["CMD", "mc", "ready", "local"]
# interval: 5s
# timeout: 5s
# retries: 5
# restart: unless-stopped
# qdrant:
# image: qdrant/qdrant:latest
# ports:
# - "6333:6333"
# - "6334:6334"
# volumes:
# - qdrant_data:/qdrant/storage
# restart: unless-stopped
# ═══════════════════════════════════════════════════════════════════
# Migrations (run once, then exit)
# ═══════════════════════════════════════════════════════════════════
migrate:
build:
context: .
dockerfile: Dockerfile
command: ["python", "-m", "alembic", "upgrade", "head"]
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
depends_on:
db:
condition: service_healthy
restart: "no"
# ═══════════════════════════════════════════════════════════════════
# Application Services
# ═══════════════════════════════════════════════════════════════════
auth:
build:
context: .
dockerfile: services/auth/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
REDIS_URL: redis://redis:6379/0
depends_on:
db:
condition: service_healthy
migrate:
condition: service_completed_successfully
restart: unless-stopped
ws-gateway:
build:
context: .
dockerfile: services/ws-gateway/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
REDIS_URL: redis://redis:6379/0
depends_on:
redis:
condition: service_healthy
auth:
condition: service_started
restart: unless-stopped
chat:
build:
context: .
dockerfile: services/chat/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
REDIS_URL: redis://redis:6379/0
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
migrate:
condition: service_completed_successfully
restart: unless-stopped
batch-agent:
build:
context: .
dockerfile: services/batch-agent/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
REDIS_URL: redis://redis:6379/0
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
migrate:
condition: service_completed_successfully
restart: unless-stopped
billing:
build:
context: .
dockerfile: services/billing/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
depends_on:
db:
condition: service_healthy
migrate:
condition: service_completed_successfully
restart: unless-stopped
volumes:
postgres_data:
copilot_tokens:
redis_data:
traefik_acme:
# minio_data:
# qdrant_data:

View File

@@ -0,0 +1,941 @@
# Adiuva — Architettura Microservizi (MVP)
## Panoramica
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
```
┌──────────────┐
│ Cloudflare │
│ (DNS + CDN) │
└──────┬───────┘
│ HTTPS / WSS
┌──────▼───────┐
│ Traefik │
│ API Gateway │
│ (routing, │
│ TLS, rate │
│ limiting) │
└──────┬───────┘
┌──────────┬───────────┼───────────┐
│ │ │ │
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
│ Auth │ │ Chat │ │ Agent │ │Billing │
│ Service │ │Service│ │ Service │ │Service │
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
│ │ │ │
┌─────▼──────────▼──────────▼───────────▼────┐
│ Infrastruttura │
│ PostgreSQL │ Redis │ Qdrant │
└─────────────────────────────────────────────┘
```
---
## 1. Suddivisione dei Servizi
### 1.1 Auth Service (`auth-service`)
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
| Endpoint originale | Metodo |
|---|---|
| `/api/v1/auth/register` | POST |
| `/api/v1/auth/login` | POST |
| `/api/v1/auth/refresh` | POST |
| `/api/v1/auth/me` | GET / PUT |
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
**Modifica chiave — JWT con RS256**:
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
- L'Auth Service firma i JWT con la **chiave privata**.
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
```python
# auth-service/app/auth/jwt.py
from cryptography.hazmat.primitives.asymmetric import rsa
from jose import jwt
PRIVATE_KEY = ... # Da env/secret
PUBLIC_KEY = ... # Derivata o da env
def create_access_token(user_id: str, tier: str) -> str:
return jwt.encode(
{"sub": user_id, "tier": tier, "exp": ...},
PRIVATE_KEY,
algorithm="RS256",
)
```
```python
# shared/auth.py (usato da tutti gli altri servizi)
from jose import jwt
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
def verify_token(token: str) -> dict:
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
```
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
---
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
| Endpoint | Tipo |
|---|---|
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
| `/api/v1/chat` | POST (REST fallback) |
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
**Scaling**: 2N repliche. Sticky cookies per le WS + Redis per cross-instance.
---
### 1.3 Agent Service (`agent-service`) ⭐ Batch
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
| Endpoint | Tipo |
|---|---|
| `/api/v1/agents/catalog` | GET |
| `/api/v1/agents/can-create` | POST |
| `/api/v1/agents/trigger` | POST |
| `/api/v1/agents/journey/start` | POST (o WS relay) |
| `/api/v1/agents/journey/message` | POST (o WS relay) |
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
```
┌──────────────┐ ┌──────────────┐ ┌──────────┐
│ Agent Service│ │ Redis │ │ Chat │
│ (batch run) │ │ │ │ Service │
│ │ │ │ │ (ha WS) │
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
│ from │ │ │ │ al │
│ device │ │ │ │ device│
│ │ │ │ │ via WS│
│ │ SUBSCRIBE │ │ PUBLISH │ │
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
│ risultato │ │ │ │ reply │
└──────────────┘ └──────────────┘ └──────────┘
```
**Scaling**: 1N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
---
### 1.4 Billing Service (`billing-service`)
**Responsabilità**: Stripe checkout, webhook, subscription management.
| Endpoint originale | Metodo |
|---|---|
| `/api/v1/billing/checkout` | POST |
| `/api/v1/billing/webhook` | POST |
| `/api/v1/billing/subscription` | GET / DELETE |
**Database**: Tabelle `subscriptions` (schema `billing`).
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
**Scaling**: 1 replica sufficiente. Basso traffico.
---
### 1.5 Servizi esclusi dall'MVP
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
| Servizio | Responsabilità | Note |
|---|---|---|
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
---
## 2. Tier Check — Dove e Come
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
### Strategia: Tier nel JWT
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
```json
{
"sub": "user_123",
"tier": "pro",
"exp": 1742515200,
"iat": 1742511600
}
```
Ogni servizio:
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
2. Legge `payload["tier"]`**zero chiamate extra**
3. Applica le sue regole di enforcement localmente
```python
# shared/auth.py — dependency FastAPI condivisa
from fastapi import Depends, HTTPException, Request
from jose import jwt
PUBLIC_KEY = ...
class CurrentUser:
def __init__(self, user_id: str, tier: str):
self.user_id = user_id
self.tier = tier
async def get_current_user(request: Request) -> CurrentUser:
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
def require_tier(*allowed_tiers: str):
"""Dependency che blocca se il tier non è tra quelli ammessi."""
async def check(user: CurrentUser = Depends(get_current_user)):
if user.tier not in allowed_tiers:
raise HTTPException(403, "Tier insufficient")
return user
return check
```
### Cosa succede quando il tier cambia (upgrade/downgrade)?
```
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
│ │ │ Service │ (Redis pub/sub) │ Service │
└──────────┘ └──────────┘ └────┬─────┘
UPDATE users
SET tier = 'power'
Al prossimo /refresh
il JWT conterrà tier='power'
```
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 1530 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
### Dove si applica in ciascun servizio
| Servizio | Enforcement |
|---|---|
| **Auth Service** | Nessuno (è lui che scrive il tier) |
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
### Rate-limit distribuito via Redis
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
```python
# shared/middleware/rate_limit.py
import redis.asyncio as aioredis
class DistributedRateLimiter:
def __init__(self, redis: aioredis.Redis):
self._redis = redis
async def check(self, user_id: str, tier: str, service: str) -> bool:
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
max_req = limits.get(tier, 20)
key = f"rate:{service}:{user_id}"
pipe = self._redis.pipeline()
pipe.incr(key)
pipe.expire(key, 60)
count, _ = await pipe.execute()
return count <= max_req
```
---
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
`DeviceConnectionManager` è un **singleton in-memory**:
```python
class DeviceConnectionManager:
def __init__(self):
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
```
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
### La soluzione: Redis Pub/Sub + Registry
```
┌──────────────────────────────────────────────────────────────┐
│ Redis │
│ │
│ Hash: ws:connections │
│ user_123 → instance_A │
│ user_456 → instance_B │
│ │
│ Pub/Sub channels: │
│ tool_call:{user_id} → tool call payloads │
│ tool_result:{call_id} → tool result payloads │
│ stream:{user_id} → text_chunk streaming │
└──────────────────────────────────────────────────────────────┘
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
┌───────────────────────┐ ┌───────────────────────┐
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
│ tool_call:user_123│ │ → user_123 è su A │
│ │ │ │
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
│ da Redis channel │ │ tool_call:user_123 │
│ │ │ {id, action, ...} │
│ 3. Invia al device │ │ │
│ via WS │ │ 4. SUBSCRIBE │
│ │ │ tool_result:{id} │
│ 4. Device risponde │ │ │
│ tool_result │──────────│► 5. Riceve risultato │
│ │ │ │
│ 5. PUBLISH │ │ │
│ tool_result:{id} │ │ │
└───────────────────────┘ └───────────────────────┘
```
### Implementazione: `RedisDeviceManager`
```python
# chat-service/app/core/device_manager.py
import asyncio
import json
import os
import redis.asyncio as aioredis
from dataclasses import dataclass, field
from fastapi import WebSocket
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
@dataclass
class LocalConnection:
ws: WebSocket
device_id: str
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
class RedisDeviceManager:
"""Device manager backed by Redis for cross-instance communication."""
def __init__(self, redis_url: str = "redis://redis:6379"):
self._redis = aioredis.from_url(redis_url)
self._pubsub = self._redis.pubsub()
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
async def start(self):
"""Avvia il listener Redis per tool_call in arrivo."""
asyncio.create_task(self._listen_tool_calls())
# ── Registrazione ──
async def register(self, user_id: str, device_id: str, ws: WebSocket):
# Registra localmente
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
# Registra in Redis quale istanza ha la connessione
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
# Sottoscrivi ai tool_call per questo utente
await self._pubsub.subscribe(f"tool_call:{user_id}")
async def unregister(self, user_id: str):
conn = self._local.pop(user_id, None)
if conn:
for fut in conn.pending_calls.values():
if not fut.done():
fut.cancel()
await self._redis.hdel("ws:connections", user_id)
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
# ── Presenza ──
async def is_online(self, user_id: str) -> bool:
return await self._redis.hexists("ws:connections", user_id)
# ── Tool-call round-trip (cross-instance) ──
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
"""
Invia un tool_call al device dell'utente.
Funziona sia che la WS sia locale che su un'altra istanza.
"""
call_id = payload["id"]
# Caso 1: connessione locale → invio diretto
if user_id in self._local:
conn = self._local[user_id]
loop = asyncio.get_event_loop()
fut: asyncio.Future[dict] = loop.create_future()
conn.pending_calls[call_id] = fut
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
return await asyncio.wait_for(fut, timeout=30.0)
# Caso 2: connessione remota → Redis pub/sub
loop = asyncio.get_event_loop()
fut = loop.create_future()
self._remote_futures[call_id] = fut
# Sottoscrivi al canale di risposta
result_channel = f"tool_result:{call_id}"
await self._pubsub.subscribe(result_channel)
# Pubblica il tool_call
await self._redis.publish(
f"tool_call:{user_id}",
json.dumps(payload),
)
try:
return await asyncio.wait_for(fut, timeout=30.0)
finally:
self._remote_futures.pop(call_id, None)
await self._pubsub.unsubscribe(result_channel)
# ── Risoluzione tool_result (da WS locale) ──
def resolve_local(self, user_id: str, call_id: str, result: dict):
conn = self._local.get(user_id)
if conn:
fut = conn.pending_calls.pop(call_id, None)
if fut and not fut.done():
fut.set_result(result)
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
"""Chiamato quando il device locale invia un tool_result."""
self.resolve_local(user_id, call_id, result)
# Pubblica anche su Redis per l'istanza remota che aspetta
await self._redis.publish(
f"tool_result:{call_id}",
json.dumps(result),
)
# ── Listener Redis ──
async def _listen_tool_calls(self):
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
async for message in self._pubsub.listen():
if message["type"] != "message":
continue
channel = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
data = json.loads(message["data"])
if channel.startswith("tool_call:"):
# Un'altra istanza vuole che inviamo un tool_call al nostro device
user_id = channel.split(":", 1)[1]
conn = self._local.get(user_id)
if conn:
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
elif channel.startswith("tool_result:"):
# Risposta a un tool_call che abbiamo inviato tramite Redis
call_id = channel.split(":", 1)[1]
fut = self._remote_futures.pop(call_id, None)
if fut and not fut.done():
fut.set_result(data)
# ── Stream cross-instance ──
async def publish_stream_chunk(self, user_id: str, chunk: dict):
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
```
---
## 4. Struttura Directory Proposta (MVP)
```
adiuva-api/
├── docker-compose.yml # Orchestrazione completa
├── docker-compose.dev.yml # Override per sviluppo locale
├── shared/ # Codice condiviso (montato come volume)
│ ├── auth.py # JWT verification (chiave pubblica)
│ ├── schemas.py # Pydantic schemas condivisi
│ ├── middleware/
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
│ │ └── sanitizer.py
│ └── models/
│ └── base.py # SQLAlchemy base condivisa
├── auth-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # users, refresh_tokens
│ ├── routes/
│ │ └── auth.py
│ └── services/
│ ├── jwt_service.py # RS256 signing
│ └── user_service.py
├── chat-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # memory_*
│ ├── routes/
│ │ ├── device_ws.py # WS connection owner
│ │ └── chat.py # REST fallback
│ ├── core/
│ │ ├── device_manager.py # RedisDeviceManager
│ │ ├── deep_agent.py # Home + floating chat
│ │ ├── memory_middleware.py
│ │ ├── ws_context.py
│ │ ├── output_formatter.py
│ │ └── llm.py
│ └── agents/ # Tool definitions (used by deep_agent)
│ ├── task_agent.py
│ ├── project_agent.py
│ ├── note_agent.py
│ └── timeline_agent.py
├── agent-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
│ ├── routes/
│ │ ├── agents.py # catalog, can-create, trigger
│ │ └── agent_setup.py # journey start/message
│ ├── core/
│ │ ├── agent_runner.py # Batch classify → process
│ │ ├── agent_registry.py
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
│ │ └── llm.py
│ └── agents/
│ ├── task_agent.py # Tool definitions (batch context)
│ ├── project_agent.py
│ ├── note_agent.py
│ ├── timeline_agent.py
│ └── filesystem_agent.py
├── billing-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # subscriptions
│ ├── routes/
│ │ └── billing.py
│ └── services/
│ ├── stripe_service.py
│ └── tier_manager.py
└── infra/
├── traefik/
│ └── traefik.yml
├── keys/
│ ├── jwt_private.pem # Solo auth-service
│ └── jwt_public.pem # Tutti i servizi
└── alembic/ # Migrazioni condivise o per-servizio
```
---
## 5. Docker Compose — Configurazione MVP
```yaml
# docker-compose.yml
services:
# ══════════════════════════════════════════════════════════
# API Gateway
# ══════════════════════════════════════════════════════════
traefik:
image: traefik:v3.2
command:
- "--api.insecure=true"
- "--providers.docker=true"
- "--providers.docker.exposedbydefault=false"
- "--entrypoints.web.address=:80"
- "--entrypoints.websecure.address=:443"
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
ports:
- "80:80"
- "443:443"
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./infra/certs:/certs:ro
restart: unless-stopped
# ══════════════════════════════════════════════════════════
# Auth Service (2 repliche)
# ══════════════════════════════════════════════════════════
auth-service:
build: ./auth-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
SERVICE_NAME: auth
secrets:
- jwt_private_key
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
- "traefik.http.services.auth.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Chat Service — Real-time WS + Chat (scalabile)
# ══════════════════════════════════════════════════════════
chat-service:
build: ./chat-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: chat
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
# REST chat endpoint
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
- "traefik.http.services.chat.loadbalancer.server.port=8000"
# WebSocket route con sticky session
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
- "traefik.http.routers.ws.service=chat-ws"
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Agent Service — Batch processing (scalabile indipendentemente)
# ══════════════════════════════════════════════════════════
agent-service:
build: ./agent-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: agent
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
- "traefik.http.services.agents.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Billing Service (1 replica)
# ══════════════════════════════════════════════════════════
billing-service:
build: ./billing-service
deploy:
replicas: 1
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: billing
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
- "traefik.http.services.billing.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Infrastruttura
# ══════════════════════════════════════════════════════════
db:
image: pgvector/pgvector:pg16
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: adiuva
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
redis:
image: redis:7-alpine
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 5
restart: unless-stopped
qdrant:
image: qdrant/qdrant:latest
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
secrets:
jwt_private_key:
file: ./infra/keys/jwt_private.pem
jwt_public_key:
file: ./infra/keys/jwt_public.pem
volumes:
postgres_data:
redis_data:
qdrant_data:
```
---
## 6. Configurazione Cloudflare + VPS
### 6.1 DNS
```
api.tuodominio.com → A record → IP del VPS
→ Proxy: ON (orange cloud)
```
### 6.2 Cloudflare Settings
| Setting | Valore | Motivo |
|---------|--------|--------|
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
| Under Attack Mode | Off (attivare se necessario) | |
### 6.3 TLS sul VPS
Due opzioni:
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
```yaml
# traefik.yml — con Cloudflare Origin Certificate
entryPoints:
websecure:
address: ":443"
tls:
certificates:
- certFile: /certs/origin.pem
keyFile: /certs/origin-key.pem
```
### 6.4 Rete VPS
```bash
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
# https://www.cloudflare.com/ips/
ufw default deny incoming
ufw allow from 173.245.48.0/20 to any port 443
ufw allow from 103.21.244.0/22 to any port 443
# ... (tutti gli IP range di Cloudflare)
ufw allow ssh
ufw enable
```
---
## 7. Comunicazione Inter-Servizio
### 7.1 Redis Pub/Sub — Event Bus
```
┌──────────┐ tier_changed:user_123 ┌──────────┐
│ Billing │ ────────────────────────► │ Auth │
│ Service │ │ Service │
└──────────┘ └──────────┘
┌──────────┐ tool_call:user_123 ┌──────────┐
│ Agent │ ────────────────────────► │ Chat │
│ Service │ │ Service │
│ (batch) │ ◄────────────────────────│ (ha WS) │
└──────────┘ tool_result:{call_id} └──────────┘
```
### 7.2 Health Checks e Service Discovery
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
- **Redis pub/sub** (tool-call cross-instance, tier events)
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
- **PostgreSQL** (dati persistenti condivisi)
---
## 8. Piano di Migrazione Incrementale (MVP)
### Fase 1 — Preparazione (nel monolite attuale)
1. Aggiungere Redis al `docker-compose.yml` attuale
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
4. Estrarre `shared/` con auth verification, schemas, middleware
### Fase 2 — Auth Service (primo split)
1. Estrarre `auth.py` routes + models in `auth-service/`
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
4. Il monolite continua a servire tutto il resto
### Fase 3 — Billing Service
1. Estrarre billing routes, Stripe service, tier manager
2. Configurare Redis pub/sub per `tier_changed` events
3. Routare via Traefik
### Fase 4 — Split Chat + Agent (il più delicato)
1. Il monolite residuo contiene WS + chat + agents
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
### Fase 5 — Scaling test
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
3. Monitoring (Prometheus + Grafana) per ogni servizio
---
## 9. Monitoraggio e Logging
```yaml
# Aggiungere al docker-compose.yml
prometheus:
image: prom/prometheus:latest
volumes:
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
restart: unless-stopped
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
volumes:
- grafana_data:/var/lib/grafana
restart: unless-stopped
loki:
image: grafana/loki:latest
restart: unless-stopped
```
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
---
## 10. Sizing VPS Minimo Consigliato (MVP)
| Componente | CPU | RAM | Note |
|---|---|---|---|
| Traefik | 0.25 | 128MB | |
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
| Billing Service | 0.25 | 128MB | |
| PostgreSQL | 1.0 | 1GB | |
| Redis | 0.25 | 256MB | |
| Qdrant | 0.5 | 512MB | |
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
---
## Riepilogo Architettura MVP
| Servizio | Repliche | Proprietario di |
|---|---|---|
| **Traefik** | 1 | Routing, TLS, sticky sessions |
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
| **Chat Service** | 2N | WebSocket, home/floating chat, streaming |
| **Agent Service** | 2N | Batch processing, directory scan, agent setup |
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
| Decisione | Scelta | Motivazione |
|---|---|---|
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
| TLS | Cloudflare Origin Certificate | Zero maintenance |
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
| Storage / Plugin | Post-MVP | Non critici per il lancio |

View File

@@ -1,43 +0,0 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
langchain>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.1.0
litellm>=1.50.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
stripe>=11.0.0
boto3>=1.35.0
slowapi>=0.1.9
sqlalchemy>=2.0.0
asyncpg>=0.30.0
alembic>=1.14.0
bcrypt>=4.2.0
python-dotenv>=1.0.0
httpx>=0.28.0
websockets>=14.0
psycopg2-binary>=2.9.0
pytest>=8.0.0
pytest-asyncio>=0.24.0
aiosqlite>=0.20.0
moto[s3]>=5.0.0
pinecone>=5.0.0
qdrant-client>=1.7.0
croniter>=3.0.0
google-api-python-client>=2.130.0
google-auth>=2.29.0
google-auth-oauthlib>=1.2.0
google-auth-httplib2>=0.2.0
msal>=1.28.0
cryptography>=42.0.0
pgvector>=0.2.5
langfuse>=3.3.1
beautifulsoup4>=4.12.0
lxml>=5.0.0
PyYAML>=6.0.0
apscheduler>=3.10.0
ruff>=0.8.0
pypdf>=4.0
python-docx>=1.1

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,19 @@
# ── Auth Service ──────────────────────────────────────────────────────────────
# This file contains env vars specific to the Auth Service.
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
# or from docker-compose environment.
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
# Generate keypair:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# openssl rsa -in private.pem -pubout -out public.pem
#
# Paste PEM content with literal \n for newlines:
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
JWT_PRIVATE_KEY=
# PUBLIC KEY — used to VERIFY JWTs.
JWT_PUBLIC_KEY=

36
services/auth/Dockerfile Normal file
View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
# Install shared + service deps in one layer
COPY services/auth/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Copy shared module (available to all services)
COPY shared/ shared/
# Copy service source
COPY services/auth/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "30"]

16
services/auth/README.md Normal file
View File

@@ -0,0 +1,16 @@
# Auth Service
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
## Tables owned
- `users`
- `refresh_tokens`
- `subscriptions` (read; Billing Service writes)
## Endpoints
- `POST /auth/register`
- `POST /auth/login`
- `POST /auth/refresh`
- `GET /auth/me`
- `PUT /auth/me`
- `GET /auth/verify` (ForwardAuth for Traefik)

View File

@@ -0,0 +1,34 @@
"""Auth Service — local configuration.
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
These are NOT in shared/config.py to prevent other services from accessing them.
"""
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class AuthSettings(BaseSettings):
# RS256 private key (PEM format). Used to SIGN JWTs.
# Only the Auth Service has this. Generate with:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# Then set the env var (newlines as \n):
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
JWT_PRIVATE_KEY: str = ""
# RS256 public key (PEM format). Used to VERIFY JWTs.
# Derived from the private key:
# openssl rsa -in private.pem -pubout -out public.pem
JWT_PUBLIC_KEY: str = ""
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
@classmethod
def _expand_pem_newlines(cls, v: str) -> str:
if isinstance(v, str) and r"\n" in v:
return v.replace(r"\n", "\n")
return v
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
auth_settings = AuthSettings()

69
services/auth/app/deps.py Normal file
View File

@@ -0,0 +1,69 @@
"""Auth dependencies — JWT validation for the Auth Service.
This is the canonical get_current_user used by protected endpoints
within the Auth Service itself (/me, /me PUT).
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.config import settings
from shared.db import get_session
from shared.models import Subscription, User
from shared.schemas import UserProfile
from app.config import auth_settings
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
The JWT is used for identity and expiry. Tier is fetched live from the
subscriptions table so upgrades/downgrades take effect immediately.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
# Live tier lookup
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname
user_result = await db.execute(
select(User.name, User.surname).where(User.id == user_id)
)
user_row = user_result.one_or_none()
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
tier=tier,
) # type: ignore[arg-type]

62
services/auth/app/main.py Normal file
View File

@@ -0,0 +1,62 @@
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
Standalone FastAPI service extracted from the adiuva-api monolith.
Owns: users, refresh_tokens, subscriptions (read).
"""
import sys
from contextlib import asynccontextmanager
from pathlib import Path
# Ensure the repo root is on sys.path so "shared" is importable.
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
# In local dev, we need to add the repo root (two levels up from this file).
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from shared.config import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
from shared.db import engine
await engine.dispose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Auth Service",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from app.routes import router
from app.verify import router as verify_router
app.include_router(router, prefix="/api/v1")
app.include_router(verify_router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "auth", "version": app.version}
return app
app = create_app()

249
services/auth/app/routes.py Normal file
View File

@@ -0,0 +1,249 @@
"""Auth routes: register, login, refresh, me.
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
"""
from __future__ import annotations
import hashlib
import time
import uuid
from datetime import datetime, timedelta, timezone
import bcrypt
from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.config import settings
from shared.db import get_session
from shared.models import RefreshToken, Subscription, User
from shared.schemas import AuthTokens, UserProfile
from app.config import auth_settings
from app.deps import get_current_user
router = APIRouter(prefix="/auth", tags=["auth"])
# ── Internal helpers ─────────────────────────────────────────────────
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def _verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def _hash_token(plain_token: str) -> str:
"""SHA-256 of the plain refresh token string."""
return hashlib.sha256(plain_token.encode()).hexdigest()
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"""Return (RS256-signed JWT, expires_at_ms)."""
now = int(time.time())
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
payload = {
"sub": user_id,
"email": email,
"tier": tier,
"exp": exp,
"iat": now,
}
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
return token, exp * 1000 # ms for client
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
"""Fetch authoritative tier from subscriptions table."""
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
return result.scalar_one_or_none() or default_tier
# ── Request bodies ────────────────────────────────────────────────────
class _RegisterRequest(BaseModel):
email: str
password: str
name: str | None = None
surname: str | None = None
class _LoginRequest(BaseModel):
email: str
password: str
class _RefreshRequest(BaseModel):
refresh_token: str
class _UpdateProfileRequest(BaseModel):
name: str | None = None
surname: str | None = None
# ── Routes ────────────────────────────────────────────────────────────
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
async def register(
body: _RegisterRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Create a new account and return JWT tokens."""
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
user = User(
id=str(uuid.uuid4()),
email=body.email,
name=body.name,
surname=body.surname,
password_hash=_hash_password(body.password),
tier="free",
encryption_key=Fernet.generate_key().decode(),
)
db.add(user)
await db.flush()
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/login", response_model=AuthTokens)
async def login(
body: _LoginRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate credentials and return JWT tokens."""
result = await db.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not _verify_password(body.password, user.password_hash):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
# Fetch live tier for the JWT claim
tier = await _get_live_tier(db, user.id)
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/refresh", response_model=AuthTokens)
async def refresh(
body: _RefreshRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Rotate a refresh token and return a new token pair."""
token_hash = _hash_token(body.refresh_token)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
rt = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
await db.delete(rt)
user_result = await db.execute(select(User).where(User.id == rt.user_id))
user = user_result.scalar_one_or_none()
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
# Fetch live tier for the new JWT
tier = await _get_live_tier(db, user.id)
plain_token = str(uuid.uuid4())
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=new_expires,
)
db.add(new_rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.get("/me", response_model=UserProfile)
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
"""Return the profile for the authenticated user."""
return current_user
@router.put("/me", response_model=UserProfile)
async def update_profile(
body: _UpdateProfileRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update the authenticated user's name and surname."""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
if body.name is not None:
user.name = body.name
if body.surname is not None:
user.surname = body.surname
await db.commit()
await db.refresh(user)
return UserProfile(
id=user.id,
email=user.email,
name=user.name,
surname=user.surname,
tier=current_user.tier,
)

View File

@@ -0,0 +1,66 @@
"""ForwardAuth verification endpoint for Traefik.
Traefik calls GET /api/v1/auth/verify on every request to a protected
service. This endpoint validates the JWT from the Authorization header
and returns identity headers that Traefik injects into downstream requests.
Downstream services NEVER validate JWTs themselves — they trust the
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
"""
from __future__ import annotations
from fastapi import APIRouter, Request, Response
from fastapi import status as http_status
from jose import JWTError, jwt
from sqlalchemy import select
from shared.config import settings
from shared.db import async_session
from shared.models import Subscription
from app.config import auth_settings
router = APIRouter(tags=["auth"])
@router.get("/auth/verify")
async def verify(request: Request) -> Response:
"""Validate JWT and return identity headers for Traefik ForwardAuth.
Returns 200 with X-User-* headers on success, 401 on failure.
Traefik copies response headers to the downstream request.
"""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
token = auth_header[7:] # strip "Bearer "
try:
payload = jwt.decode(
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
except JWTError:
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
# Live tier lookup from subscriptions table
async with async_session() as db:
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
return Response(
status_code=http_status.HTTP_200_OK,
headers={
"X-User-Id": user_id,
"X-User-Email": email,
"X-User-Tier": tier,
},
)

View File

@@ -0,0 +1,11 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
bcrypt>=4.2.0
cryptography>=42.0.0
python-dotenv>=1.0.0

View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/batch-agent/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/batch-agent/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "300"]

View File

@@ -0,0 +1,23 @@
# Batch Agent Service
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
## Tables owned
- `local_agent_configs`
- `cloud_agent_configs`
- `agent_run_logs`
## Endpoints
- `GET /agents/catalog`
- `POST /agents/can-create`
- `POST /agents/trigger`
- `GET /agents/{id}/history`
## Redis channels
- Subscribe: `batch:request:{user_id}`
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
- BRPOP: `tool:result:{call_id}` (30s timeout)
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
## TODO
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.

View File

@@ -0,0 +1,910 @@
"""Agent run orchestrator — adapted for Batch Agent Service.
Key changes from monolith app/core/agent_runner.py:
- No DeviceConnectionManager — tool calls go through Redis ws_context.
- set_current_user / clear_current_user replace set_client_executor.
- run_local_agent accepts a serialized dict (from Redis / REST) instead
of SQLAlchemy model objects.
- _finalize_run writes to PostgreSQL via shared.db.async_session.
- Cloud agent import path changed to app.integrations.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from sqlalchemy import select
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from shared.agents.note_agent import NOTE_TOOLS
from shared.agents.project_agent import PROJECT_TOOLS
from shared.agents.task_agent import TASK_TOOLS
from shared.agents.timeline_agent import TIMELINE_TOOLS
from shared.llm import get_llm
from shared.ws_context import execute_on_client, set_current_user, clear_current_user
import app.tracing as tracing
from shared.db import async_session
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from shared.redis import redis_client, ws_out_channel
logger = logging.getLogger(__name__)
# ── Concurrency guard ─────────────────────────────────────────────────────
_running_agents: set[str] = set()
def is_agent_running(agent_id: str) -> bool:
return agent_id in _running_agents
# ── Timeouts ───────────────────────────────────────────────────────────────
_TOOL_CALL_TIMEOUT: int = 30
_MAX_PROCESSING_STEPS: int = 12
_MAX_SCAN_DEPTH: int = 5
# ── Data-type to tool mapping ─────────────────────────────────────────────
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
"tasks": TASK_TOOLS,
"notes": NOTE_TOOLS,
"timelines": TIMELINE_TOOLS,
}
# ── Step 1: Classification prompt ─────────────────────────────────────────
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
"tasks": (
"Action items, to-dos, deliverables — anything that describes work to be done, "
"assigned to someone, or tracked with a due date or status."
),
"notes": (
"Documentation, meeting notes, summaries, reference material — "
"written content meant to be read and referenced rather than acted on."
),
"timelines": (
"Project milestones, deadlines, scheduled events — "
"specific dates that mark a point in the progress of a project."
),
"projects": (
"High-level project entities — only relevant if the file clearly introduces "
"a new project or updates the scope of an existing one."
),
}
_STEP1_SYSTEM_PROMPT = """\
You are a file classifier for a freelance project management tool.
Your job is to match a file to an existing project and identify which data domains to extract.
## Project matching rules (STRICT — follow in order)
1. Search the file content for any mention of a project name, client name, acronym, or topic
that overlaps with the existing projects listed below.
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
when the file has zero meaningful connection to any listed project.
4. When in doubt, pick the closest match from the list.
## Response format
Respond ONLY with a JSON object — no markdown, no explanation:
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
## Domain definitions (only consider domains in the allowed list)
{domain_definitions}
## Existing projects
{projects_list}
"""
# ── Step 2: Processing prompt ─────────────────────────────────────────────
_PROCESSING_SYSTEM_PROMPT = """\
You are a data extraction assistant for a freelance project management tool.
Your task: extract structured data from the file content and persist it using the available tools.
## Mandatory process — follow this order for EVERY item you extract
1. READ the existing records listed below for the relevant domain.
2. SEARCH for a match by title, topic, or semantic similarity.
3. If a match exists → call the update_* tool with the existing record's id.
4. If no match exists → call the create_* tool and set isAiSuggested=1.
NEVER call create_* without first checking the existing records.
NEVER duplicate a record that already exists under a different wording.
## Existing records (source of truth)
{existing_context}
## Context
Project: {project_context}
Domains to extract: {data_types}
{custom_prompt_section}
"""
# ── Cloud processing prompt ───────────────────────────────────────────────
_CLOUD_PROCESSING_PROMPT = """\
You are a data extraction and management assistant for a freelance project
management tool.
Available tools:
Filesystem : read_file_content, list_directory, get_file_metadata
Tasks : list_tasks, create_task, update_task, add_task_comment
Notes : list_notes, get_note, create_note, update_note
Timelines : list_timelines, create_timeline, update_timeline
Projects : list_all_projects, get_project, create_project, update_project
Your task:
1. Read the full content of each file below using read_file_content.
2. For each piece of information found, ALWAYS try to match and update an
existing record before creating a new one.
3. ONLY act on these entity types: {data_types}.
4. Do NOT invent data. Only extract what is clearly present in the files.
5. If a file contains no relevant data for the target entity types, skip it.
{project_context}
Files to process:
{file_list}
{custom_prompt_section}
After processing all files, respond with a brief summary of what you updated
and what you created.
"""
# ── LLM tool-calling loop ─────────────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _run_agent_with_tools(
*,
system_prompt: str,
user_message: str,
tools: list[Any],
max_steps: int,
langfuse_handler: Any | None = None,
) -> str:
"""Run an LLM agent with tool-calling, returning the final text response."""
callbacks = [langfuse_handler] if langfuse_handler else None
llm = get_llm(callbacks=callbacks)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_message),
]
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_runner: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_runner: tool_result name=%s output=%s",
call_name,
str(tool_output)[:200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Tool list builder ─────────────────────────────────────────────────────
def _build_processing_tools(data_types: list[str]) -> list[Any]:
tools: list[Any] = list(FILESYSTEM_TOOLS)
for dt in data_types:
dt_tools = _DATA_TYPE_TOOLS.get(dt)
if dt_tools:
tools.extend(dt_tools)
return tools
# ── Code-based directory scanner ─────────────────────────────────────────
async def _scan_directories(
paths: list[str],
extensions: list[str],
last_run_at: datetime | None,
) -> list[str]:
all_files: list[str] = []
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
async def _walk(path: str, depth: int) -> None:
if depth > _MAX_SCAN_DEPTH:
return
try:
result = await execute_on_client(action="list_directory", data={"path": path})
except Exception as exc:
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
return
for entry in result.get("entries", []):
entry_path = entry.get("path", "")
if not entry_path:
continue
if entry.get("type") == "directory":
await _walk(entry_path, depth + 1)
elif entry.get("type") == "file":
if ext_set:
dot_pos = entry_path.rfind(".")
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
if file_ext not in ext_set:
continue
all_files.append(entry_path)
for root in paths:
await _walk(root, depth=0)
if last_run_at is None:
return all_files
last_run_ms = int(last_run_at.timestamp() * 1000)
filtered: list[str] = []
for file_path in all_files:
try:
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
modified_at = meta.get("modifiedAt")
if modified_at is None:
filtered.append(file_path)
continue
if isinstance(modified_at, (int, float)):
mod_ms = int(modified_at)
else:
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
if mod_ms > last_run_ms:
filtered.append(file_path)
except Exception:
filtered.append(file_path)
return filtered
# ── Code-based entity fetchers ────────────────────────────────────────────
async def _fetch_projects() -> list[dict]:
try:
result = await execute_on_client(action="select", table="projects")
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch projects: %s", exc)
return []
_DOMAIN_TABLE: dict[str, str] = {
"tasks": "tasks",
"notes": "notes",
"timelines": "timelines",
"projects": "projects",
}
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
table = _DOMAIN_TABLE.get(domain)
if not table:
return []
filters: dict[str, Any] = {}
if project_id != "standalone" and domain != "projects":
filters["projectId"] = project_id
try:
result = await execute_on_client(
action="select",
table=table,
filters=filters if filters else None,
)
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
return []
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
if not rows:
return f"No existing {domain}."
lines: list[str] = []
for r in rows:
if domain == "tasks":
desc = r.get("description") or ""
desc_part = f"{desc[:120]}" if desc else ""
assignee = r.get("assignee") or r.get("assignees") or ""
due = r.get("dueDate") or r.get("due_date") or ""
meta = ", ".join(filter(None, [
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
f"assignee: {assignee}" if assignee else "",
f"due: {due}" if due else "",
]))
lines.append(
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
f" ({meta}, id: {r['id']})"
)
elif domain == "notes":
snippet = (r.get("content") or "")[:200].replace("\n", " ")
snippet_part = f"\n Preview: {snippet}" if snippet else ""
lines.append(
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
)
elif domain == "timelines":
lines.append(
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
)
elif domain == "projects":
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
summary_part = f"{summary}" if summary else ""
lines.append(
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
f" (id: {r['id']})"
)
return f"Existing {domain}:\n" + "\n".join(lines)
# ── Step 1: LLM file classifier ───────────────────────────────────────────
async def _classify_file(
file_path: str,
file_content: str,
projects: list[dict],
config_data_types: list[str],
langfuse_handler: Any | None = None,
custom_system_prompt: str | None = None,
) -> tuple[str, list[str], str | None]:
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
if not file_content.strip():
return fallback
valid_project_ids = {p["id"] for p in projects}
def _fmt_project(p: dict) -> str:
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
summary_part = f"{summary[:100]}" if summary else ""
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
domain_definitions = "\n".join(
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
for d in config_data_types
if d in _DOMAIN_DESCRIPTIONS
)
if custom_system_prompt:
# Fixture-provided prompt takes absolute priority
system = custom_system_prompt.format_map(
{"domain_definitions": domain_definitions, "projects_list": projects_list}
)
else:
system = tracing.compile_prompt(
"batch_file_classifier",
fallback=_STEP1_SYSTEM_PROMPT,
variables={
"domain_definitions": domain_definitions,
"projects_list": projects_list,
},
)
llm = get_llm(callbacks=[langfuse_handler] if langfuse_handler else None)
try:
response = await llm.ainvoke([
SystemMessage(content=system),
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
])
raw = _as_text(response.content).strip()
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
parsed = json.loads(raw.strip())
raw_project_id: str = str(parsed.get("project_id") or "new")
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
new_project_name: str | None = (
str(parsed["new_project_name"]).strip() or None
if project_id == "new" and parsed.get("new_project_name")
else None
)
domains: list[str] = [
d for d in parsed.get("domains", [])
if d in config_data_types
]
if not domains:
domains = list(config_data_types)
return project_id, domains, new_project_name
except Exception as exc:
logger.warning(
"agent_runner: step1 classification failed for %r: %s", file_path, exc
)
return fallback
# ── Local agent runner (two-step per file) ────────────────────────────────
async def run_local_agent(user_id: str, trigger_data: dict[str, Any], *, langfuse_handler: Any | None = None) -> None:
"""Execute a local directory agent run.
In the microservice world, trigger_data is a serialized dict from
the REST route (forwarded via Redis), containing the agent config
fields and run_context.
set_current_user() must be called BEFORE this function.
"""
run_context: dict = trigger_data.get("run_context", {})
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
run_id = run_context.get("run_id")
_running_agents.add(agent_id)
# Extract config from trigger payload
directory_paths: list[str] = trigger_data.get("directory_paths", [])
if not directory_paths:
directory = trigger_data.get("directory", "")
if directory:
directory_paths = [directory]
data_types: list[str] = trigger_data.get("data_types", [])
file_extensions: list[str] = trigger_data.get("file_extensions", [])
prompt_template: str = trigger_data.get("prompt_template", "")
last_run_at_raw = trigger_data.get("last_run_at")
last_run_at: datetime | None = None
if last_run_at_raw:
if isinstance(last_run_at_raw, str):
last_run_at = datetime.fromisoformat(last_run_at_raw)
elif isinstance(last_run_at_raw, (int, float)):
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
errors: list[str] = []
items_processed = 0
items_created = 0
custom_section = (
f"User instructions:\n{prompt_template}"
if prompt_template
else ""
)
# Create or load run log
run_log_id = run_id
if not run_log_id:
async with async_session() as db:
run_log = AgentRunLog(
agent_id=agent_id,
agent_type="local",
user_id=user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
try:
# ── Scan directories ─────────────────────────────────────────
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
file_paths = await _scan_directories(
paths=directory_paths,
extensions=file_extensions,
last_run_at=last_run_at,
)
logger.info(
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
)
if not file_paths:
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
return
# ── Fetch all projects once ──────────────────────────────────
projects = await _fetch_projects()
for file_path in file_paths:
try:
file_result = await execute_on_client(
action="read_file_content", data={"path": file_path}
)
file_content: str = file_result.get("content", "")
if not file_content:
continue
items_processed += 1
# Step 1 — classify file
project_id, domains, new_project_name = await _classify_file(
file_path=file_path,
file_content=file_content,
projects=projects,
config_data_types=data_types,
langfuse_handler=langfuse_handler,
)
# Step 2 — resolve project_id, fetch entities, process
if project_id == "new":
proj_name = new_project_name or "Untitled Project"
try:
proj_result = await execute_on_client(
action="insert",
table="projects",
data={"name": proj_name, "clientId": None},
)
created = proj_result.get("row", {})
effective_project_id = created.get("id", "standalone")
if "id" in created:
projects.append(created)
except Exception as exc:
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
effective_project_id = "standalone"
proj_name = "unknown"
project_context = (
f"Project: {proj_name} (id: {effective_project_id}). "
"Always set projectId to this id on every record you create."
)
else:
effective_project_id = project_id
proj = next((p for p in projects if p["id"] == project_id), None)
proj_name = proj.get("name", project_id) if proj else project_id
project_context = (
f"Project: {proj_name} (id: {project_id}). "
"Always set projectId to this id on every record you create."
)
domains = [d for d in domains if d != "projects"]
existing_blocks: list[str] = []
for domain in domains:
rows = await _fetch_domain_entities(domain, effective_project_id)
existing_blocks.append(_format_entities_for_context(domain, rows))
existing_context = "\n\n".join(existing_blocks)
system_prompt = tracing.compile_prompt(
"batch_processing",
fallback=_PROCESSING_SYSTEM_PROMPT,
variables={
"existing_context": existing_context,
"project_context": project_context,
"data_types": ", ".join(domains),
"custom_prompt_section": custom_section,
},
)
processing_tools = _build_processing_tools(domains)
result_text = await _run_agent_with_tools(
system_prompt=system_prompt,
user_message=(
f"Process this file and extract relevant information.\n\n"
f"File: {file_path}\n\nContent:\n{file_content}"
),
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
langfuse_handler=langfuse_handler,
)
logger.info(
"agent_runner: run=%s file=%r result=%s",
run_log_id, file_path, result_text[:200],
)
except Exception as exc:
errors.append(f"Error processing '{file_path}': {exc}")
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
finally:
_running_agents.discard(agent_id)
# ── Finalise ────────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log_id,
status=final_status,
items_processed=items_processed,
items_created=items_created,
errors=errors,
)
# Notify Electron that the run is complete via Redis
if run_context:
try:
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps({
"type": "run_complete",
"run_context": run_context,
"status": final_status,
}))
except Exception as exc:
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
# ── Cloud agent runner ─────────────────────────────────────────────────────
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
async def run_cloud_agent(user_id: str, config_id: str, *, langfuse_handler: Any | None = None) -> None:
"""Execute a cloud connector agent run.
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
messages from the provider, and runs LLM extraction.
set_current_user() must be called BEFORE this function.
"""
from app.integrations import decrypt_token, encrypt_token, get_provider
async with async_session() as db:
result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
config = result.scalar_one_or_none()
if config is None:
logger.error("agent_runner: cloud config %s not found", config_id)
return
# Create run log
run_log = AgentRunLog(
agent_id=config.id,
agent_type="cloud",
user_id=user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
# ── Decrypt OAuth token ────────────────────────────────────────
if not config.oauth_token_encrypted:
await _finalize_run(
run_log_id,
status="error",
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
)
return
try:
credentials_info = decrypt_token(config.oauth_token_encrypted)
except ValueError as exc:
await _finalize_run(
run_log_id,
status="error",
errors=[f"Failed to decrypt OAuth token: {exc}"],
)
return
# ── Instantiate provider ──────────────────────────────────────
try:
provider = get_provider(config.provider, credentials_info)
except ValueError as exc:
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
return
# ── Fetch messages ────────────────────────────────────────────
since: datetime | None = config.last_run_at
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
if since.tzinfo is None:
since = since.replace(tzinfo=timezone.utc)
errors: list[str] = []
items_processed = 0
try:
if config.provider == "gmail":
raw_messages = await provider.fetch_messages(
filter_config=config.filter_config,
since=since,
)
elif config.provider == "outlook":
raw_messages = await provider.fetch_emails(
filter_config=config.filter_config,
since=since,
)
elif config.provider == "teams":
raw_messages = await provider.fetch_messages(
filter_config=config.filter_config,
since=since,
)
else:
raw_messages = []
except RuntimeError as exc:
await _finalize_run(
run_log_id,
status="error",
errors=[f"Provider fetch failed: {exc}"],
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
return
logger.info(
"agent_runner: cloud agent %s fetched %d item(s) from %s",
config.id, len(raw_messages), config.provider,
)
# ── Extract + insert via LLM ─────────────────────────────────
try:
processing_tools = _build_processing_tools(config.data_types)
custom_section = (
f"User instructions:\n{config.prompt_template}"
if config.prompt_template
else ""
)
for msg in raw_messages:
content_text = msg.as_text
if not content_text:
continue
items_processed += 1
processing_prompt = tracing.compile_prompt(
"batch_cloud_processing",
fallback=_CLOUD_PROCESSING_PROMPT,
variables={
"data_types": ", ".join(config.data_types),
"project_context": "Determine the appropriate project from the message context.",
"file_list": f"Message from {config.provider} (id: {msg.id})",
"custom_prompt_section": custom_section,
},
)
try:
await _run_agent_with_tools(
system_prompt=processing_prompt,
user_message=f"Process this message content:\n\n{content_text[:8000]}",
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
langfuse_handler=langfuse_handler,
)
except Exception as exc:
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
# ── Persist refreshed token ───────────────────────────────────
refreshed = getattr(provider, "refreshed_credentials", None)
if refreshed:
try:
new_encrypted = encrypt_token(refreshed)
async with async_session() as db:
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
)
cfg_row = cfg_result.scalar_one_or_none()
if cfg_row:
cfg_row.oauth_token_encrypted = new_encrypted
await db.commit()
except Exception as exc:
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
# ── Finalise ──────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log_id,
status=final_status,
items_processed=items_processed,
items_created=0,
errors=errors,
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
# ── Internal helper ─────────────────────────────────────────────────────────
async def _finalize_run(
run_log_id: int | str,
*,
status: str,
items_processed: int = 0,
items_created: int = 0,
errors: list[str] | None = None,
update_config_last_run: bool = False,
config_id: str | None = None,
config_type: str | None = None,
) -> None:
"""Persist the run outcome and optionally update last_run_at on the config."""
now = datetime.now(timezone.utc)
try:
async with async_session() as db:
result = await db.execute(
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
)
managed = result.scalar_one_or_none()
if managed is None:
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
return
managed.status = status
managed.items_processed = items_processed
managed.items_created = items_created
managed.errors = errors or []
managed.completed_at = now
if update_config_last_run and config_id:
if config_type == "local":
cfg_result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
elif config_type == "cloud":
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
await db.commit()
except Exception as exc:
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)

View File

@@ -0,0 +1 @@
"""Batch Agent Service domain agents and filesystem tools."""

View File

@@ -0,0 +1,83 @@
"""Filesystem agent — tools for reading local directories and files on Electron.
Adapted for Batch Agent Service: import from app.ws_context.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from shared.ws_context import execute_on_client
@tool
async def list_directory(path: str) -> str:
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
result = await execute_on_client(
action="list_directory",
data={"path": path},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{path}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str:
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
result = await execute_on_client(
action="read_file_content",
data={"path": path},
)
content: str = result.get("content", "")
if not content:
return f"File '{path}' is empty or could not be read."
return content
@tool
async def get_file_metadata(path: str) -> str:
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
result = await execute_on_client(
action="get_file_metadata",
data={"path": path},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", path)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
FILESYSTEM_TOOLS: list[Any] = [
list_directory,
read_file_content,
get_file_metadata,
]

View File

@@ -1,20 +1,11 @@
"""Cloud provider integration utilities.
Provides:
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
* ``get_provider()`` factory that returns the correct client given a
provider name and decrypted OAuth credentials dict.
* ``encrypt_token()`` / ``decrypt_token()`` Fernet-based at-rest
encryption for OAuth tokens stored in ``cloud_agent_configs``.
Adapted for Batch Agent Service: import from shared.config instead of app.config.
Encryption rationale
--------------------
Unlike user content (which is E2E-encrypted client-side and **never**
decrypted server-side), OAuth tokens *must* be decrypted server-side
because the backend makes provider API calls on behalf of the user.
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var it
is never returned to clients.
Provides:
* Shared message dataclasses (EmailMessage, ChatMessage)
* get_provider() factory for Gmail/MS Graph clients
* encrypt_token() / decrypt_token() Fernet-based OAuth token encryption
"""
from __future__ import annotations
@@ -27,7 +18,7 @@ from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken
from app.config.settings import settings
from shared.config import settings
if TYPE_CHECKING:
from app.integrations.gmail import GmailClient
@@ -35,13 +26,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
# ── Shared message types ──────────────────────────────────────────────────
@dataclass
class EmailMessage:
"""A single email message fetched from Gmail or Outlook."""
id: str
subject: str
sender: str
@@ -51,7 +38,6 @@ class EmailMessage:
@property
def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M")
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
return (
@@ -64,8 +50,6 @@ class EmailMessage:
@dataclass
class ChatMessage:
"""A single Teams chat or channel message fetched from MS Graph."""
id: str
content: str
sender: str
@@ -74,7 +58,6 @@ class ChatMessage:
@property
def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M")
channel_str = f" [channel: {self.channel}]" if self.channel else ""
return (
@@ -84,15 +67,7 @@ class ChatMessage:
)
# ── Fernet helpers ────────────────────────────────────────────────────────
def _get_fernet() -> Fernet:
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set callers
must ensure this is configured before persisting OAuth tokens.
"""
key = settings.OAUTH_ENCRYPTION_KEY
if not key:
raise RuntimeError(
@@ -103,15 +78,6 @@ def _get_fernet() -> Fernet:
def encrypt_token(token_info: dict) -> str:
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
Stores the full ``{access_token, refresh_token, token_uri, client_id,
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: ``token_info`` is not a non-empty dict.
"""
if not isinstance(token_info, dict) or not token_info:
raise ValueError("token_info must be a non-empty dict")
plaintext = json.dumps(token_info).encode("utf-8")
@@ -119,13 +85,6 @@ def encrypt_token(token_info: dict) -> str:
def decrypt_token(encrypted: str) -> dict:
"""Decrypt a Fernet-encrypted token string and return the credential dict.
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: The encrypted string is invalid or was encrypted with a
different key.
"""
try:
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
return json.loads(plaintext)
@@ -133,25 +92,10 @@ def decrypt_token(encrypted: str) -> dict:
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
# ── Provider factory ──────────────────────────────────────────────────────
def get_provider(
provider: str,
credentials_info: dict,
) -> "GmailClient | MSGraphClient":
"""Return the correct provider client for *provider*.
Parameters
----------
provider:
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
credentials_info:
Decrypted OAuth credential dict (Google or Microsoft shape).
Raises:
ValueError: Unknown provider name.
"""
if provider == "gmail":
from app.integrations.gmail import GmailClient
return GmailClient(credentials_info)

View File

@@ -1,26 +1,7 @@
"""Gmail API client for cloud agent integration.
Wraps the Google Gmail REST API to fetch email messages matching a
``filter_config`` dict. Uses the official ``google-api-python-client``
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
blocking the event loop.
Token refresh is handled transparently: when the stored access token has
expired, ``google.auth.transport.requests.Request`` will use the refresh
token to obtain a fresh one. The caller is responsible for persisting
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
(see ``agent_runner.run_cloud_agent``).
Credential dict shape (Google OAuth2):
{
"token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "<client_id>",
"client_secret": "<client_secret>",
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
}
Adapted for Batch Agent Service: import from app.integrations instead of
app.integrations (same relative path within the service).
"""
from __future__ import annotations
@@ -38,13 +19,8 @@ from app.integrations import EmailMessage
logger = logging.getLogger(__name__)
# Gmail search date format — e.g. "after:2025/01/01"
_GMAIL_DATE_FMT = "%Y/%m/%d"
# Maximum characters of body text forwarded to the LLM.
_BODY_TRUNCATE = 8_000
# Maximum messages retrieved per run (prevents runaway quota usage).
_MAX_MESSAGES = 200
@@ -52,20 +28,9 @@ def _build_gmail_query(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
"""Build a Gmail search query string from *filter_config* and *since*.
Supported ``filter_config`` keys:
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
senders (list[str]): Sender addresses or domains to include
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
A hard ``since`` date (from last run) always overrides ``date_range.from``
when it is earlier.
"""
parts: list[str] = []
cfg = filter_config or {}
# Labels — joined with OR when multiple given.
labels: list[str] = cfg.get("labels", [])
if labels:
if len(labels) == 1:
@@ -74,17 +39,14 @@ def _build_gmail_query(
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
parts.append(f"({label_expr})")
# Senders — each prefixed with "from:".
senders: list[str] = cfg.get("senders", [])
for sender in senders:
parts.append(f"from:{sender}")
# Date range.
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
to_str: str | None = date_range.get("to")
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
effective_since: datetime | None = since
if from_str:
try:
@@ -110,18 +72,12 @@ def _build_gmail_query(
def _strip_html(raw_html: str) -> str:
"""Remove HTML tags and decode entities to get plain text."""
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
decoded = html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip()
def _parse_body(payload: dict[str, Any]) -> str:
"""Recursively extract the plain-text body from a Gmail message payload.
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
Returns an empty string if no body can be extracted.
"""
mime_type: str = payload.get("mimeType", "")
body: dict = payload.get("body", {})
parts: list[dict] = payload.get("parts", [])
@@ -139,7 +95,6 @@ def _parse_body(payload: dict[str, Any]) -> str:
return _strip_html(raw)
return ""
# Multipart — prefer text/plain part, fall back to text/html.
plain_fallback = ""
for part in parts:
part_mime = part.get("mimeType", "")
@@ -155,7 +110,6 @@ def _parse_body(payload: dict[str, Any]) -> str:
def _parse_date(raw: str) -> datetime:
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
try:
parsed = email.utils.parsedate_to_datetime(raw)
if parsed.tzinfo is None:
@@ -166,16 +120,6 @@ def _parse_date(raw: str) -> datetime:
class GmailClient:
"""Fetch email messages from a Gmail account via the Gmail REST API.
Parameters
----------
credentials_info:
Decrypted OAuth2 credential dict. Must contain at minimum
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
``client_id`` + ``client_secret``.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None:
from google.oauth2.credentials import Credentials
@@ -200,38 +144,20 @@ class GmailClient:
expiry=expiry,
)
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_messages(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
to avoid blocking the async event loop.
Token refresh is performed automatically when the access token has
expired. After the call, ``self.refreshed_credentials`` may be
consulted to detect whether new credentials should be persisted.
"""
query = _build_gmail_query(filter_config, since)
logger.debug("gmail: executing search query %r", query)
return await asyncio.to_thread(self._fetch_sync, query)
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
If the credentials were refreshed during ``fetch_messages()``, returns
a new dict that should be re-encrypted and written back to the DB.
Returns ``None`` if no refresh occurred.
"""
creds = self._credentials
if not creds.valid and creds.expired:
return None
# Check whether the token changed from what was stored.
if creds.token != self._credentials_info.get("token"):
result = {
"token": creds.token,
@@ -246,15 +172,11 @@ class GmailClient:
return result
return None
# ── Internal sync worker ───────────────────────────────────────────────
def _fetch_sync(self, query: str) -> list[EmailMessage]:
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
import googleapiclient.discovery
import googleapiclient.errors
from google.auth.transport.requests import Request
# Refresh token if needed before building the service.
if self._credentials.expired and self._credentials.refresh_token:
try:
self._credentials.refresh(Request())
@@ -264,9 +186,8 @@ class GmailClient:
service = googleapiclient.discovery.build(
"gmail", "v1", credentials=self._credentials, cache_discovery=False
)
user_api = service.users() # type: ignore[attr-defined]
user_api = service.users()
# ── List matching message IDs ──────────────────────────────────────
ids: list[str] = []
page_token: str | None = None
while len(ids) < _MAX_MESSAGES:
@@ -293,12 +214,10 @@ class GmailClient:
break
if not ids:
logger.debug("gmail: no messages matched query %r", query)
return []
logger.info("gmail: fetching %d message(s)", len(ids))
# ── Fetch individual message details ──────────────────────────────
messages: list[EmailMessage] = []
for msg_id in ids:
try:
@@ -326,10 +245,8 @@ class GmailClient:
date=date,
labels=labels,
))
except googleapiclient.errors.HttpError as exc:
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
except Exception as exc:
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
logger.warning("gmail: skipping message %s: %s", msg_id, exc)
logger.info("gmail: returned %d message(s)", len(messages))
return messages

View File

@@ -1,24 +1,6 @@
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
"""Microsoft Graph API client for Outlook and Teams.
Handles two data sources:
* **Outlook email** (``provider="outlook"``) ``fetch_emails()`` calls
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
* **Teams messages** (``provider="teams"``) ``fetch_messages()`` calls
``/me/chats/getAllMessages`` filtered by date.
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
dependency) is used for all API calls.
Credential dict shape (Microsoft OAuth2 / MSAL):
{
"access_token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_type": "Bearer",
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
"expires_in": 3600
}
Adapted for Batch Agent Service: import settings from shared.config.
"""
from __future__ import annotations
@@ -30,23 +12,19 @@ from typing import Any
import httpx
from app.config.settings import settings
from shared.config import settings
from app.integrations import ChatMessage, EmailMessage
logger = logging.getLogger(__name__)
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
# Max items fetched per run.
_MAX_EMAILS = 200
_MAX_MESSAGES = 200
# Max characters of body forwarded to the LLM.
_BODY_TRUNCATE = 8_000
def _strip_html(raw: str) -> str:
"""Strip HTML tags and collapse whitespace."""
no_tags = re.sub(r"<[^>]+>", " ", raw)
import html as _html
decoded = _html.unescape(no_tags)
@@ -54,7 +32,6 @@ def _strip_html(raw: str) -> str:
def _odata_datetime(dt: datetime) -> str:
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
utc = dt.astimezone(timezone.utc)
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
@@ -63,29 +40,14 @@ def _build_email_filter(
filter_config: dict[str, Any] | None,
since: datetime | None,
) -> str:
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
Supported ``filter_config`` keys:
senders (list[str]): Sender email addresses.
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
folders (list[str]): Folder display names (not directly filterable
via OData, so ignored here callers iterate
folder IDs separately if needed; listed for
completeness).
A hard ``since`` date always overrides ``date_range.from`` when it is
earlier.
"""
clauses: list[str] = []
cfg = filter_config or {}
# Senders.
senders: list[str] = cfg.get("senders", [])
if senders:
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
clauses.append("(" + " or ".join(sender_clauses) + ")")
# Date range.
date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from")
@@ -117,33 +79,16 @@ def _build_email_filter(
class MSGraphClient:
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
Parameters
----------
credentials_info:
Decrypted MSAL credential dict.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None:
self._credentials_info = credentials_info
self._access_token: str = credentials_info.get("access_token", "")
self._original_access_token: str = self._access_token
self._refresh_token: str | None = credentials_info.get("refresh_token")
# ── Token management ───────────────────────────────────────────────────
def _auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._access_token}"}
async def _refresh_access_token(self) -> None:
"""Use MSAL to exchange the refresh token for a fresh access token.
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
Raises:
RuntimeError: MSAL reports an auth error.
"""
import msal
app = msal.ConfidentialClientApplication(
@@ -164,7 +109,6 @@ class MSGraphClient:
raise RuntimeError(f"MS Graph token refresh failed: {error}")
self._access_token = result["access_token"]
# MSAL may issue a new refresh token.
if "refresh_token" in result:
self._refresh_token = result["refresh_token"]
self._credentials_info["refresh_token"] = result["refresh_token"]
@@ -172,16 +116,10 @@ class MSGraphClient:
@property
def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
Returns ``None`` if no change was made.
"""
if self._access_token != self._original_access_token:
return {**self._credentials_info, "access_token": self._access_token}
return None
# ── HTTP helpers ───────────────────────────────────────────────────────
async def _get(
self,
client: httpx.AsyncClient,
@@ -190,10 +128,8 @@ class MSGraphClient:
*,
retry_on_401: bool = True,
) -> dict[str, Any]:
"""GET *url* with auth; refresh token on 401 and retry once."""
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
logger.debug("ms_graph: 401 on %s — refreshing token", url)
await self._refresh_access_token()
resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 429:
@@ -201,22 +137,11 @@ class MSGraphClient:
resp.raise_for_status()
return resp.json()
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_emails(
self,
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[EmailMessage]:
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
Parameters
----------
filter_config:
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
since:
Hard lower-bound on email date (from last agent run).
"""
odata_filter = _build_email_filter(filter_config, since)
params: dict[str, Any] = {
"$top": 50,
@@ -237,7 +162,7 @@ class MSGraphClient:
if len(emails) >= _MAX_EMAILS:
break
url = data.get("@odata.nextLink", "")
params = {} # nextLink already contains encoded params.
params = {}
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
return emails
@@ -247,13 +172,6 @@ class MSGraphClient:
filter_config: dict[str, Any] | None = None,
since: datetime | None = None,
) -> list[ChatMessage]:
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
The ``filter_config.channels`` key is checked as a text-filter on
the channel name post-fetch (the API doesn't support channel OData
filter directly on ``getAllMessages``).
"""
cfg = filter_config or {}
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
params: dict[str, Any] = {"$top": 50}
@@ -268,11 +186,9 @@ class MSGraphClient:
try:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
except httpx.HTTPStatusError as exc:
# getAllMessages requires specific licensing; degrade gracefully.
if exc.response.status_code in (403, 404):
logger.warning(
"ms_graph: /me/chats/getAllMessages not available (%d)"
"check Teams license or permissions",
"ms_graph: /me/chats/getAllMessages not available (%d)",
exc.response.status_code,
)
break
@@ -292,8 +208,6 @@ class MSGraphClient:
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
return messages
# ── Parsers ────────────────────────────────────────────────────────────
@staticmethod
def _parse_email(item: dict[str, Any]) -> EmailMessage:
subject: str = item.get("subject", "(no subject)") or "(no subject)"

View File

@@ -0,0 +1,395 @@
"""Chatbot Journey — guided conversation to build an agent prompt_template.
Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
and app.llm instead of monolith paths. Session state is in-memory (could
be moved to Redis for horizontal scaling in the future).
Journey flow:
1. Redis consumer dispatches ``journey_start`` with basic agent config.
2. Server creates an in-memory session, runs the setup LLM with
file-system tools to explore the directory, returns first question.
3. ``journey_message`` frames drive the conversation.
4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
5. Server parses the block and returns ``journey_reply`` with ``done=True``.
"""
from __future__ import annotations
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from shared.llm import get_llm
import app.tracing as tracing
logger = logging.getLogger(__name__)
# ── Session TTL ───────────────────────────────────────────────────────────
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
# Sentinel strings used to delimit the LLM-produced prompt_template.
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
_MIN_TURNS_BEFORE_NUDGE: int = 3
_MAX_TURNS: int = 15
_MAX_TOOL_STEPS: int = 6
# ── In-memory session store ───────────────────────────────────────────────
@dataclass
class JourneySession:
session_id: str
user_id: str
agent_type: str # "local" | "cloud"
directory: str
data_types: list[str]
history: list[dict[str, Any]] = field(default_factory=list)
system_prompt: str = ""
created_at: float = field(default_factory=time.monotonic)
def is_expired(self) -> bool:
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
# session_id → session
_sessions: dict[str, JourneySession] = {}
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
"""Retrieve session; return None on missing, expired, or wrong owner."""
s = _sessions.get(session_id)
if s is None or s.is_expired():
_sessions.pop(session_id, None)
return None
if s.user_id != user_id:
return None
return s
# ── System prompt builder ─────────────────────────────────────────────────
_SYSTEM_PROMPT_TEMPLATE = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand exactly what data the user wants to extract from their
local directory and produce a concise prompt_template that a separate AI will use
as its instruction set.
You have access to file-system tools to explore the user's directory:
- list_directory: to see folder structure
- read_file_content: to peek at file contents
- get_file_metadata: to check file info
The user's configured directory is: {directory}
Target data types: {data_types}
IMPORTANT — project assignment is handled automatically. You MUST NOT ask the user
about projects, projectId, or how to link records to projects. Never include
projectId logic or project creation instructions in the generated prompt_template.
Start by exploring the directory to understand its structure. Then ask concise,
focused questions one at a time. Cover only the topics relevant to the target
data types listed above:
1. Content type and format — confirmed by your exploration.
2. For TASKS (if in scope): field mapping for title, status, priority, content,
dueDate (where is the date found? what's the fallback when absent?),
and assignee (is there a person name to assign?).
3. For NOTES when TASKS are also in scope: note vs task distinction —
what makes something a note rather than a task?
4. For TIMELINES (if in scope): the date source — what marks a milestone or event?
5. Exclusions and special handling applicable to the target data types.
Keep asking focused questions until you are at least 90% confident. Then stop and
output the final prompt_template immediately, wrapped between these exact markers
on their own lines:
{template_start}
<the complete extraction prompt here>
{template_end}
The prompt_template must be concise (bullet points, ~1525 lines maximum).
Specify only:
- Scope: what files/content qualify and what entity types to create.
- Field mapping rules per entity type (camelCase fields: title, status, priority,
dueDate, content, assignee, etc.).
- dueDate rule (if tasks in scope): source and fallback behaviour.
- Note vs task rule (if both in scope): the criterion that separates them.
- Timeline date rule (if timelines in scope): what constitutes a timeline event.
- Exclusion/filtering rules.
- 23 concrete mapping examples based on what you discovered.
{existing_section}Begin by exploring the directory, then ask your first question.\
"""
def _build_system_prompt(
directory: str,
data_types: list[str],
existing_template: str | None = None,
) -> str:
existing_section = (
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
f"---\n{existing_template}\n---\n"
if existing_template
else ""
)
# Use Langfuse compile_prompt ({{variable}} syntax) with Python .format() fallback
return tracing.compile_prompt(
"journey_system",
fallback=_SYSTEM_PROMPT_TEMPLATE,
variables={
"directory": directory,
"data_types": ", ".join(data_types),
"existing_section": existing_section,
},
)
# ── Template extraction ───────────────────────────────────────────────────
def _extract_template(text: str) -> str | None:
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
return None
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
end_idx = text.index(_TEMPLATE_END)
return text[start_idx:end_idx].strip() or None
# ── LLM call with tool support ───────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _call_llm_with_tools(
system_prompt: str,
history: list[dict[str, Any]],
tools: list[Any],
langfuse_handler: Any | None = None,
) -> str:
"""Build LangChain messages from history and invoke the LLM with tools.
Handles tool-calling loops: if the LLM calls tools, execute them and
continue until a final text response is produced.
"""
messages: list[Any] = [SystemMessage(content=system_prompt)]
for turn in history:
if turn["role"] == "user":
messages.append(HumanMessage(content=turn["content"]))
else:
messages.append(AIMessage(content=turn["content"]))
callbacks = [langfuse_handler] if langfuse_handler else None
llm = get_llm(model=None, temperature=0.4, callbacks=callbacks)
llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(_MAX_TOOL_STEPS):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"journey: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:500],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"journey: tool_result name=%s output=%s",
call_name,
str(tool_output)[:800],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max tool steps.
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Journey handlers (called from redis_consumer) ────────────────────────
async def handle_journey_start(
user_id: str,
frame: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> dict[str, Any]:
"""Handle a ``journey_start`` request.
Creates a session, runs the setup LLM with directory exploration,
and returns the ``journey_reply`` payload.
"""
agent_type = frame.get("agent_type", "local")
directory = frame.get("directory", "")
data_types = frame.get("data_types", [])
existing_template = frame.get("existing_template")
session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt = _build_system_prompt(directory, data_types, existing_template)
session = JourneySession(
session_id=session_id,
user_id=user_id,
agent_type=agent_type,
directory=directory,
data_types=data_types,
system_prompt=system_prompt,
)
seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
]
ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt,
history=seed_history,
tools=list(FILESYSTEM_TOOLS),
langfuse_handler=langfuse_handler,
)
session.history.extend(seed_history)
session.history.append({"role": "assistant", "content": ai_reply})
_sessions[session_id] = session
logger.info(
"journey: session %s started for user %s (directory=%s)",
session_id,
user_id,
directory,
)
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
or "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}
async def handle_journey_message(
user_id: str,
frame: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> dict[str, Any]:
"""Handle a ``journey_message`` request.
Appends the user message, calls the LLM, and returns the
``journey_reply`` payload.
"""
session_id = frame.get("session_id", "")
message = frame.get("message", "")
session = get_journey_session(session_id, user_id)
if session is None:
return {
"type": "journey_reply",
"session_id": session_id,
"message": "Journey session not found or expired. Please start a new setup.",
"done": True,
"prompt_template": None,
}
session.history.append({"role": "user", "content": message})
ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
langfuse_handler=langfuse_handler,
)
session.history.append({"role": "assistant", "content": ai_reply})
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
if not done:
turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
nudge_content = (
"[System: You have enough information. Please generate the final "
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})
nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
langfuse_handler=langfuse_handler,
)
session.history.append({"role": "assistant", "content": nudge_reply})
prompt_template = _extract_template(nudge_reply)
if prompt_template is not None:
done = True
ai_reply = nudge_reply
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
if _TEMPLATE_START in ai_reply
else "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
logger.info("journey: session %s completed for user %s", session_id, user_id)
return {
"type": "journey_reply",
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
}

View File

@@ -0,0 +1,76 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Identical to services/chat/app/llm.py. Uses shared.config.settings.
"""
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from shared.config import settings
litellm.drop_params = True
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github/"):
return settings.GITHUB_TOKEN or None
if model.startswith("github_copilot/"):
return None
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
callbacks: list | None = None,
) -> ChatOpenAI | ChatLiteLLM:
model = model or settings.LLM_MODEL
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
if settings.GITHUB_TOKEN:
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
callbacks=callbacks,
)
async def embed(text: str) -> list[float]:
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

View File

@@ -0,0 +1,79 @@
"""Batch Agent Service — FastAPI application.
Owns: agent_runner (local directory + cloud connectors), journey builder,
filesystem_agent, integrations (Gmail, MS Graph).
Communicates with WS Gateway via Redis:
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
- Publishes to ws:out:{user_id} (journey replies + tool calls)
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
"""
from __future__ import annotations
import asyncio
import logging
import sys
from pathlib import Path
# Ensure the repo root is on sys.path so ``shared`` is importable when
# running locally (in Docker the COPY already places it at /app/shared/).
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.redis_consumer import start_consumer
from app.routes import router
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Initialise Langfuse tracing (no-op if keys are missing)
from app.tracing import init_langfuse
init_langfuse()
logger.info("batch-agent: starting Redis consumer")
task = asyncio.create_task(start_consumer())
yield
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
from app.tracing import shutdown as shutdown_langfuse
shutdown_langfuse()
from shared.db import engine
await engine.dispose()
from shared.redis import redis_client
await redis_client.aclose()
logger.info("batch-agent: Redis consumer stopped")
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
app.include_router(router)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "ok", "service": "batch-agent"}

View File

@@ -0,0 +1,183 @@
"""Redis consumer for the Batch Agent Service.
Subscribes to batch:request:* (pattern) and dispatches:
- journey_start → handle_journey_start
- journey_message → handle_journey_message
- agent_trigger → run_local_agent / run_cloud_agent
Results are published back to ws:out:{user_id} via Redis.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
from shared.redis import redis_client, batch_request_channel, ws_out_channel
import app.tracing as tracing
from shared.ws_context import set_current_user, clear_current_user
logger = logging.getLogger(__name__)
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
"""Publish a frame to the user's WS outbound channel."""
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps(payload))
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
"""Handle a journey_start request from WS Gateway."""
from app.journey import handle_journey_start
session_id = data.get("session_id", "")
set_current_user(user_id)
try:
with tracing.trace_span(
name="journey_start",
user_id=user_id,
session_id=session_id,
input=data.get("directory", ""),
metadata={"data_types": data.get("data_types", [])},
tags=["journey"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
reply = await handle_journey_start(user_id, data, langfuse_handler=langfuse_handler)
tracing.link_prompt_to_trace(span, "journey_system")
span.update(output=reply.get("message", "")[:500])
await _publish_to_user(user_id, reply)
tracing.flush()
except Exception as exc:
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey setup failed: {exc}",
"done": True,
"prompt_template": None,
})
finally:
clear_current_user()
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
"""Handle a journey_message from WS Gateway."""
from app.journey import handle_journey_message
session_id = data.get("session_id", "")
set_current_user(user_id)
try:
with tracing.trace_span(
name="journey_message",
user_id=user_id,
session_id=session_id,
input=data.get("message", "")[:200],
tags=["journey"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
reply = await handle_journey_message(user_id, data, langfuse_handler=langfuse_handler)
tracing.link_prompt_to_trace(span, "journey_system")
span.update(output=reply.get("message", "")[:500])
await _publish_to_user(user_id, reply)
tracing.flush()
except Exception as exc:
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey processing failed: {exc}",
"done": True,
"prompt_template": None,
})
finally:
clear_current_user()
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
from app.agent_runner import run_local_agent
run_context = data.get("run_context", {})
agent_id = run_context.get("agent_id", "")
set_current_user(user_id)
try:
with tracing.trace_span(
name="agent_trigger",
user_id=user_id,
trace_id=run_context.get("run_id"),
input={"agent_id": agent_id, "directory": data.get("directory", "")},
metadata={"data_types": data.get("data_types", [])},
tags=["batch", "agent_run"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
await run_local_agent(user_id, data, langfuse_handler=langfuse_handler)
tracing.link_prompt_to_trace(span, "batch_processing")
span.update(output={"status": "completed"})
tracing.flush()
except Exception as exc:
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "run_complete",
"status": "error",
"run_context": run_context,
})
finally:
clear_current_user()
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
"""Route a batch request to the correct handler."""
msg_type = message_data.get("type", "")
if msg_type == "journey_start":
await _handle_journey_start(user_id, message_data)
elif msg_type == "journey_message":
await _handle_journey_message(user_id, message_data)
elif msg_type == "agent_trigger":
await _handle_agent_trigger(user_id, message_data)
elif msg_type == "device_online":
logger.info("batch-agent: device_online user=%s device=%s", user_id, message_data.get("device_id", "?"))
else:
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
async def start_consumer() -> None:
"""Subscribe to batch:request:* and dispatch incoming frames."""
pubsub = redis_client.pubsub()
await pubsub.psubscribe("batch:request:*")
logger.info("batch-agent: subscribed to batch:request:*")
try:
async for message in pubsub.listen():
if message["type"] != "pmessage":
continue
channel: str = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
# Extract user_id from channel: batch:request:{user_id}
parts = channel.split(":", 2)
if len(parts) < 3:
continue
user_id = parts[2]
raw = message["data"]
if isinstance(raw, bytes):
raw = raw.decode()
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("batch-agent: invalid JSON on channel %s", channel)
continue
# Dispatch in a separate task to avoid blocking the consumer
asyncio.create_task(_dispatch(user_id, data))
except asyncio.CancelledError:
logger.info("batch-agent: consumer shutting down")
finally:
await pubsub.punsubscribe("batch:request:*")

View File

@@ -0,0 +1,208 @@
"""Agent REST routes — catalog, billing checks, trigger.
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
Agent trigger dispatches via Redis to the consumer instead of spawning
an in-process background task.
"""
from __future__ import annotations
import json
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Header, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.db import async_session
from shared.models import AgentRunLog
from shared.redis import redis_client, batch_request_channel
from app.agent_runner import is_agent_running
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Tier feature limits ───────────────────────────────────────────────
# Mirrors app/billing/tier_manager.py FEATURES dict.
FEATURES: dict[str, dict] = {
"free": {"batch_active": 1, "batch_runs_per_day": 3},
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
"power": {"batch_active": 20, "batch_runs_per_day": 100},
"team": {"batch_active": -1, "batch_runs_per_day": -1},
}
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
async with async_session() as db:
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog")
async def get_agent_catalog(
x_user_id: str = Header(..., alias="X-User-Id"),
) -> list[dict]:
return [
{
"type": "local_directory",
"name": "Local Directory Monitor",
"description": "Watches local directories, extracts data from files using AI",
},
{
"type": "gmail",
"name": "Gmail Connector",
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
},
{
"type": "teams",
"name": "Microsoft Teams Connector",
"description": "Monitors Teams messages, extracts action items",
},
{
"type": "outlook",
"name": "Outlook Connector",
"description": "Scans Outlook inbox, extracts tasks/notes",
},
]
# ── Can-create check ─────────────────────────────────────────────────
@router.post("/can-create")
async def can_create_agent(
body: dict,
x_user_id: str = Header(..., alias="X-User-Id"),
x_user_tier: str = Header("free", alias="X-User-Tier"),
) -> dict:
active_agents = body.get("active_agents", 0)
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or active_agents < limit
return {
"allowed": allowed,
"tier": x_user_tier,
"active_agents": active_agents,
"limit": limit,
}
# ── Trigger ──────────────────────────────────────────────────────────
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: dict,
x_user_id: str = Header(..., alias="X-User-Id"),
x_user_tier: str = Header("free", alias="X-User-Tier"),
) -> dict:
"""Trigger a local agent run — creates run log and dispatches via Redis."""
active_agents = body.get("active_agents", 0)
_enforce_agent_limit(x_user_tier, active_agents)
await _enforce_run_frequency(x_user_tier, x_user_id)
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running.",
)
# Create run log in DB
async with async_session() as db:
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=x_user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
run_context = {
"type": "agent_batch",
"run_id": run_log_id,
"agent_id": stable_agent_id,
}
# Dispatch to the Redis consumer for processing
trigger_data = {
"type": "agent_trigger",
"directory": body.get("directory", ""),
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
"data_types": _to_data_types(body.get("what_to_extract", [])),
"file_extensions": body.get("file_extensions", []),
"prompt_template": body.get("custom_agent_prompt", ""),
"device_id": body.get("device_id", ""),
"run_context": run_context,
}
channel = batch_request_channel(x_user_id)
await redis_client.publish(channel, json.dumps(trigger_data))
return {
"id": run_log_id,
"agent_id": stable_agent_id,
"agent_type": "local",
"status": "running",
"items_processed": 0,
"items_created": 0,
"errors": [],
"started_at": _dt_ms(run_log.started_at),
"completed_at": None,
}

View File

@@ -0,0 +1,336 @@
"""Langfuse tracing & prompt management for the Batch Agent Service (v4 SDK).
Provides:
- ``init_langfuse()`` — initialise the singleton client at startup
- ``trace_span()`` — context manager that creates a trace + span
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
- ``flush()`` / ``shutdown()`` — lifecycle management
All functions gracefully degrade to no-ops when Langfuse is not configured,
so the service works identically with or without observability keys.
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import Any
from shared.config import settings
logger = logging.getLogger(__name__)
# ── State ────────────────────────────────────────────────────────────────
_initialised: bool = False
_disabled: bool = False
def _is_configured() -> bool:
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
def init_langfuse() -> None:
"""Initialise the Langfuse singleton. Call once at startup."""
global _initialised, _disabled
if _initialised or _disabled:
return
if not _is_configured():
_disabled = True
logger.info("tracing: Langfuse keys not set — tracing disabled")
return
try:
from langfuse import Langfuse
Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_HOST,
)
_initialised = True
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
except Exception as exc:
_disabled = True
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
def _get_client() -> Any | None:
"""Return the singleton Langfuse client, or *None* if disabled."""
if _disabled:
return None
if not _initialised:
init_langfuse()
if _disabled:
return None
try:
from langfuse import get_client
return get_client()
except Exception:
return None
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
class _NullSpan:
"""Drop-in replacement when Langfuse is disabled."""
def update(self, **_: Any) -> None: ...
def set_trace_io(self, **_: Any) -> None: ...
def score_trace(self, **_: Any) -> None: ...
# ── Trace context manager ───────────────────────────────────────────────
@contextmanager
def trace_span(
*,
name: str,
user_id: str,
session_id: str | None = None,
trace_id: str | None = None,
input: Any = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
):
"""Context manager that creates a Langfuse trace/span.
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
A ``CallbackHandler`` created inside this block auto-inherits the trace
context, so there is no need to pass trace IDs manually.
"""
lf = _get_client()
if lf is None:
yield _NullSpan()
return
try:
from langfuse import Langfuse, propagate_attributes
trace_ctx: dict[str, str] = {}
if trace_id is not None:
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
with lf.start_as_current_observation(
as_type="span",
name=name,
input=input,
metadata=metadata or {},
**({"trace_context": trace_ctx} if trace_ctx else {}),
) as span:
with propagate_attributes(
user_id=user_id,
session_id=session_id,
tags=tags or [],
):
yield span
except Exception as exc:
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
yield _NullSpan()
# ── LangChain callback handler ──────────────────────────────────────────
def get_langfuse_callback() -> Any | None:
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
Must be called inside a ``trace_span()`` block for proper linking.
Returns *None* when Langfuse is disabled.
"""
if _disabled and not _initialised:
return None
try:
from langfuse.langchain import CallbackHandler
return CallbackHandler()
except Exception as exc:
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
return None
# ── Prompt management ────────────────────────────────────────────────────
def get_prompt(
name: str,
*,
version: int | None = None,
label: str | None = None,
fallback: str | None = None,
cache_ttl_seconds: int = 300,
) -> str | None:
"""Fetch a managed prompt from Langfuse by name (without variable compilation).
Returns the raw prompt string, or *fallback* if the prompt is not
found or Langfuse is disabled.
"""
lf = _get_client()
if lf is None:
return fallback
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
return prompt.prompt
except Exception as exc:
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
return fallback
def compile_prompt(
name: str,
*,
fallback: str,
variables: dict[str, str],
version: int | None = None,
label: str | None = None,
cache_ttl_seconds: int = 300,
) -> str:
"""Fetch a managed prompt from Langfuse and compile it with ``{{variables}}``.
If the prompt exists in Langfuse, uses the SDK's ``.compile(**variables)``
which replaces ``{{key}}`` placeholders. If Langfuse is disabled or the
prompt is not found, falls back to ``fallback.format(**variables)`` (Python
``{key}`` placeholders).
This means:
- Langfuse prompts use ``{{variable}}`` syntax.
- Hardcoded fallback strings use Python ``{variable}`` syntax.
"""
lf = _get_client()
if lf is None:
return fallback.format(**variables)
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
return prompt.compile(**variables)
except Exception as exc:
logger.warning("tracing: compile_prompt(%s) failed, using fallback: %s", name, exc)
return fallback.format(**variables)
def get_prompt_object(
name: str,
*,
version: int | None = None,
label: str | None = None,
cache_ttl_seconds: int = 300,
) -> Any | None:
"""Fetch the raw Langfuse prompt *object* (not the compiled string).
Returns ``None`` when Langfuse is disabled or the prompt is not found.
Use this when you need to pass the prompt to ``start_observation(prompt=...)``
for linking the prompt to a trace in the Langfuse UI.
"""
lf = _get_client()
if lf is None:
return None
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
return lf.get_prompt(**kwargs)
except Exception as exc:
logger.warning("tracing: get_prompt_object(%s) failed: %s", name, exc)
return None
def link_prompt_to_trace(
span: Any,
prompt_name: str,
*,
version: int | None = None,
label: str | None = None,
) -> None:
"""Link a Langfuse managed prompt to a span/observation.
Uses the SDK v4 ``prompt=`` parameter so that the prompt version
appears linked in the Langfuse UI with metrics tracking.
"""
lf = _get_client()
if lf is None or isinstance(span, _NullSpan):
return
try:
prompt = get_prompt_object(prompt_name, version=version, label=label)
if prompt is not None:
span.update(prompt=prompt)
except Exception as exc:
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
# ── Scoring helper ───────────────────────────────────────────────────────
def score_trace(
trace_id: str,
name: str,
value: float,
*,
comment: str | None = None,
) -> None:
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
lf = _get_client()
if lf is None:
return
try:
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
except Exception as exc:
logger.warning("tracing: score_trace failed: %s", exc)
# ── Shutdown ─────────────────────────────────────────────────────────────
def flush() -> None:
"""Flush pending Langfuse events."""
lf = _get_client()
if lf is not None:
try:
lf.flush()
except Exception as exc:
logger.warning("tracing: flush failed: %s", exc)
def shutdown() -> None:
"""Flush and close the Langfuse client."""
global _initialised, _disabled
lf = _get_client()
if lf is not None:
try:
lf.flush()
lf.shutdown()
except Exception as exc:
logger.warning("tracing: shutdown failed: %s", exc)
_initialised = False
_disabled = False

View File

@@ -0,0 +1 @@
"""Batch Agent E2E evaluation harness."""

View File

@@ -0,0 +1,5 @@
"""Allow running the eval package as ``python -m eval``."""
from eval.cli import main
main()

View File

@@ -0,0 +1,285 @@
"""CLI entry point for the batch agent evaluation harness.
Usage::
# From services/batch-agent/:
python -m eval run # all agent fixtures, default model
python -m eval run --fixture=classify-invoices # single fixture
python -m eval run --models=gpt-4o,gpt-5.3-codex # multiple models
python -m eval run --mode=step1 # only step1 fixtures
python -m eval run --no-judge # skip LLM judge scoring
python -m eval interactive # interactive journey session
python -m eval interactive --fixture=journey-invoice-setup
python -m eval interactive --model=gpt-4o
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
python -m eval list # list all fixtures
python -m eval sync # sync fixtures to Langfuse datasets
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import sys
from pathlib import Path
# Ensure the service root and repo root are in sys.path.
# Service root must come BEFORE repo root so its ``app/`` package
# shadows the monolith ``app/`` in the repo root.
_SERVICE_ROOT = Path(__file__).resolve().parent.parent
_REPO_ROOT = _SERVICE_ROOT.parent.parent
_sr = str(_SERVICE_ROOT)
_rr = str(_REPO_ROOT)
if _rr not in sys.path:
sys.path.insert(0, _rr)
# Always force service root to position 0 (python -m may have already
# added CWD further down the list, which loses to repo root).
if _sr in sys.path:
sys.path.remove(_sr)
sys.path.insert(0, _sr)
from eval.config import discover_fixtures, discover_journey_fixtures
from eval.runner import run_fixture_eval, print_results
from eval.interactive import run_interactive
from eval import langfuse_eval
def _setup_logging(verbose: bool) -> None:
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s %(name)-20s %(levelname)-5s %(message)s",
datefmt="%H:%M:%S",
)
# Quiet noisy libraries
for name in ("httpx", "httpcore", "openai", "litellm", "urllib3"):
logging.getLogger(name).setLevel(logging.WARNING)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Batch Agent E2E evaluation harness",
prog="python -m eval",
)
sub = parser.add_subparsers(dest="command", required=True)
# ── run ───────────────────────────────────────────────────────
run_cmd = sub.add_parser("run", help="Run evaluations")
run_cmd.add_argument(
"--fixture", "-f",
help="Run only the named fixture (default: all)",
)
run_cmd.add_argument(
"--models", "-m",
default="github_copilot/gpt-5.3-codex",
help="Comma-separated list of models to test (default: github_copilot/gpt-5.3-codex)",
)
run_cmd.add_argument(
"--mode",
default=None,
choices=["step1", "step2", "full"],
help="Only run fixtures with this mode (default: all)",
)
run_cmd.add_argument(
"--no-judge",
action="store_true",
help="Skip LLM-as-judge scoring",
)
run_cmd.add_argument(
"--judge-model",
default="gpt-4o",
help="Model for LLM judge (default: gpt-4o)",
)
run_cmd.add_argument(
"--fixtures-dir",
default=None,
help="Path to fixtures directory (default: eval/fixtures/)",
)
run_cmd.add_argument("-v", "--verbose", action="store_true")
# ── list ──────────────────────────────────────────────────────
list_cmd = sub.add_parser("list", help="List available fixtures")
list_cmd.add_argument("--fixtures-dir", default=None)
list_cmd.add_argument("-v", "--verbose", action="store_true")
# ── sync ──────────────────────────────────────────────────────
sync_cmd = sub.add_parser("sync", help="Sync fixtures to Langfuse datasets")
sync_cmd.add_argument("--fixture", "-f", default=None, help="Sync only the named fixture")
sync_cmd.add_argument("--fixtures-dir", default=None)
sync_cmd.add_argument("-v", "--verbose", action="store_true")
# ── interactive ───────────────────────────────────────────────
inter_cmd = sub.add_parser("interactive", help="Interactive journey session (human-in-the-loop)")
inter_cmd.add_argument(
"--fixture", "-f",
help="Journey fixture to use (default: pick interactively)",
)
inter_cmd.add_argument(
"--model", "-m",
default="github_copilot/gpt-5.3-codex",
help="Model for the journey AI (default: github_copilot/gpt-5.3-codex)",
)
inter_cmd.add_argument(
"--judge-model",
default="gpt-4o",
help="Model for LLM judge (default: gpt-4o)",
)
inter_cmd.add_argument(
"--fixtures-dir",
default=None,
help="Path to fixtures directory (default: eval/fixtures/)",
)
inter_cmd.add_argument(
"--data-dir",
default=None,
help="Override sample data directory (e.g. path to private test files not in git)",
)
inter_cmd.add_argument("-v", "--verbose", action="store_true")
return parser.parse_args()
def _fixtures_dir(arg: str | None) -> Path | None:
if arg:
return Path(arg)
return None
async def _cmd_run(args: argparse.Namespace) -> None:
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
if not fixtures:
print("No fixtures found. Create YAML files in eval/fixtures/.")
return
if args.fixture:
fixtures = [f for f in fixtures if f.name == args.fixture]
if not fixtures:
print(f"Fixture '{args.fixture}' not found.")
return
models = [m.strip() for m in args.models.split(",")]
all_results = []
for fixture in fixtures:
if args.mode and fixture.mode != args.mode:
continue
results = await run_fixture_eval(
fixture,
models=models,
use_llm_judge=not args.no_judge,
judge_model=args.judge_model,
)
all_results.extend(results)
print_results(all_results)
def _cmd_list(args: argparse.Namespace) -> None:
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
if not fixtures and not journey_fixtures:
print("No fixtures found.")
return
if fixtures:
print(f"\n{'[Agent Fixtures]'}")
print(f"{'Name':<30} {'Mode':<6} {'Types':<25} {'Expected'}")
print("-" * 90)
for f in fixtures:
types = ", ".join(f.data_types)
n_expected = len(f.expected) + len(f.expected_classification)
print(f"{f.name:<30} {f.mode:<6} {types:<25} {n_expected}")
if journey_fixtures:
print(f"\n{'[Journey Fixtures]'}")
print(f"{'Name':<30} {'Types':<25} {'Messages':<10} {'Criteria'}")
print("-" * 90)
for f in journey_fixtures:
types = ", ".join(f.data_types)
print(f"{f.name:<30} {types:<25} {len(f.user_messages):<10} {len(f.expected_template_criteria)}")
print()
def _cmd_sync(args: argparse.Namespace) -> None:
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
if args.fixture:
fixtures = [f for f in fixtures if f.name == args.fixture]
journey_fixtures = [f for f in journey_fixtures if f.name == args.fixture]
if not fixtures and not journey_fixtures:
print("No fixtures to sync.")
return
for fixture in fixtures:
name = langfuse_eval.sync_fixture_to_dataset(fixture)
if name:
print(f"Synced: {fixture.name}{name}")
else:
print(f"Skipped: {fixture.name} (Langfuse not configured)")
for fixture in journey_fixtures:
name = langfuse_eval.sync_journey_fixture_to_dataset(fixture)
if name:
print(f"Synced: {fixture.name}{name}")
else:
print(f"Skipped: {fixture.name} (Langfuse not configured)")
async def _cmd_interactive(args: argparse.Namespace) -> None:
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
if not journey_fixtures:
print("No journey fixtures found. Create YAML files with type: journey in eval/fixtures/.")
return
if args.fixture:
fixtures = [f for f in journey_fixtures if f.name == args.fixture]
if not fixtures:
print(f"Journey fixture '{args.fixture}' not found.")
return
fixture = fixtures[0]
elif len(journey_fixtures) == 1:
fixture = journey_fixtures[0]
else:
# Let user pick
print("\nAvailable journey fixtures:")
for i, f in enumerate(journey_fixtures, 1):
print(f" {i}. {f.name}{f.description[:60]}")
print()
try:
choice = int(input("Pick a fixture number: ").strip()) - 1
fixture = journey_fixtures[choice]
except (ValueError, IndexError, EOFError, KeyboardInterrupt):
print("Invalid choice.")
return
await run_interactive(
fixture,
model=args.model,
judge_model=args.judge_model,
data_dir=Path(args.data_dir).resolve() if args.data_dir else None,
)
def main() -> None:
args = _parse_args()
_setup_logging(args.verbose)
if args.command == "run":
asyncio.run(_cmd_run(args))
elif args.command == "interactive":
asyncio.run(_cmd_interactive(args))
elif args.command == "list":
_cmd_list(args)
elif args.command == "sync":
_cmd_sync(args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,220 @@
"""Eval configuration — YAML fixture loader and dataclasses.
Fixtures come in two families:
1. **Agent fixtures** — test the batch agent pipeline.
Three modes controlled by ``mode``:
``step1`` — classification prompt only.
``step2`` — processing prompt only.
``full`` — both steps in sequence.
2. **Journey fixtures** — test the prompt-template builder conversation
(unchanged).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
import yaml
logger = logging.getLogger(__name__)
EvalMode = Literal["step1", "step2", "full"]
@dataclass
class ExpectedRecord:
"""A single expected extraction result.
Only the fields specified are checked — unspecified fields are ignored.
"""
table: str # tasks | notes | timelines | projects
fields: dict[str, Any] # field_name → expected_value
@dataclass
class ExpectedClassification:
"""Expected output of step-1 classification for one file."""
file: str # relative path to the sample file
project_id: str # expected matched project id, or "new"
domains: list[str] # expected domain list
new_project_name: str | None = None
@dataclass
class EvalFixture:
"""A complete test scenario loaded from YAML.
``mode`` determines which pipeline steps are exercised:
- **step1**: only ``_classify_file``
- **step2**: only the processing LLM + tool loop
- **full**: both steps in sequence (``run_local_agent``)
"""
name: str
description: str
mode: EvalMode
directory: str # relative path to sample files
data_types: list[str]
file_extensions: list[str]
models: list[str] # if empty, use CLI default
fixture_path: Path = field(default_factory=lambda: Path("."))
# ── Step-1 inputs (classification) ───────────────────────────
domain_definitions: str = ""
projects_list: list[dict[str, Any]] = field(default_factory=list)
custom_step1_prompt: str = ""
# ── Step-2 inputs (processing) ───────────────────────────────
existing_context: str = ""
project_context: str = ""
custom_prompt_section: str = ""
# ── Seed records for mock executor ───────────────────────────
seed_records: dict[str, list[dict]] = field(default_factory=dict)
# ── Expected outputs ─────────────────────────────────────────
expected_classification: list[ExpectedClassification] = field(default_factory=list)
expected: list[ExpectedRecord] = field(default_factory=list)
@property
def fixture_dir(self) -> Path:
"""Absolute path to the sample files directory."""
return self.fixture_path.parent / self.directory
@classmethod
def from_yaml(cls, path: Path) -> "EvalFixture":
"""Load a fixture from a YAML file."""
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
mode: EvalMode = raw.get("mode", "full")
# Parse expected records (step2/full)
expected: list[ExpectedRecord] = []
for table, records in (raw.get("expected") or {}).items():
for rec in records:
expected.append(ExpectedRecord(table=table, fields=rec))
# Parse expected classification (step1/full)
expected_classification: list[ExpectedClassification] = []
for item in raw.get("expected_classification") or []:
expected_classification.append(ExpectedClassification(
file=item["file"],
project_id=item["project_id"],
domains=item.get("domains", []),
new_project_name=item.get("new_project_name"),
))
return cls(
name=raw["name"],
description=raw.get("description", ""),
mode=mode,
directory=raw.get("directory", "sample_files"),
data_types=raw.get("data_types", ["tasks"]),
file_extensions=raw.get("file_extensions", []),
models=raw.get("models", []),
fixture_path=path,
# Step-1 inputs
domain_definitions=raw.get("domain_definitions", ""),
projects_list=raw.get("projects_list", []),
# Step-2 inputs
existing_context=raw.get("existing_context", ""),
project_context=raw.get("project_context", ""),
custom_prompt_section=raw.get("custom_prompt_section", ""),
# Shared
seed_records=raw.get("seed_records", {}),
expected_classification=expected_classification,
expected=expected,
)
def discover_fixtures(fixtures_dir: Path | None = None) -> list[EvalFixture]:
"""Find and load all YAML fixtures in the fixtures directory."""
if fixtures_dir is None:
fixtures_dir = Path(__file__).parent / "fixtures"
fixtures: list[EvalFixture] = []
if not fixtures_dir.is_dir():
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
return fixtures
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
try:
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
if raw.get("type") == "journey":
continue # Skip journey fixtures
fixtures.append(EvalFixture.from_yaml(yaml_path))
logger.info("eval: loaded fixture %s from %s", fixtures[-1].name, yaml_path.name)
except Exception as exc:
logger.error("eval: failed to load fixture %s: %s", yaml_path.name, exc)
return fixtures
# ── Journey fixtures ─────────────────────────────────────────────────────
@dataclass
class JourneyFixture:
"""A journey test scenario — tests the prompt_template builder conversation."""
name: str
description: str
directory: str # relative path to sample files
data_types: list[str]
expected_template_criteria: list[str] # what the template should contain/satisfy
user_messages: list[str] = field(default_factory=list) # for automated journey runs (unused in interactive mode)
models: list[str] = field(default_factory=list)
fixture_path: Path = field(default_factory=lambda: Path("."))
@property
def fixture_dir(self) -> Path:
"""Absolute path to the sample files directory."""
return self.fixture_path.parent / self.directory
@classmethod
def from_yaml(cls, path: Path) -> "JourneyFixture":
"""Load a journey fixture from a YAML file."""
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
return cls(
name=raw["name"],
description=raw.get("description", ""),
directory=raw.get("directory", "sample_files"),
data_types=raw.get("data_types", ["tasks"]),
user_messages=raw.get("user_messages", []),
expected_template_criteria=raw.get("expected_template_criteria", []),
models=raw.get("models", []),
fixture_path=path,
)
def discover_journey_fixtures(fixtures_dir: Path | None = None) -> list[JourneyFixture]:
"""Find and load all journey YAML fixtures in the fixtures directory."""
if fixtures_dir is None:
fixtures_dir = Path(__file__).parent / "fixtures"
fixtures: list[JourneyFixture] = []
if not fixtures_dir.is_dir():
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
return fixtures
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
try:
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
if raw.get("type") != "journey":
continue
fixtures.append(JourneyFixture.from_yaml(yaml_path))
logger.info("eval: loaded journey fixture %s from %s", fixtures[-1].name, yaml_path.name)
except Exception as exc:
logger.error("eval: failed to load journey fixture %s: %s", yaml_path.name, exc)
return fixtures

View File

@@ -0,0 +1,40 @@
# Fixture: classify-invoices (step1)
# Tests _STEP1_SYSTEM_PROMPT — file classification and project matching.
# Verifies that the LLM correctly matches files to existing projects
# and identifies the right data domains.
name: classify-invoices
mode: step1
description: >
Test file classification on Italian freelance invoices and meeting notes.
Verifies project matching and domain identification.
directory: sample_files/invoices
data_types: [tasks, notes, timelines]
file_extensions: [txt, md]
# ── Step-1 prompt variables ──────────────────────────────────────
domain_definitions: |
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
projects_list:
- id: "proj-web-redesign"
name: "Redesign Sito Web Corporate"
status: "active"
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
- id: "proj-ecommerce"
name: "E-Commerce FashionStore"
status: "active"
aiSummary: "Next.js e-commerce platform for FashionStore srl"
# ── Expected classification results ─────────────────────────────
expected_classification:
- file: "sample_files/invoices/fattura_042.txt"
project_id: "proj-web-redesign"
domains: [tasks, notes, timelines]
- file: "sample_files/invoices/meeting_ecommerce.md"
project_id: "proj-ecommerce"
domains: [tasks, notes, timelines]

View File

@@ -0,0 +1,108 @@
# Fixture: full-invoices (full)
# Tests both _STEP1_SYSTEM_PROMPT and _PROCESSING_SYSTEM_PROMPT in sequence
# via run_local_agent(). Verifies end-to-end classification + extraction.
name: full-invoices
mode: full
description: >
End-to-end test: classify Italian invoices/meeting notes into the
correct project, then extract tasks, notes, and timeline events.
directory: sample_files/invoices
data_types: [tasks, notes, timelines]
file_extensions: [txt, md]
# ── Step-1 prompt variables ──────────────────────────────────────
domain_definitions: |
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
projects_list:
- id: "proj-web-redesign"
name: "Redesign Sito Web Corporate"
status: "active"
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
- id: "proj-ecommerce"
name: "E-Commerce FashionStore"
status: "active"
aiSummary: "Next.js e-commerce platform for FashionStore srl"
# ── Step-2 prompt variables ──────────────────────────────────────
existing_context: |
Existing tasks:
(none)
Existing notes:
(none)
Existing timelines:
(none)
project_context: ""
custom_prompt_section: |
User instructions:
Estrai i dati dai file come segue:
- TASK: ogni azione da fare, deliverable, o item con scadenza.
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
Mappa "media priorità" → priority: medium.
Mappa "bassa priorità" → priority: low.
Se un item è marcato come "completato" o [x], impostalo status: done.
Altrimenti status: todo.
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
- TIMELINE: date di scadenza, milestone, meeting futuri.
Imposta sempre isAiSuggested=1.
# ── Seed records (pre-existing DB state) ─────────────────────────
seed_records:
projects:
- id: "proj-web-redesign"
name: "Redesign Sito Web Corporate"
status: "active"
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
- id: "proj-ecommerce"
name: "E-Commerce FashionStore"
status: "active"
aiSummary: "Next.js e-commerce platform for FashionStore srl"
tasks: []
notes: []
timelines: []
# ── Expected classification (step 1) ─────────────────────────────
expected_classification:
- file: "sample_files/invoices/fattura_042.txt"
project_id: "proj-web-redesign"
domains: [tasks, notes, timelines]
- file: "sample_files/invoices/meeting_ecommerce.md"
project_id: "proj-ecommerce"
domains: [tasks, notes, timelines]
# ── Expected extractions (step 2) ────────────────────────────────
expected:
tasks:
- title: "Sviluppo frontend React"
priority: "high"
status: "todo"
- title: "Integrazione API backend"
priority: "medium"
status: "todo"
- title: "Testing cross-browser e fix bug responsive"
status: "todo"
- title: "Preparare wireframe homepage"
priority: "high"
status: "todo"
- title: "Setup progetto Next.js e configurare CI/CD"
priority: "medium"
status: "todo"
- title: "Ricerca plugin Stripe per gestione abbonamenti"
priority: "low"
status: "todo"
notes:
- title: "Meeting Kickoff Progetto E-Commerce"
timelines:
- title: "MVP E-Commerce pronto"
- title: "Meeting di revisione"

View File

@@ -0,0 +1,28 @@
# Journey Fixture: journey-invoice-setup
# Used by `python -m eval interactive` for human-in-the-loop testing
# of the journey chatbot's prompt-building conversation.
type: journey
name: journey-invoice-setup
description: >
Interactive test for the journey chatbot — explore a directory of
Italian invoices and meeting notes, answer the chatbot's questions,
and verify it produces a well-structured prompt_template for data
extraction.
directory: sample_files/invoices
data_types: [tasks, notes, timelines, projects]
# Criteria the generated prompt_template must satisfy
# Each is scored 0-1 by an LLM judge
expected_template_criteria:
- "Mentions creating tasks from action items and work descriptions"
- "Mentions creating notes from meeting summaries"
- "Mentions extracting timeline events from deadlines and meeting dates"
- "Mentions creating projects from relevant information"
- "Sets isAiSuggested=1 on all created records"
- "Does NOT include projectId assignment logic"
- "Uses camelCase field names (title, status, priority, dueDate, content)"
# Models to test (empty = use CLI --models default)
models: []

Some files were not shown because too many files have changed in this diff Show More