50 Commits

Author SHA1 Message Date
Roberto
cc0e258e8c fix(api): WS index frames accept both camelCase and snake_case keys (Electron toSnakeCase compat) 2026-05-13 08:58:46 +02:00
Roberto
12e203e63d fix(api): multi-project manifest lists projects even with zero indexed files 2026-05-12 18:10:57 +02:00
Roberto
ffcd7390f0 feat(api): pagination + search + PDF/DOCX extract in folder agent tools 2026-05-12 17:31:43 +02:00
Roberto
91e880f9d4 fix(api): home agent falls back to multi-project folder manifest when no project_id 2026-05-12 16:54:47 +02:00
Roberto
7d47ca54be feat(api): emit Langfuse generation traces for folder indexer 2026-05-12 16:40:20 +02:00
Roberto
956fa88853 feat(api): multi-project folder manifest for daily brief
Add build_brief_multi_project_manifest() to deep_agent.py that fetches
all project folder manifests via execute_on_client and keeps the top 5
most-recently-modified files per project. Wire into run_home_brief in
brief_agent.py, injecting the <linked_folders> block into the system
prompt alongside FOLDER_TOOLS.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:40:47 +02:00
Roberto
fb2f59ccea feat(api): inject folder manifest into home agent when project context active
Add optional project_id param to run_home_stream. When set, fetch the linked
folder manifest via _fetch_project_manifest and prepend the <linked_folder>
block to the system prompt. Also build an explicit tools list that extends
_all_tools_for_user with FOLDER_TOOLS so the home agent can read folder
files. device_ws._handle_home_request extracts project_id / projectId from
the home_request frame and forwards it to the runner.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:32:20 +02:00
Roberto
56dbb7f4cd feat(api): inject folder manifest into task brief agent
Add _fetch_project_manifest helper that calls read_project_folder_manifest
via execute_on_client. Wire it into run_task_brief_research_stream (new
optional project_id param) so the <linked_folder> block is prepended to the
system prompt when the task belongs to a linked project. Also bind
FOLDER_TOOLS into the task-brief tool palette so the agent can read folder
files. device_ws extracts project_id / projectId from the task_brief_request
frame and forwards it.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:31:21 +02:00
Roberto
506f517851 feat(api): manifest formatter with token-budget truncation 2026-05-12 11:28:13 +02:00
Roberto
520c186991 feat(api): scoped read_project_folder_file tool with traversal guard
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:26:02 +02:00
Roberto
582bf27deb feat(api): WS index_session frames + handlers
Add six v7 WsFrameType enum members (index_session_start/cancel/batch,
index_file_result/progress/done), wire dispatch in device_ws message loop,
and implement _handle_index_session_start/cancel/file_batch with per-file
summarisation, token accounting, and quota enforcement.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:22:20 +02:00
Roberto
2aeb453229 feat(api): PDF + DOCX extraction in folder indexer
Add pypdf/python-docx deps, _extract_pdf_text/_extract_docx_text helpers,
and summarize_pdf/summarize_docx wrappers that delegate to summarize_text.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:15:17 +02:00
Roberto
b7a4edac90 feat(api): folder_indexer.summarize_image via gpt-4o-mini vision
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 11:09:37 +02:00
Roberto
822b4cd8b1 feat(api): folder_indexer.summarize_text via gpt-4o-mini 2026-05-12 11:05:43 +02:00
Roberto
ab24fc4c91 feat(api): POST /billing/quota/check endpoint
Pre-flight quota check for folder_index. Returns 402 with reason
when file cap or monthly token budget would be exceeded; 200 {"ok": true}
otherwise. Also adds auth_headers_free fixture to conftest.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 09:14:56 +02:00
Roberto
a98e99f7a2 feat(api): folder quota helpers with atomic token usage
Implements check_folder_quota and add_token_usage in app/billing/quota.py
with dialect-aware upsert (pg_insert on PostgreSQL, read-then-write on SQLite).
Adds test_user_free/test_user_power fixtures and db alias to conftest.py.
6 new tests pass.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 08:23:22 +02:00
Roberto
a0ff285bcd feat(api): tier features for folder integration
Add folder_max_files and folder_monthly_tokens to all four tier dicts
in FEATURES, and add get_feature_value() helper to TierManager.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:39:36 +02:00
Roberto
177c1a87dd feat(api): MonthlyTokenUsage model + AgentRunLog.tokens_used
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:30:33 +02:00
Roberto
441a4ea05c chore(api): fix stale Revises comment in folder migration 2026-05-12 07:21:13 +02:00
Roberto
a693a64bf5 feat(api): add migration for folder token tracking
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-12 07:16:23 +02:00
Roberto
67562b8092 Add task brief research agent: Stage 1 deep-research + canvas draft emission
- run_task_brief_research() runner with brief-specific tool set and max_steps=12
- New agents: client_agent (list_clients, get_client) and relations_agent (query_relations)
- search_associative tool wrapping MemoryMiddleware semantic search
- BRIEF_RESEARCH_TOOLS constant: read-only task/project/note/timeline + memory + client/relations
- canvas block extraction in output_formatter (splits visible text from <canvas> draft)
- device_ws.py: task_brief_research request type; emits canvas_draft mutation on stream_end
- Stage 2 briefMode: briefing_context injected into floating system prompt when present
- briefingContext kwarg wired through compile_prompt call chain

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-04 15:09:58 +02:00
Roberto
6f4c68b359 Update note management from db vector to index 2026-04-30 00:11:17 +02:00
Roberto
c20c6d7853 Fix home message tools calls 2026-04-29 09:21:41 +02:00
Roberto
6787e690ba fix tools calls 2026-04-27 09:15:08 +02:00
Roberto
cb8f56d909 date format fix 2026-04-26 21:06:38 +02:00
Roberto
2c7cac9e03 Fix using tools in home agent 2026-04-19 14:48:05 +02:00
Roberto
ea9094f47f Add llm providers 2026-04-19 00:32:12 +02:00
Roberto Musso
d5fea95561 Phase 3 — WS frame + REST fallbacka 2026-04-18 22:18:53 +02:00
Roberto Musso
0b5ef48463 Phase 7: audit memory 2026-04-17 22:43:55 +02:00
Roberto Musso
ca8721e1ac PHASE 5 — Proactive mining (Power tier only) 2026-04-17 17:58:30 +02:00
Roberto Musso
f658e5e6a3 fix: clean up stale and obsolete tests
- test_deep_agent: update patch target get_llm -> get_agent_llm (8 tests)
- test_device_ws: remove 5 tests for deleted agent_data_queue API
- test_schemas_v3: remove agent_run/agent_data/agent_complete from v2 compat list
- Delete test_agent_runner.py (superseded by test_agent_runner_v2.py)
- Delete test_agent_setup.py (superseded by test_journey_v2.py)
- Delete test_classify_file.py (_classify_file removed in v2 rewrite)
2026-04-17 17:57:58 +02:00
Roberto Musso
341ee140e5 PHASE 3 — relational tier (Mem0g-light) 2026-04-17 17:04:27 +02:00
Roberto Musso
741b9b87fb PHASE 2 — Mem0-style Extract/Update pipeline 2026-04-16 17:57:49 +02:00
Roberto Musso
2d8abb6311 memory evolution phase 1 2026-04-16 15:46:12 +02:00
Roberto Musso
e668e3fd20 update setting page 2026-04-15 11:43:56 +02:00
Roberto Musso
7ccdad431f feat(i18n): inject user language into AI agent system prompts
- Add _language_instruction() to deep_agent.py, reads language from core memory
- Append language directive to all 4 run_* functions (task/project/checkpoint/note)
- Minor fixes: alembic env, route imports, test cleanup
2026-04-12 00:35:23 +02:00
Roberto Musso
4073863dc6 feat: add onboarding wizard backend - migration, schema, memory routes 2026-04-11 23:38:53 +02:00
Roberto Musso
a85f8fde29 feat(langfuse): propagate user_id and session_id to all traces
- Add hash_user_id() to SHA-256 hash user IDs before sending to Langfuse
- Add langfuse_context() helper wrapping propagate_attributes()
- deep_agent: extract session_id from _debug context, wrap all agent
  runs and classifier with langfuse_context(user_id, session_id)
- agent_runner: add session_id param, pass run_id as session for batch
- agent_setup: wrap journey LLM calls with langfuse_context
- Remove redundant metadata dicts (now handled by propagate_attributes)
2026-04-10 22:44:05 +02:00
Roberto Musso
90500a3462 fix: return 409 when unverified OAuth email conflicts with existing account
Before: branch 3 of oauth_callback attempted to INSERT a user with a
duplicate email → DB constraint violation → 500.

After: if email_verified=False and the email already exists, raise 409
with a message directing the user to sign in with their password.

Also adds test_callback_unverified_email_conflict_returns_409.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 13:46:15 +02:00
Roberto Musso
c1a8ac7669 test: add TestOAuth suite for Google OAuth routes
6 tests covering the authorize and callback endpoints:
- authorize returns URL + state, 503 when unconfigured
- callback: state mismatch → 401, new user creation, existing OAuth
  link re-login (same user sub), email-match auto-linking to password user

Provider methods (exchange_code, get_userinfo) are mocked via AsyncMock
so tests run without hitting Google APIs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 13:42:11 +02:00
Roberto Musso
c510cbaae5 feat: add OAuth web-callback route and update OAUTH_REDIRECT_URI default
GET /auth/oauth/{provider}/web-callback receives the Google redirect and
bounces immediately to adiuvai://oauth/callback deep link. Google Cloud
Console only accepts http/https redirect URIs — adiuvai:// is not valid.
Default OAUTH_REDIRECT_URI now points to localhost:8000 for dev; override
with the API domain env var in production.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 13:03:05 +02:00
Roberto Musso
ce139bbac3 feat: add OAuth DB schema — oauth_accounts table, nullable password_hash, avatar_url on User
Step 1 of Google login integration: Alembic migration for oauth_accounts +
avatar_url on users, OAuthAccount model with User relationship, UserProfile
schema extended with avatar_url, get_current_user updated to include avatar_url.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 09:20:52 +02:00
Roberto Musso
3cf067faea feat: enhance agent configuration and model management with per-agent overrides 2026-04-10 08:45:14 +02:00
Roberto Musso
7253f6fe72 testing journey agent creation 2026-04-09 00:40:16 +02:00
Roberto Musso
41db3a7089 update env variables 2026-04-08 23:52:52 +02:00
Roberto Musso
cc94194fd1 update app name 2026-04-08 23:27:34 +02:00
Roberto Musso
96c91e386d remove deprecated docs 2026-04-08 23:23:14 +02:00
Roberto Musso
c0aef71141 refactor(tests): remove non-deterministic journey eval cases 4.2–4.5
Keep only 4.1 (first reply contains question) as automated eval.
Multi-turn cases (4.2–4.5) are non-deterministic and tested manually
with results tracked in Langfuse.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 09:41:43 +02:00
Roberto Musso
467abc8d42 Merge branch 'develop' into feature/batch-agent-v2 2026-04-08 00:48:23 +02:00
Roberto Musso
5753f8def9 refactor: remove storage, backup, plugin/marketplace features
- Delete app/storage/ (blob_store, vector_store, encryption)
- Delete app/marketplace/ (plugin_registry, plugin_review, revenue_share)
- Delete routes: backup.py, plugins.py, storage.py, vectors.py
- Relocate embed endpoint to POST /chat/embed
- Rewrite migration 001 (remove storage/plugin tables)
- Delete migration 002 (seed_plugins)
- Remove S3/Pinecone/Qdrant env vars from settings
- Remove storage/backup quotas from tier_manager
- Remove MinIO and Qdrant from docker-compose
- Delete tests: test_backup, test_plugins, test_storage
- Update README.md and clean .env.example
2026-04-08 00:47:37 +02:00
102 changed files with 8504 additions and 6850 deletions

View File

@@ -2,7 +2,7 @@
ENV=dev ENV=dev
# ── Database ────────────────────────────────────────────────────────────────── # ── Database ──────────────────────────────────────────────────────────────────
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
# ── Auth ────────────────────────────────────────────────────────────────────── # ── Auth ──────────────────────────────────────────────────────────────────────
JWT_SECRET=replace-with-a-long-random-secret JWT_SECRET=replace-with-a-long-random-secret
@@ -13,38 +13,82 @@ JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
# ── LLM ─────────────────────────────────────────────────────────────────────── # ── LLM ───────────────────────────────────────────────────────────────────────
# LiteLLM model identifiers — change to swap providers without code changes. # LiteLLM model identifiers — change to swap providers without code changes.
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3 # Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
#
# API keys — only the key(s) matching your chosen provider(s) are required.
# The correct key is picked automatically from the model prefix (e.g.
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
OPENAI_API_KEY= OPENAI_API_KEY=
ANTHROPIC_API_KEY= ANTHROPIC_API_KEY=
GOOGLE_API_KEY= GOOGLE_API_KEY=
LLM_MODEL=gpt-4o CEREBRAS_API_KEY=
LLM_ROUTER_MODEL=gpt-4o-mini GROQ_API_KEY=
DEEPSEEK_API_KEY=
# Default model used by any agent that does not have a specific override below.
LLM_MODEL=gpt-5-mini
LLM_EMBED_MODEL=text-embedding-3-small
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
# In Docker, point this to a named-volume path so tokens survive restarts.
# GITHUB_COPILOT_TOKEN_DIR=
# ── Per-agent model overrides ─────────────────────────────────────────────────
# Leave a value empty to fall back to LLM_MODEL.
# Each agent resolves its API key from the model prefix automatically.
#
# Intent classifier — routes user messages to the right domain agent.
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
LLM_MODEL_CLASSIFIER=
# Home-agent — handles chat from the home screen (all tools available).
LLM_MODEL_HOME_AGENT=
# Floating-agent — handles contextual chat triggered from a task/project/note.
LLM_MODEL_FLOATING_AGENT=
# Unified-processor — processes local directory files (local agent runner).
LLM_MODEL_UNIFIED_PROCESSOR=
# Cloud-processor — fetches and processes data from cloud connectors.
LLM_MODEL_CLOUD_PROCESSOR=
# Brief-agent — produces home and project text briefs.
# A small model (e.g. gpt-4o-mini) is sufficient.
# LLM_MODEL_BRIEF_AGENT=
# Task-brief-agent — per-task deep research (Stage 1 executive assistant).
# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash).
# LLM_MODEL_TASK_BRIEF_AGENT=
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
LLM_MODEL_SETUP_AGENT=
# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2).
# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0).
LLM_MODEL_MEMORY_EXTRACTOR=
# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only).
# Defaults to gpt-4o-mini when empty.
LLM_MODEL_MEMORY_MINER=
# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7).
# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended).
LLM_MODEL_MEMORY_AUDITOR=
# Scheduler — set to false to disable memory cron jobs (automatically false in tests).
SCHEDULER_ENABLED=true
# ── Stripe (leave empty to stub billing) ────────────────────────────────────── # ── Stripe (leave empty to stub billing) ──────────────────────────────────────
STRIPE_SECRET_KEY= STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET= STRIPE_WEBHOOK_SECRET=
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
# ── Vector Store ──────────────────────────────────────────────────────────────
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
PINECONE_API_KEY=
PINECONE_INDEX=adiuva
QDRANT_URL=
QDRANT_API_KEY=
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
# ── Langfuse (leave empty to disable observability) ─────────────────────────── # ── Langfuse (leave empty to disable observability) ───────────────────────────
LANGFUSE_SECRET_KEY= LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY= LANGFUSE_PUBLIC_KEY=
# LANGFUSE_HOST=https://cloud.langfuse.com # EU (default) # LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default)
# LANGFUSE_HOST=https://us.cloud.langfuse.com # US # LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US
# LANGFUSE_HOST=http://localhost:3000 # Self-hosted # LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted
# ── CORS ────────────────────────────────────────────────────────────────────── # ── CORS ──────────────────────────────────────────────────────────────────────
# Comma-separated list parsed by Settings (override default if needed) # Comma-separated list parsed by Settings (override default if needed)

View File

@@ -48,23 +48,23 @@ jobs:
key: ${{ secrets.SSH_KEY }} key: ${{ secrets.SSH_KEY }}
script: | script: |
set -e set -e
DEPLOY_DIR="/opt/adiuva-api" DEPLOY_DIR="/opt/adiuvai-api"
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git" REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
TAG="${{ gitea.ref_name }}" TAG="${{ gitea.ref_name }}"
# ── Pull latest code ── # ── Pull latest code ──
cd /tmp && rm -rf adiuva-api-deploy cd /tmp && rm -rf adiuvai-api-deploy
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
# ── Sync source (preserve .env) ── # ── Sync source (preserve .env) ──
cp -rf /tmp/adiuva-api-deploy/app/ \ cp -rf /tmp/adiuvai-api-deploy/app/ \
/tmp/adiuva-api-deploy/alembic/ \ /tmp/adiuvai-api-deploy/alembic/ \
/tmp/adiuva-api-deploy/alembic.ini \ /tmp/adiuvai-api-deploy/alembic.ini \
/tmp/adiuva-api-deploy/Dockerfile \ /tmp/adiuvai-api-deploy/Dockerfile \
/tmp/adiuva-api-deploy/docker-compose.yml \ /tmp/adiuvai-api-deploy/docker-compose.yml \
/tmp/adiuva-api-deploy/requirements.txt \ /tmp/adiuvai-api-deploy/requirements.txt \
"$DEPLOY_DIR/" "$DEPLOY_DIR/"
rm -rf /tmp/adiuva-api-deploy rm -rf /tmp/adiuvai-api-deploy
# ── Verify .env ── # ── Verify .env ──
if [ ! -f "$DEPLOY_DIR/.env" ]; then if [ ! -f "$DEPLOY_DIR/.env" ]; then

View File

@@ -58,7 +58,7 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Build image - name: Build image
run: docker build -t adiuva-api:ci . run: docker build -t adiuvai-api:ci .
- name: Verify gunicorn installed - name: Verify gunicorn installed
run: docker run --rm adiuva-api:ci gunicorn --version run: docker run --rm adiuvai-api:ci gunicorn --version

4
.gitignore vendored
View File

@@ -21,12 +21,16 @@ env/
.pytest_cache/ .pytest_cache/
htmlcov/ htmlcov/
.coverage .coverage
tests/fixtures/private*/
# Docker # Docker
*.log *.log
# OS # OS
.DS_Store .DS_Store
# Smoke scripts (dev-only, not for CI)
scripts/smoke_*.py
Thumbs.db Thumbs.db
# Claude Code # Claude Code

794
README.md
View File

@@ -1,793 +1,5 @@
# Adiuva Cloud API ## DEV
Run in DEV with command:
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
---
## Table of Contents
- [Overview](#overview)
- [Architecture](#architecture)
- [Key Features](#key-features)
- [Tech Stack](#tech-stack)
- [Getting Started](#getting-started)
- [Docker Deployment](#docker-deployment)
- [Environment Variables](#environment-variables)
- [API Reference](#api-reference)
- [Data Model](#data-model)
- [AI Agent System](#ai-agent-system)
- [Orchestration & Execution Plans](#orchestration--execution-plans)
- [Middleware](#middleware)
- [Storage Layer](#storage-layer)
- [Billing & Tiers](#billing--tiers)
- [Plugin Marketplace](#plugin-marketplace)
- [Testing](#testing)
- [Project Structure](#project-structure)
- [License](#license)
---
## Overview
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
### Design Principles
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
---
## Architecture
``` ```
┌──────────────┐ ┌────────────────────────────────────────────────────────┐ uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-config logging.conf
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
│ Desktop App │────▶│ │
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
└──────────────┘ │ │
│ ┌──────────────────┐ ┌────────────────────────────┐ │
│ │ Auth Routes │ │ Chat Routes │ │
│ │ Billing Routes │ │ ↓ │ │
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
│ │ Backup Routes │ │ ↓ classify intent │ │
│ │ Plugin Routes │ │ Agent Registry │ │
│ │ Vector Routes │ │ ↓ │ │
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
│ │ (GPT-4o + LangChain) │ │
│ └────────────────────────────┘ │
└────────────────────────────────────────────────────────┘
│ │ │
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
│ Billing, │ │ backups) │ │ (Vectors) │
│ Metadata) │ └───────────────┘ └────────────────┘
└────────────┘
┌────────▼───┐
│ Stripe │
│ (Billing, │
│ Connect) │
└────────────┘
``` ```
---
## Key Features
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
---
## Tech Stack
| Package | Version | Purpose |
|---|---|---|
| `fastapi` | ≥ 0.115.0 | Web framework |
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
| `gunicorn` | ≥ 22.0.0 | Production process manager |
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
| `boto3` | ≥ 1.35.0 | AWS S3 client |
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
| `alembic` | ≥ 1.14.0 | Database migration management |
| `bcrypt` | ≥ 4.2.0 | Password hashing |
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
| `websockets` | ≥ 14.0 | WebSocket protocol support |
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
| `pytest` | ≥ 8.0.0 | Test framework |
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
| `ruff` | ≥ 0.8.0 | Linter and formatter |
---
## Getting Started
### Prerequisites
- Python 3.12+
- PostgreSQL 16+
- An OpenAI API key (for LLM features)
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
- AWS credentials (optional — needed for S3 storage in production)
### Installation
```bash
# Clone the repository
git clone <repo-url> && cd adiuva-api
# Create a virtual environment
python -m venv .venv && source .venv/bin/activate
# Install dependencies
pip install -r requirements.txt
# Configure environment
cp .env.example .env
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
```
### Database Setup
```bash
# Start PostgreSQL (or use the Docker Compose database)
docker compose up db -d
# Run migrations
alembic upgrade head
```
### Run the Development Server
```bash
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
---
## Docker Deployment
### Quick Start
```bash
docker compose up --build
```
This starts two services:
- **app** — FastAPI server on port `8000`
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
The compose file also includes optional services for fully local deployments:
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
### Dockerfile Details
The Dockerfile uses a multi-stage build:
1. **Builder stage** — Installs Python dependencies into a virtual environment.
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
```bash
# Production command (run by the container)
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
```
---
## Homelab / Self-Hosted Deployment
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
### 1. Start all services
```bash
docker compose up -d
```
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
### 2. Create the MinIO bucket
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
```bash
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
docker compose exec minio mc mb local/adiuva
```
### 3. Configure your `.env`
```bash
# Database (uses the compose PostgreSQL)
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
# S3 → MinIO
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=http://minio:9000
AWS_ACCESS_KEY_ID=minioadmin
AWS_SECRET_ACCESS_KEY=minioadmin
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
QDRANT_URL=http://qdrant:6333
QDRANT_API_KEY=
PINECONE_API_KEY=
# Billing — leave empty to stub (no Stripe needed)
STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=
# LLM — the only external service
OPENAI_API_KEY=sk-...
LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# Auth
JWT_SECRET=your-secret-here
ENV=dev
```
### 4. Run migrations
```bash
docker compose exec app alembic upgrade head
```
### What runs where
| Service | Runs on | Port | Notes |
|---|---|---|---|
| FastAPI app | Docker | 8000 | API server |
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
| Stripe | — | — | Stubbed when keys are empty |
| OpenAI / LLM | Cloud | — | Only external dependency |
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
---
## Environment Variables
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
| Variable | Type | Default | Description |
|---|---|---|---|
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
| `S3_REGION` | `str` | `us-east-1` | AWS region |
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
---
## API Reference
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
### Health
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
### Auth
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
### Chat
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
### Plans
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
### Storage (Cloud Records)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
### Vectors (Cloud Vector Store)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
### Backup
| Method | Path | Auth | Description |
|---|---|---|---|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
### Plugins (Marketplace)
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
### Billing
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
---
## Data Model
9 tables managed by Alembic migrations. Source: `app/models.py`
### Tables
| Table | Primary Key | Key Columns | Purpose |
|---|---|---|---|
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
### Enum Types
| Enum | Values |
|---|---|
| `billing_tier` | `free`, `pro`, `power`, `team` |
| `plugin_status` | `pending_review`, `approved`, `rejected` |
| `review_decision` | `approved`, `rejected` |
### Migrations
| Version | Description |
|---|---|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
---
## AI Agent System
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
### Architecture
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
### Registered Agents
| Agent | Registry Name | Tools | Description |
|---|---|---|---|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
### Switching LLM Providers
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
```bash
# OpenAI (default)
LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# Anthropic
LLM_MODEL=anthropic/claude-3.5-sonnet
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
# Google Gemini
LLM_MODEL=gemini/gemini-pro
LLM_ROUTER_MODEL=gemini/gemini-flash
# Local Ollama
LLM_MODEL=ollama/llama3
LLM_ROUTER_MODEL=ollama/llama3
# AWS Bedrock
LLM_MODEL=bedrock/anthropic.claude-v2
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
```
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
---
## Orchestration & Execution Plans
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
### Orchestrator
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
### Execution Plans
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
### Built-in Templates (6)
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
### Built-in Playbooks (2)
| Playbook | Description |
|---|---|
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
---
## Middleware
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
### JWT Authentication
Source: `app/api/middleware/auth.py`
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
- Falls back to `free` when no subscription row exists.
- Raises `401 Unauthorized` on invalid or expired tokens.
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
### Tier-Based Rate Limiter
Source: `app/api/middleware/rate_limit.py`
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
- Per-user 60-second window sized by subscription tier:
| Tier | Requests / Minute |
|---|---|
| Free | 20 |
| Pro | 60 |
| Power | 120 |
| Team | 200 |
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
- **Exempt paths:** register, login, webhook, health
### Response Sanitizer
Source: `app/api/middleware/sanitizer.py`
- Runs only on `/api/v1/chat` endpoints.
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
- Logs sanitization events as `WARNING`.
- Binary responses (storage, backup) are never touched.
---
## Storage Layer
### Blob Store
Source: `app/storage/blob_store.py`
- S3-backed storage for E2E encrypted blobs.
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
- The backend **never inspects or decrypts blob content**.
### Vector Store
Source: `app/storage/vector_store.py`
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
- Methods: `upsert()`, `search()`, `delete()`
### Encryption Utilities
Source: `app/storage/encryption.py`
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
- **No decryption key ever reaches the backend.**
---
## Billing & Tiers
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
### Feature Matrix
| Feature | Free | Pro | Power | Team |
|---|---|---|---|---|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
| Batch Active | 2 | 10 | Unlimited | Unlimited |
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
| Batch Builder | — | — | ✓ | ✓ |
| Plugin Marketplace | — | — | ✓ | ✓ |
| SSO | — | — | — | ✓ |
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
### Stripe Integration
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
### Tier Manager
- `get_tier(user_id)` — Returns the user's current billing tier.
- `check_feature(tier, feature)` — Boolean feature gate check.
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
---
## Plugin Marketplace
Source: `app/marketplace/`
### Plugin Registry
- PostgreSQL-backed catalog of submitted and approved plugins.
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
### Review Queue
- Automated security checklist before human review:
- Plugin ID must match `^[a-z0-9-]+$`
- Permissions must be from the allowed set only
- No binary blobs in the manifest
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
- `get_pending(db)` — Lists plugins awaiting review.
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
### Revenue Sharing
- **70% developer / 30% platform** split on all paid plugin sales.
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
- Gracefully stubs transfers when Stripe is not configured.
### Seed Plugins
| Plugin | Category | Price |
|---|---|---|
| GitHub Sync | Productivity | Free |
| Slack Notifier | Communication | €4.99 |
| Time Tracker | Productivity | €9.99 |
---
## Testing
### Running Tests
```bash
# Run all tests
pytest
# Run a specific test file
pytest tests/test_auth.py
# Run with verbose output
pytest -v
```
### Test Infrastructure
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
- **No external dependencies** — all tests run fully offline.
### Test Coverage
| File | Coverage |
|---|---|
| `test_auth.py` | Register, login, token access, refresh, expiration |
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
---
## Project Structure
```
adiuva-api/
├── alembic.ini # Alembic configuration
├── BACKEND_PLAN.md # Architecture & design decisions
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
├── Dockerfile # Multi-stage production build
├── requirements.txt # Python dependencies
├── alembic/ # Database migrations
│ ├── env.py # Alembic environment config
│ ├── script.py.mako # Migration template
│ └── versions/
│ ├── 001_initial_schema.py # Tables, indexes, FKs
│ └── 002_seed_plugins.py # Seed marketplace plugins
├── app/ # Application source
│ ├── main.py # FastAPI app factory, middleware, routes
│ ├── db.py # Async SQLAlchemy engine & session
│ ├── models.py # SQLAlchemy ORM models (9 tables)
│ ├── schemas.py # Pydantic request/response schemas
│ │
│ ├── config/
│ │ └── settings.py # Pydantic Settings (env vars)
│ │
│ ├── agents/ # LLM-powered domain agents
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
│ │ ├── project_agent.py # Project lifecycle (6 tools)
│ │ ├── timeline_agent.py # Milestones (4 tools)
│ │ └── note_agent.py # Markdown notes (5 tools)
│ │
│ ├── core/ # Orchestration engine
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
│ │ ├── orchestrator.py # Intent classification & routing
│ │ └── execution_plan.py # Plan builder, templates, cache
│ │
│ ├── api/ # HTTP layer
│ │ ├── deps.py # Shared FastAPI dependencies
│ │ ├── middleware/
│ │ │ ├── auth.py # JWT validation, live tier lookup
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
│ │ │ └── sanitizer.py # Prompt IP leak protection
│ │ └── routes/
│ │ ├── auth.py # Register, login, refresh, me
│ │ ├── chat.py # Chat + WebSocket streaming
│ │ ├── plans.py # Execution plan playbooks
│ │ ├── storage.py # E2E encrypted record CRUD
│ │ ├── vectors.py # Vector upsert, search, delete
│ │ ├── backup.py # Encrypted backup management
│ │ ├── plugins.py # Marketplace browse & install
│ │ └── billing.py # Stripe checkout & webhooks
│ │
│ ├── storage/ # Storage backends
│ │ ├── blob_store.py # S3 blob storage
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
│ │ └── encryption.py # Checksum verification utilities
│ │
│ ├── billing/ # Subscription management
│ │ ├── stripe_service.py # Stripe API integration
│ │ └── tier_manager.py # Feature matrix & quota enforcement
│ │
│ └── marketplace/ # Plugin ecosystem
│ ├── plugin_registry.py # Catalog CRUD & search
│ ├── plugin_review.py # Security checklist & review queue
│ └── revenue_share.py # 70/30 split & Stripe Connect
└── tests/ # Test suite
├── conftest.py # Fixtures: DB, S3, auth, seeds
├── test_auth.py
├── test_orchestrator.py
├── test_agents.py
├── test_storage.py
├── test_backup.py
├── test_plugins.py
├── test_agent_registry.py
├── test_execution_plan.py
└── test_middleware.py
```
---
## License
*To be determined.*

View File

@@ -16,7 +16,7 @@ import re
from logging.config import fileConfig from logging.config import fileConfig
from alembic import context from alembic import context
from sqlalchemy import engine_from_config, pool from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.ext.asyncio import create_async_engine
# Alembic Config object (gives access to alembic.ini values). # Alembic Config object (gives access to alembic.ini values).

View File

@@ -1,5 +1,4 @@
"""Initial schema: users, refresh_tokens, subscriptions, storage_records, """Initial schema: users, refresh_tokens, subscriptions.
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
Revision ID: 001 Revision ID: 001
Revises: Revises:
@@ -28,18 +27,6 @@ def upgrade() -> None:
EXCEPTION WHEN duplicate_object THEN NULL; EXCEPTION WHEN duplicate_object THEN NULL;
END $$; END $$;
""") """)
op.execute("""
DO $$ BEGIN
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
# ── users ───────────────────────────────────────────────────────────── # ── users ─────────────────────────────────────────────────────────────
op.create_table( op.create_table(
@@ -88,122 +75,10 @@ def upgrade() -> None:
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"]) op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"]) op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
# ── storage_records ───────────────────────────────────────────────────
op.create_table(
"storage_records",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("table_name", sa.String(100), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
# ── backup_metadata ───────────────────────────────────────────────────
op.create_table(
"backup_metadata",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column("timestamp", sa.BigInteger, nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
# ── plugins ───────────────────────────────────────────────────────────
op.create_table(
"plugins",
sa.Column("id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("description", sa.Text, nullable=False, server_default=""),
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
sa.Column("category", sa.String(100), nullable=False, server_default=""),
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
sa.Column("s3_package_key", sa.String(500), nullable=True),
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
sa.Column("rejection_reason", sa.Text, nullable=True),
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
)
# ── plugin_installations ──────────────────────────────────────────────
op.create_table(
"plugin_installations",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
)
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
# ── plugin_reviews ────────────────────────────────────────────────────
op.create_table(
"plugin_reviews",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
sa.Column("notes", sa.Text, nullable=True),
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
)
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
# ── revenue_events ────────────────────────────────────────────────────
op.create_table(
"revenue_events",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
def downgrade() -> None: def downgrade() -> None:
op.drop_table("revenue_events")
op.drop_table("plugin_reviews")
op.drop_table("plugin_installations")
op.drop_table("plugins")
op.drop_table("backup_metadata")
op.drop_table("storage_records")
op.drop_table("subscriptions") op.drop_table("subscriptions")
op.drop_table("refresh_tokens") op.drop_table("refresh_tokens")
op.drop_table("users") op.drop_table("users")
op.execute("DROP TYPE IF EXISTS review_decision")
op.execute("DROP TYPE IF EXISTS plugin_status")
op.execute("DROP TYPE IF EXISTS billing_tier") op.execute("DROP TYPE IF EXISTS billing_tier")

View File

@@ -1,92 +0,0 @@
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
Revision ID: 002
Revises: 001
Create Date: 2026-03-03
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_SEED_PLUGINS = [
{
"id": "plugin-github-sync",
"name": "GitHub Sync",
"description": "Sync tasks with GitHub Issues and pull requests.",
"version": "1.0.0",
"author_name": "Adiuva",
"category": "productivity",
"price_cents": 0,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-slack-notify",
"name": "Slack Notifier",
"description": "Post task and timeline updates to Slack channels.",
"version": "1.2.0",
"author_name": "Adiuva",
"category": "communication",
"price_cents": 499,
"permissions": json.dumps(["read:tasks", "read:timelines"]),
"status": "approved",
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-time-tracker",
"name": "Time Tracker",
"description": "Track time spent on tasks with automatic reporting.",
"version": "0.9.1",
"author_name": "Third Party",
"category": "productivity",
"price_cents": 999,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
]
def upgrade() -> None:
plugins = sa.table(
"plugins",
sa.column("id", sa.String),
sa.column("name", sa.String),
sa.column("description", sa.Text),
sa.column("version", sa.String),
sa.column("author_name", sa.String),
sa.column("category", sa.String),
sa.column("price_cents", sa.Integer),
sa.column("permissions", sa.Text),
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
sa.column("s3_package_key", sa.String),
sa.column("install_count", sa.Integer),
sa.column("avg_rating", sa.Float),
)
op.bulk_insert(plugins, _SEED_PLUGINS)
def downgrade() -> None:
op.execute(
"DELETE FROM plugins WHERE id IN ("
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
")"
)

View File

@@ -14,7 +14,7 @@ from alembic import op
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
revision: str = "003" revision: str = "003"
down_revision: Union[str, None] = "002" down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None

View File

@@ -0,0 +1,54 @@
"""Phase 1 — confirm pgvector activation on memory_associative.
Migration 004 created the embedding column as vector(1536) and added the
IVFFlat index. This migration is the Phase-1 checkpoint:
1. Ensures the pgvector extension is enabled (idempotent).
2. Ensures the canonical Phase-1 IVFFlat index exists under the name
memory_associative_embedding_idx (creates it only if absent).
Revision ID: 005
Revises: 9a1f2d0b6c7e
Create Date: 2026-04-15
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
revision: str = "005"
down_revision: Union[str, None] = "e04100e88ace"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Ensure pgvector extension is enabled (also done in 004, idempotent).
op.execute("CREATE EXTENSION IF NOT EXISTS vector;")
# Ensure the canonical Phase-1 IVFFlat index exists.
# 004 may have created ix_memory_associative_embedding; this adds the
# Phase-1 name memory_associative_embedding_idx if it is missing.
op.execute(
"""
DO $$
BEGIN
IF NOT EXISTS (
SELECT 1
FROM pg_indexes
WHERE tablename = 'memory_associative'
AND indexname = 'memory_associative_embedding_idx'
) THEN
CREATE INDEX memory_associative_embedding_idx
ON memory_associative
USING ivfflat (embedding vector_cosine_ops)
WITH (lists = 100);
END IF;
END $$;
"""
)
def downgrade() -> None:
op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;")

View File

@@ -0,0 +1,74 @@
"""Add memory_relations table (Phase 3 — relational tier).
Revision ID: 006
Revises: 1f5975a4f3f4
Create Date: 2026-04-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "006"
down_revision: Union[str, None] = "1f5975a4f3f4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"memory_relations",
sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True),
sa.Column(
"user_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("users.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("subject_label", sa.String(128), nullable=False),
sa.Column("subject_type", sa.String(32), nullable=False),
sa.Column("predicate", sa.String(64), nullable=False),
sa.Column("object_label", sa.String(128), nullable=False),
sa.Column("object_type", sa.String(32), nullable=False),
sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"),
sa.Column(
"source_episode_id",
postgresql.UUID(as_uuid=False),
sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("notes_encrypted", sa.LargeBinary, nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.func.now(),
),
sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True),
)
op.create_index(
"memory_relations_user_subject_idx",
"memory_relations",
["user_id", "subject_label"],
)
op.create_index(
"memory_relations_user_predicate_idx",
"memory_relations",
["user_id", "predicate"],
)
def downgrade() -> None:
op.drop_index("memory_relations_user_predicate_idx", "memory_relations")
op.drop_index("memory_relations_user_subject_idx", "memory_relations")
op.drop_table("memory_relations")

View File

@@ -0,0 +1,38 @@
"""add extraction_queue
Revision ID: 1f5975a4f3f4
Revises: 005
Create Date: 2026-04-16 17:26:25.790870
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = '1f5975a4f3f4'
down_revision: Union[str, None] = '005'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
'extraction_queue',
sa.Column('id', sa.Uuid(as_uuid=False), nullable=False),
sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False),
sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id'),
)
op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False)
def downgrade() -> None:
op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue')
op.drop_table('extraction_queue')

View File

@@ -1,4 +1,8 @@
"""add agent_config to local_agent_configs """Restore agent config tables and add agent_config column.
9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both
ORM models are still active. This migration recreates them with agent_config
added to local_agent_configs.
Revision ID: a3b9c0d1e2f3 Revision ID: a3b9c0d1e2f3
Revises: 9a1f2d0b6c7e Revises: 9a1f2d0b6c7e
@@ -9,8 +13,9 @@ from __future__ import annotations
from typing import Sequence, Union from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic. # revision identifiers, used by Alembic.
@@ -21,11 +26,82 @@ depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None: def upgrade() -> None:
op.add_column( # Recreate enum types (idempotent — they may already exist from migration 003)
op.execute("""
DO $$ BEGIN
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
bind = op.get_bind()
inspector = sa.inspect(bind)
existing = set(inspector.get_table_names())
# ── local_agent_configs (with agent_config column) ────────────────────
if "local_agent_configs" not in existing:
op.create_table(
"local_agent_configs", "local_agent_configs",
sa.Column("agent_config", sa.JSON(), nullable=True), sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("device_id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("agent_config", sa.JSON, nullable=True),
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
) )
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
# ── cloud_agent_configs ───────────────────────────────────────────────
if "cloud_agent_configs" not in existing:
op.create_table(
"cloud_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column(
"provider",
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
nullable=False,
),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
sa.Column("filter_config", sa.JSON, nullable=True),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
def downgrade() -> None: def downgrade() -> None:
op.drop_column("local_agent_configs", "agent_config") op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
op.drop_table("cloud_agent_configs")
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
op.drop_table("local_agent_configs")

View File

@@ -0,0 +1,56 @@
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
Revision ID: b4c0d1e2f3a4
Revises: a3b9c0d1e2f3
Create Date: 2026-04-10 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "b4c0d1e2f3a4"
down_revision: Union[str, None] = "a3b9c0d1e2f3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── users: make password_hash nullable (social users have no password) ──
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
# ── users: add avatar_url ─────────────────────────────────────────────
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
# ── oauth_accounts ────────────────────────────────────────────────────
op.create_table(
"oauth_accounts",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("provider", sa.String(50), nullable=False),
sa.Column("provider_user_id", sa.String(255), nullable=False),
sa.Column("provider_email", sa.String(255), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
)
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
def downgrade() -> None:
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
op.drop_table("oauth_accounts")
op.drop_column("users", "avatar_url")
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)

View File

@@ -0,0 +1,31 @@
"""Add onboarding_completed_at column to users table.
Revision ID: c5d1e2f3a4b5
Revises: b4c0d1e2f3a4
Create Date: 2026-04-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c5d1e2f3a4b5"
down_revision: Union[str, None] = "b4c0d1e2f3a4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"users",
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("users", "onboarding_completed_at")

View File

@@ -0,0 +1,46 @@
"""Add token tracking columns for folder integration.
Revision ID: d6e3f4a5b6c7
Revises: 006
Create Date: 2026-05-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects.postgresql import UUID
# revision identifiers, used by Alembic.
revision: str = "d6e3f4a5b6c7"
down_revision: Union[str, None] = "006"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"agent_run_logs",
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
)
op.create_table(
"monthly_token_usage",
sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False),
sa.Column("year_month", sa.String(7), nullable=False),
sa.Column("feature", sa.String(64), nullable=False),
sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"),
sa.PrimaryKeyConstraint("user_id", "year_month", "feature"),
)
op.create_index(
"ix_monthly_token_usage_user_month",
"monthly_token_usage",
["user_id", "year_month"],
)
def downgrade() -> None:
op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage")
op.drop_table("monthly_token_usage")
op.drop_column("agent_run_logs", "tokens_used")

View File

@@ -0,0 +1,34 @@
"""avatar_url_varchar_to_text
Revision ID: e04100e88ace
Revises: c5d1e2f3a4b5
Create Date: 2026-04-13 09:13:06.733674
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'e04100e88ace'
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.alter_column('users', 'avatar_url',
existing_type=sa.VARCHAR(length=2048),
type_=sa.Text(),
existing_nullable=True)
def downgrade() -> None:
op.alter_column('users', 'avatar_url',
existing_type=sa.Text(),
type_=sa.VARCHAR(length=2048),
existing_nullable=True)

View File

@@ -0,0 +1,52 @@
"""Client agent — read-only tools for the clients table."""
from __future__ import annotations
import json
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
@tool
async def list_clients(search: str = "", limit: int = 20) -> str:
"""List clients, optionally filtered by a name/email substring search.
search: optional substring to match against client name or email.
limit: max rows to return (default 20).
"""
filters: dict[str, Any] = {"limit": limit}
if search:
filters["search"] = search
result = await execute_on_client(action="select", table="clients", filters=filters)
rows = result.get("rows", [])
if not rows:
return "No clients found."
lines = [
f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, "
f"company: {r.get('company', '')})"
for r in rows
]
return f"Found {len(rows)} client(s):\n" + "\n".join(lines)
@tool
async def get_client(id: str) -> str:
"""Get full details for one client by UUID.
id: the client's UUID.
"""
if not id:
return "Client id is required."
result = await execute_on_client(action="get", table="clients", data={"id": id})
row = result.get("row") or result.get("rows", [None])[0] if result else None
if not row:
return f"Client '{id}' not found."
return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}"
CLIENT_TOOLS: list[Any] = [list_clients, get_client]

View File

@@ -7,12 +7,31 @@ handles actual disk I/O and responds with ``tool_result`` frames.
from __future__ import annotations from __future__ import annotations
import os
import re
from pathlib import Path
from typing import Any from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
# Max characters returned by read_file_content in journey (exploration) tools.
# The journey only needs to understand file structure, not full content.
_JOURNEY_READ_MAX_CHARS: int = 4000
def _resolve_path(path: str, base: str) -> str:
"""Resolve *path* against *base* when *path* is relative.
The LLM often passes ``"."`` meaning "the configured directory".
Without this, Electron resolves ``"."`` relative to its own CWD instead
of the user's chosen directory.
"""
if os.path.isabs(path):
return path
return str(Path(base) / path)
@tool @tool
async def list_directory(path: str) -> str: async def list_directory(path: str) -> str:
@@ -83,3 +102,93 @@ FILESYSTEM_TOOLS: list[Any] = [
read_file_content, read_file_content,
get_file_metadata, get_file_metadata,
] ]
def make_directory_tools(base_directory: str) -> list[Any]:
"""Return filesystem tools that resolve relative paths against *base_directory*.
Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target
directory upfront (e.g., journey setup sessions). Relative paths like ``"."``
from the LLM are resolved to the correct absolute path before being sent to
the Electron client, preventing it from falling back to its own CWD.
"""
def _compact_for_journey(raw: str) -> str:
"""Strip HTML noise and truncate for journey exploration.
The journey LLM only needs to understand file structure (headers,
first paragraphs). Full CSS/style blocks are pure noise that eat
up context window budget.
"""
text = re.sub(r"<style[^>]*>.*?</style>", "", raw, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<script[^>]*>.*?</script>", "", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)
if len(text) > _JOURNEY_READ_MAX_CHARS:
text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]"
return text
@tool
async def list_directory(path: str) -> str: # noqa: F811
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
resolved = _resolve_path(path, base_directory)
result = await execute_on_client(
action="list_directory",
data={"path": resolved},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{resolved}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str: # noqa: F811
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
resolved = _resolve_path(path, base_directory)
result = await execute_on_client(
action="read_file_content",
data={"path": resolved},
)
content: str = result.get("content", "")
if not content:
return f"File '{resolved}' is empty or could not be read."
return _compact_for_journey(content)
@tool
async def get_file_metadata(path: str) -> str: # noqa: F811
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
resolved = _resolve_path(path, base_directory)
result = await execute_on_client(
action="get_file_metadata",
data={"path": resolved},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", resolved)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
return [list_directory, read_file_content, get_file_metadata]

168
app/agents/folder_agent.py Normal file
View File

@@ -0,0 +1,168 @@
"""Scoped file-read and search tools for the project folder feature."""
from __future__ import annotations
from langchain_core.tools import tool
from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text
from app.core.ws_context import execute_on_client
# Cap returned slice size to keep tool output under control.
_MAX_RETURN_CHARS = 50_000
_MAX_SEARCH_MATCHES = 20
def _is_unsafe_path(rel: str) -> bool:
if not rel:
return True
norm = rel.replace("\\", "/")
if norm.startswith("/"):
return True
# Windows drive letter
if len(rel) >= 2 and rel[1] == ":":
return True
parts = norm.split("/")
return ".." in parts
async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict:
"""Return the raw Electron tool_result dict for a file read."""
return await execute_on_client(
action="read_project_folder_file",
data={
"projectId": project_id,
"relativePath": relative_path,
"offset": offset,
"length": length,
},
)
def _decode(result: dict) -> tuple[str, str, int]:
"""Decode a tool_result into (text, kind, total_size). For pdf/docx,
extracts text from base64. For images, returns a placeholder string.
For text, content is already a sliced utf-8 string.
"""
kind = result.get("kind", "text")
content = result.get("content", "") or ""
total = int(result.get("totalSize", 0) or 0)
if kind == "image":
return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total)
if kind == "pdf":
return (_extract_pdf_text(content), kind, total)
if kind == "docx":
return (_extract_docx_text(content), kind, total)
return (content, kind, total)
@tool
async def read_project_folder_file(
project_id: str,
relative_path: str,
offset: int = 0,
length: int = _MAX_RETURN_CHARS,
) -> str:
"""Read a slice of a file inside the project's linked folder.
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
offset: char offset to start reading from (0 = beginning).
length: max chars to return. Default 50000. Use smaller values to save tokens.
Returns text content slice with a header showing position. Header tells you
when more content is available; call again with the suggested next offset.
For PDF / DOCX files the backend extracts text first, then applies offset/length
on the extracted text. For images returns a placeholder; navigate with the
manifest summary instead.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
result = await _fetch_file(project_id, relative_path, offset, length)
text, kind, total_size = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind in ("pdf", "docx"):
# Backend extracted full text — apply offset/length on chars.
sliced = text[offset:offset + length]
slice_end = min(offset + length, len(text))
header = (
f"[file={relative_path} kind={kind} offset={offset} end={slice_end} "
f"totalChars={len(text)}]"
)
if slice_end < len(text):
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + sliced
if kind == "text":
slice_end = offset + len(text)
header = (
f"[file={relative_path} kind=text offset={offset} end={slice_end} "
f"totalBytes={total_size}]"
)
if slice_end < total_size:
header += f"\n[More content available — call again with offset={slice_end}.]"
return header + "\n" + text
# image or unknown
return text
@tool
async def search_project_folder_file(
project_id: str,
relative_path: str,
query: str,
context_lines: int = 3,
) -> str:
"""Search a project folder file for a query string (case-insensitive substring).
Args:
project_id: project ID.
relative_path: path relative to the linked folder root.
query: text to search for.
context_lines: number of lines of context around each match (default 3).
Returns matching line ranges with surrounding context and 1-based line numbers.
Capped at 20 matches; if more exist the header shows the total.
Works on text, code, markdown, PDF (extracted), and DOCX (extracted).
Images and binary files are not searchable.
"""
if _is_unsafe_path(relative_path):
return "Access denied"
if not query:
return "Empty query."
# For text we still need full file; pass length=very large.
result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000)
text, kind, _ = _decode(result)
if not text and kind in ("missing", "error"):
return f"File not found or unreadable: {relative_path}"
if kind == "image":
return "Cannot search inside images."
lines = text.splitlines()
q = query.lower()
matches = [i for i, line in enumerate(lines) if q in line.lower()]
if not matches:
return f"No matches for '{query}' in {relative_path}."
shown = matches[:_MAX_SEARCH_MATCHES]
snippets: list[str] = []
for i in shown:
start = max(0, i - context_lines)
end = min(len(lines), i + context_lines + 1)
block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end))
snippets.append(block)
header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']"
body = "\n---\n".join(snippets)
return header + "\n" + body
FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file]

View File

@@ -1,13 +1,14 @@
"""Note agent — Markdown note management (list, get, create, update, delete).""" """Note agent — Markdown note management (list, get, create, update, propose edit)."""
from __future__ import annotations from __future__ import annotations
import asyncio
import re import re
from typing import Any from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.llm import embed from app.core.note_summarizer import generate_note_summary
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
_UUID_RE = re.compile( _UUID_RE = re.compile(
@@ -18,25 +19,22 @@ _UUID_RE = re.compile(
def _is_uuid(value: str) -> bool: def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value)) return bool(_UUID_RE.match(value))
NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n" def _fmt_summary(row: dict) -> str:
"and delete Markdown notes in their workspace.\n\n" summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip()
"Rules:\n" if summary:
" - content is always Markdown; preserve formatting when updating\n" return f"{summary}"
" - project_id is optional; link a note to a project when mentioned\n" snippet = (row.get("content") or "")[:120].replace("\n", " ").strip()
" - When updating, call get_note first if you need to read existing content\n" return f"{snippet}" if snippet else ""
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool @tool
async def list_notes(project_id: str = "") -> str: async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id.""" """List notes with AI summaries, optionally scoped to a project by project_id.
Returns id, title, and ai_summary for each note so you can decide which
note to read in full with get_note before creating or updating.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( result = await execute_on_client(
action="select", action="select",
@@ -46,7 +44,7 @@ async def list_notes(project_id: str = "") -> str:
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
return "No notes found." return "No notes found."
lines = [f"- {r['title']} (id: {r['id']})" for r in rows] lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines) return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@@ -81,14 +79,10 @@ async def create_note(
}, },
) )
row = result["row"] row = result["row"]
# Index the note content in the vector store. note_id: str = row["id"]
vector = await embed(content) # Generate summary asynchronously — fire-and-forget.
await execute_on_client( asyncio.create_task(_refresh_summary(note_id, title, content))
action="vector_upsert", return f"Note created: '{row['title']}' (id: {note_id})."
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note created: '{row['title']}' (id: {row['id']})."
@tool @tool
@@ -97,7 +91,8 @@ async def update_note(
title: str = "", title: str = "",
content: str = "", content: str = "",
) -> str: ) -> str:
"""Update an existing note. Only pass fields that should change. """Update an existing note directly (no approval required).
Use propose_note_edit instead when human review is needed.
note_id: UUID of the note (required) note_id: UUID of the note (required)
If you need to preserve existing content, call get_note first. If you need to preserve existing content, call get_note first.
""" """
@@ -112,17 +107,63 @@ async def update_note(
data={"id": note_id, "updates": updates}, data={"id": note_id, "updates": updates},
) )
row = result["row"] row = result["row"]
# Re-index if content changed.
if content: if content:
vector = await embed(content) new_title = title or row.get("title", "")
await execute_on_client( asyncio.create_task(_refresh_summary(note_id, new_title, content))
action="vector_upsert",
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note updated: '{row['title']}' (id: {row['id']})." return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def propose_note_edit(
note_id: str,
edit_type: str,
proposed_content: str,
reasoning: str = "",
anchor_before: str = "",
anchor_text: str = "",
agent_id: str = "",
run_id: str = "",
) -> str:
"""Propose an AI edit to an existing note, pending human approval.
Use this instead of update_note when review_required is true.
The user will see the proposal highlighted before it is merged.
note_id: UUID of the target note (required)
edit_type: 'append' | 'insert' | 'replace'
- append: adds proposed_content at the end of the note
- insert: inserts proposed_content immediately after anchor_before text
- replace: replaces the first occurrence of anchor_text with proposed_content
proposed_content: the new Markdown text to add or substitute (required)
reasoning: brief explanation shown to the user (recommended)
anchor_before: for 'insert' — the text snippet that precedes the insertion point
anchor_text: for 'replace' — the exact text to be replaced
agent_id: agent identifier (for traceability)
run_id: run identifier (for traceability)
"""
if edit_type not in ("append", "insert", "replace"):
return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'."
result = await execute_on_client(
action="propose_note_edit",
data={
"noteId": note_id,
"type": edit_type,
"proposedContent": proposed_content,
"reasoning": reasoning or None,
"anchorBefore": anchor_before or None,
"anchorText": anchor_text or None,
"agentId": agent_id or None,
"runId": run_id or None,
},
)
edit_id = result.get("id", "?")
return (
f"Edit proposal created (id: {edit_id}) for note {note_id}. "
f"Status: pending user approval."
)
@tool @tool
async def delete_note(note_id: str) -> str: async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID.""" """Delete a note permanently by its UUID."""
@@ -130,10 +171,36 @@ async def delete_note(note_id: str) -> str:
return f"Note {note_id} deleted." return f"Note {note_id} deleted."
async def _refresh_summary(note_id: str, title: str, content: str) -> None:
"""Generate and persist the AI summary for a note. Fire-and-forget."""
try:
summary = await generate_note_summary(title, content)
if summary:
await execute_on_client(
action="update",
table="notes",
data={
"id": note_id,
"updates": {
"aiSummary": summary,
"aiSummaryUpdatedAt": int(__import__("time").time() * 1000),
},
},
)
except Exception:
pass # fire-and-forget; errors logged by generate_note_summary
NOTE_TOOLS: list[Any] = [ NOTE_TOOLS: list[Any] = [
list_notes, list_notes,
get_note, get_note,
create_note, create_note,
update_note, update_note,
propose_note_edit,
delete_note, delete_note,
] ]
NOTE_READ_TOOLS: list[Any] = [
list_notes,
get_note,
]

View File

@@ -8,22 +8,6 @@ from langchain_core.tools import tool
from app.core.ws_context import execute_on_client from app.core.ws_context import execute_on_client
PROJECT_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool @tool
async def list_projects( async def list_projects(
@@ -141,3 +125,9 @@ PROJECT_TOOLS: list[Any] = [
update_project, update_project,
delete_project, delete_project,
] ]
PROJECT_READ_TOOLS: list[Any] = [
list_projects,
list_all_projects,
get_project,
]

View File

@@ -0,0 +1,63 @@
"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations."""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
# Injected at tool-factory time by _brief_research_tools(); not a module-level global.
# Each tool closure captures the user_id bound at factory time.
def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any:
"""Return a query_relations tool bound to *user_id*."""
@tool
async def query_relations(
subject_label: str = "",
predicate: str = "",
object_label: str = "",
limit: int = 10,
) -> str:
"""Query the relational memory graph for entity relationships.
Returns rows where subject ↔ predicate ↔ object match the given filters.
All parameters are optional — omit to retrieve all relations up to limit.
subject_label: entity label on the left side (e.g. a client name, "Acme Corp").
predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to").
object_label: entity label on the right side (e.g. a project name, "Website Redesign").
limit: max rows to return (default 10).
"""
import logging
logger = logging.getLogger(__name__)
logger.info(
"relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r",
trace_id or "-", user_id, subject_label, predicate, object_label,
)
async with async_session() as db:
memory = MemoryMiddleware(db)
rows = await memory.query_relations(
user_id=user_id,
subject=subject_label or None,
predicate=predicate or None,
object_=object_label or None,
limit=limit,
)
if not rows:
return "No relational memory entries found for the given filters."
lines = [
f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}"
+ (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "")
for r in rows
]
return f"Found {len(rows)} relation(s):\n" + "\n".join(lines)
return query_relations

View File

@@ -18,23 +18,6 @@ _UUID_RE = re.compile(
def _is_uuid(value: str) -> bool: def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value)) return bool(_UUID_RE.match(value))
TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ──────────────────────────────────────────────────────── # ── Task tools ────────────────────────────────────────────────────────
@@ -43,32 +26,141 @@ TASK_SYSTEM_PROMPT = (
async def list_tasks( async def list_tasks(
project_id: str = "", project_id: str = "",
status: str = "", status: str = "",
priority: str = "",
assignee: str = "",
search: str = "", search: str = "",
order_by: str = "", order_by: str = "",
order_dir: str = "",
due_date_from: int = -1,
due_date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
is_ai_suggested: int = -1,
limit: int = 50,
offset: int = 0,
) -> str: ) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done), """List tasks with optional filters. Returns up to `limit` results (default 50).
a search string, or an order_by field name (dueDate|priority|createdAt)."""
project_id: UUID of the project to scope results to.
status: filter by status — todo | in_progress | done.
priority: filter by priority — high | medium | low.
assignee: substring to match against assignee names. OMIT unless the user explicitly
names a person or refers to themselves ("my tasks", "assigned to me", "mine").
Do NOT default to the current user.
search: substring search across title and description.
order_by: sort field — dueDate | priority | createdAt | completedAt.
order_dir: asc (default) | desc.
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
limit: max rows to return (default 50). Use with offset to paginate.
offset: skip first N rows (default 0).
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
Tip — prefer count_tasks for "how many" questions to avoid listing rows.
Tip — for natural-language windows ("today", "tomorrow", "this week", "last month", etc.)
take due_date_from / due_date_to verbatim from the DATE CONTEXT block in the system prompt;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( filters: dict[str, Any] = {
action="select",
table="tasks",
filters={
"projectId": normalized_project_id or None, "projectId": normalized_project_id or None,
"status": status or None, "status": status or None,
"priority": priority or None,
"search": search or None, "search": search or None,
"orderBy": order_by or None, "orderBy": order_by or None,
}, "orderDir": order_dir or None,
) "limit": limit,
"offset": offset,
}
if assignee:
filters["assignee"] = assignee
if due_date_from != -1:
filters["dueDateFrom"] = due_date_from
if due_date_to != -1:
filters["dueDateTo"] = due_date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
result = await execute_on_client(action="select", table="tasks", filters=filters)
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
return "No tasks found matching the given filters." return "No tasks found matching the given filters."
lines = [ lines = [
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})" f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, "
f"dueDate: {r.get('dueDate')}, completedAt: {r.get('completedAt')}, "
f"projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows for r in rows
] ]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines) return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
@tool
async def count_tasks(
project_id: str = "",
status: str = "",
priority: str = "",
assignee: str = "",
search: str = "",
due_date_from: int = -1,
due_date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
is_ai_suggested: int = -1,
) -> str:
"""Count tasks matching the given filters without returning rows.
Use this instead of list_tasks for "how many" questions — it is much cheaper.
Same filter parameters as list_tasks (no limit/offset/order_by needed).
assignee: OMIT unless the user explicitly names a person or refers to themselves
("my tasks"). Do NOT default to the current user.
due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit.
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
Tip — for natural-language windows take due_date_from / due_date_to from the DATE CONTEXT block;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
filters: dict[str, Any] = {
"projectId": normalized_project_id or None,
"status": status or None,
"priority": priority or None,
"search": search or None,
}
if assignee:
filters["assignee"] = assignee
if due_date_from != -1:
filters["dueDateFrom"] = due_date_from
if due_date_to != -1:
filters["dueDateTo"] = due_date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
result = await execute_on_client(action="count", table="tasks", filters=filters)
return f"Task count: {result.get('count', 0)}"
@tool @tool
async def create_task( async def create_task(
title: str, title: str,
@@ -89,6 +181,8 @@ async def create_task(
due_date: Unix timestamp in milliseconds; 0 means no due date due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested is_ai_suggested: 1 if proactively suggested, 0 if user-requested
completedAt is set automatically when status is 'done'.
""" """
result = await execute_on_client( result = await execute_on_client(
action="insert", action="insert",
@@ -107,7 +201,7 @@ async def create_task(
row = result["row"] row = result["row"]
return ( return (
f"Task created: '{row['title']}' " f"Task created: '{row['title']}' "
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})" f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']}, projectId: {row.get('projectId')})"
) )
@@ -125,6 +219,10 @@ async def update_task(
"""Update fields on an existing task. Only pass fields you want to change. """Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required) task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
completedAt is managed automatically:
- setting status to 'done' records the current timestamp
- changing status away from 'done' clears completedAt
""" """
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if title: if title:
@@ -147,7 +245,7 @@ async def update_task(
data={"id": task_id, "updates": updates}, data={"id": task_id, "updates": updates},
) )
row = result["row"] row = result["row"]
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})" return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']}, projectId: {row.get('projectId')})"
@tool @tool
@@ -158,21 +256,36 @@ async def delete_task(task_id: str) -> str:
@tool @tool
async def list_tasks_due_today() -> str: async def list_tasks_due_today(user_timezone: str = "UTC", include_done: bool = False) -> str:
"""List all tasks whose due date falls on today's date.""" """List all tasks whose due date falls on today's date.
now = datetime.now(tz=timezone.utc)
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000) user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
end_ms = start_ms + 86_400_000 - 1 # last ms of today Always pass the user's timezone so 'today' is computed in their local time.
include_done: set True to also include already-completed tasks due today (default False).
"""
try:
from zoneinfo import ZoneInfo
tz = ZoneInfo(user_timezone or "UTC")
except Exception:
tz = timezone.utc
now_local = datetime.now(tz=tz)
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
start_ms = int(start_dt.timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1
filters: dict[str, Any] = {"dueDateFrom": start_ms, "dueDateTo": end_ms}
if not include_done:
filters["status"] = "todo"
result = await execute_on_client( result = await execute_on_client(
action="select", action="select",
table="tasks", table="tasks",
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms}, filters=filters,
) )
rows = result.get("rows", []) rows = result.get("rows", [])
if not rows: if not rows:
return "No tasks are due today." return "No tasks are due today."
lines = [ lines = [
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})" f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, "
f"projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows for r in rows
] ]
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines) return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
@@ -210,7 +323,6 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
) )
row = result.get("row", {}) row = result.get("row", {})
row_author = row.get("author", author) row_author = row.get("author", author)
# Electron payloads can vary (taskId vs task_id). Fall back to input task_id.
row_task_id = row.get("taskId") or row.get("task_id") or task_id row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown") row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})." return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@@ -228,6 +340,7 @@ async def delete_task_comment(comment_id: str) -> str:
TASK_TOOLS: list[Any] = [ TASK_TOOLS: list[Any] = [
list_tasks, list_tasks,
count_tasks,
create_task, create_task,
update_task, update_task,
delete_task, delete_task,
@@ -236,3 +349,10 @@ TASK_TOOLS: list[Any] = [
add_task_comment, add_task_comment,
delete_task_comment, delete_task_comment,
] ]
TASK_READ_TOOLS: list[Any] = [
list_tasks,
count_tasks,
list_tasks_due_today,
list_task_comments,
]

View File

@@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import re import re
from datetime import datetime, timezone
from typing import Any from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
@@ -17,35 +18,130 @@ _UUID_RE = re.compile(
def _is_uuid(value: str) -> bool: def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value)) return bool(_UUID_RE.match(value))
TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n" @tool
"track progress on a project — they are not calendar events.\n\n" async def list_timelines(
"Rules:\n" project_id: str = "",
" - project_id is REQUIRED for every create; confirm with the user if unknown\n" type: str = "",
" - For listing, project_id must be a UUID; never pass plain names as project_id\n" is_completed: int = -1,
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n" is_ai_suggested: int = -1,
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n" order_by: str = "",
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n" order_dir: str = "",
" - For update_timeline, use -1 for integer fields you do not want to change\n" date_from: int = -1,
" - Listing without a project_id returns all timelines across projects\n" date_to: int = -1,
" - Always echo the title and formatted date in your confirmation." created_at_from: int = -1,
) created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
limit: int = 50,
offset: int = 0,
) -> str:
"""List timeline events (milestones, checkpoints, activities) with optional filters.
project_id: UUID to scope results to a specific project.
type: filter by event type — milestone | checkpoint | activity.
is_completed: 0 = incomplete only, 1 = completed only, -1 = any (default).
is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any.
order_by: sort field — date (default) | createdAt | completedAt.
order_dir: asc (default) | desc.
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
limit: max rows to return (default 50). Use with offset to paginate.
offset: skip first N rows (default 0).
Tip — combine *_from and *_to for a closed range; pass only one for open-ended.
Tip — prefer count_timelines for "how many" questions to avoid listing rows.
Tip — for natural-language windows ("today", "this week", "last month", etc.)
take date_from / date_to verbatim from the DATE CONTEXT block in the system prompt;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
filters: dict[str, Any] = {
"projectId": normalized_project_id or None,
"orderBy": order_by or None,
"orderDir": order_dir or None,
"limit": limit,
"offset": offset,
}
if type:
filters["type"] = type
if is_completed != -1:
filters["isCompleted"] = is_completed
if is_ai_suggested != -1:
filters["isAiSuggested"] = is_ai_suggested
if date_from != -1:
filters["dateFrom"] = date_from
if date_to != -1:
filters["dateTo"] = date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
result = await execute_on_client(action="select", table="timelines", filters=filters)
rows = result.get("rows", [])
if not rows:
return "No timeline events found."
lines = [
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
f"completed: {bool(r.get('isCompleted'))}, completedAt: {r.get('completedAt')}, "
f"projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} timeline event(s):\n" + "\n".join(lines)
@tool @tool
async def list_timelines(project_id: str = "") -> str: async def count_timelines(
"""List timelines. Provide project_id to scope to a specific project.""" project_id: str = "",
type: str = "",
is_completed: int = -1,
is_ai_suggested: int = -1,
date_from: int = -1,
date_to: int = -1,
created_at_from: int = -1,
created_at_to: int = -1,
completed_at_from: int = -1,
completed_at_to: int = -1,
) -> str:
"""Count timeline events matching the given filters without returning rows.
Use this instead of list_timelines for "how many" questions — it is much cheaper.
Same filter parameters as list_timelines (no limit/offset/order_by needed).
date_from / date_to: ms epoch range for the event date. Use -1 to omit.
completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit.
Tip — for natural-language windows take date_from / date_to from the DATE CONTEXT block;
do not compute boundaries from the current UTC instant.
"""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client( filters: dict[str, Any] = {"projectId": normalized_project_id or None}
action="select", if type:
table="timelines", filters["type"] = type
filters={"projectId": normalized_project_id or None}, if is_completed != -1:
) filters["isCompleted"] = is_completed
rows = result.get("rows", []) if is_ai_suggested != -1:
if not rows: filters["isAiSuggested"] = is_ai_suggested
return "No timelines found." if date_from != -1:
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows] filters["dateFrom"] = date_from
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines) if date_to != -1:
filters["dateTo"] = date_to
if created_at_from != -1:
filters["createdAtFrom"] = created_at_from
if created_at_to != -1:
filters["createdAtTo"] = created_at_to
if completed_at_from != -1:
filters["completedAtFrom"] = completed_at_from
if completed_at_to != -1:
filters["completedAtTo"] = completed_at_to
result = await execute_on_client(action="count", table="timelines", filters=filters)
return f"Timeline event count: {result.get('count', 0)}"
@tool @tool
@@ -53,13 +149,19 @@ async def create_timeline(
project_id: str, project_id: str,
title: str, title: str,
date: int, date: int,
type: str = "milestone",
is_completed: int = 0,
is_ai_suggested: int = 0, is_ai_suggested: int = 0,
) -> str: ) -> str:
"""Create a project timeline (milestone). """Create a project timeline event.
project_id: REQUIRED UUID of the parent project project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone title: descriptive name for the event
date: Unix timestamp in milliseconds date: Unix timestamp in milliseconds for the event date
type: milestone (default) | checkpoint | activity
is_completed: 1 if already completed, 0 if not (default 0)
is_ai_suggested: 1 if proactively suggested, 0 if user-requested is_ai_suggested: 1 if proactively suggested, 0 if user-requested
completedAt is set automatically when is_completed is 1.
""" """
result = await execute_on_client( result = await execute_on_client(
action="insert", action="insert",
@@ -68,11 +170,13 @@ async def create_timeline(
"projectId": project_id, "projectId": project_id,
"title": title, "title": title,
"date": date, "date": date,
"type": type,
"isCompleted": is_completed,
"isAiSuggested": is_ai_suggested, "isAiSuggested": is_ai_suggested,
}, },
) )
row = result["row"] row = result["row"]
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})" return f"Timeline event created: '{row['title']}' (id: {row['id']}, date: {row['date']}, type: {row.get('type')})"
@tool @tool
@@ -80,35 +184,87 @@ async def update_timeline(
timeline_id: str, timeline_id: str,
title: str = "", title: str = "",
date: int = -1, date: int = -1,
is_completed: int = -1,
) -> str: ) -> str:
"""Update a timeline. Only pass fields that should change. """Update a timeline event. Only pass fields that should change.
timeline_id: UUID of the timeline (required) timeline_id: UUID of the event (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp) date: -1 means unchanged; any other value sets the new date (ms timestamp)
is_completed: 0 = mark incomplete, 1 = mark complete, -1 = unchanged
completedAt is managed automatically:
- setting is_completed to 1 records the current timestamp
- setting is_completed to 0 clears completedAt
""" """
updates: dict[str, Any] = {} updates: dict[str, Any] = {}
if title: if title:
updates["title"] = title updates["title"] = title
if date != -1: if date != -1:
updates["date"] = date updates["date"] = date
if is_completed != -1:
updates["isCompleted"] = is_completed
result = await execute_on_client( result = await execute_on_client(
action="update", action="update",
table="timelines", table="timelines",
data={"id": timeline_id, "updates": updates}, data={"id": timeline_id, "updates": updates},
) )
row = result["row"] row = result["row"]
return f"Timeline updated: '{row['title']}' (id: {row['id']})" return f"Timeline event updated: '{row['title']}' (id: {row['id']})"
@tool @tool
async def delete_timeline(timeline_id: str) -> str: async def delete_timeline(timeline_id: str) -> str:
"""Delete a timeline permanently by its UUID.""" """Delete a timeline event permanently by its UUID."""
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id}) await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
return f"Timeline {timeline_id} deleted." return f"Timeline event {timeline_id} deleted."
@tool
async def list_timelines_today(user_timezone: str = "UTC", include_completed: bool = True) -> str:
"""List all timeline events whose date falls on today.
user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York').
Always pass the user's timezone so 'today' is computed in their local time.
include_completed: set False to exclude already-completed events (default True).
"""
try:
from zoneinfo import ZoneInfo
tz = ZoneInfo(user_timezone or "UTC")
except Exception:
tz = timezone.utc
now_local = datetime.now(tz=tz)
start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz)
start_ms = int(start_dt.timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1
filters: dict[str, Any] = {"dateFrom": start_ms, "dateTo": end_ms}
if not include_completed:
filters["isCompleted"] = 0
result = await execute_on_client(
action="select",
table="timelines",
filters=filters,
)
rows = result.get("rows", [])
if not rows:
return "No timeline events today."
lines = [
f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, "
f"completed: {bool(r.get('isCompleted'))}, projectId: {r.get('projectId')}, id: {r['id']})"
for r in rows
]
return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines)
TIMELINE_TOOLS: list[Any] = [ TIMELINE_TOOLS: list[Any] = [
list_timelines, list_timelines,
count_timelines,
list_timelines_today,
create_timeline, create_timeline,
update_timeline, update_timeline,
delete_timeline, delete_timeline,
] ]
TIMELINE_READ_TOOLS: list[Any] = [
list_timelines,
count_timelines,
list_timelines_today,
]

View File

@@ -65,16 +65,39 @@ async def get_current_user(
default_tier = "power" if settings.ENV == "dev" else "free" default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname from user row. # Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
user_result = await db.execute( user_result = await db.execute(
select(User.name, User.surname).where(User.id == user_id) select(
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
User.password_hash,
).where(User.id == user_id)
) )
user_row = user_result.one_or_none() user_row = user_result.one_or_none()
# Convert onboarding_completed_at to epoch ms (int) or None.
onboarding_ms: int | None = None
if user_row and user_row.onboarding_completed_at is not None:
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
# Load decrypted core memory.
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
memory_dict: dict[str, str] = {}
try:
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(user_id)
memory_dict = {b["label"]: b["value"] for b in blocks}
except Exception:
pass # Non-critical — return empty memory on failure
return UserProfile( return UserProfile(
id=user_id, id=user_id,
email=email, email=email,
name=user_row.name if user_row else None, name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None, surname=user_row.surname if user_row else None,
avatar_url=user_row.avatar_url if user_row else None,
has_password=bool(user_row.password_hash) if user_row else False,
tier=tier, tier=tier,
onboarding_completed_at=onboarding_ms,
memory=memory_dict,
) # type: ignore[arg-type] ) # type: ignore[arg-type]

View File

@@ -8,8 +8,7 @@ that could reveal server-side prompt IP:
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …) - Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints - Exact-match known prompt fingerprints
Binary responses (storage blobs, backup data) are never touched — the The middleware only activates for paths under /api/v1/chat.
middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified. names of the fields that were modified.

View File

@@ -31,10 +31,9 @@ from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.filesystem_agent import FILESYSTEM_TOOLS from app.agents.filesystem_agent import make_directory_tools
from app.config.settings import settings from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback from app.core.llm import get_agent_llm, model_for_agent
from app.core.llm import get_llm
from app.schemas import AgentConfig from app.schemas import AgentConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -257,15 +256,17 @@ async def _call_llm_with_tools(
else: else:
messages.append(AIMessage(content=turn["content"])) messages.append(AIMessage(content=turn["content"]))
llm = get_llm(model=None, temperature=0.4) llm = get_agent_llm("setup", temperature=0.4)
llm_with_tools = llm.bind_tools(tools) llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools} tool_map = {tool_def.name: tool_def for tool_def in tools}
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
_lf_ctx.__enter__()
_span_ctx = ( _span_ctx = (
lf.start_as_current_observation( lf.start_as_current_observation(
as_type="span", as_type="span",
name="journey-setup", name="journey-setup",
metadata={"user_id": user_id or None, "session_id": session_id or None},
input=history[-1]["content"] if history else "", input=history[-1]["content"] if history else "",
) )
if lf else None if lf else None
@@ -273,12 +274,12 @@ async def _call_llm_with_tools(
_span = _span_ctx.__enter__() if _span_ctx else None _span = _span_ctx.__enter__() if _span_ctx else None
try: try:
for _ in range(_MAX_TOOL_STEPS): for step in range(_MAX_TOOL_STEPS):
_gen_ctx = ( _gen_ctx = (
lf.start_as_current_observation( lf.start_as_current_observation(
as_type="generation", as_type="generation",
name="journey-setup-llm", name="journey-setup-llm",
model=settings.LLM_MODEL, model=model_for_agent("setup"),
prompt=langfuse_prompt, prompt=langfuse_prompt,
input=messages, input=messages,
) )
@@ -287,15 +288,27 @@ async def _call_llm_with_tools(
_gen = _gen_ctx.__enter__() if _gen_ctx else None _gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages) response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx: if _gen_ctx:
_gen.update(output=_as_text(response.content), usage=extract_usage(response)) _gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
_gen_ctx.__exit__(None, None, None) _gen_ctx.__exit__(None, None, None)
resp_text = _as_text(response.content)
# Guard against empty responses (e.g. model returned finish_reason
# 'error' which LiteLLM maps to 'stop' with empty content).
if not response.tool_calls and not resp_text.strip():
logger.warning(
"agent_setup: journey LLM returned empty response at step %d — retrying",
step,
)
# Drop the empty AIMessage so we don't pollute history, and retry.
continue
messages.append(response) messages.append(response)
if not response.tool_calls: if not response.tool_calls:
if _span: if _span:
_span.update(output=_as_text(response.content)) _span.update(output=resp_text)
return _as_text(response.content) return resp_text
for call in response.tool_calls: for call in response.tool_calls:
call_name = str(call.get("name", "")) call_name = str(call.get("name", ""))
@@ -324,10 +337,14 @@ async def _call_llm_with_tools(
final_text = _as_text(final.content) final_text = _as_text(final.content)
if _span: if _span:
_span.update(output=final_text) _span.update(output=final_text)
return final_text return final_text or (
"Sorry, I had trouble processing the files. "
"Could you try again? If the issue persists, the files might be too large for me to analyse."
)
finally: finally:
if _span_ctx: if _span_ctx:
_span_ctx.__exit__(None, None, None) _span_ctx.__exit__(None, None, None)
_lf_ctx.__exit__(None, None, None)
if lf: if lf:
lf.flush() lf.flush()
@@ -372,7 +389,7 @@ async def handle_journey_start(
ai_reply = await _call_llm_with_tools( ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt, system_prompt=system_prompt,
history=seed_history, history=seed_history,
tools=list(FILESYSTEM_TOOLS), tools=make_directory_tools(directory),
user_id=user_id, user_id=user_id,
session_id=session_id, session_id=session_id,
langfuse_prompt=langfuse_prompt, langfuse_prompt=langfuse_prompt,
@@ -436,10 +453,11 @@ async def handle_journey_message(
session.history.append({"role": "user", "content": message}) session.history.append({"role": "user", "content": message})
# Call the LLM with tools. # Call the LLM with tools.
session_tools = make_directory_tools(session.directory)
ai_reply = await _call_llm_with_tools( ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt, system_prompt=session.system_prompt,
history=session.history, history=session.history,
tools=list(FILESYSTEM_TOOLS), tools=session_tools,
user_id=session.user_id, user_id=session.user_id,
session_id=session_id, session_id=session_id,
langfuse_prompt=session.langfuse_prompt, langfuse_prompt=session.langfuse_prompt,
@@ -464,7 +482,7 @@ async def handle_journey_message(
nudge_reply = await _call_llm_with_tools( nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt, system_prompt=session.system_prompt,
history=session.history, history=session.history,
tools=list(FILESYSTEM_TOOLS), tools=session_tools,
user_id=session.user_id, user_id=session.user_id,
session_id=session_id, session_id=session_id,
langfuse_prompt=session.langfuse_prompt, langfuse_prompt=session.langfuse_prompt,

View File

@@ -12,17 +12,21 @@ in backend agent-config tables.
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import logging
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES from app.billing.tier_manager import FEATURES
from app.core.agent_runner import is_agent_running, run_local_agent from app.core.agent_runner import is_agent_running, run_local_agent
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.note_summarizer import generate_note_summary
from app.db import get_session from app.db import get_session
from app.models import AgentRunLog, LocalAgentConfig from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import ( from app.schemas import (
@@ -34,6 +38,8 @@ from app.schemas import (
UserProfile, UserProfile,
) )
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/agents", tags=["agents"]) router = APIRouter(prefix="/agents", tags=["agents"])
@@ -177,6 +183,11 @@ async def trigger_agent_run(
_enforce_agent_limit(current_user.tier, body.active_agents) _enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db) await _enforce_run_frequency(current_user.tier, current_user.id, db)
last_run_dt = (
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
if body.last_run_at
else None
)
config = LocalAgentConfig( config = LocalAgentConfig(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
user_id=current_user.id, user_id=current_user.id,
@@ -184,10 +195,12 @@ async def trigger_agent_run(
name="Local Directory Monitor", name="Local Directory Monitor",
directory_paths=[body.directory], directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract), data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt, prompt_template=body.custom_agent_prompt or "",
agent_config=body.agent_config,
file_extensions=[], file_extensions=[],
schedule_cron=body.batch_interval, schedule_cron=body.batch_interval,
enabled=True, enabled=True,
last_run_at=last_run_dt,
) )
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id. # Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
@@ -220,3 +233,25 @@ async def trigger_agent_run(
) )
return _to_run_log_response(run_log) return _to_run_log_response(run_log)
# ── Note summary endpoint ──────────────────────────────────────────────────────
class NoteSummarizeRequest(BaseModel):
title: str
content: str
class NoteSummarizeResponse(BaseModel):
summary: str
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
async def summarize_note(
body: NoteSummarizeRequest,
current_user: UserProfile = Depends(get_current_user),
) -> NoteSummarizeResponse:
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
summary = await generate_note_summary(body.title, body.content)
return NoteSummarizeResponse(summary=summary)

View File

@@ -1,34 +1,68 @@
"""Auth routes: register, login, refresh, me. """Auth routes: register, login, refresh, me, OAuth social login, onboarding.
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
tables). Passwords are hashed with bcrypt; refresh tokens are stored as tables). Passwords are hashed with bcrypt; refresh tokens are stored as
SHA-256 hashes so plaintext never reaches the DB. SHA-256 hashes so plaintext never reaches the DB.
OAuth (Google):
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
""" """
from __future__ import annotations from __future__ import annotations
import hashlib import hashlib
import json
import time import time
import urllib.parse
import uuid import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Literal
import bcrypt import bcrypt
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, HTTPException, status from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import RedirectResponse
from jose import jwt from jose import jwt
from pydantic import BaseModel from pydantic import BaseModel, Field
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
from app.config.settings import settings from app.config.settings import settings
from app.core.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session from app.db import get_session
from app.models import RefreshToken, User from app.models import OAuthAccount, RefreshToken, User
from app.schemas import AuthTokens, UserProfile from app.schemas import AuthTokens, UserProfile
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
# ── OAuth provider registry ───────────────────────────────────────────
def _get_google_provider() -> GoogleOAuthProvider:
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
raise HTTPException(
status.HTTP_503_SERVICE_UNAVAILABLE,
"Google login is not configured on this server",
)
return GoogleOAuthProvider(
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
redirect_uri=settings.OAUTH_REDIRECT_URI,
)
_PROVIDERS = {"google": _get_google_provider}
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
# Production note: replace with Redis for multi-process deployments.
_pending_states: dict[str, tuple[str, float]] = {}
_STATE_TTL_SECONDS = 600 # 10 minutes
# ── Internal helpers ───────────────────────────────────────────────── # ── Internal helpers ─────────────────────────────────────────────────
@@ -231,5 +265,531 @@ async def update_profile(
email=user.email, email=user.email,
name=user.name, name=user.name,
surname=user.surname, surname=user.surname,
avatar_url=user.avatar_url,
tier=current_user.tier, tier=current_user.tier,
) )
# ── OAuth helpers ─────────────────────────────────────────────────────
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
"""Create a refresh token row and return (plain_token, AuthTokens)."""
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return plain_token, AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
# ── OAuth request/response schemas ───────────────────────────────────
class _OAuthAuthorizeResponse(BaseModel):
url: str
state: str
class _OAuthCallbackRequest(BaseModel):
code: str
state: str
# ── OAuth routes ──────────────────────────────────────────────────────
@router.get(
"/oauth/{provider}/web-callback",
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
include_in_schema=False,
)
async def oauth_web_callback(
provider: Literal["google"],
code: str,
state: str,
) -> RedirectResponse:
"""Google redirects here after user consent.
This endpoint immediately redirects to the Electron deep-link URI so the
desktop app receives the authorization code. It is intentionally simple —
no state validation here (the Electron app + backend callback do that).
Registered in Google Cloud Console as:
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
"""
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
deep_link = f"adiuvai://oauth/callback?{params}"
return RedirectResponse(url=deep_link, status_code=302)
@router.get(
"/oauth/{provider}/authorize",
response_model=_OAuthAuthorizeResponse,
summary="Start OAuth flow — returns the provider consent-screen URL",
)
async def oauth_authorize(
provider: Literal["google"],
) -> _OAuthAuthorizeResponse:
"""Generate a PKCE state + code_challenge and return the authorization URL.
The client opens this URL in the system browser. After the user grants
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
with ``code`` and ``state`` query params. The client then calls
``POST /auth/oauth/{provider}/callback`` with those values.
"""
provider_factory = _PROVIDERS.get(provider)
if provider_factory is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
oauth_provider = provider_factory()
state = str(uuid.uuid4())
code_verifier, code_challenge = generate_pkce_pair()
# Purge expired states to prevent unbounded growth.
now = time.time()
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
for s in expired:
del _pending_states[s]
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
return _OAuthAuthorizeResponse(url=url, state=state)
@router.post(
"/oauth/{provider}/callback",
response_model=AuthTokens,
summary="Complete OAuth flow — exchange code and issue JWT tokens",
)
async def oauth_callback(
provider: Literal["google"],
body: _OAuthCallbackRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate state, exchange the authorization code, and sign in (or register) the user.
Resolution order:
1. ``oauth_accounts`` row match → existing user, log in.
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
3. No match → create new user (password_hash=None, avatar from provider).
"""
provider_factory = _PROVIDERS.get(provider)
if provider_factory is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
# Validate state (CSRF protection).
now = time.time()
entry = _pending_states.pop(body.state, None)
if entry is None or entry[1] < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
code_verifier, _ = entry
oauth_provider = provider_factory()
# Exchange code for tokens.
try:
token_data = await oauth_provider.exchange_code(
code=body.code,
code_verifier=code_verifier,
redirect_uri=settings.OAUTH_REDIRECT_URI,
)
except Exception:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
)
access_token_google = token_data.get("access_token")
if not access_token_google:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
# Fetch user identity.
try:
userinfo = await oauth_provider.get_userinfo(access_token_google)
except Exception:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
# ── Resolution order ──────────────────────────────────────────────
# 1. Existing OAuth link?
oauth_result = await db.execute(
select(OAuthAccount).where(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == userinfo.provider_user_id,
)
)
oauth_account = oauth_result.scalar_one_or_none()
if oauth_account is not None:
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
user = user_result.scalar_one()
# Backfill avatar if the user doesn't have one yet.
if user.avatar_url is None and userinfo.avatar_url:
user.avatar_url = userinfo.avatar_url
await db.commit()
plain_token, tokens = await _issue_refresh_token(user, db)
await db.commit()
return tokens
# 2. Email match with a verified Google email → link accounts.
if userinfo.email_verified:
email_result = await db.execute(select(User).where(User.email == userinfo.email))
existing_user = email_result.scalar_one_or_none()
if existing_user is not None:
new_link = OAuthAccount(
user_id=existing_user.id,
provider=provider,
provider_user_id=userinfo.provider_user_id,
provider_email=userinfo.email,
)
db.add(new_link)
if existing_user.avatar_url is None and userinfo.avatar_url:
existing_user.avatar_url = userinfo.avatar_url
plain_token, tokens = await _issue_refresh_token(existing_user, db)
await db.commit()
return tokens
# Guard: if the email is already taken but we couldn't auto-link (e.g.
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
if not userinfo.email_verified:
conflict = await db.execute(select(User).where(User.email == userinfo.email))
if conflict.scalar_one_or_none() is not None:
raise HTTPException(
status.HTTP_409_CONFLICT,
"An account with this email already exists. "
"Please sign in with your password.",
)
# 3. New user — social-only account (no password).
new_user = User(
id=str(uuid.uuid4()),
email=userinfo.email,
name=userinfo.name,
password_hash=None,
avatar_url=userinfo.avatar_url,
tier="free",
encryption_key=Fernet.generate_key().decode(),
)
db.add(new_user)
await db.flush() # populate new_user.id
new_oauth = OAuthAccount(
user_id=new_user.id,
provider=provider,
provider_user_id=userinfo.provider_user_id,
provider_email=userinfo.email,
)
db.add(new_oauth)
plain_token, tokens = await _issue_refresh_token(new_user, db)
await db.commit()
return tokens
# ── Onboarding helpers ────────────────────────────────────────────────
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
# We can't call the FastAPI dependency directly, but we can replicate
# the core logic inline. Instead, we just re-query the same way.
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
user_result = await db.execute(
select(
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
User.password_hash,
).where(User.id == user_id)
)
user_row = user_result.one_or_none()
onboarding_ms: int | None = None
if user_row and user_row.onboarding_completed_at is not None:
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
memory_dict: dict[str, str] = {}
try:
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(user_id)
memory_dict = {b["label"]: b["value"] for b in blocks}
except Exception:
pass
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
avatar_url=user_row.avatar_url if user_row else None,
has_password=bool(user_row.password_hash) if user_row else False,
tier=tier,
onboarding_completed_at=onboarding_ms,
memory=memory_dict,
)
# ── Onboarding routes ────────────────────────────────────────────────
class _UpdateMemoryRequest(BaseModel):
memory: dict[str, str] = Field(default_factory=dict)
mark_onboarded: bool = False
@router.put("/me/memory", response_model=UserProfile)
async def update_memory(
body: _UpdateMemoryRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update core memory key/value pairs and optionally mark onboarding complete."""
mw = MemoryMiddleware(db)
for key, value in body.memory.items():
await mw.update_core(current_user.id, key, value)
if body.mark_onboarded:
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
user.onboarding_completed_at = datetime.now(timezone.utc)
await db.commit()
return await _build_profile(current_user.id, current_user.email, db)
@router.post("/me/onboarding/reset")
async def reset_onboarding(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
):
"""Reset onboarding so the wizard runs again on next login."""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
user.onboarding_completed_at = None
await db.commit()
return {"status": "reset"}
class _NormalizeRequest(BaseModel):
inputs: dict[str, str]
class _NormalizeResponse(BaseModel):
normalized: dict[str, str]
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
async def normalize_onboarding(
body: _NormalizeRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _NormalizeResponse:
"""One-shot LLM normalization for free-text onboarding answers."""
if not body.inputs:
return _NormalizeResponse(normalized={})
try:
llm = get_llm(model="gpt-4o-mini", temperature=0)
prompt = (
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
"Return a JSON object with the same keys and normalized values.\n"
"Examples: 'i build websites''Web Developer', 'tech-ish stuff''Technology'\n"
f"Input: {json.dumps(body.inputs)}"
)
response = await llm.ainvoke(
[
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
{"role": "user", "content": prompt},
],
)
normalized = json.loads(response.content)
return _NormalizeResponse(normalized=normalized)
except Exception:
# LLM failure must never block onboarding — return inputs unchanged
return _NormalizeResponse(normalized=body.inputs)
# ── Password management ───────────────────────────────────────────────
class _ChangePasswordRequest(BaseModel):
current_password: str = Field(min_length=1)
new_password: str = Field(min_length=8)
@router.put("/me/password", status_code=status.HTTP_200_OK)
async def change_password(
body: _ChangePasswordRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Change the authenticated user's password.
Requires the current password for verification.
Returns 400 for social-only users (no password set).
"""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
if user.password_hash is None:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"This account uses social login and has no password to change",
)
if not _verify_password(body.current_password, user.password_hash):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
user.password_hash = _hash_password(body.new_password)
await db.commit()
return {"ok": True}
# ── OAuth account management ─────────────────────────────────────────
@router.get("/me/oauth-accounts", response_model=list[dict])
async def list_oauth_accounts(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[dict]:
"""List all OAuth providers linked to the authenticated user."""
result = await db.execute(
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
)
accounts = result.scalars().all()
return [
{
"provider": a.provider,
"provider_email": a.provider_email,
"created_at": int(a.created_at.timestamp() * 1000),
}
for a in accounts
]
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
async def unlink_oauth_account(
provider: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unlink an OAuth provider from the authenticated user.
Refuses if the user has no password and this is their only login method.
"""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
oauth_result = await db.execute(
select(OAuthAccount).where(
OAuthAccount.user_id == current_user.id,
OAuthAccount.provider == provider,
)
)
account = oauth_result.scalar_one_or_none()
if account is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
# Safety: don't let users lock themselves out.
all_oauth = await db.execute(
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
)
oauth_count = len(all_oauth.scalars().all())
if user.password_hash is None and oauth_count <= 1:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Cannot unlink the only login method. Set a password first.",
)
await db.delete(account)
await db.commit()
return {"ok": True}
# ── Avatar update ─────────────────────────────────────────────────────
class _UpdateAvatarRequest(BaseModel):
avatar_url: str = Field(min_length=1)
@router.put("/me/avatar", response_model=UserProfile)
async def update_avatar(
body: _UpdateAvatarRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update the authenticated user's avatar URL.
Accepts {"avatar_url": "https://..."} — the client uploads the image
to its own storage and passes the resulting URL here.
"""
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
user.avatar_url = body.avatar_url
await db.commit()
return await _build_profile(current_user.id, current_user.email, db)
# ── Account deletion ─────────────────────────────────────────────────
@router.delete("/me", status_code=status.HTTP_200_OK)
async def delete_account(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Permanently delete the authenticated user's account.
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
is cancelled if active.
"""
# Cancel Stripe subscription if present.
try:
from app.billing.stripe_service import stripe_service # noqa: PLC0415
await stripe_service.cancel_subscription(current_user.id, db)
except HTTPException:
pass # No subscription — that's fine
# Delete all memory rows (core, associative, episodic, proactive).
try:
from app.models import ( # noqa: PLC0415
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
)
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
await db.execute(
model.__table__.delete().where(model.user_id == current_user.id)
)
except Exception:
pass # Non-critical — cascade on User will handle most
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
await db.delete(user)
await db.commit()
return {"ok": True}

View File

@@ -1,171 +0,0 @@
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
PostgreSQL ``backup_metadata`` table.
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
treating "history" as a ``{backup_id}`` path parameter.
"""
from __future__ import annotations
import uuid
from email.utils import parsedate_to_datetime
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import BackupMetadata as BackupMetadataModel
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/backup", tags=["backup"])
_blob_store = BlobStore()
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total backup bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
BackupMetadataModel.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_backup_quota(
user: UserProfile, size_bytes: int, db: AsyncSession
) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
current = await _current_backup_bytes(user.id, db)
tier_manager.enforce_backup_quota(
user.tier, current_bytes=current, additional_bytes=size_bytes
)
@router.put("")
async def upload_backup(
request: Request,
x_backup_version: int = Header(..., alias="X-Backup-Version"),
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Upload an E2E-encrypted backup blob.
Metadata is passed via custom headers; the raw body is the encrypted blob.
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
await _check_backup_quota(current_user, len(blob), db)
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
)
row = BackupMetadataModel(
id=str(uuid.uuid4()),
user_id=current_user.id,
s3_key=s3_key,
version=x_backup_version,
timestamp=x_backup_timestamp,
checksum=x_backup_checksum,
size_bytes=len(blob),
)
db.add(row)
await db.commit()
return {"ok": True}
@router.get("/history", response_model=list[BackupMetadata])
async def backup_history(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[BackupMetadata]:
"""Return backup metadata records for the authenticated user (no blob bytes)."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
)
rows = result.scalars().all()
return [
BackupMetadata(
version=r.version,
timestamp=r.timestamp,
checksum=r.checksum,
chunk_count=1,
)
for r in rows
]
@router.get("")
async def download_backup(
request: Request,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
.limit(1)
)
latest = result.scalar_one_or_none()
if latest is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
ims_header = request.headers.get("If-Modified-Since")
if ims_header:
try:
ims_dt = parsedate_to_datetime(ims_header)
ims_ms = int(ims_dt.timestamp() * 1000)
if latest.timestamp <= ims_ms:
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
except Exception:
pass # malformed header — ignore and serve the blob
blob = await _blob_store.download(current_user.id, latest.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={
"X-Backup-Version": str(latest.version),
"X-Backup-Timestamp": str(latest.timestamp),
"X-Checksum": latest.checksum,
},
)
@router.delete("/{backup_id}", response_model=dict)
async def delete_backup(
backup_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a specific backup by ID."""
result = await db.execute(
select(BackupMetadataModel).where(
BackupMetadataModel.id == backup_id,
BackupMetadataModel.user_id == current_user.id,
)
)
target = result.scalar_one_or_none()
if target is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
await _blob_store.delete(current_user.id, target.s3_key)
await db.delete(target)
await db.commit()
return {"ok": True}

View File

@@ -9,7 +9,7 @@ from __future__ import annotations
from typing import Any from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status from fastapi import APIRouter, Depends, Header, HTTPException, Request, status
from pydantic import BaseModel from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
@@ -83,3 +83,50 @@ async def cancel_subscription(
"""Cancel the active subscription.""" """Cancel the active subscription."""
await stripe_service.cancel_subscription(current_user.id, db) await stripe_service.cancel_subscription(current_user.id, db)
return {"ok": True} return {"ok": True}
@router.get("/invoices", response_model=list[dict])
async def list_invoices(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[dict[str, Any]]:
"""Return billing history (invoices) from Stripe.
Returns an empty list when Stripe is not configured.
"""
invoices = await stripe_service.list_invoices(current_user.id, db)
return invoices
# ── Quota check ────────────────────────────────────────────────────────
from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402
class QuotaCheckRequest(BaseModel):
feature: str
estimated_files: int
@router.post("/quota/check")
async def quota_check(
payload: QuotaCheckRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict:
"""Pre-flight folder quota check. 402 if tier limits would be exceeded."""
if payload.feature != "folder_index":
raise HTTPException(status_code=400, detail="Unknown feature")
try:
await check_folder_quota(
user_id=current_user.id,
tier=current_user.tier,
estimated_files=payload.estimated_files,
db=db,
)
except QuotaExceeded as exc:
raise HTTPException(
status_code=402,
detail={"reason": exc.reason, "message": str(exc)},
)
return {"ok": True}

View File

@@ -1,20 +1,42 @@
"""Chat routes: POST /chat (REST fallback). """Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device). WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
""" """
from __future__ import annotations from __future__ import annotations
from fastapi import APIRouter, Depends import uuid
from typing import Literal
from fastapi import APIRouter, Depends, HTTPException
from fastapi.responses import JSONResponse from fastapi.responses import JSONResponse
from pydantic import BaseModel
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_home from app.core.deep_agent import run_home
from app.core.llm import embed
from app.core.memory_middleware import MemoryMiddleware
from app.db import async_session
from app.schemas import ChatRequest, UserProfile from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"]) router = APIRouter(prefix="/chat", tags=["chat"])
# ── Embed helpers ─────────────────────────────────────────────────────────
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("") @router.post("")
async def chat( async def chat(
body: ChatRequest, body: ChatRequest,
@@ -27,3 +49,68 @@ async def chat(
context=body.context.model_dump(), context=body.context.model_dump(),
) )
return JSONResponse(content={"response": response}) return JSONResponse(content={"response": response})
class _BriefRequest(BaseModel):
mode: Literal["home", "project"]
project_id: str | None = None
class _BriefResponse(BaseModel):
response: str
@router.post("/brief", response_model=_BriefResponse)
async def brief(
body: _BriefRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _BriefResponse:
"""REST fallback for brief when the device WebSocket is not ready."""
if body.mode == "project":
if not body.project_id:
raise HTTPException(status_code=422, detail="project_id required for project mode")
try:
uuid.UUID(body.project_id)
except ValueError:
raise HTTPException(status_code=422, detail="project_id must be a valid UUID")
request_id = str(uuid.uuid4())
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
current_user.id,
"",
trace_id=request_id,
session_id=request_id,
)
context: dict = {
"_debug": {"request_id": request_id, "user_id": current_user.id},
**memory_context,
}
chunks: list[str] = []
if body.mode == "project":
stream = run_project_brief(current_user.id, body.project_id, context) # type: ignore[arg-type]
else:
stream = run_home_brief(current_user.id, context)
async for event_type, data in stream:
if event_type == "token" and data:
chunks.append(str(data))
return _BriefResponse(response="".join(chunks))
@router.post("/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by Electron (vectordb.ts) for local note search.
"""
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

View File

@@ -42,19 +42,25 @@ from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs from app.core.agent_runner import trigger_pending_runs
from app.core.deep_agent import run_floating_stream, run_home_stream from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream
from app.core.output_formatter import extract_canvas_block
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter from app.core.output_formatter import StreamFormatter
from app.core.ws_context import clear_client_executor, set_client_executor from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session from app.db import async_session
from app.models import AgentRunLog from app.models import AgentRunLog
from app.schemas import WsFrameType from app.schemas import WsFrameType, WsStreamEnd
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"]) router = APIRouter(prefix="/ws", tags=["device-ws"])
# ── v7 folder index session state ─────────────────────────────────────
# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled }
_index_sessions: dict[str, dict] = {}
_HEARTBEAT_INTERVAL = 30 # seconds _HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping _PONG_TIMEOUT = 10 # seconds — grace window after a ping
@@ -158,6 +164,16 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_floating_request(websocket, user_id, frame) _handle_floating_request(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.brief_request:
asyncio.create_task(
_handle_brief_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.task_brief_request:
asyncio.create_task(
_handle_task_brief_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_start: elif frame_type == WsFrameType.journey_start:
asyncio.create_task( asyncio.create_task(
_handle_journey_start(websocket, user_id, frame) _handle_journey_start(websocket, user_id, frame)
@@ -168,6 +184,19 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_journey_message(websocket, user_id, frame) _handle_journey_message(websocket, user_id, frame)
) )
elif frame_type == WsFrameType.index_session_start:
asyncio.create_task(
_handle_index_session_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_file_batch:
asyncio.create_task(
_handle_index_file_batch(websocket, user_id, frame)
)
elif frame_type == WsFrameType.index_session_cancel:
await _handle_index_session_cancel(websocket, frame)
elif frame_type == "pong": elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive. # Heartbeat ack — nothing to do, connection is alive.
pass pass
@@ -199,11 +228,13 @@ async def _handle_home_request(
request_id = frame.get("request_id") or str(uuid4()) request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "") message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4()) session_id: str = frame.get("session_id") or str(uuid4())
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info( logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s", "device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s",
user_id, user_id,
request_id, request_id,
session_id, session_id,
project_id,
message[:200], message[:200],
) )
@@ -220,6 +251,7 @@ async def _handle_home_request(
context: dict = { context: dict = {
"conversation_history": frame.get("conversation_history", []), "conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context, **memory_context,
} }
@@ -227,7 +259,7 @@ async def _handle_home_request(
set_client_executor(executor) set_client_executor(executor)
response_chunks: list[str] = [] response_chunks: list[str] = []
try: try:
event_stream = run_home_stream(user_id, message, context) event_stream = run_home_stream(user_id, message, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id) formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream): async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json()) await websocket.send_text(ws_frame.model_dump_json())
@@ -287,8 +319,10 @@ async def _handle_floating_request(
) )
context: dict = { context: dict = {
"conversation_history": frame.get("conversation_history", []),
"scope": scope, "scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context, **memory_context,
} }
@@ -325,6 +359,179 @@ async def _handle_floating_request(
) )
async def _handle_brief_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a brief_request frame — streams plain-text brief back on the socket.
No episode storage — briefs are not conversations.
"""
import uuid as _uuid
request_id = frame.get("request_id") or str(uuid4())
session_id = frame.get("session_id") or str(uuid4())
mode: str = frame.get("mode", "home")
project_id: str | None = frame.get("project_id")
logger.info(
"device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s",
user_id, request_id, mode, project_id,
)
# Validate project_id for project mode before touching LLM.
if mode == "project":
try:
if not project_id:
raise ValueError("project_id required for project mode")
_uuid.UUID(project_id)
except (ValueError, AttributeError) as exc:
logger.warning(
"device_ws: brief_request invalid project_id user=%s req=%s: %s",
user_id, request_id, exc,
)
await websocket.send_text(
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
)
return
# Enrich context with memory (no user message — use empty string as probe).
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
"",
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
if mode == "project":
event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type]
else:
event_stream = run_home_brief(user_id, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
except Exception as exc:
logger.error(
"device_ws: brief_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
await websocket.send_text(
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
)
finally:
clear_client_executor()
logger.info(
"device_ws: brief_request_end user=%s req=%s mode=%s",
user_id, request_id, mode,
)
# ── v6 Task Brief Handler ────────────────────────────────────────────
async def _handle_task_brief_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a task_brief_request frame — Stage-1 executive assistant deep research.
Streams the briefing markdown back to the client.
On stream_end, emits a ``canvas_draft`` mutation if the agent produced one.
"""
request_id = frame.get("request_id") or str(uuid4())
session_id = frame.get("session_id") or str(uuid4())
task_id: str = frame.get("task_id") or frame.get("taskId") or ""
project_id: str | None = frame.get("project_id") or frame.get("projectId") or None
logger.info(
"device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]",
user_id, request_id, task_id, project_id,
)
if not task_id:
await websocket.send_text(
WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json()
)
return
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
f"task brief: {task_id}",
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
await websocket.send_text(ws_frame.model_dump_json())
elif ws_frame.type == "stream_start":
await websocket.send_text(ws_frame.model_dump_json())
# stream_end is emitted below with mutations — skip formatter's version
except Exception as exc:
logger.error(
"device_ws: task_brief_request failed user=%s req=%s task=%s: %s",
user_id, request_id, task_id, exc,
)
await websocket.send_text(
WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json()
)
return
finally:
clear_client_executor()
# Extract canvas block then emit stream_end with optional mutations.
full_response = "".join(response_chunks)
_visible, canvas_content, canvas_kind = extract_canvas_block(full_response)
mutations: list[dict] = []
if canvas_content:
mutations.append({
"type": "canvas_draft",
"content": canvas_content,
"kind": canvas_kind,
})
await websocket.send_text(
WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json()
)
logger.info(
"device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s",
user_id, request_id, task_id, len(full_response), canvas_kind or "none",
)
# ── v4 Journey Handlers ───────────────────────────────────────────── # ── v4 Journey Handlers ─────────────────────────────────────────────
@@ -382,6 +589,174 @@ async def _handle_journey_message(
clear_client_executor() clear_client_executor()
# ── v7 Folder Index Handlers ──────────────────────────────────────────
async def _handle_index_session_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Register a new folder index session. No response sent — client is declaring intent."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
project_id: str | None = frame.get("projectId") or frame.get("project_id")
total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0)
if not session_id:
logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id)
return
_index_sessions[session_id] = {
"user_id": user_id,
"project_id": project_id,
"processed": 0,
"total": total,
"cancelled": False,
}
logger.info(
"device_ws: index_session_start user=%s session=%s project=%s total=%d",
user_id, session_id, project_id, total,
)
async def _handle_index_session_cancel(
websocket: WebSocket,
frame: dict,
) -> None:
"""Mark a session as cancelled and emit index_session_done(cancelled)."""
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
session = _index_sessions.get(session_id)
if session:
session["cancelled"] = True
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "cancelled",
}))
_index_sessions.pop(session_id, None)
logger.info("device_ws: index_session_cancel session=%s", session_id)
async def _handle_index_file_batch(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Process a batch of files for an index session, streaming results back."""
# Lazy imports to avoid heavy load at module startup.
from app.core.folder_indexer import ( # noqa: PLC0415
summarize_image,
summarize_pdf,
summarize_docx,
summarize_text,
)
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.billing.quota import add_token_usage # noqa: PLC0415
session_id: str = frame.get("sessionId") or frame.get("session_id") or ""
files: list[dict] = frame.get("files", [])
session = _index_sessions.get(session_id)
if not session or session.get("cancelled"):
return
async with async_session() as db:
tier = await tier_manager.get_tier(user_id, db)
raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens")
cap: int | None = None if raw_cap == -1 else raw_cap
for file_info in files:
if session.get("cancelled"):
return
# Electron's toSnakeCase converts payload keys, so accept both forms.
rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or ""
kind: str = file_info.get("kind") or "text"
content: str = file_info.get("content") or ""
ext: str = file_info.get("ext") or ""
mime: str = file_info.get("mime") or "application/octet-stream"
name: str = rel_path.split("/")[-1] or rel_path
try:
if kind == "image":
res = await summarize_image(image_b64=content, mime=mime)
elif kind == "pdf":
res = await summarize_pdf(pdf_b64=content, name=name)
elif kind == "docx":
res = await summarize_docx(docx_b64=content, name=name)
else:
res = await summarize_text(content=content, ext=ext, name=name)
except Exception as exc:
logger.warning(
"device_ws: index_file_batch summarize failed session=%s path=%s: %s",
session_id, rel_path, exc,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": None,
"tokensUsed": 0,
"error": str(exc),
}))
session["processed"] += 1
continue
# Account for token usage and check cap.
usage = await add_token_usage(
user_id=user_id,
feature="folder_index",
tokens=res.tokens_used,
db=db,
cap=cap,
)
await websocket.send_text(json.dumps({
"type": WsFrameType.index_file_result,
"sessionId": session_id,
"relPath": rel_path,
"summary": res.summary,
"tokensUsed": res.tokens_used,
}))
session["processed"] += 1
if usage.exhausted:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "quota_exceeded",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session quota_exceeded user=%s session=%s",
user_id, session_id,
)
return
# After processing the batch, emit progress.
processed = session["processed"]
total = session["total"]
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_progress,
"sessionId": session_id,
"processed": processed,
"total": total,
}))
if processed >= total:
await websocket.send_text(json.dumps({
"type": WsFrameType.index_session_done,
"sessionId": session_id,
"status": "completed",
}))
_index_sessions.pop(session_id, None)
logger.info(
"device_ws: index_session_done completed user=%s session=%s processed=%d",
user_id, session_id, processed,
)
# ── Heartbeat ───────────────────────────────────────────────────────── # ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None: async def _heartbeat_loop(websocket: WebSocket) -> None:

225
app/api/routes/memory.py Normal file
View File

@@ -0,0 +1,225 @@
"""Memory management routes — view/edit/delete user memory tiers.
All routes require authentication. Data is always user-scoped.
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Annotated
from fastapi import APIRouter, Depends, Header, HTTPException, status
from pydantic import BaseModel, Field
from sqlalchemy import delete, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.models import (
ExtractionQueue,
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
MemoryRelation,
)
from app.schemas import UserProfile
router = APIRouter(prefix="/memory", tags=["memory"])
logger = logging.getLogger(__name__)
_ALLOWED_PREDICATES = {
"works_at",
"reports_to",
"stakeholder_of",
"last_contacted_on",
"owes_followup",
"manages",
"collaborates_with",
"owns",
"member_of",
"custom",
}
# ── Response schemas ─────────────────────────────────────────────────────────
class RelationOut(BaseModel):
id: str
subject_label: str
subject_type: str
predicate: str
object_label: str
object_type: str
confidence: float
last_confirmed_at: int | None = None # epoch ms
class RelationPatch(BaseModel):
subject_label: str | None = None
object_label: str | None = None
predicate: str | None = None
confidence: float | None = Field(None, ge=0.0, le=1.0)
class CoreAddBody(BaseModel):
key: str = Field(..., min_length=1, max_length=255)
value: str = Field(..., min_length=1)
# ── Helpers ──────────────────────────────────────────────────────────────────
def _relation_to_out(row: MemoryRelation) -> RelationOut:
last_ms: int | None = None
if row.last_confirmed_at is not None:
last_ms = int(row.last_confirmed_at.timestamp() * 1000)
return RelationOut(
id=row.id,
subject_label=row.subject_label,
subject_type=row.subject_type,
predicate=row.predicate,
object_label=row.object_label,
object_type=row.object_type,
confidence=row.confidence,
last_confirmed_at=last_ms,
)
# ── Routes ───────────────────────────────────────────────────────────────────
@router.get("/core", response_model=dict[str, str])
async def get_core_memory(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, str]:
"""Return all core memory k/v pairs (plaintext) for the current user."""
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(current_user.id)
return {b["label"]: b["value"] for b in blocks}
@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_core_key(
key: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> None:
"""Delete a single core memory key (GDPR Art. 17)."""
mw = MemoryMiddleware(db)
deleted = await mw.delete_core(current_user.id, key)
if not deleted:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found")
@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str])
async def add_core_key(
body: CoreAddBody,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, str]:
"""Add or overwrite a core memory key/value pair."""
mw = MemoryMiddleware(db)
await mw.update_core(current_user.id, body.key, body.value)
return {body.key: body.value}
@router.get("/relational", response_model=list[RelationOut])
async def get_relational_memory(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[RelationOut]:
"""Return all relational memory rows for the current user."""
mw = MemoryMiddleware(db)
rows = await mw.query_relations(current_user.id, limit=200)
return [_relation_to_out(r) for r in rows]
@router.patch("/relational/{relation_id}", response_model=RelationOut)
async def patch_relation(
relation_id: str,
body: RelationPatch,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> RelationOut:
"""Edit a relation row's labels, predicate, or confidence."""
if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES:
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}",
)
result = await db.execute(
select(MemoryRelation).where(
MemoryRelation.id == relation_id,
MemoryRelation.user_id == current_user.id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
if body.subject_label is not None:
row.subject_label = body.subject_label
if body.object_label is not None:
row.object_label = body.object_label
if body.predicate is not None:
row.predicate = body.predicate
if body.confidence is not None:
row.confidence = body.confidence
row.last_confirmed_at = datetime.now(timezone.utc)
await db.commit()
await db.refresh(row)
logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id)
return _relation_to_out(row)
@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT)
async def delete_relation(
relation_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> None:
"""Hard-delete a relation row (GDPR Art. 17)."""
result = await db.execute(
select(MemoryRelation).where(
MemoryRelation.id == relation_id,
MemoryRelation.user_id == current_user.id,
)
)
row = result.scalar_one_or_none()
if row is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found")
await db.delete(row)
await db.commit()
logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id)
@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT)
async def forget_all(
x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> None:
"""Wipe all memory tiers for the current user (GDPR Art. 17).
Requires ``X-Confirm: true`` header. Does NOT delete the user account.
"""
if x_confirm != "true":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.",
)
uid = current_user.id
await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid))
await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid))
await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid))
await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid))
await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid))
await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid))
await db.commit()
logger.warning("memory: forget_all GDPR wipe user=%s", uid)

View File

@@ -1,148 +0,0 @@
"""Plugins routes: browse and install plugins from the marketplace.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
"""
from __future__ import annotations
from typing import Any, Literal
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db import get_session
from app.marketplace.plugin_registry import registry
from app.marketplace.revenue_share import revenue_share
from app.models import PluginInstallation, PluginReview as PluginReviewModel
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
router = APIRouter(prefix="/plugins", tags=["plugins"])
# ── Tier gate ─────────────────────────────────────────────────────────
def _require_plugin_tier(user: UserProfile) -> None:
"""Raise HTTP 403 for users below Power tier."""
if user.tier not in ("power", "team"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Plugin marketplace requires Power tier or above",
)
# ── Local detail schema ────────────────────────────────────────────────
class _PluginDetail(BaseModel):
plugin: PluginManifest
install_count: int
ratings: list[Any]
# ── Routes ────────────────────────────────────────────────────────────
@router.get("", response_model=PluginListResponse)
async def list_plugins(
category: str | None = Query(default=None),
q: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> PluginListResponse:
"""Browse the plugin marketplace. Requires Power tier or above."""
_require_plugin_tier(current_user)
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
@router.get("/{plugin_id}", response_model=_PluginDetail)
async def get_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _PluginDetail:
"""Get full plugin details including install count. Requires Power tier or above."""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Fetch review ratings for this plugin
review_result = await db.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
)
reviews = review_result.scalars().all()
ratings = [
{
"reviewer_id": r.reviewer_id,
"decision": r.decision,
"notes": r.notes,
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
}
for r in reviews
]
return _PluginDetail(
plugin=entry["manifest"],
install_count=entry["install_count"],
ratings=ratings,
)
@router.post("/{plugin_id}/install", response_model=dict)
async def install_plugin(
plugin_id: str,
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
Requires Power tier or above.
"""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Record the installation in plugin_installations
installation = PluginInstallation(
plugin_id=plugin_id,
user_id=current_user.id,
)
db.add(installation)
await db.flush()
await revenue_share.record_install(
db,
plugin_id=plugin_id,
user_id=current_user.id,
amount_cents=entry["manifest"].price_cents,
)
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
return {"ok": True, "download_url": download_url}
@router.delete("/{plugin_id}/install", response_model=dict)
async def uninstall_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unregister a plugin installation."""
result = await db.execute(
select(PluginInstallation).where(
PluginInstallation.plugin_id == plugin_id,
PluginInstallation.user_id == current_user.id,
)
)
installation = result.scalar_one_or_none()
if installation is not None:
await db.delete(installation)
await db.commit()
await registry.record_uninstall(db, plugin_id)
return {"ok": True}

View File

@@ -1,195 +0,0 @@
"""Storage routes: CRUD for E2E-encrypted cloud records.
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
PostgreSQL ``storage_records`` table.
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import StorageRecord
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/storage", tags=["storage"])
_blob_store = BlobStore()
# ── Local response schemas ─────────────────────────────────────────────
class _CreateResponse(BaseModel):
id: str
created_at: int
class _RecordMeta(BaseModel):
id: str
table: str
checksum: str
created_at: int
updated_at: int
# ── Helpers ────────────────────────────────────────────────────────────
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
StorageRecord.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
current = await _current_usage_bytes(user.id, db)
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
async def _get_record_for_user(
record_id: str, user_id: str, db: AsyncSession
) -> StorageRecord:
"""Look up a record and verify ownership. Returns 404 on mismatch
to prevent user enumeration attacks."""
result = await db.execute(
select(StorageRecord).where(
StorageRecord.id == record_id, StorageRecord.user_id == user_id
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
return record
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
async def create_record(
body: StorageRecordCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
await _check_quota(current_user, len(body.blob), db)
record_id = str(uuid.uuid4())
s3_key = await _blob_store.upload(
current_user.id, body.table, record_id, body.blob, body.checksum
)
record = StorageRecord(
id=record_id,
user_id=current_user.id,
table_name=body.table,
s3_key=s3_key,
checksum=body.checksum,
size_bytes=len(body.blob),
)
db.add(record)
await db.commit()
await db.refresh(record)
created_at_ms = int(record.created_at.timestamp() * 1000)
return _CreateResponse(id=record_id, created_at=created_at_ms)
@router.get("/records", response_model=list[_RecordMeta])
async def list_records(
table: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=50, ge=1, le=200),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[_RecordMeta]:
"""List record metadata for the authenticated user. Blob bytes are never returned."""
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
if table is not None:
query = query.where(StorageRecord.table_name == table)
query = query.offset((page - 1) * limit).limit(limit)
result = await db.execute(query)
rows = result.scalars().all()
return [
_RecordMeta(
id=r.id,
table=r.table_name,
checksum=r.checksum,
created_at=int(r.created_at.timestamp() * 1000),
updated_at=int(r.updated_at.timestamp() * 1000),
)
for r in rows
]
@router.get("/records/{record_id}")
async def download_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
record = await _get_record_for_user(record_id, current_user.id, db)
blob = await _blob_store.download(current_user.id, record.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={"X-Checksum": record.checksum},
)
@router.put("/records/{record_id}", response_model=dict)
async def update_record(
record_id: str,
body: StorageRecordUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Replace the blob for an existing record. Verifies checksum before storing."""
record = await _get_record_for_user(record_id, current_user.id, db)
reject_if_tampered(body.blob, body.checksum)
delta = len(body.blob) - record.size_bytes
if delta > 0:
await _check_quota(current_user, delta, db)
s3_key = await _blob_store.upload(
current_user.id, record.table_name, record_id, body.blob, body.checksum
)
record.s3_key = s3_key
record.checksum = body.checksum
record.size_bytes = len(body.blob)
await db.commit()
return {"ok": True}
@router.delete("/records/{record_id}", response_model=dict)
async def delete_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a record and its S3 blob."""
record = await _get_record_for_user(record_id, current_user.id, db)
await _blob_store.delete(current_user.id, record.s3_key)
await db.delete(record)
await db.commit()
return {"ok": True}

View File

@@ -1,79 +0,0 @@
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
from __future__ import annotations
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.core.llm import embed
from app.schemas import (
UserProfile,
VectorSearchRequest,
VectorSearchResponse,
VectorUpsertRequest,
)
from app.storage.encryption import reject_if_tampered
from app.storage.vector_store import VectorStore
router = APIRouter(prefix="/storage", tags=["vectors"])
_vector_store = VectorStore()
class _VectorDeleteRequest(BaseModel):
ids: list[str]
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
@router.post("/vectors/upsert", response_model=dict)
async def upsert_vectors(
body: VectorUpsertRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, int]:
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
for item in body.vectors:
reject_if_tampered(item.blob, item.checksum)
await _vector_store.upsert(current_user.id, body.vectors)
return {"upserted": len(body.vectors)}
@router.post("/vectors/search", response_model=VectorSearchResponse)
async def search_vectors(
body: VectorSearchRequest,
current_user: UserProfile = Depends(get_current_user),
) -> VectorSearchResponse:
"""Search the user-scoped vector namespace with an encrypted query blob."""
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
return VectorSearchResponse(results=results)
@router.delete("/vectors", response_model=dict)
async def delete_vectors(
body: _VectorDeleteRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, bool]:
"""Delete vectors by ID, scoped to the authenticated user."""
await _vector_store.delete(current_user.id, body.ids)
return {"ok": True}
@router.post("/vectors/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
"""
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

1
app/auth/__init__.py Normal file
View File

@@ -0,0 +1 @@
"OAuth provider abstractions and utilities."

135
app/auth/oauth_providers.py Normal file
View File

@@ -0,0 +1,135 @@
"""OAuth 2.0 + PKCE provider abstractions.
Each provider implements a three-step flow designed for a desktop (public) client:
1. get_authorization_url(state, code_challenge) → str
Build the provider's consent-screen URL. State and code_challenge are
generated server-side; the client opens this URL in the system browser.
2. exchange_code(code, code_verifier, redirect_uri) → dict
Exchange the short-lived authorization code for an access token.
The code_verifier proves ownership of the PKCE challenge.
3. get_userinfo(access_token) → OAuthUserInfo
Fetch the canonical user identity from the provider.
Currently supported providers:
- GoogleOAuthProvider (scope: openid email profile)
Adding a new provider:
- Implement the three methods above.
- Register in _PROVIDERS inside routes/auth.py.
"""
from __future__ import annotations
import base64
import hashlib
import os
import urllib.parse
from dataclasses import dataclass
import httpx
# ── Data transfer objects ─────────────────────────────────────────────
@dataclass
class OAuthUserInfo:
"""Normalized user identity returned by any provider."""
provider_user_id: str
email: str
email_verified: bool
avatar_url: str | None
name: str | None
# ── PKCE helpers ──────────────────────────────────────────────────────
def generate_pkce_pair() -> tuple[str, str]:
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
The code_verifier is a random 32-byte URL-safe base64 string.
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
"""
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return code_verifier, code_challenge
# ── Google provider ───────────────────────────────────────────────────
class GoogleOAuthProvider:
"""Google OAuth 2.0 provider (openid email profile scope).
Uses Google's standard authorization endpoint with PKCE S256.
Does NOT use google-auth-oauthlib to keep the flow generic and async.
"""
name = "google"
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self, state: str, code_challenge: str) -> str:
"""Build the Google consent-screen URL."""
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"response_type": "code",
"scope": "openid email profile",
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "select_account",
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
async def exchange_code(
self, code: str, code_verifier: str, redirect_uri: str
) -> dict:
"""Exchange authorization code for an access token."""
async with httpx.AsyncClient() as client:
response = await client.post(
self._TOKEN_URL,
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"code_verifier": code_verifier,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
},
)
response.raise_for_status()
return response.json()
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
"""Fetch the authenticated user's identity from Google."""
async with httpx.AsyncClient() as client:
response = await client.get(
self._USERINFO_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
response.raise_for_status()
data = response.json()
return OAuthUserInfo(
provider_user_id=data["sub"],
email=data["email"],
email_verified=data.get("email_verified", False),
avatar_url=data.get("picture"),
name=data.get("name"),
)

139
app/billing/quota.py Normal file
View File

@@ -0,0 +1,139 @@
"""Quota checks and atomic token-usage accounting for folder integration."""
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime, timezone
from sqlalchemy import select, update
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.ext.asyncio import AsyncSession
from app.billing.tier_manager import TierManager
from app.models import MonthlyTokenUsage
from app.schemas import BillingTier
class QuotaExceeded(Exception):
"""Raised when a folder operation cannot proceed under the user's tier."""
def __init__(self, reason: str, message: str) -> None:
super().__init__(message)
self.reason = reason # "max_files" | "monthly_tokens"
@dataclass
class TokenUsageResult:
tokens_used: int
exhausted: bool
def _current_year_month() -> str:
return datetime.now(timezone.utc).strftime("%Y-%m")
_tier_manager = TierManager()
async def check_folder_quota(
*,
user_id: str,
tier: BillingTier,
estimated_files: int,
db: AsyncSession,
) -> None:
"""Raise QuotaExceeded if folder_max_files or folder_monthly_tokens
would be violated. -1 in either feature means unlimited."""
max_files = _tier_manager.get_feature_value(tier, "folder_max_files")
if max_files != -1 and estimated_files > max_files:
raise QuotaExceeded(
"max_files",
f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.",
)
cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens")
if cap == -1:
return
ym = _current_year_month()
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)
).scalar_one_or_none()
used = row.tokens_used if row else 0
if used >= cap:
raise QuotaExceeded(
"monthly_tokens",
f"Monthly token budget exhausted ({used}/{cap}); resets next month.",
)
async def add_token_usage(
*,
user_id: str,
feature: str,
tokens: int,
db: AsyncSession,
cap: int | None = None,
) -> TokenUsageResult:
"""Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature).
Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls
back to a read-then-write on other engines (e.g. aiosqlite in tests).
Returns post-update total and whether cap is exhausted.
"""
ym = _current_year_month()
# Detect dialect to choose between native upsert and portable fallback.
dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr]
if dialect_name == "postgresql":
# Native atomic upsert — production path.
stmt = (
pg_insert(MonthlyTokenUsage)
.values(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
.on_conflict_do_update(
index_elements=["user_id", "year_month", "feature"],
set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens},
)
.returning(MonthlyTokenUsage.tokens_used)
)
used: int = (await db.execute(stmt)).scalar_one()
await db.commit()
else:
# Portable fallback — used in tests (SQLite) and any non-PG engine.
row = (
await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == user_id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == feature,
)
)
).scalar_one_or_none()
if row is None:
row = MonthlyTokenUsage(
user_id=user_id,
year_month=ym,
feature=feature,
tokens_used=tokens,
)
db.add(row)
else:
row.tokens_used += tokens
await db.commit()
await db.refresh(row)
used = row.tokens_used
exhausted = cap is not None and cap != -1 and used >= cap
return TokenUsageResult(tokens_used=used, exhausted=exhausted)

View File

@@ -43,8 +43,8 @@ class StripeService:
self, self,
user_id: str, user_id: str,
tier: str, tier: str,
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}", success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
cancel_url: str = "https://app.adiuva.app/billing/cancel", cancel_url: str = "https://app.adiuvai.app/billing/cancel",
) -> str: ) -> str:
"""Create a Stripe checkout session and return the URL. """Create a Stripe checkout session and return the URL.
@@ -200,6 +200,45 @@ class StripeService:
sub.status = "canceled" sub.status = "canceled"
await db.commit() await db.commit()
async def list_invoices(
self, user_id: str, db: AsyncSession, limit: int = 24
) -> list[dict[str, Any]]:
"""Return recent invoices for the user from Stripe.
Returns an empty list when Stripe is not configured or the user has
no ``stripe_customer_id``.
"""
if not self._configured():
return []
from app.models import User # noqa: PLC0415
result = await db.execute(
select(User.stripe_customer_id).where(User.id == user_id)
)
customer_id = result.scalar_one_or_none()
if not customer_id:
return []
try:
s = self._client()
invoices = s.Invoice.list(customer=customer_id, limit=limit)
return [
{
"id": inv.id,
"amount_due": inv.amount_due,
"amount_paid": inv.amount_paid,
"currency": inv.currency,
"status": inv.status,
"created": inv.created * 1000, # epoch ms
"invoice_url": inv.hosted_invoice_url,
"invoice_pdf": inv.invoice_pdf,
}
for inv in invoices.auto_paging_iter()
]
except Exception:
return []
# ── Private DB helpers ─────────────────────────────────────────────── # ── Private DB helpers ───────────────────────────────────────────────
async def _upsert_subscription( async def _upsert_subscription(

View File

@@ -22,45 +22,57 @@ FEATURES: dict[str, dict[str, Any]] = {
"agents": 3, "agents": 3,
"batch_active": 2, "batch_active": 2,
"batch_runs_per_day": 5, "batch_runs_per_day": 5,
"cloud_storage_gb": 0,
"backup_gb": 0,
"providers": 1, "providers": 1,
"batch_builder": False, "batch_builder": False,
"plugin_marketplace": False,
"sso": False, "sso": False,
"real_embeddings": False, # keyword fallback only
"realtime_extraction": False, # batch queue (Phase 2)
"relational_memory": False, # relational tier (Phase 3) — Pro+
"proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 200,
"folder_monthly_tokens": 100_000,
}, },
"pro": { "pro": {
"agents": -1, # unlimited "agents": -1, # unlimited
"batch_active": 10, "batch_active": 10,
"batch_runs_per_day": 50, "batch_runs_per_day": 50,
"cloud_storage_gb": 5,
"backup_gb": 5,
"providers": -1, "providers": -1,
"batch_builder": False, "batch_builder": False,
"plugin_marketplace": False,
"sso": False, "sso": False,
"real_embeddings": True, # pgvector cosine search
"realtime_extraction": True, # fire-and-forget asyncio.create_task
"relational_memory": True, # person/project predicates
"proactive_mining": False, # Power+ only (Phase 5)
"folder_max_files": 5000,
"folder_monthly_tokens": 2_000_000,
}, },
"power": { "power": {
"agents": -1, "agents": -1,
"batch_active": -1, # unlimited "batch_active": -1, # unlimited
"batch_runs_per_day": -1, # unlimited "batch_runs_per_day": -1, # unlimited
"cloud_storage_gb": 25,
"backup_gb": 25,
"providers": -1, "providers": -1,
"batch_builder": True, "batch_builder": True,
"plugin_marketplace": True,
"sso": False, "sso": False,
"real_embeddings": True,
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
}, },
"team": { "team": {
"agents": -1, "agents": -1,
"batch_active": -1, "batch_active": -1,
"batch_runs_per_day": -1, # unlimited "batch_runs_per_day": -1, # unlimited
"cloud_storage_gb": -1, # unlimited
"backup_gb": -1, # unlimited
"providers": -1, "providers": -1,
"batch_builder": True, "batch_builder": True,
"plugin_marketplace": True,
"sso": True, "sso": True,
"real_embeddings": True,
"realtime_extraction": True,
"relational_memory": True, # all predicates incl. custom
"proactive_mining": True, # scheduled pattern mining (Phase 5)
"folder_max_files": -1, # unlimited
"folder_monthly_tokens": -1, # unlimited
}, },
} }
@@ -119,77 +131,19 @@ class TierManager:
) )
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def get_feature_value(self, tier: BillingTier, feature: str) -> int:
"""Return integer feature value for tier. -1 means unlimited."""
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if not isinstance(value, int):
return 0
return value
# ── Rate limiting ──────────────────────────────────────────────────── # ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int: def get_rate_limit(self, tier: BillingTier) -> int:
"""Return the requests-per-minute limit for ``tier``.""" """Return the requests-per-minute limit for ``tier``."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# ── Storage quota ────────────────────────────────────────────────────
def enforce_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
``tier`` is the caller's current tier (from ``current_user.tier``).
``current_bytes`` is the total bytes already stored (queried by caller).
"""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Cloud storage is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
def enforce_backup_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
def check_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False
if limit_gb == -1:
return True
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
# Module-level singleton shared across the app. # Module-level singleton shared across the app.
tier_manager = TierManager() tier_manager = TierManager()

View File

@@ -3,7 +3,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings): class Settings(BaseSettings):
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva" DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
JWT_SECRET: str = "change-me-in-production" JWT_SECRET: str = "change-me-in-production"
JWT_ALGORITHM: str = "HS256" JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
@@ -12,26 +12,29 @@ class Settings(BaseSettings):
STRIPE_SECRET_KEY: str = "" STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = "" STRIPE_WEBHOOK_SECRET: str = ""
S3_BUCKET: str = ""
S3_REGION: str = "us-east-1"
S3_ENDPOINT_URL: str = ""
AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = ""
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
OPENAI_API_KEY: str = "" OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = "" ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = "" GOOGLE_API_KEY: str = ""
CEREBRAS_API_KEY: str = "" CEREBRAS_API_KEY: str = ""
GROQ_API_KEY: str = ""
DEEPSEEK_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o" LLM_MODEL: str = "gpt-4o"
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
LLM_EMBED_MODEL: str = "text-embedding-3-small" LLM_EMBED_MODEL: str = "text-embedding-3-small"
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research)
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide)
LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining)
LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit)
# GitHub Copilot OAuth token storage directory. # GitHub Copilot OAuth token storage directory.
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot). # Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
# In Docker, set this to a path backed by a named volume so tokens survive restarts. # In Docker, set this to a path backed by a named volume so tokens survive restarts.
@@ -45,20 +48,39 @@ class Settings(BaseSettings):
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts). # MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
MS_TENANT_ID: str = "common" MS_TENANT_ID: str = "common"
# Google Login OAuth credentials — scope: openid email profile.
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
GOOGLE_AUTH_CLIENT_ID: str = ""
GOOGLE_AUTH_CLIENT_SECRET: str = ""
# The redirect URI registered in Google Cloud Console.
# Google redirects here after consent; this backend route then bounces to
# the adiuvai:// deep link so the Electron app receives the code.
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth # Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
# tokens stored in cloud_agent_configs.oauth_token_encrypted. # tokens stored in cloud_agent_configs.oauth_token_encrypted.
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key() # Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
OAUTH_ENCRYPTION_KEY: str = "" OAUTH_ENCRYPTION_KEY: str = ""
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"] CORS_ORIGINS: list[str] = [
"app://.",
"http://localhost:3000",
"http://localhost:5173",
"http://localhost:4173", # Vite preview (web SPA)
"https://app.adiuvai.com", # Production web portal
]
LANGFUSE_SECRET_KEY: str = "" LANGFUSE_SECRET_KEY: str = ""
LANGFUSE_PUBLIC_KEY: str = "" LANGFUSE_PUBLIC_KEY: str = ""
LANGFUSE_HOST: str = "https://cloud.langfuse.com" LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
SCHEDULER_ENABLED: bool = True
ENV: Literal["dev", "prod"] = "dev" ENV: Literal["dev", "prod"] = "dev"
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8") model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore")
settings = Settings() settings = Settings()

View File

@@ -30,7 +30,6 @@ import asyncio
import json import json
import logging import logging
import os import os
import uuid
from datetime import datetime, timedelta, timezone from datetime import datetime, timedelta, timezone
from typing import Any from typing import Any
@@ -43,10 +42,9 @@ from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS from app.agents.timeline_agent import TIMELINE_TOOLS
from app.config.settings import settings
from app.core.device_manager import DeviceConnectionManager from app.core.device_manager import DeviceConnectionManager
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_llm from app.core.llm import get_agent_llm, model_for_agent
from app.core.preprocessors import detect_content_type, preprocess from app.core.preprocessors import detect_content_type, preprocess
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
from app.db import async_session from app.db import async_session
@@ -74,13 +72,13 @@ _MAX_PROCESSING_STEPS: int = 12
_MAX_SCAN_DEPTH: int = 5 _MAX_SCAN_DEPTH: int = 5
# ── Data-type to tool mapping ───────────────────────────────────────────── # ── Data-type to tool mapping ─────────────────────────────────────────────
# NOTE: "projects" is intentionally excluded — project creation/assignment is
# handled in code by the runner, never delegated to the Step 2 LLM.
_DATA_TYPE_TOOLS: dict[str, list[Any]] = { _DATA_TYPE_TOOLS: dict[str, list[Any]] = {
"tasks": TASK_TOOLS, "tasks": TASK_TOOLS,
"notes": NOTE_TOOLS, "notes": NOTE_TOOLS,
"timelines": TIMELINE_TOOLS, "timelines": TIMELINE_TOOLS,
"timelineEvents": TIMELINE_TOOLS,
"projects": PROJECT_TOOLS,
} }
# ── V2: Unified processing prompt (hot-swappable via Langfuse "unified_processing") ── # ── V2: Unified processing prompt (hot-swappable via Langfuse "unified_processing") ──
@@ -228,6 +226,7 @@ async def _run_agent_with_tools(
tools: list[Any], tools: list[Any],
max_steps: int, max_steps: int,
user_id: str = "", user_id: str = "",
session_id: str = "",
langfuse_prompt: Any = None, langfuse_prompt: Any = None,
agent_name: str = "batch-agent", agent_name: str = "batch-agent",
_tool_calls_out: list[str] | None = None, _tool_calls_out: list[str] | None = None,
@@ -238,7 +237,7 @@ async def _run_agent_with_tools(
run is appended to it (used by the caller to count ``create_*`` calls). run is appended to it (used by the caller to count ``create_*`` calls).
""" """
lf = get_langfuse() lf = get_langfuse()
llm = get_llm() llm = get_agent_llm(agent_name)
llm_with_tools = llm.bind_tools(tools) llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [ messages: list[Any] = [
SystemMessage(content=system_prompt), SystemMessage(content=system_prompt),
@@ -247,6 +246,9 @@ async def _run_agent_with_tools(
tool_map = {tool_def.name: tool_def for tool_def in tools} tool_map = {tool_def.name: tool_def for tool_def in tools}
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
_lf_ctx.__enter__()
_span_ctx = ( _span_ctx = (
lf.start_as_current_observation( lf.start_as_current_observation(
as_type="span", as_type="span",
@@ -264,7 +266,7 @@ async def _run_agent_with_tools(
lf.start_as_current_observation( lf.start_as_current_observation(
as_type="generation", as_type="generation",
name=f"{agent_name}-llm", name=f"{agent_name}-llm",
model=settings.LLM_MODEL, model=model_for_agent(agent_name),
prompt=langfuse_prompt, prompt=langfuse_prompt,
input=messages, input=messages,
) )
@@ -273,7 +275,7 @@ async def _run_agent_with_tools(
_gen = _gen_ctx.__enter__() if _gen_ctx else None _gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages) response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx: if _gen_ctx:
_gen.update(output=_as_text(response.content), usage=extract_usage(response)) _gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
_gen_ctx.__exit__(None, None, None) _gen_ctx.__exit__(None, None, None)
messages.append(response) messages.append(response)
@@ -285,7 +287,6 @@ async def _run_agent_with_tools(
return final_text return final_text
for call in response.tool_calls: for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", "")) call_name = str(call.get("name", ""))
call_args = call.get("args", {}) call_args = call.get("args", {})
logger.info( logger.info(
@@ -318,6 +319,7 @@ async def _run_agent_with_tools(
finally: finally:
if _span_ctx: if _span_ctx:
_span_ctx.__exit__(None, None, None) _span_ctx.__exit__(None, None, None)
_lf_ctx.__exit__(None, None, None)
if lf: if lf:
lf.flush() lf.flush()
@@ -386,7 +388,8 @@ async def _scan_directories(
for file_path in all_files: for file_path in all_files:
try: try:
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path}) meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
modified_at = meta.get("modifiedAt") # FE sends snake_case keys on the wire (toSnakeCase transform)
modified_at = meta.get("modified_at") or meta.get("modifiedAt")
if modified_at is None: if modified_at is None:
filtered.append(file_path) filtered.append(file_path)
continue continue
@@ -607,7 +610,6 @@ async def run_local_agent(
try: try:
# ── Code: scan directories ─────────────────────────────────── # ── Code: scan directories ───────────────────────────────────
logger.info("agent_runner: run=%s scanning directories user=%s", run_id, user_id)
file_paths = await _scan_directories( file_paths = await _scan_directories(
paths=config.directory_paths, paths=config.directory_paths,
extensions=config.file_extensions or [], extensions=config.file_extensions or [],
@@ -656,9 +658,14 @@ async def run_local_agent(
# ── Phase B: single LLM call ───────────────────────── # ── Phase B: single LLM call ─────────────────────────
extraction_rules = _get_extraction_rules(agent_config, content_type) extraction_rules = _get_extraction_rules(agent_config, content_type)
no_match_behavior = _get_no_match_behavior(agent_config) no_match_behavior = _get_no_match_behavior(agent_config)
global_rules_lines = "\n".join( base_global_rules = list(agent_config.get("global_rules", []))
f"- {r}" for r in agent_config.get("global_rules", []) if "notes" in config.data_types:
base_global_rules.append(
"For notes: when updating an existing note use `propose_note_edit` "
"(type=append/insert/replace) so the user can review AI changes. "
"Only call `update_note` for complete content replacement without review."
) )
global_rules_lines = "\n".join(f"- {r}" for r in base_global_rules)
metadata_section = _format_metadata(preprocessed.metadata) metadata_section = _format_metadata(preprocessed.metadata)
system_prompt = compile_prompt( system_prompt = compile_prompt(
@@ -686,6 +693,7 @@ async def run_local_agent(
tools=processing_tools, tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS, max_steps=_MAX_PROCESSING_STEPS,
user_id=user_id, user_id=user_id,
session_id=run_id,
langfuse_prompt=prompt_obj, langfuse_prompt=prompt_obj,
agent_name="unified-processor", agent_name="unified-processor",
_tool_calls_out=file_tool_calls, _tool_calls_out=file_tool_calls,
@@ -696,6 +704,12 @@ async def run_local_agent(
) )
items_created += file_created items_created += file_created
# Refresh project list when a project was created so
# subsequent files see it in the prompt context.
if "create_project" in file_tool_calls:
projects = await _fetch_projects()
projects_block = _format_projects(projects)
logger.info( logger.info(
"agent_runner: run=%s file=%r created=%d result=%s", "agent_runner: run=%s file=%r created=%d result=%s",
run_id, file_path, file_created, result_text[:200], run_id, file_path, file_created, result_text[:200],
@@ -911,6 +925,7 @@ async def run_cloud_agent(
tools=processing_tools, tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS, max_steps=_MAX_PROCESSING_STEPS,
user_id=user_id, user_id=user_id,
session_id=run_id,
langfuse_prompt=cloud_prompt_obj, langfuse_prompt=cloud_prompt_obj,
agent_name="cloud-processor", agent_name="cloud-processor",
) )

View File

@@ -0,0 +1,59 @@
"""In-process TTL buffer for per-session LangChain message history.
Stores the full message list (including AIMessage with tool_calls and ToolMessage)
keyed by (user_id, session_id), so agents can reconstruct tool-call context across
conversation turns without it being lossy through the wire.
Single-process only. For multi-worker deployments, replace the _SessionBuffer
implementation with one backed by Redis (serialize LangChain messages to dicts via
message_to_dict / messages_from_dict from langchain_core.messages).
"""
from __future__ import annotations
import time
from threading import Lock
from langchain_core.messages import BaseMessage
SESSION_TTL_SECONDS = 1800 # 30-minute idle expiry
MAX_MESSAGES_PER_SESSION = 80 # cap to avoid unbounded memory growth
class _SessionBuffer:
def __init__(self) -> None:
self._store: dict[tuple[str, str], tuple[float, list[BaseMessage]]] = {}
self._lock = Lock()
def _evict_stale(self) -> None:
now = time.monotonic()
stale = [k for k, (ts, _) in self._store.items() if now - ts > SESSION_TTL_SECONDS]
for k in stale:
del self._store[k]
def get(self, user_id: str, session_id: str) -> list[BaseMessage] | None:
key = (user_id, session_id)
with self._lock:
entry = self._store.get(key)
if entry is None:
return None
ts, msgs = entry
if time.monotonic() - ts > SESSION_TTL_SECONDS:
del self._store[key]
return None
self._store[key] = (time.monotonic(), msgs)
return list(msgs)
def set(self, user_id: str, session_id: str, messages: list[BaseMessage]) -> None:
key = (user_id, session_id)
capped = messages[-MAX_MESSAGES_PER_SESSION:]
with self._lock:
self._evict_stale()
self._store[key] = (time.monotonic(), capped)
def clear(self, user_id: str, session_id: str) -> None:
with self._lock:
self._store.pop((user_id, session_id), None)
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
session_buffer = _SessionBuffer()

228
app/core/brief_agent.py Normal file
View File

@@ -0,0 +1,228 @@
"""Brief agent — produces plain-text home and project status briefs.
Read-only tool subset only. Never calls _normalize_tagged_list_lines —
the brief prompt forbids XML tags, so skipping post-processing is intentional.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from datetime import date
from typing import Any
from app.agents.note_agent import NOTE_READ_TOOLS
from app.agents.project_agent import PROJECT_READ_TOOLS
from app.agents.task_agent import TASK_READ_TOOLS
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
from app.core.deep_agent import (
_language_instruction,
_proactive_hints_injection,
_read_only_memory_tools,
_relational_memory_injection,
_run_single_agent_stream,
_trace_id_from_context,
build_brief_multi_project_manifest,
)
from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback
_LANGUAGE_NAMES: dict[str, str] = {
"en": "English", "it": "Italian", "es": "Spanish",
"fr": "French", "de": "German",
"english": "English", "italian": "Italian", "italiano": "Italian",
"spanish": "Spanish", "español": "Spanish",
"french": "French", "français": "French",
"german": "German", "deutsch": "German",
}
_HOME_BRIEF_FALLBACK = """\
You are the user's personal assistant producing a short daily brief.
ROLE
Act like a calm, attentive secretary writing a stand-up note for your boss.
Warm and human, never breezy. Never cheerful filler, never emojis, never
"here is your brief" meta-text. The user is opening the app mid-workday and
is probably stressed — your job is to lower cognitive load, not add noise.
TOOLS — always call before writing
Pull fresh data every run. Do not invent counts or titles. Use at minimum:
- list_tasks_due_today — tasks the user owes today
- list_timelines_today — events starting or ending today
- list_all_projects — projects currently in progress or at risk
- memory_list_blocks / memory_get — personal context about people, clients,
payment habits, working preferences
If a tool returns nothing, simply omit that topic. Never report zeros.
WHAT TO INCLUDE
1. Tasks due today (title + priority; group the 1-2 most important).
2. Timeline events starting or ending today (and anything that starts/ends
tomorrow if the user has a very light day).
3. Active projects that need a nudge — stalled, blocked, or awaiting input.
4. Memory-aware colour where it sharpens the brief. Examples:
- "Client Rossi tends to pay late — the Acme invoice is 6 days out."
- "You usually dislike meetings before 10:00 — the call at 09:30 is unusual."
Only add a memory line when it changes what the user does. Do not pad.
WHAT TO OMIT
- Zero-counts ("no overdue items", "0 meetings today").
- Statistics ("2 active projects, 3 completed tasks").
- Headers, titles, greetings, sign-offs, dates, emojis, slang.
- Meta-phrases ("here is", "let me know if", "hope this helps").
- XML/HTML tags of any kind. Plain prose only.
LIGHT-DAY CLAUSE
If tasks + events + active-project-nudges together produce fewer than two
sentences of content, also list 1-2 projects in status on_hold or waiting
and ask a single, specific question about them — e.g. "Is the Bianchi
redesign still paused, or ready to pick back up?" One question max, grounded
in a real project name.
VOICE
- Calm. Concise. Human. Short sentences.
- Use **bold** sparingly for task titles, project names, and people's names.
- No bullet lists. Flow as 2-4 sentences of prose.
LENGTH
2-4 sentences total. Hard cap 4. If the day is truly empty, one sentence.
Respond in the user's language ({language}). Today is {today}.\
"""
_PROJECT_BRIEF_FALLBACK = """\
You are the project assistant producing a short status brief for ONE project.
ROLE
A senior project manager summarising state-of-play for the owner. Factual,
sharp, forward-looking. Never reassuring filler, never emojis.
SCOPE
Work only with project_id = {project_id}. Do not mention or pull data from
other projects. Use tools to fetch fresh data:
- get_project — current status, dates, description
- list_tasks(project_id) — open work, split by status
- list_timelines(project_id) — milestones hit, upcoming, overdue
- list_notes(project_id) — any recent decisions or blockers
- memory_get — relevant context about the client, collaborators, constraints
STRUCTURE — follow exactly, one short paragraph per section, no headers
1. **State.** One sentence: current phase, health (on track / at risk / blocked),
and why. Cite the concrete signal (overdue milestone, stalled tasks, recent
blocker note).
2. **What's moving.** What was completed or progressed recently. Name specific
tasks or milestones.
3. **Next steps.** The 1-3 most important things the user should do next, in
priority order. Be concrete — task name, who owns it, when due if known.
If waiting on someone else, name them and what the ask is.
4. **Risks / memory-flagged items.** One line max. Only include when there is
a real risk or a relevant memory (e.g. late-paying client, tight deadline,
scope change). Omit the section entirely if nothing to say.
WHAT TO OMIT
- Zero-counts ("no overdue tasks").
- Generic advice ("keep up the good work").
- Greetings, headers, bullet lists, emojis, sign-offs, meta-phrases.
- XML/HTML tags or bracketed id lists. Plain prose only.
VOICE
- Direct. Factual. No fluff.
- Use **bold** sparingly for task titles, milestone names, and the owner's name.
- Short sentences. Prefer verbs over nouns ("Client review is blocking release"
not "There is a blocker which is the client review").
LENGTH
4-8 sentences total across the 3-4 sections. Hard cap 8.
Respond in the user's language ({language}). Today is {today}.\
"""
def _resolve_language(context: dict[str, Any]) -> str:
core = context.get("core_memory") or {}
raw = (core.get("language") or "en").strip().lower()
return _LANGUAGE_NAMES.get(raw, raw.title()) or "English"
def _build_read_tools(user_id: str, trace_id: str | None) -> list[Any]:
return [
*TASK_READ_TOOLS,
*PROJECT_READ_TOOLS,
*TIMELINE_READ_TOOLS,
*NOTE_READ_TOOLS,
*_read_only_memory_tools(user_id, trace_id),
]
async def run_home_brief(
user_id: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
"""Stream a plain-text daily home brief.
Yields (event_type, data) tuples identical to _run_single_agent_stream.
Do NOT post-process output through _normalize_tagged_list_lines.
"""
from app.agents.folder_agent import FOLDER_TOOLS
trace_id = _trace_id_from_context(context)
today = date.today().isoformat()
language = _resolve_language(context)
raw_template, langfuse_prompt = get_prompt_or_fallback("home_brief", _HOME_BRIEF_FALLBACK)
system_prompt = compile_prompt(raw_template, langfuse_prompt, language=language, today=today)
system_prompt += _relational_memory_injection(context)
system_prompt += _proactive_hints_injection(context)
system_prompt += _language_instruction(context)
if today not in system_prompt:
system_prompt += f"\nToday is {today}."
brief_manifest = await build_brief_multi_project_manifest()
system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "")
tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS]
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message="Generate the daily brief.",
context=context,
langfuse_prompt=langfuse_prompt,
agent_name="brief-agent",
tools=tools,
):
yield event
async def run_project_brief(
user_id: str,
project_id: str,
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
"""Stream a plain-text project status brief for project_id.
Yields (event_type, data) tuples identical to _run_single_agent_stream.
Do NOT post-process output through _normalize_tagged_list_lines.
"""
trace_id = _trace_id_from_context(context)
today = date.today().isoformat()
language = _resolve_language(context)
raw_template, langfuse_prompt = get_prompt_or_fallback("project_brief", _PROJECT_BRIEF_FALLBACK)
system_prompt = compile_prompt(
raw_template, langfuse_prompt,
language=language, today=today, project_id=project_id,
)
system_prompt += _relational_memory_injection(context)
system_prompt += _proactive_hints_injection(context)
system_prompt += _language_instruction(context)
if today not in system_prompt:
system_prompt += f"\nToday is {today}."
tools = _build_read_tools(user_id, trace_id)
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=f"Generate the project status brief for project {project_id}.",
context=context,
langfuse_prompt=langfuse_prompt,
agent_name="brief-agent",
tools=tools,
):
yield event

File diff suppressed because it is too large Load Diff

34
app/core/embeddings.py Normal file
View File

@@ -0,0 +1,34 @@
"""OpenAI embedding helper for associative memory tier.
Single public function: ``embed_text(text) -> list[float] | None``.
Returns None on any failure — callers must implement a keyword fallback.
Never raises; all exceptions are logged as warnings.
"""
from __future__ import annotations
import logging
from openai import AsyncOpenAI
logger = logging.getLogger(__name__)
_MAX_INPUT_CHARS = 8000
_EMBEDDING_MODEL = "text-embedding-3-small"
async def embed_text(text: str) -> list[float] | None:
"""Call OpenAI text-embedding-3-small. Return None on failure (caller falls back to keyword)."""
try:
client = AsyncOpenAI()
truncated = text[:_MAX_INPUT_CHARS]
response = await client.embeddings.create(
input=truncated,
model=_EMBEDDING_MODEL,
)
result: list[float] = response.data[0].embedding
logger.debug("embeddings: embed_text dims=%d", len(result))
return result
except Exception as exc:
logger.warning("embeddings: embed_text failed: %s", exc)
return None

183
app/core/folder_indexer.py Normal file
View File

@@ -0,0 +1,183 @@
"""Per-file summarisation for project folder integration."""
from __future__ import annotations
import base64
import io
from dataclasses import dataclass
from langchain_core.messages import HumanMessage, SystemMessage
from pypdf import PdfReader
from docx import Document as DocxDocument
from app.core.langfuse_client import (
compile_prompt,
extract_usage,
get_langfuse,
get_prompt_or_fallback,
)
from app.core.llm import get_llm
_TEXT_FALLBACK = (
"You are summarising a file for an AI assistant that helps the user manage a project.\n"
"Produce a single sentence (<=30 words, <=200 chars) that captures the file's purpose "
"and most important detail.\nFile extension: {ext}\nFile name: {name}\nContent (truncated if long):\n{content}"
)
_IMAGE_FALLBACK = (
"You are summarising an image attached to a project folder.\n"
"Produce a single sentence (<=30 words, <=200 chars) describing what the image shows "
"and any obvious purpose (logo, screenshot, diagram, photo of a whiteboard, etc.)."
)
_MAX_INPUT_CHARS = 6000
@dataclass
class IndexResult:
summary: str
tokens_used: int
async def _llm_text(messages: list) -> object:
"""Make the LLM call for text summarisation.
Defined as a standalone async function so tests can patch it cleanly
without needing to mock the LLM object itself.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def _llm_vision(messages: list) -> object:
"""Make the LLM call for vision (image) summarisation.
Accepts the message list and returns the response directly, mirroring
the ``_llm_text`` caller pattern so tests can patch it at the module level.
"""
llm = get_llm(model="gpt-4o-mini", temperature=0.2)
return await llm.ainvoke(messages)
async def summarize_image(*, image_b64: str, mime: str, file_name: str | None = None) -> IndexResult:
"""Return a compact summary of an image file using vision.
Parameters
----------
image_b64:
Base64-encoded image bytes.
mime:
MIME type of the image, e.g. ``"image/png"``.
file_name:
Optional file name, attached to the Langfuse trace as input metadata.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_image", _IMAGE_FALLBACK)
messages = [
SystemMessage(content=template),
HumanMessage(content=[
{"type": "text", "text": "Summarise this image."},
{"type": "image_url", "image_url": {"url": f"data:{mime};base64,{image_b64}"}},
]),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-image",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": file_name, "mime": mime},
) as gen:
response = await _llm_vision(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_vision(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
async def summarize_text(*, content: str, ext: str, name: str) -> IndexResult:
"""Return a compact summary of a text file.
Parameters
----------
content:
Raw text content of the file (will be truncated to _MAX_INPUT_CHARS).
ext:
File extension including the leading dot, e.g. ``".md"``.
name:
File name, e.g. ``"kickoff.md"``.
"""
template, prompt_obj = get_prompt_or_fallback("folder_file_summary_text", _TEXT_FALLBACK)
truncated = content[:_MAX_INPUT_CHARS]
compiled = compile_prompt(template, prompt_obj, ext=ext, name=name, content=truncated)
messages = [
SystemMessage(content=compiled),
HumanMessage(content="Summarise this file."),
]
lf = get_langfuse()
if lf is not None:
with lf.start_as_current_observation(
as_type="generation",
name="folder-summarize-text",
model="gpt-4o-mini",
prompt=prompt_obj,
input={"file_name": name, "ext": ext, "content_chars": len(truncated)},
) as gen:
response = await _llm_text(messages)
usage = extract_usage(response)
gen.update(output=response.content, usage_details=usage)
else:
response = await _llm_text(messages)
usage = extract_usage(response)
summary = (response.content or "").strip()[:500]
return IndexResult(summary=summary, tokens_used=usage.get("total", 0))
def _extract_pdf_text(pdf_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(pdf_b64))
reader = PdfReader(buf)
parts: list[str] = []
for page in reader.pages:
try:
parts.append(page.extract_text() or "")
except Exception:
continue
return "\n".join(parts).strip()
def _extract_docx_text(docx_b64: str) -> str:
buf = io.BytesIO(base64.b64decode(docx_b64))
doc = DocxDocument(buf)
return "\n".join(p.text for p in doc.paragraphs if p.text).strip()
async def summarize_pdf(*, pdf_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a PDF file.
Parameters
----------
pdf_b64:
Base64-encoded PDF bytes.
name:
File name, e.g. ``"report.pdf"``.
"""
text = _extract_pdf_text(pdf_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".pdf", name=name)
async def summarize_docx(*, docx_b64: str, name: str) -> IndexResult:
"""Return a compact summary of a DOCX file.
Parameters
----------
docx_b64:
Base64-encoded DOCX bytes.
name:
File name, e.g. ``"spec.docx"``.
"""
text = _extract_docx_text(docx_b64)
if not text:
return IndexResult(summary="Could not extract text", tokens_used=0)
return await summarize_text(content=text, ext=".docx", name=name)

View File

@@ -39,8 +39,10 @@ Linking a prompt to a generation::
from __future__ import annotations from __future__ import annotations
import hashlib
import logging import logging
from typing import Any from contextlib import contextmanager
from typing import Any, Generator
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -67,9 +69,9 @@ def get_langfuse() -> Any | None:
_client = Langfuse( _client = Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY, secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY, public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_HOST, host=settings.LANGFUSE_BASE_URL,
) )
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_HOST) logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_BASE_URL)
except Exception as exc: except Exception as exc:
logger.warning("langfuse: failed to initialize: %s", exc) logger.warning("langfuse: failed to initialize: %s", exc)
_client = None _client = None
@@ -145,3 +147,44 @@ def extract_usage(response: Any) -> dict[str, int]:
"output": int(meta.get("output_tokens", 0)), "output": int(meta.get("output_tokens", 0)),
"total": int(meta.get("total_tokens", 0)), "total": int(meta.get("total_tokens", 0)),
} }
def hash_user_id(user_id: str) -> str:
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
This avoids sending raw database UUIDs to external observability services
while still providing a stable, deterministic identifier for per-user
metrics in the Langfuse dashboard.
"""
return hashlib.sha256(user_id.encode()).hexdigest()
@contextmanager
def langfuse_context(
user_id: str | None = None,
session_id: str | None = None,
) -> Generator[None, None, None]:
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
No-op when Langfuse is not configured or parameters are empty.
"""
lf = get_langfuse()
if lf is None or (not user_id and not session_id):
yield
return
try:
from langfuse import propagate_attributes
except ImportError:
logger.debug("langfuse: propagate_attributes not available — skipping context")
yield
return
attrs: dict[str, str] = {}
if user_id:
attrs["user_id"] = hash_user_id(user_id)
if session_id:
attrs["session_id"] = session_id
with propagate_attributes(**attrs):
yield

View File

@@ -1,6 +1,6 @@
"""LLM factory — centralised model instantiation via LiteLLM. """LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()`` Every agent and the orchestrator call ``get_llm()``
instead of directly constructing a provider-specific class. The model string instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_: <https://docs.litellm.ai/docs/providers>`_:
@@ -11,7 +11,7 @@ follows the `LiteLLM model naming convention
* Ollama: ``ollama/llama3`` * Ollama: ``ollama/llama3``
* Bedrock: ``bedrock/anthropic.claude-v2`` * Bedrock: ``bedrock/anthropic.claude-v2``
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env`` Switch providers by changing **LLM_MODEL** in ``.env``
— no code changes required. — no code changes required.
""" """
@@ -19,6 +19,7 @@ from __future__ import annotations
import os import os
import warnings import warnings
from collections.abc import Callable
from openai import AsyncOpenAI from openai import AsyncOpenAI
import litellm import litellm
@@ -50,6 +51,10 @@ def _api_key_for_model(model: str) -> str | None:
return settings.GOOGLE_API_KEY or None return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"): if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None return settings.CEREBRAS_API_KEY or None
if model.startswith("groq/"):
return settings.GROQ_API_KEY or None
if model.startswith("deepseek/"):
return settings.DEEPSEEK_API_KEY or None
if model.startswith("github_copilot/"): if model.startswith("github_copilot/"):
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM. # GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
# No API key is required; returning None lets LiteLLM handle auth. # No API key is required; returning None lets LiteLLM handle auth.
@@ -95,12 +100,39 @@ def get_llm(
) )
def get_router_llm( _AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,
"task-brief-agent": lambda: settings.LLM_MODEL_TASK_BRIEF_AGENT or settings.LLM_MODEL,
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
"memory-extractor": lambda: settings.LLM_MODEL_MEMORY_EXTRACTOR or "gpt-4o-mini",
"memory-miner": lambda: settings.LLM_MODEL_MEMORY_MINER or "gpt-4o-mini",
"memory-auditor": lambda: settings.LLM_MODEL_MEMORY_AUDITOR or settings.LLM_MODEL,
"note-summarizer": lambda: "gpt-4o-mini",
}
def model_for_agent(agent_name: str) -> str:
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
def get_agent_llm(
agent_name: str,
*, *,
temperature: float = 0, temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM: ) -> ChatOpenAI | ChatLiteLLM:
"""Return the lighter model used for intent classification / routing.""" """Return an LLM configured for *agent_name*, respecting per-agent overrides.
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
per-agent override is left empty in ``.env``.
"""
model = model_for_agent(agent_name)
return get_llm(model=model, temperature=temperature)
async def embed(text: str) -> list[float]: async def embed(text: str) -> list[float]:

View File

@@ -0,0 +1,450 @@
"""Mem0-style Extract/Update pipeline — Phase 2.
Runs after every ``store_episode`` call to distil durable facts, preferences,
routines, and relations from the latest conversation turn.
Entry point: ``run_extraction(db, user_id, last_user_msg, last_assistant_msg, session_id)``
Design notes
------------
- Two gpt-4o-mini calls per turn: extract candidates, then decide action per candidate.
- Short-circuit: if no existing neighbours → ADD without a second LLM call (cost saving).
- Zero-trust: never logs decrypted user content; relation subject/object labels are
treated as identifiers (safe to log per spec).
- Must not raise into the request path — caller wraps in asyncio.create_task().
"""
from __future__ import annotations
import json
import logging
from typing import Any, Literal
from pydantic import BaseModel, Field
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.langfuse_client import get_langfuse, get_prompt_or_fallback, extract_usage, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent
logger = logging.getLogger(__name__)
# ── Fallback prompts (used when Langfuse unavailable) ─────────────────────────
_EXTRACTION_FALLBACK = (
"You are a memory extractor for a personal AI secretary. Given the last conversation "
"turn, the user's core memory, and recent episode summaries, identify durable facts, "
"preferences, routines, and person/project relations worth remembering.\n\n"
"Output JSON matching this schema exactly:\n"
'{{"candidates": [{{"type": "<fact|preference|relation|routine>", '
'"content": "<short canonical statement>", '
'"target_tier": "<core|associative|relational|proactive>", '
'"subject": null, "predicate": null, "object": null, "confidence": 0.7}}]}}\n\n'
"Rules:\n"
"- Skip small talk, greetings, one-off questions.\n"
"- Max 5 candidates per call.\n"
"- Only extract durable information (still true next week).\n"
"- For type=relation: subject/predicate/object required.\n"
"- Default confidence=0.7.\n\n"
"## Last turn\n{last_turn}\n\n"
"## Core memory (current)\n{core_memory}\n\n"
"## Recent episodes\n{recent_episodes}"
)
_DECIDE_FALLBACK = (
"You are a memory update decision engine. Given a new memory candidate and a list of "
"existing memories from the same tier, decide what action to take.\n\n"
"Respond with exactly one word: ADD, UPDATE, DELETE, or NOOP.\n\n"
"- ADD: new information not in existing memories.\n"
"- UPDATE: contradicts or supersedes an existing memory.\n"
"- DELETE: states something is no longer true.\n"
"- NOOP: already captured accurately.\n\n"
"## New candidate\n{candidate}\n\n"
"## Existing memories (same tier, top neighbours)\n{existing_memories}"
)
# ── Pydantic schemas ───────────────────────────────────────────────────────────
class MemoryCandidate(BaseModel):
type: Literal["fact", "preference", "relation", "routine"]
content: str
target_tier: Literal["core", "associative", "relational", "proactive"]
subject: str | None = None
predicate: str | None = None
object: str | None = None
confidence: float = Field(default=0.7, ge=0.0, le=1.0)
class ExtractionResult(BaseModel):
candidates: list[MemoryCandidate] = Field(default_factory=list)
# ── Task 2.1 — Extract candidates ─────────────────────────────────────────────
async def extract_candidates(
last_turn: str,
core_memory: dict[str, str],
recent_episodes: list[str],
) -> ExtractionResult:
"""Call gpt-4o-mini to extract memory candidates from the latest turn.
Returns an ExtractionResult (may be empty on failure — never raises).
"""
core_str = "\n".join(f"{k}: {v}" for k, v in core_memory.items()) or "(empty)"
episodes_str = "\n---\n".join(recent_episodes[-5:]) or "(none)"
template, prompt_obj = get_prompt_or_fallback("memory_extraction", _EXTRACTION_FALLBACK)
# Compile with Langfuse variable syntax ({{var}}) or fallback {var}
if prompt_obj is not None:
try:
system_text = prompt_obj.compile(
last_turn=last_turn,
core_memory=core_str,
recent_episodes=episodes_str,
)
if isinstance(system_text, list):
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
except Exception as exc:
logger.warning("memory_extraction: compile failed: %s", exc)
system_text = template.format(
last_turn=last_turn,
core_memory=core_str,
recent_episodes=episodes_str,
)
else:
system_text = template.format(
last_turn=last_turn,
core_memory=core_str,
recent_episodes=episodes_str,
)
llm = get_agent_llm("memory-extractor", temperature=0)
# Bind JSON mode so the model always returns parseable output.
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
lf = get_langfuse()
try:
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Extract memory candidates as JSON."),
]
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-extraction",
model=model_for_agent("memory-extractor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm_json.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm_json.ainvoke(messages)
raw = json.loads(response.content)
result = ExtractionResult.model_validate(raw)
logger.info("memory_extraction: extracted %d candidates", len(result.candidates))
return result
except Exception as exc:
logger.warning("memory_extraction: extract_candidates failed: %s", exc)
return ExtractionResult(candidates=[])
# ── Task 2.2 — Decide action ──────────────────────────────────────────────────
async def decide_action(
candidate: MemoryCandidate,
existing: list[str],
) -> Literal["ADD", "UPDATE", "DELETE", "NOOP"]:
"""Decide what to do with a candidate given existing memories in the same tier.
Short-circuits to ADD without an LLM call when existing is empty (cost saving).
Never raises.
"""
if not existing:
return "ADD"
candidate_str = f"[{candidate.type}] {candidate.content}"
existing_str = "\n".join(f"- {m}" for m in existing)
template, prompt_obj = get_prompt_or_fallback("memory_decide_action", _DECIDE_FALLBACK)
if prompt_obj is not None:
try:
system_text = prompt_obj.compile(
candidate=candidate_str,
existing_memories=existing_str,
)
if isinstance(system_text, list):
system_text = "\n".join(m.get("content", "") for m in system_text if isinstance(m, dict))
except Exception as exc:
logger.warning("memory_extraction: decide compile failed: %s", exc)
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
else:
system_text = template.format(candidate=candidate_str, existing_memories=existing_str)
llm = get_agent_llm("memory-extractor", temperature=0)
lf = get_langfuse()
try:
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Decide action."),
]
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-decide-action",
model=model_for_agent("memory-extractor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm.ainvoke(messages)
verb = response.content.strip().upper()
if verb in ("ADD", "UPDATE", "DELETE", "NOOP"):
return verb # type: ignore[return-value]
logger.warning("memory_extraction: unexpected decide verb=%r, defaulting ADD", verb)
return "ADD"
except Exception as exc:
logger.warning("memory_extraction: decide_action failed: %s", exc)
return "ADD"
# ── Task 2.3 — Pipeline orchestrator ──────────────────────────────────────────
async def run_extraction(
db: AsyncSession,
user_id: str,
last_user_msg: str,
last_assistant_msg: str,
session_id: str | None,
) -> None:
"""Full Mem0-style extract/update pipeline for one conversation turn.
Steps:
1. Load core memory + last 5 episodes.
2. extract_candidates() → up to 5 MemoryCandidate objects.
3. For each candidate: find top-3 neighbours → decide_action() → apply.
4. Trace via Langfuse.
Never raises — wraps everything in try/except.
"""
try:
await _run_extraction_inner(db, user_id, last_user_msg, last_assistant_msg, session_id)
except Exception as exc:
logger.warning("memory_extraction: run_extraction failed user=%s: %s", user_id, exc)
async def _run_extraction_inner(
db: AsyncSession,
user_id: str,
last_user_msg: str,
last_assistant_msg: str,
session_id: str | None,
) -> None:
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
middleware = MemoryMiddleware(db)
fernet = await middleware._get_fernet(user_id)
if fernet is None:
logger.warning("memory_extraction: no fernet for user=%s, skipping", user_id)
return
# 1. Load context
core: dict[str, str] = await middleware._load_core(user_id, fernet)
episodes: list[str] = await middleware._load_episodic(user_id, fernet, session_id=session_id)
last_turn = f"User: {last_user_msg}\nAssistant: {last_assistant_msg}"
lf = get_langfuse()
async def _run(trace_id: str | None) -> dict[str, Any]:
# 2. Extract candidates
result = await extract_candidates(last_turn, core, episodes)
if not result.candidates:
logger.info("memory_extraction: no candidates user=%s", user_id)
return {"candidates": 0, "applied": 0}
logger.info(
"memory_extraction: processing %d candidates user=%s trace=%s",
len(result.candidates),
user_id,
trace_id or "-",
)
# 3. Apply each candidate
applied = 0
actions: list[str] = []
for candidate in result.candidates:
try:
await _apply_candidate(middleware, db, user_id, fernet, candidate, trace_id)
applied += 1
actions.append(f"{candidate.type}:{candidate.target_tier}")
except Exception as exc:
logger.warning(
"memory_extraction: apply failed candidate=%r user=%s: %s",
candidate.content[:80],
user_id,
exc,
)
logger.info(
"memory_extraction: applied %d/%d candidates user=%s",
applied,
len(result.candidates),
user_id,
)
return {"candidates": len(result.candidates), "applied": applied, "actions": actions}
with langfuse_context(user_id=user_id, session_id=session_id):
if lf:
with lf.start_as_current_observation(
as_type="span",
name="memory-extraction-pipeline",
input={"last_turn_preview": last_turn[:200]},
) as span:
summary = await _run(trace_id=span.id)
span.update(output=summary)
try:
lf.flush()
except Exception:
pass
else:
await _run(trace_id=None)
async def _apply_candidate(
middleware: Any,
db: AsyncSession,
user_id: str,
fernet: Any,
candidate: MemoryCandidate,
trace_id: str | None,
) -> None:
"""Fetch neighbours, decide action, apply to the appropriate tier."""
neighbours: list[str] = []
if candidate.target_tier == "core":
# For core tier: neighbours are existing core block values for similar keys.
blocks = await middleware.list_core_blocks(user_id)
neighbours = [b["value"] for b in blocks[:3]]
elif candidate.target_tier == "associative":
neighbours = await middleware.search_archival(user_id, candidate.content, top_k=3)
elif candidate.target_tier == "relational":
# Relation candidates handled specially — passed to upsert_relation directly.
# Neighbours: search by subject label if available.
neighbours = []
elif candidate.target_tier == "proactive":
neighbours = await middleware.search_recall(user_id, candidate.content, top_k=3)
action = await decide_action(candidate, neighbours)
logger.info(
"memory_extraction: candidate type=%s tier=%s action=%s",
candidate.type,
candidate.target_tier,
action,
)
if action == "NOOP":
return
if candidate.target_tier == "relational":
# Always upsert relations — decide_action skipped (no neighbour search).
if candidate.subject and candidate.predicate and candidate.object:
await _upsert_relation(
middleware, db, user_id, candidate, trace_id
)
return
if action in ("ADD", "UPDATE"):
if candidate.target_tier == "core":
# Derive a short key from the content (first 40 chars, snake_cased).
key = _content_to_key(candidate.content)
await middleware.update_core(user_id, key, candidate.content, trace_id=trace_id)
elif candidate.target_tier == "associative":
await middleware.store_associative(user_id, candidate.content)
elif candidate.target_tier == "proactive":
await _store_proactive_stub(middleware, db, user_id, candidate, fernet)
elif action == "DELETE":
if candidate.target_tier == "core":
key = _content_to_key(candidate.content)
await middleware.delete_core(user_id, key)
def _content_to_key(content: str) -> str:
"""Derive a short snake_case key from a content string (first 40 chars)."""
import re # noqa: PLC0415
slug = re.sub(r"[^a-z0-9]+", "_", content[:40].lower()).strip("_")
return slug or "memory"
async def _upsert_relation(
middleware: Any,
db: AsyncSession,
user_id: str,
candidate: MemoryCandidate,
trace_id: str | None,
) -> None:
"""Upsert a relation row via MemoryMiddleware.upsert_relation (Phase 3)."""
await middleware.upsert_relation(
user_id=user_id,
subject=candidate.subject or "unknown",
subject_type="unknown",
predicate=candidate.predicate or "related_to",
object_=candidate.object or "unknown",
object_type="unknown",
confidence=candidate.confidence,
)
logger.info(
"memory_extraction: upserted relation subject=%s predicate=%s object=%s",
candidate.subject,
candidate.predicate,
candidate.object,
)
async def _store_proactive_stub(
middleware: Any,
db: AsyncSession,
user_id: str,
candidate: MemoryCandidate,
fernet: Any,
) -> None:
"""Store a proactive pattern row directly (MemoryProactive model)."""
import uuid # noqa: PLC0415
from app.models import MemoryProactive # noqa: PLC0415
from app.core.memory_middleware import _encrypt # noqa: PLC0415
encrypted = _encrypt(fernet, candidate.content)
row = MemoryProactive(
id=str(uuid.uuid4()),
user_id=user_id,
pattern_encrypted=encrypted,
confidence=candidate.confidence,
source="inferred",
)
db.add(row)
try:
await db.commit()
logger.info("memory_extraction: stored proactive pattern user=%s", user_id)
except Exception as exc:
logger.warning("memory_extraction: store proactive failed: %s", exc)
await db.rollback()

View File

@@ -0,0 +1,581 @@
"""Memory maintenance jobs — Phase 3/5.
Three entrypoints called by the scheduler (APScheduler) registered in app/main.py:
drain_extraction_queue(db) — Free-tier batch extraction (Phase 2/5).
mine_proactive_patterns(db, user_id) — Power+ pattern mining (Phase 5).
decay_relations(db, user_id) — confidence decay + pruning for memory_relations (Phase 3).
All are safe to call manually or from tests; they never raise.
"""
from __future__ import annotations
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from cryptography.fernet import Fernet
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback
from app.models import MemoryAssociative, MemoryEpisodic, MemoryProactive, MemoryRelation, User
logger = logging.getLogger(__name__)
# Decay parameters for relations
_DECAY_FACTOR = 0.95
_DECAY_PERIOD_DAYS = 30
_PRUNE_THRESHOLD = 0.2
# Proactive pattern decay: 10 % per 7 days since last sighting
_PROACTIVE_DECAY_FACTOR = 0.9
_PROACTIVE_DECAY_PERIOD_DAYS = 7
_PROACTIVE_PRUNE_THRESHOLD = 0.2
# Mining: require at least this many episodes to attempt pattern extraction
_MIN_EPISODES_FOR_MINING = 3
_MINING_LOOKBACK_DAYS = 30
# Audit: caps to control token cost
_AUDIT_MAX_FACTS = 50
_AUDIT_MAX_LABELS = 100
async def decay_relations(db: AsyncSession, user_id: str) -> None:
"""Apply confidence decay to all relation rows for a user.
Decay rule: confidence *= 0.95 for every 30 days since last_confirmed_at.
Rows whose confidence falls below 0.2 are deleted.
Never raises — wraps in try/except.
"""
try:
await _decay_relations_inner(db, user_id)
except Exception as exc:
logger.warning("memory_maintenance: decay_relations failed user=%s: %s", user_id, exc)
async def _decay_relations_inner(db: AsyncSession, user_id: str) -> None:
result = await db.execute(
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
)
rows = result.scalars().all()
now = datetime.now(timezone.utc)
deleted = 0
decayed = 0
for row in rows:
reference = row.last_confirmed_at or row.created_at
if reference is None:
continue
if reference.tzinfo is None:
reference = reference.replace(tzinfo=timezone.utc)
days_elapsed = (now - reference).days
if days_elapsed < _DECAY_PERIOD_DAYS:
continue
periods = days_elapsed // _DECAY_PERIOD_DAYS
new_confidence = row.confidence * (_DECAY_FACTOR ** periods)
if new_confidence < _PRUNE_THRESHOLD:
await db.delete(row)
deleted += 1
logger.info(
"memory_maintenance: pruned relation id=%s user=%s subject=%s predicate=%s "
"confidence=%.3f (below threshold)",
row.id, user_id, row.subject_label, row.predicate, new_confidence,
)
else:
row.confidence = new_confidence
decayed += 1
try:
await db.commit()
logger.info(
"memory_maintenance: decay_relations user=%s decayed=%d deleted=%d",
user_id, decayed, deleted,
)
except Exception as exc:
logger.warning("memory_maintenance: decay_relations commit failed user=%s: %s", user_id, exc)
await db.rollback()
async def drain_extraction_queue(db: AsyncSession) -> None:
"""Process pending ExtractionQueue rows for Free-tier users.
Each row corresponds to a stored episode that should be fed through the
Mem0-style extraction pipeline. Rows are deleted after successful processing.
Never raises — wraps in try/except.
"""
try:
await _drain_extraction_queue_inner(db)
except Exception as exc:
logger.warning("memory_maintenance: drain_extraction_queue failed: %s", exc)
async def _drain_extraction_queue_inner(db: AsyncSession) -> None:
from app.models import ExtractionQueue # noqa: PLC0415
result = await db.execute(select(ExtractionQueue))
rows = result.scalars().all()
if not rows:
logger.debug("memory_maintenance: drain_extraction_queue nothing to drain")
return
logger.info("memory_maintenance: drain_extraction_queue pending=%d", len(rows))
from app.core.memory_extraction import run_extraction # noqa: PLC0415
processed = 0
for row in rows:
try:
await run_extraction(
db=db,
user_id=row.user_id,
last_user_msg="",
last_assistant_msg="",
session_id=None,
)
await db.delete(row)
await db.commit()
processed += 1
except Exception as exc:
logger.warning(
"memory_maintenance: drain failed row=%s user=%s: %s",
row.id, row.user_id, exc,
)
await db.rollback()
logger.info("memory_maintenance: drain_extraction_queue processed=%d/%d", processed, len(rows))
async def mine_proactive_patterns(db: AsyncSession, user_id: str) -> None:
"""Mine recurring behavioral patterns from last 30 days of episodes (Power+ only).
Steps:
1. Gate on proactive_mining tier feature.
2. Load + decrypt last 30 days of episodic summaries.
3. Call gpt-4o-mini to identify recurring patterns.
4. Encrypt and store each pattern in memory_proactive.
5. Apply decay to existing proactive rows.
Never raises — wraps in try/except.
"""
try:
await _mine_proactive_patterns_inner(db, user_id)
except Exception as exc:
logger.warning("memory_maintenance: mine_proactive_patterns failed user=%s: %s", user_id, exc)
async def _mine_proactive_patterns_inner(db: AsyncSession, user_id: str) -> None:
from app.billing.tier_manager import tier_manager # noqa: PLC0415
tier = await tier_manager.get_tier(user_id, db)
if not tier_manager.check_feature(tier, "proactive_mining"):
logger.debug("memory_maintenance: mine_proactive_patterns skipped (tier=%s)", tier)
return
# Load user Fernet key
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory_maintenance: mine_proactive_patterns no encryption_key user=%s", user_id)
return
fernet = Fernet(user.encryption_key.encode())
cutoff = datetime.now(timezone.utc) - timedelta(days=_MINING_LOOKBACK_DAYS)
episodes_result = await db.execute(
select(MemoryEpisodic)
.where(
MemoryEpisodic.user_id == user_id,
MemoryEpisodic.created_at >= cutoff,
)
.order_by(MemoryEpisodic.created_at.asc())
)
episode_rows = episodes_result.scalars().all()
if len(episode_rows) < _MIN_EPISODES_FOR_MINING:
logger.info(
"memory_maintenance: mine_proactive_patterns skipped user=%s episodes=%d (< %d)",
user_id, len(episode_rows), _MIN_EPISODES_FOR_MINING,
)
return
summaries: list[str] = []
for ep in episode_rows:
try:
plaintext = fernet.decrypt(ep.summary_encrypted.encode()).decode()
summaries.append(plaintext)
except Exception:
pass
if not summaries:
return
patterns = await _extract_proactive_patterns(summaries)
if not patterns:
logger.info("memory_maintenance: mine_proactive_patterns user=%s no patterns extracted", user_id)
return
stored = 0
for pattern_text in patterns:
try:
encrypted = fernet.encrypt(pattern_text.encode()).decode()
row = MemoryProactive(
id=str(uuid.uuid4()),
user_id=user_id,
pattern_encrypted=encrypted,
confidence=0.7,
source="inferred",
)
db.add(row)
stored += 1
except Exception as exc:
logger.warning("memory_maintenance: failed to store pattern user=%s: %s", user_id, exc)
try:
await db.commit()
logger.info(
"memory_maintenance: mine_proactive_patterns user=%s stored=%d",
user_id, stored,
)
except Exception as exc:
logger.warning("memory_maintenance: mine_proactive_patterns commit failed user=%s: %s", user_id, exc)
await db.rollback()
return
await _decay_proactive_patterns(db, user_id, fernet)
async def _extract_proactive_patterns(summaries: list[str]) -> list[str]:
"""Call memory-miner LLM to identify recurring behavioral/temporal patterns."""
from app.core.llm import get_agent_llm # noqa: PLC0415
llm = get_agent_llm("memory-miner", temperature=0)
combined = "\n---\n".join(summaries[-20:]) # cap at last 20 to control token usage
prompt = (
"You are analyzing conversation history for a personal AI secretary. "
"Identify 3-5 recurring temporal or behavioral patterns (e.g. 'always works late on Thursdays', "
"'prefers bullet-point summaries', 'frequently asks about Project Acme status'). "
"Return each pattern as a plain, short English sentence on its own line. "
"No numbering, no bullet points, no extra text.\n\n"
f"Conversation history:\n{combined}"
)
try:
response = await llm.ainvoke(prompt)
text = response.content if hasattr(response, "content") else str(response)
lines = [line.strip() for line in str(text).splitlines() if line.strip()]
return lines[:5]
except Exception as exc:
logger.warning("memory_maintenance: _extract_proactive_patterns LLM failed: %s", exc)
return []
async def _decay_proactive_patterns(db: AsyncSession, user_id: str, fernet: Fernet) -> None:
"""Decay confidence of existing proactive patterns; prune below threshold."""
result = await db.execute(
select(MemoryProactive).where(MemoryProactive.user_id == user_id)
)
rows = result.scalars().all()
now = datetime.now(timezone.utc)
deleted = 0
decayed = 0
for row in rows:
reference = row.created_at
if reference is None:
continue
if reference.tzinfo is None:
reference = reference.replace(tzinfo=timezone.utc)
days_elapsed = (now - reference).days
if days_elapsed < _PROACTIVE_DECAY_PERIOD_DAYS:
continue
periods = days_elapsed // _PROACTIVE_DECAY_PERIOD_DAYS
new_confidence = row.confidence * (_PROACTIVE_DECAY_FACTOR ** periods)
if new_confidence < _PROACTIVE_PRUNE_THRESHOLD:
await db.delete(row)
deleted += 1
else:
row.confidence = new_confidence
decayed += 1
try:
await db.commit()
logger.info(
"memory_maintenance: decay_proactive user=%s decayed=%d deleted=%d",
user_id, decayed, deleted,
)
except Exception as exc:
logger.warning("memory_maintenance: decay_proactive commit failed user=%s: %s", user_id, exc)
await db.rollback()
# ── Phase 7: weekly memory audit ──────────────────────────────────────────────
_AUDIT_CONTRADICTIONS_FALLBACK = (
"You are auditing a personal AI assistant's memory bank. "
"Each fact has an ID in brackets. "
"Find pairs that directly contradict each other "
"(e.g. 'prefers morning meetings' vs 'never schedules before noon'). "
"For each contradiction, pick the ID to DELETE (the older or less specific one). "
'Return ONLY a valid JSON array, no markdown fences: '
'[{{"delete": "<id>", "reason": "<one line>"}}]. '
"If no contradictions, return [].\n\n"
"Facts:\n{facts}"
)
_AUDIT_CANONICALIZE_FALLBACK = (
"You are auditing entity labels in a personal AI assistant's relational memory. "
"These are names of people, companies, projects, or topics. "
"Group labels that clearly refer to the same real-world entity "
"(e.g. 'giulia', 'Giulia', 'Giulia R.' → canonical 'Giulia'). "
"Return ONLY a valid JSON array, no markdown fences: "
'[{{"canonical": "<best label>", "variants": ["<v1>", "<v2>"]}}]. '
"Only include groups with at least one variant. Singletons: omit.\n\n"
"Labels:\n{labels}"
)
async def audit_memory(db: AsyncSession, user_id: str) -> None:
"""Weekly audit: contradiction scan on associative facts + label canonicalization on relations.
Steps:
1. Decrypt up to _AUDIT_MAX_FACTS associative rows; send list to memory-auditor LLM.
2. LLM flags rows to delete (direct contradictions); hard-delete them.
3. Collect unique subject/object labels from memory_relations; ask LLM to group duplicates.
4. Rewrite variant labels to their canonical form in-place.
Never raises — wraps in try/except.
"""
try:
await _audit_memory_inner(db, user_id)
except Exception as exc:
logger.warning("memory_maintenance: audit_memory failed user=%s: %s", user_id, exc)
async def _audit_memory_inner(db: AsyncSession, user_id: str) -> None:
result = await db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory_maintenance: audit_memory no encryption_key user=%s", user_id)
return
fernet = Fernet(user.encryption_key.encode())
await _scan_associative_contradictions(db, user_id, fernet)
await _canonicalize_relation_labels(db, user_id)
async def _scan_associative_contradictions(
db: AsyncSession,
user_id: str,
fernet: Fernet,
) -> None:
"""Decrypt associative facts, ask LLM to flag contradictions, delete superseded rows."""
result = await db.execute(
select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc())
.limit(_AUDIT_MAX_FACTS)
)
rows = result.scalars().all()
if len(rows) < 2:
return
id_to_text: dict[str, str] = {}
for row in rows:
try:
plaintext = fernet.decrypt(row.content_encrypted.encode()).decode()
id_to_text[row.id] = plaintext
except Exception:
pass
if len(id_to_text) < 2:
return
id_list = list(id_to_text.keys())
numbered = "\n".join(
f"{i + 1}. [{rid}] {id_to_text[rid]}" for i, rid in enumerate(id_list)
)
template, prompt_obj = get_prompt_or_fallback(
"memory_audit_contradictions", _AUDIT_CONTRADICTIONS_FALLBACK
)
system_text = compile_prompt(template, prompt_obj, facts=numbered)
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
llm = get_agent_llm("memory-auditor", temperature=0)
lf = get_langfuse()
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Audit facts for contradictions."),
]
try:
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-audit-contradictions",
model=model_for_agent("memory-auditor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm.ainvoke(messages)
text = response.content if hasattr(response, "content") else str(response)
deletions = json.loads(text.strip())
if not isinstance(deletions, list):
return
except Exception as exc:
logger.warning(
"memory_maintenance: _scan_associative_contradictions LLM/parse failed user=%s: %s",
user_id, exc,
)
return
deleted = 0
for item in deletions:
if not isinstance(item, dict):
continue
rid = item.get("delete")
if not rid or rid not in id_to_text:
continue
result2 = await db.execute(
select(MemoryAssociative).where(
MemoryAssociative.id == rid,
MemoryAssociative.user_id == user_id,
)
)
target = result2.scalar_one_or_none()
if target:
await db.delete(target)
deleted += 1
logger.info(
"memory_maintenance: audit deleted contradiction id=%s user=%s reason=%s",
rid, user_id, item.get("reason", ""),
)
if deleted:
try:
await db.commit()
except Exception as exc:
logger.warning(
"memory_maintenance: audit contradiction commit failed user=%s: %s", user_id, exc
)
await db.rollback()
logger.info(
"memory_maintenance: _scan_associative_contradictions user=%s deleted=%d", user_id, deleted
)
async def _canonicalize_relation_labels(db: AsyncSession, user_id: str) -> None:
"""Group near-duplicate entity labels in memory_relations and unify to canonical form."""
result = await db.execute(
select(MemoryRelation).where(MemoryRelation.user_id == user_id)
)
rows = result.scalars().all()
if not rows:
return
all_labels: set[str] = set()
for row in rows:
all_labels.add(row.subject_label)
all_labels.add(row.object_label)
labels_list = sorted(all_labels)[:_AUDIT_MAX_LABELS]
if len(labels_list) < 2:
return
labels_block = "\n".join(f"- {lbl}" for lbl in labels_list)
template, prompt_obj = get_prompt_or_fallback(
"memory_audit_canonicalize", _AUDIT_CANONICALIZE_FALLBACK
)
system_text = compile_prompt(template, prompt_obj, labels=labels_block)
from app.core.llm import get_agent_llm, model_for_agent # noqa: PLC0415
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
llm = get_agent_llm("memory-auditor", temperature=0)
lf = get_langfuse()
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Canonicalize entity labels."),
]
try:
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="memory-audit-canonicalize",
model=model_for_agent("memory-auditor"),
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm.ainvoke(messages)
text = response.content if hasattr(response, "content") else str(response)
groups = json.loads(text.strip())
if not isinstance(groups, list):
return
except Exception as exc:
logger.warning(
"memory_maintenance: _canonicalize_relation_labels LLM/parse failed user=%s: %s",
user_id, exc,
)
return
# Build variant → canonical map
remap: dict[str, str] = {}
for group in groups:
if not isinstance(group, dict):
continue
canonical = group.get("canonical", "")
variants = group.get("variants") or []
if not canonical:
continue
for v in variants:
if isinstance(v, str) and v != canonical:
remap[v] = canonical
if not remap:
return
updated = 0
for row in rows:
changed = False
if row.subject_label in remap:
row.subject_label = remap[row.subject_label]
changed = True
if row.object_label in remap:
row.object_label = remap[row.object_label]
changed = True
if changed:
updated += 1
if updated:
try:
await db.commit()
logger.info(
"memory_maintenance: _canonicalize_relation_labels user=%s updated=%d",
user_id, updated,
)
except Exception as exc:
logger.warning(
"memory_maintenance: canonicalize commit failed user=%s: %s", user_id, exc
)
await db.rollback()

View File

@@ -18,8 +18,10 @@ Usage:
from __future__ import annotations from __future__ import annotations
import asyncio
import logging import logging
import uuid import uuid
from datetime import datetime, timezone
from typing import Any from typing import Any
from cryptography.fernet import Fernet, InvalidToken from cryptography.fernet import Fernet, InvalidToken
@@ -27,15 +29,22 @@ from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.models import ( from app.models import (
ExtractionQueue,
MemoryAssociative, MemoryAssociative,
MemoryCore, MemoryCore,
MemoryEpisodic, MemoryEpisodic,
MemoryProactive, MemoryProactive,
MemoryRelation,
User, User,
) )
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
def _now() -> datetime:
return datetime.now(timezone.utc)
# Tuning constants # Tuning constants
_ASSOCIATIVE_TOP_K = 5 _ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10 _EPISODIC_RECENT_N = 10
@@ -64,26 +73,31 @@ class MemoryMiddleware:
associative_memory — [plaintext_content, ...] (top-k by keyword match) associative_memory — [plaintext_content, ...] (top-k by keyword match)
episodic_memory — [plaintext_summary, ...] (most recent N) episodic_memory — [plaintext_summary, ...] (most recent N)
proactive_hints — [plaintext_pattern, ...] (above threshold) proactive_hints — [plaintext_pattern, ...] (above threshold)
relational_memory — ["subject --predicate--> object", ...] (top 10, Pro+)
""" """
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
return {} return {}
user_dbg = await self._get_user_debug(user_id)
user_tier: str = user_dbg.get("tier") or "free"
core = await self._load_core(user_id, fernet) core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet) associative = await self._load_associative(user_id, message, fernet, user_tier=user_tier)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id) episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet) proactive = await self._load_proactive(user_id, fernet)
relational = await self._load_relational(user_id, user_tier=user_tier)
user_dbg = await self._get_user_debug(user_id)
logger.info( logger.info(
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d", "memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d relational=%d",
trace_id or "-", trace_id or "-",
user_id, user_id,
user_dbg.get("tier") or "-", user_tier,
len(core), len(core),
len(associative), len(associative),
len(episodic), len(episodic),
len(proactive), len(proactive),
len(relational),
) )
return { return {
@@ -91,6 +105,7 @@ class MemoryMiddleware:
"associative_memory": associative, "associative_memory": associative,
"episodic_memory": episodic, "episodic_memory": episodic,
"proactive_hints": proactive, "proactive_hints": proactive,
"relational_memory": relational,
} }
async def store_episode( async def store_episode(
@@ -104,7 +119,10 @@ class MemoryMiddleware:
"""Summarise and store a completed interaction in episodic memory. """Summarise and store a completed interaction in episodic memory.
The summary is a simple heuristic concatenation (no LLM call) to keep The summary is a simple heuristic concatenation (no LLM call) to keep
latency low. Full LLM summarisation can be added in a later step. latency low. After committing the episode row, dispatches the Mem0-style
extraction pipeline:
- Pro/Power/Team → asyncio.create_task (fire-and-forget, realtime).
- Free → enqueue an ExtractionQueue row for the daily cron.
""" """
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
@@ -113,26 +131,95 @@ class MemoryMiddleware:
summary = f"User: {message[:200]}\nAssistant: {response[:200]}" summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
encrypted = _encrypt(fernet, summary) encrypted = _encrypt(fernet, summary)
row = MemoryEpisodic( episode = MemoryEpisodic(
id=str(uuid.uuid4()), id=str(uuid.uuid4()),
user_id=user_id, user_id=user_id,
summary_encrypted=encrypted, summary_encrypted=encrypted,
session_id=session_id, session_id=session_id,
) )
self._db.add(row) self._db.add(episode)
episode_id: str = episode.id
try: try:
await self._db.commit() await self._db.commit()
user_dbg = await self._get_user_debug(user_id) user_dbg = await self._get_user_debug(user_id)
tier = user_dbg.get("tier") or "free"
logger.info( logger.info(
"memory: store_episode trace=%s user=%s tier=%s session=%s", "memory: store_episode trace=%s user=%s tier=%s session=%s",
trace_id or "-", trace_id or "-",
user_id, user_id,
user_dbg.get("tier") or "-", tier,
session_id, session_id,
) )
except Exception as exc: except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc) logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback() await self._db.rollback()
return
# ── Dispatch extraction pipeline (Phase 2) ────────────────────────────
await self._dispatch_extraction(
user_id=user_id,
episode_id=episode_id,
last_user_msg=message,
last_assistant_msg=response,
session_id=session_id,
)
async def _dispatch_extraction(
self,
user_id: str,
episode_id: str,
last_user_msg: str,
last_assistant_msg: str,
session_id: str | None,
) -> None:
"""Route extraction to realtime task or batch queue based on user tier."""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
tier = await tier_manager.get_tier(user_id, self._db)
if tier_manager.check_feature(tier, "realtime_extraction"):
# Pro/Power/Team: fire-and-forget in the background.
# Must open a fresh session — request session closes after handler returns.
from app.core.memory_extraction import run_extraction # noqa: PLC0415
from app.db import async_session # noqa: PLC0415
async def _task() -> None:
try:
async with async_session() as fresh_db:
await run_extraction(
db=fresh_db,
user_id=user_id,
last_user_msg=last_user_msg,
last_assistant_msg=last_assistant_msg,
session_id=session_id,
)
except Exception as exc:
logger.warning(
"memory: extraction task failed user=%s: %s", user_id, exc
)
asyncio.create_task(_task())
logger.info("memory: realtime extraction dispatched user=%s", user_id)
else:
# Free tier: enqueue for daily batch cron.
queue_row = ExtractionQueue(
id=str(uuid.uuid4()),
user_id=user_id,
episode_id=episode_id,
)
self._db.add(queue_row)
try:
await self._db.commit()
logger.info(
"memory: extraction enqueued (batch) user=%s episode=%s",
user_id,
episode_id,
)
except Exception as exc:
logger.warning(
"memory: extraction queue insert failed user=%s: %s", user_id, exc
)
await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None: async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
"""Upsert a core memory key/value for a user.""" """Upsert a core memory key/value for a user."""
@@ -255,6 +342,143 @@ class MemoryMiddleware:
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label) logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
return True return True
async def store_associative(
self,
user_id: str,
content: str,
entity_type: str | None = None,
entity_id: str | None = None,
) -> None:
"""Store associative memory; embed if user tier has real_embeddings."""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.core.embeddings import embed_text # noqa: PLC0415
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, content)
user_dbg = await self._get_user_debug(user_id)
user_tier = user_dbg.get("tier") or "free"
embedding: list[float] | None = None
if tier_manager.check_feature(user_tier, "real_embeddings"):
embedding = await embed_text(content)
row = MemoryAssociative(
id=str(uuid.uuid4()),
user_id=user_id,
content_encrypted=encrypted,
embedding=embedding,
entity_type=entity_type,
entity_id=entity_id,
)
self._db.add(row)
try:
await self._db.commit()
logger.info(
"memory: store_associative user=%s embedded=%s",
user_id,
embedding is not None,
)
except Exception as exc:
logger.error("memory: store_associative failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def upsert_relation(
self,
user_id: str,
subject: str,
subject_type: str,
predicate: str,
object_: str,
object_type: str,
*,
confidence: float = 0.7,
source_episode_id: str | None = None,
notes: str | None = None,
) -> None:
"""Insert or update a relation row. Matches on (user_id, subject_label, predicate, object_label).
subject_label / object_label are plaintext entity identifiers — not encrypted.
notes is optional; encrypted with user Fernet if provided.
"""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
user_dbg = await self._get_user_debug(user_id)
user_tier = user_dbg.get("tier") or "free"
if not tier_manager.check_feature(user_tier, "relational_memory"):
logger.debug("memory: upsert_relation skipped (tier=%s no relational_memory)", user_tier)
return
notes_encrypted: bytes | None = None
if notes:
fernet = await self._get_fernet(user_id)
if fernet:
notes_encrypted = fernet.encrypt(notes.encode())
result = await self._db.execute(
select(MemoryRelation).where(
MemoryRelation.user_id == user_id,
MemoryRelation.subject_label == subject,
MemoryRelation.predicate == predicate,
MemoryRelation.object_label == object_,
)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.subject_type = subject_type
existing.object_type = object_type
existing.confidence = confidence
existing.last_confirmed_at = _now()
if notes_encrypted is not None:
existing.notes_encrypted = notes_encrypted
else:
self._db.add(MemoryRelation(
id=str(uuid.uuid4()),
user_id=user_id,
subject_label=subject,
subject_type=subject_type,
predicate=predicate,
object_label=object_,
object_type=object_type,
confidence=confidence,
source_episode_id=source_episode_id,
notes_encrypted=notes_encrypted,
))
try:
await self._db.commit()
logger.info(
"memory: upsert_relation user=%s subject=%s predicate=%s object=%s",
user_id, subject, predicate, object_,
)
except Exception as exc:
logger.error("memory: upsert_relation failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def query_relations(
self,
user_id: str,
subject: str | None = None,
predicate: str | None = None,
object_: str | None = None,
limit: int = 20,
) -> list[MemoryRelation]:
"""Query relation rows for a user with optional filters."""
q = select(MemoryRelation).where(MemoryRelation.user_id == user_id)
if subject is not None:
q = q.where(MemoryRelation.subject_label == subject)
if predicate is not None:
q = q.where(MemoryRelation.predicate == predicate)
if object_ is not None:
q = q.where(MemoryRelation.object_label == object_)
q = q.order_by(MemoryRelation.confidence.desc()).limit(limit)
result = await self._db.execute(q)
return list(result.scalars().all())
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None: async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
"""Insert a long-term archival memory entry.""" """Insert a long-term archival memory entry."""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
@@ -343,13 +567,26 @@ class MemoryMiddleware:
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]: async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
"""Load lightweight user debug fields for trace logs.""" """Load lightweight user debug fields for trace logs."""
from app.config.settings import settings # noqa: PLC0415
from app.models import Subscription # noqa: PLC0415
result = await self._db.execute(select(User).where(User.id == user_id)) result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if user is None: if user is None:
return {"tier": None} return {"tier": None}
return {
"tier": user.tier, sub_result = await self._db.execute(
} select(Subscription.tier).where(Subscription.user_id == user_id)
)
sub_tier: str | None = sub_result.scalar_one_or_none()
if sub_tier:
tier = sub_tier
elif settings.ENV == "dev":
tier = "power"
else:
tier = user.tier or "free"
return {"tier": tier}
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]: async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute( result = await self._db.execute(
@@ -364,14 +601,49 @@ class MemoryMiddleware:
return out return out
async def _load_associative( async def _load_associative(
self, user_id: str, message: str, fernet: Fernet self, user_id: str, message: str, fernet: Fernet, *, user_tier: str = "free"
) -> list[str]: ) -> list[str]:
"""Load top-k associative memories. """Load top-k associative memories.
Production: uses pgvector cosine similarity on the message embedding. Pro+: pgvector cosine similarity on the message embedding (real_embeddings feature).
Current implementation: keyword-based fallback (no external embedding call) Free / embedding failure: keyword-ordered fallback (most recent rows).
so tests pass without a live OpenAI key.
""" """
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.core.embeddings import embed_text # noqa: PLC0415
if tier_manager.check_feature(user_tier, "real_embeddings"):
vec = await embed_text(message)
if vec is not None:
try:
result = await self._db.execute(
select(MemoryAssociative)
.where(
MemoryAssociative.user_id == user_id,
MemoryAssociative.embedding.isnot(None),
)
.order_by(MemoryAssociative.embedding.cosine_distance(vec))
.limit(_ASSOCIATIVE_TOP_K)
)
rows = result.scalars().all()
out: list[str] = []
for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None:
out.append(plaintext)
logger.info(
"memory: _load_associative user=%s mode=vector hits=%d",
user_id,
len(out),
)
return out
except Exception as exc:
logger.warning(
"memory: vector search failed user=%s, falling back to keyword: %s",
user_id,
exc,
)
# Keyword fallback: most recent rows
result = await self._db.execute( result = await self._db.execute(
select(MemoryAssociative) select(MemoryAssociative)
.where(MemoryAssociative.user_id == user_id) .where(MemoryAssociative.user_id == user_id)
@@ -379,7 +651,7 @@ class MemoryMiddleware:
.limit(_ASSOCIATIVE_TOP_K) .limit(_ASSOCIATIVE_TOP_K)
) )
rows = result.scalars().all() rows = result.scalars().all()
out: list[str] = [] out = []
for row in rows: for row in rows:
plaintext = _safe_decrypt(fernet, row.content_encrypted) plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None: if plaintext is not None:
@@ -408,6 +680,26 @@ class MemoryMiddleware:
out.append(plaintext) out.append(plaintext)
return out return out
async def _load_relational(self, user_id: str, *, user_tier: str = "free") -> list[str]:
"""Return top-10 relation strings for Pro+ users; empty list for Free."""
from app.billing.tier_manager import tier_manager # noqa: PLC0415
if not tier_manager.check_feature(user_tier, "relational_memory"):
return []
result = await self._db.execute(
select(MemoryRelation)
.where(MemoryRelation.user_id == user_id)
.order_by(MemoryRelation.confidence.desc())
.limit(10)
)
rows = result.scalars().all()
out = [
f"{r.subject_label} --{r.predicate}--> {r.object_label}"
for r in rows
]
return out
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]: async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute( result = await self._db.execute(
select(MemoryProactive) select(MemoryProactive)

View File

@@ -0,0 +1,51 @@
"""Note summarizer — generates a compact AI summary for a note.
Called fire-and-forget from create_note / update_note tools so the
``notes.ai_summary`` column stays current without blocking the agent loop.
"""
from __future__ import annotations
import logging
from langchain_core.messages import HumanMessage, SystemMessage
from app.core.langfuse_client import get_prompt_or_fallback
from app.core.llm import get_agent_llm
logger = logging.getLogger(__name__)
_FALLBACK_PROMPT = """\
Summarize this note in <=250 characters. Be terse and dense.
Keep proper nouns, dates, decisions, and action items.
Do not start with "This note".
Respond with the summary text only — no intro, no labels.
Title: {title}
Content: {content}"""
_MAX_CONTENT_CHARS = 4000
async def generate_note_summary(title: str, content: str) -> str:
"""Return a <=250-char summary of *title* + *content*.
Uses the Langfuse ``note_summary`` prompt (hot-swappable) with a local
fallback. Truncates *content* to 4000 chars before sending to avoid
token waste on large notes.
"""
template, _ = get_prompt_or_fallback("note_summary", _FALLBACK_PROMPT)
trimmed = content[:_MAX_CONTENT_CHARS]
system_prompt = template.format(title=title, content=trimmed)
try:
llm = get_agent_llm("note-summarizer")
response = await llm.ainvoke([
SystemMessage(content=system_prompt),
HumanMessage(content="Generate the summary."),
])
text = response.content if isinstance(response.content, str) else ""
return text.strip()[:250]
except Exception as exc:
logger.warning("note_summarizer: failed to generate summary: %s", exc)
return ""

View File

@@ -2,11 +2,35 @@
from __future__ import annotations from __future__ import annotations
import re
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
_CANVAS_BLOCK_RE = re.compile(
r'<canvas\s+kind=["\']([^"\']+)["\']>(.*?)</canvas>',
re.DOTALL | re.IGNORECASE,
)
def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
"""Strip the first <canvas kind="...">...</canvas> block from *text*.
Returns ``(visible_text, canvas_content, canvas_kind)``.
``canvas_content`` and ``canvas_kind`` are ``None`` when no block is found.
"""
match = _CANVAS_BLOCK_RE.search(text)
if not match:
return text, None, None
canvas_kind = match.group(1).strip()
canvas_content = match.group(2).strip()
visible = text[: match.start()] + text[match.end() :]
visible = visible.strip()
return visible, canvas_content, canvas_kind
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain

View File

@@ -7,10 +7,32 @@ The callback sends a `tool_call` WS frame and awaits the `tool_result`.
from __future__ import annotations from __future__ import annotations
import re
from contextvars import ContextVar from contextvars import ContextVar
from typing import Any, Callable, Coroutine from typing import Any, Callable, Coroutine
from uuid import uuid4 from uuid import uuid4
_SNAKE_TO_CAMEL_RE = re.compile(r"_([a-z])")
def _key_to_camel(key: str) -> str:
return _SNAKE_TO_CAMEL_RE.sub(lambda m: m.group(1).upper(), key)
def _keys_to_camel(obj: Any) -> Any:
"""Recursively convert dict keys from snake_case to camelCase.
Mirrors the JS-side ``toCamelCase`` applied to incoming WS frames in
``adiuvAI/src/main/api/backend-client.ts``. The Electron executor wraps
tool_result payloads in ``toSnakeCase`` before sending; this restores the
camelCase schema property names that the tool code expects to read.
"""
if isinstance(obj, dict):
return {_key_to_camel(k): _keys_to_camel(v) for k, v in obj.items()}
if isinstance(obj, list):
return [_keys_to_camel(v) for v in obj]
return obj
# Holds the execute callback for the current WS session. # Holds the execute callback for the current WS session.
# Set by the chat WS handler before the orchestrator runs; cleared after. # Set by the chat WS handler before the orchestrator runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar( _client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
@@ -82,6 +104,7 @@ async def execute_on_client(
payload["limit"] = limit payload["limit"] = limit
result = await callback(payload) result = await callback(payload)
result = _keys_to_camel(result)
collector = _tool_result_collector.get(None) collector = _tool_result_collector.get(None)
if collector is not None: if collector is not None:
collector.append({ collector.append({

View File

@@ -25,7 +25,7 @@ from __future__ import annotations
import logging import logging
import re import re
from datetime import datetime, timedelta, timezone from datetime import datetime, timezone
from typing import Any from typing import Any
import httpx import httpx

View File

@@ -4,6 +4,10 @@ import logging
from fastapi import FastAPI from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from app.api.middleware.rate_limit import TierRateLimitMiddleware
from app.api.middleware.sanitizer import SanitizerMiddleware
from app.config.settings import settings
logging.basicConfig( logging.basicConfig(
level=logging.INFO, level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s", format="%(asctime)s %(levelname)s %(name)s: %(message)s",
@@ -11,9 +15,66 @@ logging.basicConfig(
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING) logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING) logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
from app.api.middleware.rate_limit import TierRateLimitMiddleware
from app.api.middleware.sanitizer import SanitizerMiddleware async def _memory_audit_cron_tick() -> None:
from app.config.settings import settings """Weekly cron: contradiction scan + label canonicalization for all users (Phase 7)."""
import logging # noqa: PLC0415
_log = logging.getLogger(__name__)
_log.info("memory audit cron tick: starting")
try:
from app.db import async_session # noqa: PLC0415
from app.core.memory_maintenance import audit_memory # noqa: PLC0415
from app.models import User # noqa: PLC0415
from sqlalchemy import select # noqa: PLC0415
async with async_session() as db:
result = await db.execute(select(User.id))
user_ids: list[str] = list(result.scalars().all())
for uid in user_ids:
try:
async with async_session() as db:
await audit_memory(db, uid)
except Exception as exc:
_log.warning("memory audit cron tick: audit_memory failed user=%s: %s", uid, exc)
_log.info("memory audit cron tick: done users=%d", len(user_ids))
except Exception as exc:
_log.warning("memory audit cron tick: failed: %s", exc)
async def _memory_cron_tick() -> None:
"""Hourly cron: drain Free-tier extraction queue + mine proactive patterns for Power+ users."""
import logging # noqa: PLC0415
_log = logging.getLogger(__name__)
_log.info("memory cron tick: starting")
try:
from app.db import async_session # noqa: PLC0415
from app.core.memory_maintenance import drain_extraction_queue, mine_proactive_patterns # noqa: PLC0415
from app.billing.tier_manager import tier_manager # noqa: PLC0415
from app.models import User # noqa: PLC0415
from sqlalchemy import select # noqa: PLC0415
async with async_session() as db:
await drain_extraction_queue(db)
# mine proactive patterns for every Power+ user
async with async_session() as db:
result = await db.execute(select(User.id))
user_ids: list[str] = list(result.scalars().all())
for uid in user_ids:
try:
async with async_session() as db:
tier = await tier_manager.get_tier(uid, db)
if tier_manager.check_feature(tier, "proactive_mining"):
await mine_proactive_patterns(db, uid)
except Exception as exc:
_log.warning("memory cron tick: mine_proactive_patterns failed user=%s: %s", uid, exc)
_log.info("memory cron tick: done users=%d", len(user_ids))
except Exception as exc:
_log.warning("memory cron tick: failed: %s", exc)
@asynccontextmanager @asynccontextmanager
@@ -21,8 +82,21 @@ async def lifespan(app: FastAPI):
# Startup: ensure agent tool modules are loaded. # Startup: ensure agent tool modules are loaded.
import app.agents # noqa: F401 import app.agents # noqa: F401
scheduler = None
if settings.SCHEDULER_ENABLED:
from apscheduler.schedulers.asyncio import AsyncIOScheduler # noqa: PLC0415
scheduler = AsyncIOScheduler()
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
scheduler.start()
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
yield yield
if scheduler is not None:
scheduler.shutdown(wait=False)
# Shutdown: dispose SQLAlchemy connection pool # Shutdown: dispose SQLAlchemy connection pool
from app.db import engine from app.db import engine
await engine.dispose() await engine.dispose()
@@ -30,7 +104,7 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI: def create_app() -> FastAPI:
app = FastAPI( app = FastAPI(
title="Adiuva Cloud API", title="AdiuvAI Cloud API",
version="0.1.0", version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None, docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None, redoc_url=None,
@@ -50,17 +124,14 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware) app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware) app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors from app.api.routes import agents, auth, billing, chat, device_ws, memory
app.include_router(auth.router, prefix="/api/v1") app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1") app.include_router(chat.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1")
app.include_router(plugins.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1") app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1") app.include_router(agents.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1") app.include_router(device_ws.router, prefix="/api/v1")
app.include_router(memory.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"]) @app.get("/api/v1/health", tags=["health"])
async def health() -> dict: async def health() -> dict:

View File

@@ -1,7 +0,0 @@
"""Plugin marketplace package.
Three service classes introduced in Step 10:
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
- ``ReviewQueue`` — approval workflow + security checklist
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
"""

View File

@@ -1,212 +0,0 @@
"""Plugin catalog registry backed by PostgreSQL.
Maintains the authoritative list of plugins, their review status, and
aggregate install counts. All data is persisted in the ``plugins`` table.
Module-level singleton::
from app.marketplace.plugin_registry import registry
"""
from __future__ import annotations
import json
from typing import Any, Literal
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Plugin
from app.schemas import PluginListResponse, PluginManifest
_PAGE_SIZE = 20
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
try:
permissions = json.loads(p.permissions) if p.permissions else []
except (json.JSONDecodeError, TypeError):
permissions = []
return PluginManifest(
id=p.id,
name=p.name,
description=p.description,
version=p.version,
author=p.author_name,
permissions=permissions,
category=p.category,
price_cents=p.price_cents,
)
class PluginRegistry:
"""PostgreSQL-backed plugin catalog.
All methods accept an ``AsyncSession`` parameter so the calling route
controls the session lifecycle.
"""
# ── Queries ──────────────────────────────────────────────────────
async def list_plugins(
self,
db: AsyncSession,
category: str | None = None,
query: str | None = None,
page: int = 1,
sort: Literal["rating", "installs", "newest"] = "newest",
) -> PluginListResponse:
"""Return a page of approved plugins, optionally filtered and sorted."""
base = select(Plugin).where(Plugin.status == "approved")
if category:
base = base.where(Plugin.category == category)
if query:
pattern = f"%{query}%"
base = base.where(
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
)
# Count
count_q = select(func.count()).select_from(base.subquery())
total = (await db.execute(count_q)).scalar_one()
# Sort
if sort == "installs":
base = base.order_by(Plugin.install_count.desc())
elif sort == "rating":
base = base.order_by(Plugin.avg_rating.desc())
else: # newest
base = base.order_by(Plugin.created_at.desc())
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
rows = (await db.execute(base)).scalars().all()
return PluginListResponse(
plugins=[_plugin_to_manifest(r) for r in rows],
total=total,
page=page,
)
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
p = result.scalar_one_or_none()
if p is None:
return None
return {
"manifest": _plugin_to_manifest(p),
"status": p.status,
"install_count": p.install_count,
"avg_rating": p.avg_rating,
}
# ── Mutations ────────────────────────────────────────────────────
async def submit_plugin(
self,
db: AsyncSession,
manifest: PluginManifest,
package_s3_key: str,
) -> str:
"""Add *manifest* to the catalog with ``status='pending_review'``.
Returns the plugin_id. If a plugin with the same id already exists
it is overwritten (re-submission after rejection).
"""
plugin_id = manifest.id
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = existing.scalar_one_or_none()
if row is not None:
row.name = manifest.name
row.description = manifest.description
row.version = manifest.version
row.author_name = manifest.author
row.category = manifest.category
row.price_cents = manifest.price_cents
row.permissions = json.dumps(manifest.permissions)
row.status = "pending_review"
row.s3_package_key = package_s3_key
row.rejection_reason = None
else:
row = Plugin(
id=plugin_id,
name=manifest.name,
description=manifest.description,
version=manifest.version,
author_name=manifest.author,
category=manifest.category,
price_cents=manifest.price_cents,
permissions=json.dumps(manifest.permissions),
status="pending_review",
s3_package_key=package_s3_key,
install_count=0,
avg_rating=0.0,
)
db.add(row)
await db.commit()
return plugin_id
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
"""Set *plugin_id* status to ``'approved'``.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "approved"
row.rejection_reason = None
await db.commit()
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "rejected"
row.rejection_reason = reason
await db.commit()
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
"""Increment the install count for *plugin_id* (no-op if not found)."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = row.install_count + 1
await db.commit()
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
"""Decrement the install count for *plugin_id*, floored at 0."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = max(0, row.install_count - 1)
await db.commit()
# ── Internal helpers used by ReviewQueue ─────────────────────────
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review'."""
result = await db.execute(
select(Plugin).where(Plugin.status == "pending_review")
)
rows = result.scalars().all()
return [
{
"manifest": _plugin_to_manifest(r),
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
}
for r in rows
]
# Module-level singleton
registry = PluginRegistry()

View File

@@ -1,125 +0,0 @@
"""Plugin review workflow backed by PostgreSQL.
Manages the approval queue for newly submitted plugins and enforces a
security checklist before any plugin is made visible in the marketplace.
Module-level singleton::
from app.marketplace.plugin_review import review_queue
"""
from __future__ import annotations
import re
from typing import Any, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import registry
from app.models import PluginReview as PluginReviewModel
from app.schemas import PluginManifest
# ── Security policy ───────────────────────────────────────────────────
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
{
"read:tasks",
"write:tasks",
"read:projects",
"write:projects",
"read:notes",
"write:notes",
"read:timelines",
"write:timelines",
"read:calendar",
"write:calendar",
}
)
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
def validate_manifest(manifest: PluginManifest) -> None:
"""Enforce the plugin security checklist.
Raises:
``ValueError`` on the first violation found. Callers should catch
this and return HTTP 422 / reject the submission.
Checks:
1. Plugin id matches ``^[a-z0-9-]+$``
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
3. No manifest field contains raw binary data
"""
if not _PLUGIN_ID_RE.match(manifest.id):
raise ValueError(
f"Invalid plugin id format: '{manifest.id}'. "
"Only lowercase letters, digits, and hyphens are allowed."
)
for perm in manifest.permissions:
if perm not in ALLOWED_PERMISSIONS:
raise ValueError(
f"Unknown permission: '{perm}'. "
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
)
for field_name, value in manifest.model_dump().items():
if isinstance(value, (bytes, bytearray)):
raise ValueError(
f"Binary content is not allowed in manifest field '{field_name}'."
)
class ReviewQueue:
"""Approval queue for pending plugin submissions.
Delegates status changes to the shared ``PluginRegistry`` singleton.
Review records are persisted in the ``plugin_reviews`` table.
"""
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all plugins currently awaiting review.
Each item is ``{plugin_id, manifest, submitted_at}``.
"""
entries = await registry.get_pending_entries(db)
return [
{
"plugin_id": e["manifest"].id,
"manifest": e["manifest"],
"submitted_at": e["submitted_at"],
}
for e in entries
]
async def submit_review(
self,
db: AsyncSession,
plugin_id: str,
reviewer_id: str,
decision: Literal["approved", "rejected"],
notes: str = "",
) -> None:
"""Record a review decision and update the plugin's status.
Raises:
``KeyError`` if *plugin_id* is not found in the registry.
"""
if decision == "approved":
await registry.approve_plugin(db, plugin_id)
else:
await registry.reject_plugin(db, plugin_id, reason=notes)
review = PluginReviewModel(
plugin_id=plugin_id,
reviewer_id=reviewer_id,
decision=decision,
notes=notes,
)
db.add(review)
await db.commit()
# Module-level singleton
review_queue = ReviewQueue()

View File

@@ -1,233 +0,0 @@
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
Records every plugin installation as a revenue event and facilitates
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
in the ``revenue_events`` table.
Module-level singleton::
from app.marketplace.revenue_share import revenue_share
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from sqlalchemy import extract, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
from app.models import Plugin, RevenueEvent
logger = logging.getLogger(__name__)
# ── Revenue split constants ───────────────────────────────────────────
DEVELOPER_SHARE: float = 0.70
PLATFORM_SHARE: float = 0.30
class RevenueShare:
"""Records installation revenue events and coordinates developer payouts.
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
is not configured, consistent with the rest of the billing layer.
"""
# ── Helpers ──────────────────────────────────────────────────────
@staticmethod
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
@staticmethod
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Core operations ──────────────────────────────────────────────
async def record_install(
self,
db: AsyncSession,
plugin_id: str,
user_id: str,
amount_cents: int,
) -> None:
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
For free plugins (``amount_cents == 0``) no payment is initiated but
the event is still recorded for analytics.
For paid plugins the developer receives 70 % via a Stripe Connect
destination charge. If Stripe is not configured or the charge fails
the installation still succeeds (the event is recorded and the install
count is incremented) — a warning is logged for monitoring.
"""
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
stripe_transfer_id: str | None = None
if amount_cents > 0 and self._stripe_configured():
# Look up the plugin's author Stripe account from the DB
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = result.scalar_one_or_none()
developer_stripe_account: str | None = None
if plugin_row and plugin_row.author_id:
# Future: look up user.stripe_connect_account_id
developer_stripe_account = None # no real account yet
if developer_stripe_account:
try:
s = self._stripe()
transfer = s.Transfer.create(
amount=developer_share_cents,
currency="eur",
destination=developer_stripe_account,
description=f"Revenue share for plugin {plugin_id}",
metadata={"plugin_id": plugin_id, "user_id": user_id},
)
stripe_transfer_id = transfer["id"]
except Exception as exc:
logger.warning(
"Stripe Connect transfer failed for plugin %s: %s",
plugin_id,
exc,
)
else:
logger.debug(
"No Stripe account on file for plugin %s developer; "
"skipping transfer.",
plugin_id,
)
event = RevenueEvent(
plugin_id=plugin_id,
user_id=user_id,
amount_cents=amount_cents,
developer_share_cents=developer_share_cents,
stripe_transfer_id=stripe_transfer_id,
)
db.add(event)
await db.commit()
await registry.record_install(db, plugin_id)
async def get_earnings(
self,
db: AsyncSession,
developer_id: str,
period: str | None = None,
) -> dict[str, Any]:
"""Return aggregated earnings for *developer_id*.
``period`` is an optional ``YYYY-MM`` string to restrict the window.
Returns::
{
"developer_id": str,
"period": str | None,
"total_installs": int,
"total_revenue_cents": int,
"developer_share_cents": int,
}
"""
# Find plugin ids belonging to this developer (by author_name match)
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
plugin_result = await db.execute(plugin_q)
developer_plugin_ids = [row[0] for row in plugin_result.all()]
if not developer_plugin_ids:
return {
"developer_id": developer_id,
"period": period,
"total_installs": 0,
"total_revenue_cents": 0,
"developer_share_cents": 0,
}
query = select(
func.count().label("total_installs"),
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
if period:
# Filter by YYYY-MM: extract year and month from created_at
try:
year, month = period.split("-")
query = query.where(
extract("year", RevenueEvent.created_at) == int(year),
extract("month", RevenueEvent.created_at) == int(month),
)
except ValueError:
pass # invalid period format — return all
result = await db.execute(query)
row = result.one()
return {
"developer_id": developer_id,
"period": period,
"total_installs": row.total_installs,
"total_revenue_cents": row.total_revenue,
"developer_share_cents": row.dev_share,
}
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
Marks processed events with ``paid_at`` timestamp.
Stubs gracefully when Stripe is not configured.
"""
try:
year, month = period.split("-")
year_int, month_int = int(year), int(month)
except ValueError:
logger.warning("Invalid period format: %s", period)
return
result = await db.execute(
select(RevenueEvent).where(
RevenueEvent.plugin_id == plugin_id,
RevenueEvent.paid_at.is_(None),
extract("year", RevenueEvent.created_at) == year_int,
extract("month", RevenueEvent.created_at) == month_int,
)
)
unpaid = list(result.scalars().all())
total_dev_share = sum(e.developer_share_cents for e in unpaid)
if total_dev_share <= 0 or not unpaid:
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
return
if self._stripe_configured():
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = plugin_result.scalar_one_or_none()
developer_stripe_account: str | None = None # Future: fetch from DB
if plugin_row and developer_stripe_account:
try:
s = self._stripe()
s.Transfer.create(
amount=total_dev_share,
currency="eur",
destination=developer_stripe_account,
description=f"Payout for plugin {plugin_id} period {period}",
)
except Exception as exc:
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
return
paid_ts = datetime.now(timezone.utc)
for event in unpaid:
event.paid_at = paid_ts
await db.commit()
# Module-level singleton
revenue_share = RevenueShare()

View File

@@ -1,23 +1,20 @@
"""SQLAlchemy ORM models for all persistent tables. """SQLAlchemy ORM models for all persistent tables.
Only auth, billing, storage metadata, and marketplace data live here. Only auth, billing, agent config, and memory data live here.
User content (notes, tasks, etc.) is NEVER persisted server-side — User content (notes, tasks, etc.) lives exclusively on the client.
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
Table inventory: Table inventory:
users — account credentials + tier users — account credentials + tier
refresh_tokens — hashed refresh token store refresh_tokens — hashed refresh token store
subscriptions — Stripe subscription records subscriptions — Stripe subscription records
storage_records — S3 blob metadata (no plaintext) local_agent_configs — per-device batch agent configs
backup_metadata — encrypted backup manifests cloud_agent_configs — OAuth-backed cloud agent configs
plugins — marketplace plugin catalog agent_run_logs — execution history for all agents
plugin_installations — per-user install records
plugin_reviews — admin review decisions
revenue_events — Stripe Connect 70/30 split ledger
memory_core — per-user persistent key/value preferences (encrypted) memory_core — per-user persistent key/value preferences (encrypted)
memory_associative — per-user semantic memory with embeddings (encrypted) memory_associative — per-user semantic memory with embeddings (encrypted)
memory_episodic — per-user session summaries (encrypted) memory_episodic — per-user session summaries (encrypted)
memory_proactive — per-user behavioral patterns (encrypted) memory_proactive — per-user behavioral patterns (encrypted)
memory_relations — per-user entity/relation graph (Mem0g-light, Phase 3)
""" """
from __future__ import annotations from __future__ import annotations
@@ -25,8 +22,8 @@ from __future__ import annotations
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from pgvector.sqlalchemy import Vector
from sqlalchemy import ( from sqlalchemy import (
BigInteger,
Boolean, Boolean,
DateTime, DateTime,
Enum, Enum,
@@ -34,9 +31,9 @@ from sqlalchemy import (
ForeignKey, ForeignKey,
Integer, Integer,
JSON, JSON,
LargeBinary,
String, String,
Text, Text,
UniqueConstraint,
Uuid, Uuid,
func, func,
) )
@@ -58,8 +55,6 @@ def _now() -> datetime:
# ── Enum types ──────────────────────────────────────────────────────────── # ── Enum types ────────────────────────────────────────────────────────────
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier") TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
AgentTypeEnum = Enum("local", "cloud", name="agent_type") AgentTypeEnum = Enum("local", "cloud", name="agent_type")
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status") AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider") CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
@@ -77,7 +72,8 @@ class User(Base):
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True) email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
name: Mapped[str | None] = mapped_column(String(100), nullable=True) name: Mapped[str | None] = mapped_column(String(100), nullable=True)
surname: Mapped[str | None] = mapped_column(String(100), nullable=True) surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False) password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free") tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True) stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration. # Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
@@ -86,6 +82,9 @@ class User(Base):
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
) )
onboarding_completed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, default=None
)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now() DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
) )
@@ -96,6 +95,9 @@ class User(Base):
subscription: Mapped[Subscription | None] = relationship( subscription: Mapped[Subscription | None] = relationship(
back_populates="user", uselist=False, cascade="all, delete-orphan" back_populates="user", uselist=False, cascade="all, delete-orphan"
) )
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
class RefreshToken(Base): class RefreshToken(Base):
@@ -116,6 +118,25 @@ class RefreshToken(Base):
user: Mapped[User] = relationship(back_populates="refresh_tokens") user: Mapped[User] = relationship(back_populates="refresh_tokens")
class OAuthAccount(Base):
__tablename__ = "oauth_accounts"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
provider: Mapped[str] = mapped_column(String(50), nullable=False)
provider_user_id: Mapped[str] = mapped_column(String(255), nullable=False)
provider_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="oauth_accounts")
class Subscription(Base): class Subscription(Base):
__tablename__ = "subscriptions" __tablename__ = "subscriptions"
@@ -137,151 +158,6 @@ class Subscription(Base):
user: Mapped[User] = relationship(back_populates="subscription") user: Mapped[User] = relationship(back_populates="subscription")
class StorageRecord(Base):
__tablename__ = "storage_records"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class BackupMetadata(Base):
__tablename__ = "backup_metadata"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class Plugin(Base):
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
# nullable until developer account system is built
author_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
submitted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
installations: Mapped[list[PluginInstallation]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
reviews: Mapped[list[PluginReview]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
revenue_events: Mapped[list[RevenueEvent]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
class PluginInstallation(Base):
__tablename__ = "plugin_installations"
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="installations")
class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
reviewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
class RevenueEvent(Base):
__tablename__ = "revenue_events"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
class LocalAgentConfig(Base): class LocalAgentConfig(Base):
__tablename__ = "local_agent_configs" __tablename__ = "local_agent_configs"
@@ -367,6 +243,7 @@ class AgentRunLog(Base):
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running") status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0) items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0) items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
errors: Mapped[list | None] = mapped_column(JSON, nullable=True) errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column( started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
@@ -387,6 +264,17 @@ class AgentRunLog(Base):
) )
class MonthlyTokenUsage(Base):
__tablename__ = "monthly_token_usage"
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), primary_key=True
)
year_month: Mapped[str] = mapped_column(String(7), primary_key=True) # 'YYYY-MM'
feature: Mapped[str] = mapped_column(String(64), primary_key=True)
tokens_used: Mapped[int] = mapped_column(Integer, nullable=False, default=0, server_default="0")
# ── Memory models ───────────────────────────────────────────────────────────── # ── Memory models ─────────────────────────────────────────────────────────────
@@ -426,8 +314,8 @@ class MemoryAssociative(Base):
nullable=False, index=True, nullable=False, index=True,
) )
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False) content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration. # vector(1536) via pgvector; SQLite tests use NULL embeddings so no dialect issue.
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True) embedding: Mapped[list | None] = mapped_column(Vector(1536), nullable=True)
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True) entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True) entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
@@ -475,3 +363,85 @@ class MemoryProactive(Base):
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now() DateTime(timezone=True), nullable=False, server_default=func.now()
) )
class ExtractionQueue(Base):
"""Batch extraction queue for Free-tier users (Phase 2).
Pro/Power/Team users get realtime asyncio.create_task() extraction.
Free users get a queue row here; a daily cron (Phase 5) drains it.
"""
__tablename__ = "extraction_queue"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
episode_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), nullable=True,
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class MemoryRelation(Base):
"""Per-user entity/relation graph row (Mem0g-light, Phase 3).
subject_label/object_label are plaintext entity identifiers (not user content).
notes_encrypted is optional Fernet-encrypted per-user commentary.
confidence in [0.0, 1.0] — decays 5 % per 30 days since last_confirmed_at.
"""
__tablename__ = "memory_relations"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
subject_label: Mapped[str] = mapped_column(String(128), nullable=False)
subject_type: Mapped[str] = mapped_column(String(32), nullable=False)
predicate: Mapped[str] = mapped_column(String(64), nullable=False)
object_label: Mapped[str] = mapped_column(String(128), nullable=False)
object_type: Mapped[str] = mapped_column(String(32), nullable=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.7)
source_episode_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False),
ForeignKey("memory_episodic.id", ondelete="SET NULL"),
nullable=True,
)
notes_encrypted: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
last_confirmed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True
)
class Plugin(Base):
"""Plugin marketplace catalog entry."""
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False)
version: Mapped[str] = mapped_column(String(50), nullable=False)
author_name: Mapped[str] = mapped_column(String(255), nullable=False)
category: Mapped[str] = mapped_column(String(100), nullable=False)
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
status: Mapped[str] = mapped_column(String(50), nullable=False, default="pending")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)

View File

@@ -30,6 +30,16 @@ class UserProfile(BaseModel):
name: str | None = None name: str | None = None
surname: str | None = None surname: str | None = None
tier: BillingTier tier: BillingTier
avatar_url: str | None = None
has_password: bool = True
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
class OAuthAccountInfo(BaseModel):
provider: str
provider_email: str | None = None
created_at: int # epoch ms
# ── Chat ───────────────────────────────────────────────────────────── # ── Chat ─────────────────────────────────────────────────────────────
@@ -50,88 +60,6 @@ class ChatResponse(BaseModel):
response: str response: str
# ── Backup ───────────────────────────────────────────────────────────
class BackupMetadata(BaseModel):
version: int
timestamp: int
checksum: str
chunk_count: int
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
class StorageRecord(BaseModel):
id: str
user_id: str
table: str
blob: bytes
checksum: str
created_at: int
updated_at: int
class StorageRecordCreate(BaseModel):
table: str
blob: bytes
checksum: str
class StorageRecordUpdate(BaseModel):
blob: bytes
checksum: str
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
class VectorItem(BaseModel):
id: str
blob: bytes # encrypted vector + metadata — backend never decrypts
checksum: str
class VectorUpsertRequest(BaseModel):
vectors: list[VectorItem]
class VectorSearchRequest(BaseModel):
query_blob: bytes # encrypted query — backend never decrypts
top_k: int = 10
class VectorSearchResult(BaseModel):
id: str
score: float
blob: bytes
class VectorSearchResponse(BaseModel):
results: list[VectorSearchResult]
# ── Plugin Marketplace ────────────────────────────────────────────────
class PluginManifest(BaseModel):
id: str
name: str
description: str
version: str
author: str
permissions: list[str]
category: str
price_cents: int = 0
class PluginListResponse(BaseModel):
plugins: list[PluginManifest]
total: int
page: int
class PluginInstallRequest(BaseModel):
plugin_id: str
# ── WebSocket Frame Protocol ────────────────────────────────────────── # ── WebSocket Frame Protocol ──────────────────────────────────────────
class WsFrameType(str, Enum): class WsFrameType(str, Enum):
@@ -157,6 +85,17 @@ class WsFrameType(str, Enum):
journey_start = "journey_start" journey_start = "journey_start"
journey_message = "journey_message" journey_message = "journey_message"
journey_reply = "journey_reply" journey_reply = "journey_reply"
# ── v5 brief frame types ──────────────────────────────────────────
brief_request = "brief_request"
# ── v6 task brief frame types ─────────────────────────────────────
task_brief_request = "task_brief_request"
# ── v7 folder index frame types ───────────────────────────────────
index_session_start = "index_session_start"
index_file_batch = "index_file_batch"
index_session_cancel = "index_session_cancel"
index_file_result = "index_file_result"
index_session_progress = "index_session_progress"
index_session_done = "index_session_done"
class WsToolCall(BaseModel): class WsToolCall(BaseModel):
@@ -212,6 +151,16 @@ class WsDeviceHello(BaseModel):
# ── WebSocket v3 Frame Models ───────────────────────────────────────── # ── WebSocket v3 Frame Models ─────────────────────────────────────────
class FormatPrefsModel(BaseModel):
"""User display preferences sent by Electron on each request."""
timezone: str = "UTC"
date_format: str = "dd/MM/yyyy"
time_format: str = "24h"
locale: str = "en-US"
now_iso: str = ""
class WsFloatingScope(BaseModel): class WsFloatingScope(BaseModel):
"""Scope for a floating request — narrows the agent to a specific entity.""" """Scope for a floating request — narrows the agent to a specific entity."""
@@ -225,6 +174,7 @@ class WsHomeRequest(BaseModel):
type: Literal[WsFrameType.home_request] = WsFrameType.home_request type: Literal[WsFrameType.home_request] = WsFrameType.home_request
message: str message: str
conversation_history: list[dict[str, Any]] = Field(default_factory=list) conversation_history: list[dict[str, Any]] = Field(default_factory=list)
format_prefs: FormatPrefsModel | None = None
class WsFloatingRequest(BaseModel): class WsFloatingRequest(BaseModel):
@@ -233,6 +183,18 @@ class WsFloatingRequest(BaseModel):
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str message: str
scope: WsFloatingScope scope: WsFloatingScope
format_prefs: FormatPrefsModel | None = None
class WsBriefRequest(BaseModel):
"""Client → Server: Request a plain-text brief (home or project)."""
type: Literal[WsFrameType.brief_request] = WsFrameType.brief_request
request_id: str | None = None
session_id: str | None = None
mode: Literal["home", "project"]
project_id: str | None = None
format_prefs: FormatPrefsModel | None = None
class WsStreamStart(BaseModel): class WsStreamStart(BaseModel):
@@ -255,6 +217,8 @@ class WsStreamEnd(BaseModel):
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
request_id: str request_id: str
error: str | None = None
mutations: list[dict[str, Any]] | None = None
class WsDomain(BaseModel): class WsDomain(BaseModel):
@@ -318,10 +282,11 @@ class AgentTriggerRequest(BaseModel):
device_id: str = Field(default="") device_id: str = Field(default="")
agent_id: str | None = None # FE stable agent ID (electron-store UUID) agent_id: str | None = None # FE stable agent ID (electron-store UUID)
what_to_extract: list[str] = Field(min_length=1) what_to_extract: list[str] = Field(min_length=1)
actions_by_type: dict[str, list[str]] | None = None
batch_interval: str = Field(min_length=1) batch_interval: str = Field(min_length=1)
custom_agent_prompt: str = Field(min_length=1) custom_agent_prompt: str | None = None
agent_config: dict | None = None
active_agents: int = Field(ge=0, default=0) active_agents: int = Field(ge=0, default=0)
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
# ── Agent Run Log ───────────────────────────────────────────────────── # ── Agent Run Log ─────────────────────────────────────────────────────

View File

@@ -1 +0,0 @@
"""Cloud storage layer — E2E encrypted blobs and vectors."""

View File

@@ -1,106 +0,0 @@
"""S3-backed store for E2E-encrypted blobs.
Keys are structured as ``{user_id}/{table}/{record_id}``.
The backend never inspects blob content — it stores and retrieves opaque bytes.
"""
from __future__ import annotations
from typing import Any
import boto3
from app.config.settings import settings
class BlobStore:
"""Thin wrapper around boto3 S3.
All blobs must be E2E encrypted by the client before upload.
The backend adds SSE-S3 as an extra layer of at-rest encryption
but cannot decrypt the inner client-side payload.
"""
def _client(self) -> Any:
kwargs: dict[str, Any] = {
"region_name": settings.S3_REGION,
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
}
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
return boto3.client("s3", **kwargs)
@staticmethod
def _key(user_id: str, table: str, record_id: str) -> str:
return f"{user_id}/{table}/{record_id}"
async def upload(
self,
user_id: str,
table: str,
record_id: str,
blob: bytes,
checksum: str,
) -> str:
"""Store *blob* in S3 and return the S3 key.
Args:
user_id: Owner of the blob (used as key prefix).
table: Logical table name (e.g. ``"tasks"``).
record_id: Record UUID.
blob: Raw bytes (pre-encrypted by client).
checksum: SHA-256 hex digest supplied by the client; stored as
object metadata for download-time verification.
Returns:
The S3 key under which the blob was stored.
"""
key = self._key(user_id, table, record_id)
self._client().put_object(
Bucket=settings.S3_BUCKET,
Key=key,
Body=blob,
ServerSideEncryption="AES256", # SSE-S3 at rest
Metadata={"checksum": checksum},
)
return key
async def download(self, user_id: str, s3_key: str) -> bytes:
"""Retrieve the blob stored at *s3_key*.
*user_id* is retained in the signature so higher-level code can
enforce ownership without re-parsing the key.
Raises:
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
object does not exist.
"""
response = self._client().get_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
return response["Body"].read()
async def delete(self, user_id: str, s3_key: str) -> None:
"""Delete the object at *s3_key*.
S3 ``delete_object`` is idempotent — it succeeds even if the key does
not exist.
"""
self._client().delete_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
async def list_keys(self, user_id: str, table: str) -> list[str]:
"""Return all S3 keys for a given user + table combination.
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
"""
prefix = f"{user_id}/{table}/"
response = self._client().list_objects_v2(
Bucket=settings.S3_BUCKET,
Prefix=prefix,
)
return [obj["Key"] for obj in response.get("Contents", [])]

View File

@@ -1,32 +0,0 @@
"""Integrity verification only — the backend NEVER decrypts user data."""
from __future__ import annotations
import hashlib
import hmac
from fastapi import HTTPException
def verify_checksum(blob: bytes, checksum: str) -> bool:
"""Return ``True`` if SHA-256(blob) matches *checksum*.
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
timing-based side-channel attacks.
"""
computed = hashlib.sha256(blob).hexdigest()
return hmac.compare_digest(computed, checksum)
def reject_if_tampered(blob: bytes, checksum: str) -> None:
"""Raise ``HTTP 400`` if the blob does not match its checksum.
Call this before storing or forwarding any client-provided blob.
The backend never holds decryption keys — this check only verifies
that the opaque bytes arrived intact.
"""
if not verify_checksum(blob, checksum):
raise HTTPException(
status_code=400,
detail="Checksum mismatch: blob integrity check failed",
)

View File

@@ -1,205 +0,0 @@
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
Vectors are pre-encrypted blobs from the client. The backend stores them
alongside a deterministic 32-dim float representation derived from the blob's
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
is a known trade-off documented in the backend plan.
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
``user_id`` payload field on a shared collection.
"""
from __future__ import annotations
import base64
import hashlib
from typing import Any
from pinecone import Pinecone
from qdrant_client import QdrantClient
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
from app.config.settings import settings
from app.schemas import VectorItem, VectorSearchResult
_QDRANT_COLLECTION = "adiuva_vectors"
def _blob_to_vector(blob: bytes) -> list[float]:
"""Derive a 32-dim float vector from *blob* for storage purposes only.
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
normalises each byte to the range [-1.0, 1.0]. This vector carries no
semantic meaning on encrypted data.
"""
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
class VectorStore:
"""Thin wrapper around Pinecone or Qdrant.
The backend to use is selected at runtime:
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
"""
def _use_pinecone(self) -> bool:
return bool(settings.PINECONE_API_KEY)
# ── Pinecone helpers ──────────────────────────────────────────────
def _pinecone_index(self) -> Any:
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
return pc.Index(settings.PINECONE_INDEX)
# ── Qdrant helpers ────────────────────────────────────────────────
def _qdrant_client(self) -> Any:
return QdrantClient(
url=settings.QDRANT_URL,
api_key=settings.QDRANT_API_KEY or None,
)
# ── Public API ────────────────────────────────────────────────────
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
"""Store encrypted vectors in the backend.
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
so it can be returned verbatim during search.
Args:
user_id: Used as Pinecone namespace or Qdrant payload field.
vectors: List of encrypted vector items from the client.
"""
if self._use_pinecone():
await self._pinecone_upsert(user_id, vectors)
else:
await self._qdrant_upsert(user_id, vectors)
async def search(
self,
user_id: str,
query_blob: bytes,
top_k: int,
) -> list[VectorSearchResult]:
"""Query the vector store and return encrypted result blobs.
The query vector is derived from *query_blob* using the same
deterministic mapping as upsert.
Args:
user_id: Scopes the search to this user's namespace.
query_blob: Encrypted query from the client.
top_k: Maximum number of results to return.
Returns:
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
"""
if self._use_pinecone():
return await self._pinecone_search(user_id, query_blob, top_k)
return await self._qdrant_search(user_id, query_blob, top_k)
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
"""Remove vectors by ID, scoped to *user_id*.
Args:
user_id: Namespace / payload filter to prevent cross-user deletion.
vector_ids: List of vector IDs to remove.
"""
if self._use_pinecone():
await self._pinecone_delete(user_id, vector_ids)
else:
await self._qdrant_delete(user_id, vector_ids)
# ── Pinecone implementation ───────────────────────────────────────
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
index = self._pinecone_index()
records = [
{
"id": v.id,
"values": _blob_to_vector(v.blob),
"metadata": {
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
}
for v in vectors
]
index.upsert(vectors=records, namespace=user_id)
async def _pinecone_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
index = self._pinecone_index()
query_vector = _blob_to_vector(query_blob)
response = index.query(
vector=query_vector,
top_k=top_k,
namespace=user_id,
include_metadata=True,
)
results: list[VectorSearchResult] = []
for match in response.get("matches", []):
blob_bytes = base64.b64decode(match["metadata"]["blob"])
results.append(
VectorSearchResult(
id=match["id"],
score=match["score"],
blob=blob_bytes,
)
)
return results
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
index = self._pinecone_index()
index.delete(ids=vector_ids, namespace=user_id)
# ── Qdrant implementation ─────────────────────────────────────────
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
client = self._qdrant_client()
points = [
PointStruct(
id=v.id,
vector=_blob_to_vector(v.blob),
payload={
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
)
for v in vectors
]
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
async def _qdrant_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
client = self._qdrant_client()
query_vector = _blob_to_vector(query_blob)
hits = client.search(
collection_name=_QDRANT_COLLECTION,
query_vector=query_vector,
query_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=top_k,
)
return [
VectorSearchResult(
id=str(hit.id),
score=hit.score,
blob=base64.b64decode(hit.payload["blob"]),
)
for hit in hits
]
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
client = self._qdrant_client()
client.delete(
collection_name=_QDRANT_COLLECTION,
points_selector=PointIdsList(points=vector_ids),
)

View File

@@ -7,7 +7,7 @@ services:
- path: .env - path: .env
required: false required: false
environment: environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
volumes: volumes:
- copilot_tokens:/root/.config/litellm/github_copilot - copilot_tokens:/root/.config/litellm/github_copilot
@@ -21,7 +21,7 @@ services:
environment: environment:
POSTGRES_USER: postgres POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: postgres
POSTGRES_DB: adiuva POSTGRES_DB: adiuvai
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
healthcheck: healthcheck:
@@ -36,37 +36,6 @@ services:
# image: redis:7-alpine # image: redis:7-alpine
# restart: unless-stopped # restart: unless-stopped
# ── Local S3-compatible storage (MinIO) ──
minio:
image: minio/minio:latest
command: server /data --console-address ":9001"
ports:
- "9000:9000"
- "9001:9001"
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin
volumes:
- minio_data:/data
healthcheck:
test: ["CMD", "mc", "ready", "local"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
# ── Local vector store (Qdrant) ──
qdrant:
image: qdrant/qdrant:latest
ports:
- "6333:6333"
- "6334:6334"
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
volumes: volumes:
postgres_data: postgres_data:
minio_data:
qdrant_data:
copilot_tokens: copilot_tokens:

View File

@@ -1,941 +0,0 @@
# Adiuva — Architettura Microservizi (MVP)
## Panoramica
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
```
┌──────────────┐
│ Cloudflare │
│ (DNS + CDN) │
└──────┬───────┘
│ HTTPS / WSS
┌──────▼───────┐
│ Traefik │
│ API Gateway │
│ (routing, │
│ TLS, rate │
│ limiting) │
└──────┬───────┘
┌──────────┬───────────┼───────────┐
│ │ │ │
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
│ Auth │ │ Chat │ │ Agent │ │Billing │
│ Service │ │Service│ │ Service │ │Service │
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
│ │ │ │
┌─────▼──────────▼──────────▼───────────▼────┐
│ Infrastruttura │
│ PostgreSQL │ Redis │ Qdrant │
└─────────────────────────────────────────────┘
```
---
## 1. Suddivisione dei Servizi
### 1.1 Auth Service (`auth-service`)
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
| Endpoint originale | Metodo |
|---|---|
| `/api/v1/auth/register` | POST |
| `/api/v1/auth/login` | POST |
| `/api/v1/auth/refresh` | POST |
| `/api/v1/auth/me` | GET / PUT |
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
**Modifica chiave — JWT con RS256**:
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
- L'Auth Service firma i JWT con la **chiave privata**.
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
```python
# auth-service/app/auth/jwt.py
from cryptography.hazmat.primitives.asymmetric import rsa
from jose import jwt
PRIVATE_KEY = ... # Da env/secret
PUBLIC_KEY = ... # Derivata o da env
def create_access_token(user_id: str, tier: str) -> str:
return jwt.encode(
{"sub": user_id, "tier": tier, "exp": ...},
PRIVATE_KEY,
algorithm="RS256",
)
```
```python
# shared/auth.py (usato da tutti gli altri servizi)
from jose import jwt
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
def verify_token(token: str) -> dict:
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
```
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
---
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
| Endpoint | Tipo |
|---|---|
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
| `/api/v1/chat` | POST (REST fallback) |
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
**Scaling**: 2N repliche. Sticky cookies per le WS + Redis per cross-instance.
---
### 1.3 Agent Service (`agent-service`) ⭐ Batch
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
| Endpoint | Tipo |
|---|---|
| `/api/v1/agents/catalog` | GET |
| `/api/v1/agents/can-create` | POST |
| `/api/v1/agents/trigger` | POST |
| `/api/v1/agents/journey/start` | POST (o WS relay) |
| `/api/v1/agents/journey/message` | POST (o WS relay) |
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
```
┌──────────────┐ ┌──────────────┐ ┌──────────┐
│ Agent Service│ │ Redis │ │ Chat │
│ (batch run) │ │ │ │ Service │
│ │ │ │ │ (ha WS) │
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
│ from │ │ │ │ al │
│ device │ │ │ │ device│
│ │ │ │ │ via WS│
│ │ SUBSCRIBE │ │ PUBLISH │ │
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
│ risultato │ │ │ │ reply │
└──────────────┘ └──────────────┘ └──────────┘
```
**Scaling**: 1N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
---
### 1.4 Billing Service (`billing-service`)
**Responsabilità**: Stripe checkout, webhook, subscription management.
| Endpoint originale | Metodo |
|---|---|
| `/api/v1/billing/checkout` | POST |
| `/api/v1/billing/webhook` | POST |
| `/api/v1/billing/subscription` | GET / DELETE |
**Database**: Tabelle `subscriptions` (schema `billing`).
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
**Scaling**: 1 replica sufficiente. Basso traffico.
---
### 1.5 Servizi esclusi dall'MVP
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
| Servizio | Responsabilità | Note |
|---|---|---|
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
---
## 2. Tier Check — Dove e Come
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
### Strategia: Tier nel JWT
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
```json
{
"sub": "user_123",
"tier": "pro",
"exp": 1742515200,
"iat": 1742511600
}
```
Ogni servizio:
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
2. Legge `payload["tier"]`**zero chiamate extra**
3. Applica le sue regole di enforcement localmente
```python
# shared/auth.py — dependency FastAPI condivisa
from fastapi import Depends, HTTPException, Request
from jose import jwt
PUBLIC_KEY = ...
class CurrentUser:
def __init__(self, user_id: str, tier: str):
self.user_id = user_id
self.tier = tier
async def get_current_user(request: Request) -> CurrentUser:
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
def require_tier(*allowed_tiers: str):
"""Dependency che blocca se il tier non è tra quelli ammessi."""
async def check(user: CurrentUser = Depends(get_current_user)):
if user.tier not in allowed_tiers:
raise HTTPException(403, "Tier insufficient")
return user
return check
```
### Cosa succede quando il tier cambia (upgrade/downgrade)?
```
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
│ │ │ Service │ (Redis pub/sub) │ Service │
└──────────┘ └──────────┘ └────┬─────┘
UPDATE users
SET tier = 'power'
Al prossimo /refresh
il JWT conterrà tier='power'
```
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 1530 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
### Dove si applica in ciascun servizio
| Servizio | Enforcement |
|---|---|
| **Auth Service** | Nessuno (è lui che scrive il tier) |
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
### Rate-limit distribuito via Redis
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
```python
# shared/middleware/rate_limit.py
import redis.asyncio as aioredis
class DistributedRateLimiter:
def __init__(self, redis: aioredis.Redis):
self._redis = redis
async def check(self, user_id: str, tier: str, service: str) -> bool:
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
max_req = limits.get(tier, 20)
key = f"rate:{service}:{user_id}"
pipe = self._redis.pipeline()
pipe.incr(key)
pipe.expire(key, 60)
count, _ = await pipe.execute()
return count <= max_req
```
---
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
`DeviceConnectionManager` è un **singleton in-memory**:
```python
class DeviceConnectionManager:
def __init__(self):
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
```
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
### La soluzione: Redis Pub/Sub + Registry
```
┌──────────────────────────────────────────────────────────────┐
│ Redis │
│ │
│ Hash: ws:connections │
│ user_123 → instance_A │
│ user_456 → instance_B │
│ │
│ Pub/Sub channels: │
│ tool_call:{user_id} → tool call payloads │
│ tool_result:{call_id} → tool result payloads │
│ stream:{user_id} → text_chunk streaming │
└──────────────────────────────────────────────────────────────┘
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
┌───────────────────────┐ ┌───────────────────────┐
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
│ tool_call:user_123│ │ → user_123 è su A │
│ │ │ │
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
│ da Redis channel │ │ tool_call:user_123 │
│ │ │ {id, action, ...} │
│ 3. Invia al device │ │ │
│ via WS │ │ 4. SUBSCRIBE │
│ │ │ tool_result:{id} │
│ 4. Device risponde │ │ │
│ tool_result │──────────│► 5. Riceve risultato │
│ │ │ │
│ 5. PUBLISH │ │ │
│ tool_result:{id} │ │ │
└───────────────────────┘ └───────────────────────┘
```
### Implementazione: `RedisDeviceManager`
```python
# chat-service/app/core/device_manager.py
import asyncio
import json
import os
import redis.asyncio as aioredis
from dataclasses import dataclass, field
from fastapi import WebSocket
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
@dataclass
class LocalConnection:
ws: WebSocket
device_id: str
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
class RedisDeviceManager:
"""Device manager backed by Redis for cross-instance communication."""
def __init__(self, redis_url: str = "redis://redis:6379"):
self._redis = aioredis.from_url(redis_url)
self._pubsub = self._redis.pubsub()
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
async def start(self):
"""Avvia il listener Redis per tool_call in arrivo."""
asyncio.create_task(self._listen_tool_calls())
# ── Registrazione ──
async def register(self, user_id: str, device_id: str, ws: WebSocket):
# Registra localmente
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
# Registra in Redis quale istanza ha la connessione
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
# Sottoscrivi ai tool_call per questo utente
await self._pubsub.subscribe(f"tool_call:{user_id}")
async def unregister(self, user_id: str):
conn = self._local.pop(user_id, None)
if conn:
for fut in conn.pending_calls.values():
if not fut.done():
fut.cancel()
await self._redis.hdel("ws:connections", user_id)
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
# ── Presenza ──
async def is_online(self, user_id: str) -> bool:
return await self._redis.hexists("ws:connections", user_id)
# ── Tool-call round-trip (cross-instance) ──
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
"""
Invia un tool_call al device dell'utente.
Funziona sia che la WS sia locale che su un'altra istanza.
"""
call_id = payload["id"]
# Caso 1: connessione locale → invio diretto
if user_id in self._local:
conn = self._local[user_id]
loop = asyncio.get_event_loop()
fut: asyncio.Future[dict] = loop.create_future()
conn.pending_calls[call_id] = fut
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
return await asyncio.wait_for(fut, timeout=30.0)
# Caso 2: connessione remota → Redis pub/sub
loop = asyncio.get_event_loop()
fut = loop.create_future()
self._remote_futures[call_id] = fut
# Sottoscrivi al canale di risposta
result_channel = f"tool_result:{call_id}"
await self._pubsub.subscribe(result_channel)
# Pubblica il tool_call
await self._redis.publish(
f"tool_call:{user_id}",
json.dumps(payload),
)
try:
return await asyncio.wait_for(fut, timeout=30.0)
finally:
self._remote_futures.pop(call_id, None)
await self._pubsub.unsubscribe(result_channel)
# ── Risoluzione tool_result (da WS locale) ──
def resolve_local(self, user_id: str, call_id: str, result: dict):
conn = self._local.get(user_id)
if conn:
fut = conn.pending_calls.pop(call_id, None)
if fut and not fut.done():
fut.set_result(result)
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
"""Chiamato quando il device locale invia un tool_result."""
self.resolve_local(user_id, call_id, result)
# Pubblica anche su Redis per l'istanza remota che aspetta
await self._redis.publish(
f"tool_result:{call_id}",
json.dumps(result),
)
# ── Listener Redis ──
async def _listen_tool_calls(self):
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
async for message in self._pubsub.listen():
if message["type"] != "message":
continue
channel = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
data = json.loads(message["data"])
if channel.startswith("tool_call:"):
# Un'altra istanza vuole che inviamo un tool_call al nostro device
user_id = channel.split(":", 1)[1]
conn = self._local.get(user_id)
if conn:
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
elif channel.startswith("tool_result:"):
# Risposta a un tool_call che abbiamo inviato tramite Redis
call_id = channel.split(":", 1)[1]
fut = self._remote_futures.pop(call_id, None)
if fut and not fut.done():
fut.set_result(data)
# ── Stream cross-instance ──
async def publish_stream_chunk(self, user_id: str, chunk: dict):
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
```
---
## 4. Struttura Directory Proposta (MVP)
```
adiuva-api/
├── docker-compose.yml # Orchestrazione completa
├── docker-compose.dev.yml # Override per sviluppo locale
├── shared/ # Codice condiviso (montato come volume)
│ ├── auth.py # JWT verification (chiave pubblica)
│ ├── schemas.py # Pydantic schemas condivisi
│ ├── middleware/
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
│ │ └── sanitizer.py
│ └── models/
│ └── base.py # SQLAlchemy base condivisa
├── auth-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # users, refresh_tokens
│ ├── routes/
│ │ └── auth.py
│ └── services/
│ ├── jwt_service.py # RS256 signing
│ └── user_service.py
├── chat-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # memory_*
│ ├── routes/
│ │ ├── device_ws.py # WS connection owner
│ │ └── chat.py # REST fallback
│ ├── core/
│ │ ├── device_manager.py # RedisDeviceManager
│ │ ├── deep_agent.py # Home + floating chat
│ │ ├── memory_middleware.py
│ │ ├── ws_context.py
│ │ ├── output_formatter.py
│ │ └── llm.py
│ └── agents/ # Tool definitions (used by deep_agent)
│ ├── task_agent.py
│ ├── project_agent.py
│ ├── note_agent.py
│ └── timeline_agent.py
├── agent-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
│ ├── routes/
│ │ ├── agents.py # catalog, can-create, trigger
│ │ └── agent_setup.py # journey start/message
│ ├── core/
│ │ ├── agent_runner.py # Batch classify → process
│ │ ├── agent_registry.py
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
│ │ └── llm.py
│ └── agents/
│ ├── task_agent.py # Tool definitions (batch context)
│ ├── project_agent.py
│ ├── note_agent.py
│ ├── timeline_agent.py
│ └── filesystem_agent.py
├── billing-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # subscriptions
│ ├── routes/
│ │ └── billing.py
│ └── services/
│ ├── stripe_service.py
│ └── tier_manager.py
└── infra/
├── traefik/
│ └── traefik.yml
├── keys/
│ ├── jwt_private.pem # Solo auth-service
│ └── jwt_public.pem # Tutti i servizi
└── alembic/ # Migrazioni condivise o per-servizio
```
---
## 5. Docker Compose — Configurazione MVP
```yaml
# docker-compose.yml
services:
# ══════════════════════════════════════════════════════════
# API Gateway
# ══════════════════════════════════════════════════════════
traefik:
image: traefik:v3.2
command:
- "--api.insecure=true"
- "--providers.docker=true"
- "--providers.docker.exposedbydefault=false"
- "--entrypoints.web.address=:80"
- "--entrypoints.websecure.address=:443"
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
ports:
- "80:80"
- "443:443"
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./infra/certs:/certs:ro
restart: unless-stopped
# ══════════════════════════════════════════════════════════
# Auth Service (2 repliche)
# ══════════════════════════════════════════════════════════
auth-service:
build: ./auth-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
SERVICE_NAME: auth
secrets:
- jwt_private_key
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
- "traefik.http.services.auth.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Chat Service — Real-time WS + Chat (scalabile)
# ══════════════════════════════════════════════════════════
chat-service:
build: ./chat-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: chat
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
# REST chat endpoint
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
- "traefik.http.services.chat.loadbalancer.server.port=8000"
# WebSocket route con sticky session
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
- "traefik.http.routers.ws.service=chat-ws"
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Agent Service — Batch processing (scalabile indipendentemente)
# ══════════════════════════════════════════════════════════
agent-service:
build: ./agent-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: agent
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
- "traefik.http.services.agents.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Billing Service (1 replica)
# ══════════════════════════════════════════════════════════
billing-service:
build: ./billing-service
deploy:
replicas: 1
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: billing
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
- "traefik.http.services.billing.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Infrastruttura
# ══════════════════════════════════════════════════════════
db:
image: pgvector/pgvector:pg16
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: adiuva
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
redis:
image: redis:7-alpine
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 5
restart: unless-stopped
qdrant:
image: qdrant/qdrant:latest
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
secrets:
jwt_private_key:
file: ./infra/keys/jwt_private.pem
jwt_public_key:
file: ./infra/keys/jwt_public.pem
volumes:
postgres_data:
redis_data:
qdrant_data:
```
---
## 6. Configurazione Cloudflare + VPS
### 6.1 DNS
```
api.tuodominio.com → A record → IP del VPS
→ Proxy: ON (orange cloud)
```
### 6.2 Cloudflare Settings
| Setting | Valore | Motivo |
|---------|--------|--------|
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
| Under Attack Mode | Off (attivare se necessario) | |
### 6.3 TLS sul VPS
Due opzioni:
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
```yaml
# traefik.yml — con Cloudflare Origin Certificate
entryPoints:
websecure:
address: ":443"
tls:
certificates:
- certFile: /certs/origin.pem
keyFile: /certs/origin-key.pem
```
### 6.4 Rete VPS
```bash
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
# https://www.cloudflare.com/ips/
ufw default deny incoming
ufw allow from 173.245.48.0/20 to any port 443
ufw allow from 103.21.244.0/22 to any port 443
# ... (tutti gli IP range di Cloudflare)
ufw allow ssh
ufw enable
```
---
## 7. Comunicazione Inter-Servizio
### 7.1 Redis Pub/Sub — Event Bus
```
┌──────────┐ tier_changed:user_123 ┌──────────┐
│ Billing │ ────────────────────────► │ Auth │
│ Service │ │ Service │
└──────────┘ └──────────┘
┌──────────┐ tool_call:user_123 ┌──────────┐
│ Agent │ ────────────────────────► │ Chat │
│ Service │ │ Service │
│ (batch) │ ◄────────────────────────│ (ha WS) │
└──────────┘ tool_result:{call_id} └──────────┘
```
### 7.2 Health Checks e Service Discovery
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
- **Redis pub/sub** (tool-call cross-instance, tier events)
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
- **PostgreSQL** (dati persistenti condivisi)
---
## 8. Piano di Migrazione Incrementale (MVP)
### Fase 1 — Preparazione (nel monolite attuale)
1. Aggiungere Redis al `docker-compose.yml` attuale
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
4. Estrarre `shared/` con auth verification, schemas, middleware
### Fase 2 — Auth Service (primo split)
1. Estrarre `auth.py` routes + models in `auth-service/`
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
4. Il monolite continua a servire tutto il resto
### Fase 3 — Billing Service
1. Estrarre billing routes, Stripe service, tier manager
2. Configurare Redis pub/sub per `tier_changed` events
3. Routare via Traefik
### Fase 4 — Split Chat + Agent (il più delicato)
1. Il monolite residuo contiene WS + chat + agents
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
### Fase 5 — Scaling test
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
3. Monitoring (Prometheus + Grafana) per ogni servizio
---
## 9. Monitoraggio e Logging
```yaml
# Aggiungere al docker-compose.yml
prometheus:
image: prom/prometheus:latest
volumes:
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
restart: unless-stopped
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
volumes:
- grafana_data:/var/lib/grafana
restart: unless-stopped
loki:
image: grafana/loki:latest
restart: unless-stopped
```
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
---
## 10. Sizing VPS Minimo Consigliato (MVP)
| Componente | CPU | RAM | Note |
|---|---|---|---|
| Traefik | 0.25 | 128MB | |
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
| Billing Service | 0.25 | 128MB | |
| PostgreSQL | 1.0 | 1GB | |
| Redis | 0.25 | 256MB | |
| Qdrant | 0.5 | 512MB | |
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
---
## Riepilogo Architettura MVP
| Servizio | Repliche | Proprietario di |
|---|---|---|
| **Traefik** | 1 | Routing, TLS, sticky sessions |
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
| **Chat Service** | 2N | WebSocket, home/floating chat, streaming |
| **Agent Service** | 2N | Batch processing, directory scan, agent setup |
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
| Decisione | Scelta | Motivazione |
|---|---|---|
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
| TLS | Cloudflare Origin Certificate | Zero maintenance |
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
| Storage / Plugin | Post-MVP | Non critici per il lancio |

View File

@@ -32,8 +32,12 @@ google-auth-oauthlib>=1.2.0
google-auth-httplib2>=0.2.0 google-auth-httplib2>=0.2.0
msal>=1.28.0 msal>=1.28.0
cryptography>=42.0.0 cryptography>=42.0.0
langfuse>=2.0.0 pgvector>=0.2.5
langfuse>=3.3.1
beautifulsoup4>=4.12.0 beautifulsoup4>=4.12.0
lxml>=5.0.0 lxml>=5.0.0
PyYAML>=6.0.0 PyYAML>=6.0.0
apscheduler>=3.10.0
ruff>=0.8.0 ruff>=0.8.0
pypdf>=4.0
python-docx>=1.1

1
results.xml Normal file

File diff suppressed because one or more lines are too long

View File

@@ -17,6 +17,8 @@ from jose import jwt
from sqlalchemy import StaticPool, event from sqlalchemy import StaticPool, event
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy import select
from app.config.settings import settings from app.config.settings import settings
from app.db import Base, get_session from app.db import Base, get_session
from app.main import app from app.main import app
@@ -134,6 +136,38 @@ def auth_header(tier: str = "power", user_id: str | None = None) -> dict[str, st
return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"} return {"Authorization": f"Bearer {make_jwt(tier, user_id)}"}
# ── Convenience aliases and per-tier user fixtures ────────────────────
@pytest_asyncio.fixture
async def db(db_session: AsyncSession) -> AsyncSession:
"""Alias for db_session — used by folder quota tests."""
return db_session
@pytest_asyncio.fixture
async def test_user_free(db_session: AsyncSession):
"""Return the seeded free-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["free"])
)
return result.scalar_one()
@pytest_asyncio.fixture
async def test_user_power(db_session: AsyncSession):
"""Return the seeded power-tier User row."""
result = await db_session.execute(
select(User).where(User.id == TEST_USER_IDS["power"])
)
return result.scalar_one()
@pytest.fixture
def auth_headers_free() -> dict[str, str]:
"""Authorization header for the seeded free-tier user."""
return auth_header("free")
# ── CLI options ─────────────────────────────────────────────────────── # ── CLI options ───────────────────────────────────────────────────────
def pytest_addoption(parser): def pytest_addoption(parser):

View File

@@ -1,19 +1,11 @@
# Journey V2 eval test cases — Step 4 # Journey V2 eval test cases — Step 4
# #
# Each case simulates a complete journey session: # Only case 4.1 is kept as an automated eval. Cases 4.24.5 (multi-turn
# 1. handle_journey_start is called with directory + data_types # conversations that expect the LLM to produce a complete AgentConfig)
# 2. handle_journey_message is called for each entry in user_messages # are non-deterministic and tested manually — results tracked in Langfuse.
# 3. Assertions are evaluated on the final reply
#
# directory_files: list of {path, content_file} — content_file is relative to data/
# #
# Assertion keys: # Assertion keys:
# expect_question: true → first reply must contain "?" # expect_question: true → first reply must contain "?"
# expect_done: true → final reply must have done=True
# expect_valid_config: true → agent_config must be parseable as AgentConfig with content_types > 0
# expect_content_type_id: <str> → AgentConfig.content_types must contain an entry with this id
# expect_extraction_contains: <str> → first content_type extraction_prompt must contain this word
# expect_global_rules: true → AgentConfig.global_rules must be non-empty
- id: "4.1" - id: "4.1"
description: "Journey start explores directory, first reply contains a question" description: "Journey start explores directory, first reply contains a question"
@@ -25,63 +17,3 @@
user_messages: [] user_messages: []
score_name: "journey.start" score_name: "journey.start"
expect_question: true expect_question: true
- id: "4.2"
description: "Full 3-turn conversation produces a valid AgentConfig JSON"
directory: "/test/emails"
data_types: ["tasks", "notes", "timelines"]
directory_files:
- path: "/test/emails/email_backup.html"
content_file: "email_action.html"
user_messages:
- "These are email exports from Outlook in HTML format"
- "Create tasks for emails with direct action requests, notes for informational emails"
- "Yes, that looks correct. No other rules."
score_name: "journey.valid_json"
expect_done: true
expect_valid_config: true
- id: "4.3"
description: "Journey detects email_html content type from directory exploration"
directory: "/test/emails"
data_types: ["tasks", "notes"]
directory_files:
- path: "/test/emails/message.html"
content_file: "email_action.html"
user_messages:
- "HTML email backups from my mail client, exported from Outlook"
- "Create tasks from emails that contain assignments or direct action items"
- "Correct, no other rules needed"
score_name: "journey.detect_email"
expect_done: true
expect_content_type_id: "email_html"
- id: "4.4"
description: "Custom user rule (only notes, no tasks) reflected in extraction_prompt"
directory: "/test/emails"
data_types: ["notes"]
directory_files:
- path: "/test/emails/email.html"
content_file: "email_info.html"
user_messages:
- "HTML emails from my work inbox"
- "Create only notes from all emails — I do not want tasks or timelines to be created"
- "Yes, exactly"
score_name: "journey.custom_rules"
expect_done: true
expect_extraction_contains: "note"
- id: "4.5"
description: "Global rule (no project = no entity) appears in AgentConfig.global_rules"
directory: "/test/emails"
data_types: ["tasks", "notes"]
directory_files:
- path: "/test/emails/email.html"
content_file: "email_action.html"
user_messages:
- "Email backups from Outlook"
- "Create tasks from action request emails, notes from informational emails"
- "If the email cannot be matched to any project, do not create any entity at all"
score_name: "journey.global_rules"
expect_done: true
expect_global_rules: true

View File

@@ -1,810 +0,0 @@
"""Tests for Step 3.4: agent_runner module.
Coverage:
Unit:
- _is_overdue — cron schedule overdue detection
- _extract_items_from_content — LLM extraction + JSON parsing + validation
- _send_insert_to_client — tool_call frame construction + timeout
- run_local_agent — end-to-end local agent happy path
- run_local_agent — device offline path
- run_local_agent — file-read timeout path
- run_local_agent — LLM extraction error path
- run_cloud_agent — stub returns error immediately
- trigger_pending_runs — skipped when config is client-owned
- trigger_pending_runs — non-overdue skipped
- trigger_pending_runs — device_id filter for local agents
Integration:
- POST /agents/can-create — billing eligibility check
- POST /agents/trigger — creates run log + dispatches background task
"""
from __future__ import annotations
import asyncio
import json
import uuid
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from app.core.agent_runner import (
_extract_items_from_content,
_is_overdue,
_send_insert_to_client,
run_cloud_agent,
run_local_agent,
trigger_pending_runs,
)
from app.core.device_manager import DeviceConnectionManager
from app.db import get_session
from app.main import app
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_FREE_UID = TEST_USER_IDS["free"]
_PRO_UID = TEST_USER_IDS["pro"]
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
return LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=user_id,
device_id=device_id,
name="Test Local Agent",
directory_paths=["/home/user/emails"],
data_types=["tasks", "notes"],
prompt_template="Extract tasks and notes from this document.",
file_extensions=[".txt", ".eml"],
schedule_cron="0 */6 * * *",
enabled=True,
last_run_at=None,
)
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
return CloudAgentConfig(
id=str(uuid.uuid4()),
user_id=user_id,
provider="gmail",
name="Test Gmail Agent",
data_types=["tasks"],
prompt_template="Extract tasks from email.",
schedule_cron="0 */6 * * *",
enabled=True,
last_run_at=None,
)
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
return AgentRunLog(
id=str(uuid.uuid4()),
agent_id=agent_id,
agent_type=agent_type,
user_id=user_id,
status="running",
started_at=datetime.now(timezone.utc),
)
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
mgr = DeviceConnectionManager()
ws = MagicMock()
ws.send_text = AsyncMock()
mgr.register(user_id, device_id, ws)
return mgr
# ---------------------------------------------------------------------------
# _is_overdue
# ---------------------------------------------------------------------------
def test_is_overdue_never_run():
"""An agent that has never run is always overdue."""
assert _is_overdue("0 */6 * * *", None) is True
def test_is_overdue_very_recently_run():
"""An agent that just ran is not overdue."""
last = datetime.now(timezone.utc)
assert _is_overdue("0 */6 * * *", last) is False
def test_is_overdue_long_ago():
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
from datetime import timedelta
last = datetime.now(timezone.utc) - timedelta(days=2)
assert _is_overdue("0 */6 * * *", last) is True
def test_is_overdue_invalid_cron_returns_false():
"""Unparseable cron must not raise and should return False (fail-safe)."""
assert _is_overdue("not a cron", None) is False
def test_is_overdue_naive_datetime():
"""Naive datetime objects are handled without raising."""
from datetime import timedelta
last = datetime.utcnow() - timedelta(days=1) # naive
# Should not raise.
result = _is_overdue("0 */6 * * *", last)
assert isinstance(result, bool)
# ---------------------------------------------------------------------------
# _extract_items_from_content
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_extract_items_happy_path():
"""LLM returns valid JSON array; items with allowed tables are returned."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = json.dumps([
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
items = await _extract_items_from_content(
"Extract tasks and notes.",
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
["tasks", "notes"],
)
assert len(items) == 2
assert items[0]["table"] == "tasks"
assert items[0]["data"]["title"] == "Buy milk"
assert items[1]["table"] == "notes"
@pytest.mark.asyncio
async def test_extract_items_strips_forbidden_fields():
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = json.dumps([
{
"table": "tasks",
"data": {
"title": "Review PR",
"id": "should-be-removed",
"createdAt": 99999,
"isAiSuggested": 0,
"isApproved": 1,
},
}
])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
assert len(items) == 1
data = items[0]["data"]
assert "id" not in data
assert "createdAt" not in data
assert "isAiSuggested" not in data
assert "isApproved" not in data
assert data["title"] == "Review PR"
@pytest.mark.asyncio
async def test_extract_items_invalid_json_returns_empty():
"""LLM returning invalid JSON must return empty list without raising."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = "Sorry, I cannot extract anything."
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
assert items == []
@pytest.mark.asyncio
async def test_extract_items_disallowed_table_filtered():
"""Items whose table is not in data_types are discarded."""
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = json.dumps([
{"table": "tasks", "data": {"title": "Valid task"}},
{"table": "projects", "data": {"name": "Should be filtered"}},
])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
# Only "tasks" is in data_types — "projects" should be filtered.
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
assert len(items) == 1
assert items[0]["table"] == "tasks"
@pytest.mark.asyncio
async def test_extract_items_empty_data_types_returns_empty():
"""If no allowed data_types match, skip LLM call and return immediately."""
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock()
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
items = await _extract_items_from_content("Extract.", "content", [])
mock_llm.ainvoke.assert_not_called()
assert items == []
@pytest.mark.asyncio
async def test_extract_items_llm_error_propagates():
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
with pytest.raises(RuntimeError, match="API unavailable"):
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
# ---------------------------------------------------------------------------
# _send_insert_to_client
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_send_insert_to_client_happy_path():
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
mgr = _make_manager()
sent_payloads: list[dict] = []
original_send = mgr.send_frame
async def _capture_send(uid: str, frame: dict) -> None:
sent_payloads.append(frame)
# Immediately resolve the pending call with a success result.
call_id = frame["id"]
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
mgr.send_frame = _capture_send # type: ignore[method-assign]
result = await _send_insert_to_client(
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
)
assert len(sent_payloads) == 1
payload = sent_payloads[0]
assert payload["action"] == "insert"
assert payload["table"] == "tasks"
assert payload["data"]["title"] == "Buy milk"
assert payload["data"]["isAiSuggested"] == 1
assert payload["data"]["isApproved"] == 0
assert result["row"]["title"] == "Buy milk"
@pytest.mark.asyncio
async def test_send_insert_to_client_timeout():
"""asyncio.TimeoutError is raised when Electron does not respond."""
mgr = _make_manager()
async def _slow_send(uid: str, frame: dict) -> None:
# Never resolve the pending call.
pass
mgr.send_frame = _slow_send # type: ignore[method-assign]
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
with pytest.raises(asyncio.TimeoutError):
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
# ---------------------------------------------------------------------------
# run_local_agent
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_local_agent_device_offline():
"""run_local_agent marks run as error when device is offline."""
config = _make_local_config()
run_log = _make_run_log(config.id)
mgr = DeviceConnectionManager() # Empty — no device registered.
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_local_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once()
_args, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("not connected" in e for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_local_agent_happy_path():
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
config = _make_local_config()
run_log = _make_run_log(config.id)
mgr = _make_manager()
# Build a fake agent_data frame (will be queued after send).
file_frame = {
"type": "agent_data",
"run_id": run_log.id,
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
}
agent_complete_frame = None # sentinel
sent_frames: list[dict] = []
async def _mock_send(uid: str, frame: dict) -> None:
sent_frames.append(frame)
if frame.get("type") == "agent_run":
# Simulate Electron responding with file data then agent_complete.
q = mgr.get_agent_data_queue(uid, frame["run_id"])
await q.put(file_frame)
await q.put(agent_complete_frame)
elif frame.get("type") == "tool_call":
# Resolve the pending insert immediately.
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
mgr.send_frame = _mock_send # type: ignore[method-assign]
mock_llm = MagicMock()
mock_response = MagicMock()
mock_response.content = json.dumps([
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
])
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_local_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once()
_args, kwargs = mock_finalize.call_args
assert kwargs["status"] == "success"
assert kwargs["items_processed"] == 1
assert kwargs["items_created"] == 1
assert kwargs["errors"] == []
assert kwargs["update_config_last_run"] is False
# Verify agent_run frame was sent.
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
assert len(agent_run_frames) == 1
assert agent_run_frames[0]["agent_id"] == config.id
assert "paths" in agent_run_frames[0]["config"]
# Verify insert frame was sent with AI flags.
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
assert len(insert_frames) == 1
assert insert_frames[0]["data"]["isAiSuggested"] == 1
assert insert_frames[0]["data"]["isApproved"] == 0
@pytest.mark.asyncio
async def test_run_local_agent_file_read_timeout():
"""run_local_agent marks run as partial/error when device stops sending files."""
config = _make_local_config()
run_log = _make_run_log(config.id)
mgr = _make_manager()
async def _mock_send(uid: str, frame: dict) -> None:
# Don't put anything in the queue — simulate stalled device.
pass
mgr.send_frame = _mock_send # type: ignore[method-assign]
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_local_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once()
_args, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error" # No items created, so error (not partial).
assert any("timed out" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_local_agent_llm_extraction_error():
"""LLM errors per-file are recorded; run continues for remaining files."""
config = _make_local_config()
run_log = _make_run_log(config.id)
mgr = _make_manager()
file_frame = {
"type": "agent_data",
"run_id": run_log.id,
"files": [
{"path": "/file1.eml", "content": "Email one."},
{"path": "/file2.eml", "content": "Email two."},
],
}
async def _mock_send(uid: str, frame: dict) -> None:
if frame.get("type") == "agent_run":
q = mgr.get_agent_data_queue(uid, frame["run_id"])
await q.put(file_frame)
await q.put(None) # agent_complete sentinel
mgr.send_frame = _mock_send # type: ignore[method-assign]
mock_llm = MagicMock()
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_local_agent(_FREE_UID, config, run_log, mgr)
_args, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert kwargs["items_processed"] == 2 # Both files attempted.
assert kwargs["items_created"] == 0
assert len(kwargs["errors"]) == 2 # One error per file.
# ---------------------------------------------------------------------------
# run_cloud_agent (stub)
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_cloud_agent_device_offline():
"""Cloud agent aborts immediately when no device is connected."""
config = _make_cloud_config()
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = DeviceConnectionManager() # empty — no devices registered
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
mock_finalize.assert_called_once()
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_no_oauth_token():
"""Cloud agent errors when no OAuth token is stored."""
config = _make_cloud_config()
config.oauth_token_encrypted = None
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_token_decrypt_failure():
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
config = _make_cloud_config()
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
from cryptography.fernet import Fernet as _Fernet
valid_key = _Fernet.generate_key().decode()
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
patch("app.integrations.settings") as mock_settings:
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("decrypt" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_happy_path_gmail():
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
from app.integrations import EmailMessage, encrypt_token
from cryptography.fernet import Fernet as _Fernet
fernet_key = _Fernet.generate_key().decode()
credentials = {
"token": "access_abc",
"refresh_token": "refresh_xyz",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "cid",
"client_secret": "csec",
}
config = _make_cloud_config()
config.provider = "gmail"
config.prompt_template = "Extract tasks from this email."
config.data_types = ["tasks"]
with patch("app.integrations.settings") as ms:
ms.OAUTH_ENCRYPTION_KEY = fernet_key
config.oauth_token_encrypted = encrypt_token(credentials)
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
sample_email = EmailMessage(
id="msg001",
subject="Action required",
sender="boss@company.com",
body_text="Please fix the bug by Friday.",
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
)
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
with patch("app.integrations.settings") as mock_int_settings, \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
patch("app.core.agent_runner.async_session"):
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
mock_gmail = AsyncMock()
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
mock_gmail.refreshed_credentials = None
with patch("app.integrations.decrypt_token", return_value=credentials), \
patch("app.integrations.get_provider", return_value=mock_gmail):
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
mock_extract.assert_called_once()
mock_insert.assert_called_once()
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "success"
assert kwargs["items_processed"] == 1
assert kwargs["items_created"] == 1
assert kwargs["config_type"] == "cloud"
@pytest.mark.asyncio
async def test_run_cloud_agent_provider_fetch_error():
"""Cloud agent records error status when provider fetch raises RuntimeError."""
credentials = {"token": "abc"}
config = _make_cloud_config()
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
config.prompt_template = "Extract tasks."
config.data_types = ["tasks"]
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
mock_provider = AsyncMock()
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
mock_provider.refreshed_credentials = None
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
patch("app.integrations.decrypt_token", return_value=credentials), \
patch("app.integrations.get_provider", return_value=mock_provider), \
patch("app.core.agent_runner.async_session"):
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
_, kwargs = mock_finalize.call_args
assert kwargs["status"] == "error"
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
@pytest.mark.asyncio
async def test_run_cloud_agent_refreshed_token_persisted():
"""When the provider refreshes its token, the new ciphertext is written to DB."""
from app.integrations import EmailMessage, encrypt_token
from cryptography.fernet import Fernet as _Fernet
fernet_key = _Fernet.generate_key().decode()
credentials = {"token": "old_token", "refresh_token": "rt_old"}
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
config = _make_cloud_config()
config.prompt_template = "Extract tasks."
config.data_types = ["tasks"]
with patch("app.integrations.settings") as ms:
ms.OAUTH_ENCRYPTION_KEY = fernet_key
config.oauth_token_encrypted = encrypt_token(credentials)
run_log = _make_run_log(config.id, agent_type="cloud")
mgr = _make_manager()
mock_provider = AsyncMock()
mock_provider.fetch_messages = AsyncMock(return_value=[])
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
# Track DB writes via mock async_session.
mock_cfg_row = MagicMock()
mock_cfg_row.oauth_token_encrypted = None
mock_db = AsyncMock()
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
mock_db.__aexit__ = AsyncMock(return_value=False)
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
cfg_result = MagicMock()
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
mock_db.execute = AsyncMock(return_value=cfg_result)
mock_db.commit = AsyncMock()
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
patch("app.integrations.decrypt_token", return_value=credentials), \
patch("app.integrations.get_provider", return_value=mock_provider), \
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
patch("app.core.agent_runner.async_session", return_value=mock_db), \
patch("app.integrations.settings") as mock_int_settings:
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
# The new encrypted token should have been written to the config row.
mock_encrypt.assert_called_once_with(fresh_credentials)
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
@pytest.mark.asyncio
async def test_finalize_run_updates_cloud_config_last_run_at():
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
from app.core.agent_runner import _finalize_run
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
run_log.id = str(uuid.uuid4())
mock_cfg = MagicMock()
mock_cfg.last_run_at = None
cfg_result = MagicMock()
cfg_result.scalar_one_or_none.return_value = mock_cfg
mock_db = AsyncMock()
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
mock_db.__aexit__ = AsyncMock(return_value=False)
mock_db.merge = AsyncMock(return_value=run_log)
mock_db.execute = AsyncMock(return_value=cfg_result)
mock_db.commit = AsyncMock()
config_id = str(uuid.uuid4())
with patch("app.core.agent_runner.async_session", return_value=mock_db):
await _finalize_run(
run_log,
status="success",
update_config_last_run=True,
config_id=config_id,
config_type="cloud",
)
# CloudAgentConfig.last_run_at should have been set.
assert mock_cfg.last_run_at is not None
mock_db.commit.assert_called()
# ---------------------------------------------------------------------------
# trigger_pending_runs
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_trigger_pending_runs_no_overdue():
"""Pending-run scan is skipped because agent config is client-owned."""
mgr = _make_manager()
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called()
@pytest.mark.asyncio
async def test_trigger_pending_runs_device_id_filter():
"""Device filtering is no longer backend-managed in pending runs."""
mgr = _make_manager(device_id="dev-001")
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called()
@pytest.mark.asyncio
async def test_trigger_pending_runs_dispatches_overdue():
"""No pending runs are dispatched by backend after config deprecation."""
mgr = _make_manager()
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called()
# ---------------------------------------------------------------------------
# Integration: POST /agents/can-create and /agents/trigger
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _override_db(db_session):
"""Route all get_session calls to the test SQLite session."""
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
@pytest.mark.asyncio
async def test_can_create_agent_allows_when_under_limit(client):
"""POST /agents/can-create returns allowed=True when under tier limit."""
resp = client.post(
"/api/v1/agents/can-create",
json={"active_agents": 0},
headers=auth_header("free"),
)
assert resp.status_code == 200
body = resp.json()
assert body["allowed"] is True
assert body["tier"] == "free"
assert body["active_agents"] == 0
assert body["limit"] == 2
@pytest.mark.asyncio
async def test_can_create_agent_denies_when_at_limit(client):
"""POST /agents/can-create returns allowed=False at free-tier limit."""
resp = client.post(
"/api/v1/agents/can-create",
json={"active_agents": 2},
headers=auth_header("free"),
)
assert resp.status_code == 200
body = resp.json()
assert body["allowed"] is False
assert body["limit"] == 2
@pytest.mark.asyncio
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
"""POST /agents/trigger creates a local run log and dispatches background task."""
dispatched: list[tuple[str, str]] = []
async def _fake_run(user_id, cfg, run_log, device_mgr):
dispatched.append((user_id, cfg.id))
def _fake_create_task(coro):
coro.close()
return MagicMock()
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
patch("asyncio.create_task") as mock_create_task:
mock_create_task.side_effect = _fake_create_task
resp = client.post(
"/api/v1/agents/trigger",
json={
"directory": "/home/user/docs",
"what_to_extract": ["task", "note"],
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
"batch_interval": "0 */6 * * *",
"custom_agent_prompt": "Extract tasks and notes.",
"active_agents": 0,
},
headers=auth_header("power"),
)
assert resp.status_code == 202
data = resp.json()
assert isinstance(data["agent_id"], str)
assert data["agent_id"]
assert data["status"] == "running"
assert data["agent_type"] == "local"
# Verify create_task was called (dispatching background run).
mock_create_task.assert_called_once()

View File

@@ -40,7 +40,6 @@ from app.core.agent_runner import (
_format_projects, _format_projects,
_get_extraction_rules, _get_extraction_rules,
_get_no_match_behavior, _get_no_match_behavior,
_is_overdue,
run_local_agent, run_local_agent,
) )
from app.core.device_manager import DeviceConnectionManager from app.core.device_manager import DeviceConnectionManager
@@ -383,7 +382,6 @@ async def test_eval_runner(runner_case, pytestconfig):
await run_local_agent(_USER_ID, config, run_log, mgr) await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args _, kwargs = mock_fin.call_args
inserts = [c for c in calls if c["action"] == "insert"]
score, comment = _evaluate_case(case, calls, kwargs) score, comment = _evaluate_case(case, calls, kwargs)
if obs is not None: if obs is not None:

View File

@@ -1,243 +0,0 @@
"""Tests for the Chatbot Journey endpoints.
Covers:
1. Start journey for local agent → session_id + first question, done=False
2. Start journey for cloud agent → contextual email-focused question
3. Start journey with existing agent_id → session seeded, first question returned
4. Start journey with non-existent agent_id → still succeeds (graceful fallback)
5. Message: continue conversation → done=False, follow-up question returned
6. Message: LLM wraps up → done=True + prompt_template extracted correctly
7. Message with max-turns nudge → no crash, returns response
8. Invalid session_id → 404
9. Expired session → 404
10. Session ownership: user B cannot access user A's session
11. No JWT on /start → 401
12. No JWT on /message → 401
"""
from __future__ import annotations
import time
import uuid
from unittest.mock import AsyncMock, patch
import pytest
from fastapi.testclient import TestClient
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.routes.agent_setup import (
_SESSION_TTL_SECONDS,
_TEMPLATE_END,
_TEMPLATE_START,
_extract_template,
_sessions,
)
from app.models import LocalAgentConfig
from tests.conftest import TEST_USER_IDS, auth_header
# ── Helpers ──────────────────────────────────────────────────────────────
def _start(client: TestClient, agent_type: str = "local", agent_id: str | None = None, tier: str = "power") -> dict:
body: dict = {"agent_type": agent_type}
if agent_id:
body["agent_id"] = agent_id
resp = client.post("/api/v1/agents/journey/start", json=body, headers=auth_header(tier))
return resp
def _message(client: TestClient, session_id: str, message: str, tier: str = "power") -> dict:
return client.post(
"/api/v1/agents/journey/message",
json={"session_id": session_id, "message": message},
headers=auth_header(tier),
)
# ── Unit: _extract_template ───────────────────────────────────────────────
def test_extract_template_present():
text = f"Some preamble.\n{_TEMPLATE_START}\nExtract tasks from emails.\n{_TEMPLATE_END}\nTrailing text."
result = _extract_template(text)
assert result == "Extract tasks from emails."
def test_extract_template_absent():
assert _extract_template("No markers here.") is None
def test_extract_template_empty_content():
text = f"{_TEMPLATE_START}\n{_TEMPLATE_END}"
assert _extract_template(text) is None
# ── Start journey ─────────────────────────────────────────────────────────
def test_start_journey_local(client: TestClient):
resp = _start(client, agent_type="local")
assert resp.status_code == 200
body = resp.json()
assert "session_id" in body
assert body["done"] is False
assert body["prompt_template"] is None
assert len(body["message"]) > 0
# Local question should be about files/directories
assert any(w in body["message"].lower() for w in ("file", "director", "document", "monitor"))
def test_start_journey_cloud(client: TestClient):
resp = _start(client, agent_type="cloud")
assert resp.status_code == 200
body = resp.json()
assert body["done"] is False
# Cloud question should mention emails or messages
assert any(w in body["message"].lower() for w in ("email", "message", "communication"))
def test_start_journey_with_agent_id(client: TestClient, db_session: AsyncSession):
"""When agent_id is provided, session should be created even if agent doesn't exist."""
fake_agent_id = str(uuid.uuid4())
resp = _start(client, agent_type="local", agent_id=fake_agent_id)
# Should succeed gracefully even if the agent_id doesn't exist
assert resp.status_code == 200
body = resp.json()
assert body["done"] is False
def test_start_journey_with_existing_agent(client: TestClient, db_session: AsyncSession):
"""When a real local agent is provided, session is seeded with its prompt_template."""
import asyncio
user_id = TEST_USER_IDS["power"]
agent = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=user_id,
name="Test Agent",
device_id="device-1",
directory_paths=["/home/user/emails"],
data_types=["tasks"],
prompt_template="Extract tasks from .eml files.",
file_extensions=[".eml"],
schedule_cron="0 */6 * * *",
enabled=True,
)
async def _seed():
db_session.add(agent)
await db_session.commit()
asyncio.get_event_loop().run_until_complete(_seed())
resp = _start(client, agent_type="local", agent_id=agent.id)
assert resp.status_code == 200
body = resp.json()
assert body["done"] is False
# The session should be stored
assert body["session_id"] in _sessions
def test_start_journey_requires_auth(client: TestClient):
resp = client.post("/api/v1/agents/journey/start", json={"agent_type": "local"})
assert resp.status_code == 401
# ── Message ───────────────────────────────────────────────────────────────
def test_message_continues_conversation(client: TestClient):
"""A mid-journey reply (no template markers) returns done=False."""
follow_up = "That looks good. Can you tell me more about priority rules?"
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
start_resp = _start(client, agent_type="local")
assert start_resp.status_code == 200
session_id = start_resp.json()["session_id"]
msg_resp = _message(client, session_id, "I have .eml and .txt files")
assert msg_resp.status_code == 200
body = msg_resp.json()
assert body["done"] is False
assert body["prompt_template"] is None
assert body["message"] == follow_up
assert body["session_id"] == session_id
def test_message_produces_template(client: TestClient):
"""When the LLM includes PROMPT_TEMPLATE markers, done=True and prompt_template is set."""
final_template = "Extract tasks from email. Subject → title. 'urgent' → high priority."
llm_response = (
"Great, I have all the information I need.\n"
f"{_TEMPLATE_START}\n{final_template}\n{_TEMPLATE_END}\n"
)
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=llm_response)):
start_resp = _start(client, agent_type="cloud")
assert start_resp.status_code == 200
session_id = start_resp.json()["session_id"]
msg_resp = _message(client, session_id, "Only invoices from clients")
assert msg_resp.status_code == 200
body = msg_resp.json()
assert body["done"] is True
assert body["prompt_template"] == final_template
# Session should be cleaned up
assert session_id not in _sessions
def test_message_invalid_session(client: TestClient):
resp = _message(client, "nonexistent-session-id", "hello")
assert resp.status_code == 404
def test_message_wrong_owner(client: TestClient):
"""User B cannot access user A's session."""
start_resp = _start(client, agent_type="local", tier="power")
session_id = start_resp.json()["session_id"]
# user with "pro" tier (different user_id) tries to send a message
resp = client.post(
"/api/v1/agents/journey/message",
json={"session_id": session_id, "message": "hello"},
headers=auth_header("pro"), # different user
)
assert resp.status_code == 404
def test_message_expired_session(client: TestClient):
"""Expired sessions return 404."""
start_resp = _start(client, agent_type="local")
session_id = start_resp.json()["session_id"]
# Manually expire the session
_sessions[session_id].created_at = time.monotonic() - _SESSION_TTL_SECONDS - 1
resp = _message(client, session_id, "hello")
assert resp.status_code == 404
def test_message_requires_auth(client: TestClient):
resp = client.post(
"/api/v1/agents/journey/message",
json={"session_id": "any", "message": "hello"},
)
assert resp.status_code == 401
def test_message_max_turns_nudge(client: TestClient):
"""After _MAX_TURNS user messages, a system nudge is appended but no crash occurs."""
from app.api.routes.agent_setup import _MAX_TURNS
follow_up = "Tell me more about priority rules."
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
start_resp = _start(client, agent_type="local")
session_id = start_resp.json()["session_id"]
for i in range(_MAX_TURNS):
resp = _message(client, session_id, f"Answer {i + 1}")
assert resp.status_code == 200
# While no template produced, session must still exist
if resp.json()["done"]:
break # LLM decided to wrap up early — also fine

View File

@@ -1,4 +1,4 @@
"""Tests for auth routes: register, login, refresh, me. """Tests for auth routes: register, login, refresh, me, OAuth social login.
Exercises the full auth lifecycle through the FastAPI TestClient against the Exercises the full auth lifecycle through the FastAPI TestClient against the
in-memory SQLite test database seeded by ``conftest.py``. in-memory SQLite test database seeded by ``conftest.py``.
@@ -7,9 +7,11 @@ in-memory SQLite test database seeded by ``conftest.py``.
from __future__ import annotations from __future__ import annotations
import time import time
from unittest.mock import AsyncMock, patch
from jose import jwt from jose import jwt
from app.auth.oauth_providers import GoogleOAuthProvider, OAuthUserInfo
from app.config.settings import settings from app.config.settings import settings
from tests.conftest import auth_header, TEST_USER_IDS from tests.conftest import auth_header, TEST_USER_IDS
@@ -204,3 +206,153 @@ class TestMe:
token = jwt.encode(payload, "wrong-secret", algorithm="HS256") token = jwt.encode(payload, "wrong-secret", algorithm="HS256")
resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"}) resp = client.get("/api/v1/auth/me", headers={"Authorization": f"Bearer {token}"})
assert resp.status_code == 401 assert resp.status_code == 401
# ── TestOAuth ─────────────────────────────────────────────────────────
class TestOAuth:
"""GET /auth/oauth/google/authorize and POST /auth/oauth/google/callback."""
FAKE_PROVIDER_USER_ID = "google-sub-12345"
FAKE_EMAIL = "oauth@example.com"
FAKE_AVATAR = "https://lh3.googleusercontent.com/photo.jpg"
def _patch_google(self, monkeypatch) -> None:
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "fake-client-id")
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "fake-client-secret")
def _userinfo(
self,
email: str | None = None,
email_verified: bool = True,
) -> OAuthUserInfo:
return OAuthUserInfo(
provider_user_id=self.FAKE_PROVIDER_USER_ID,
email=email or self.FAKE_EMAIL,
email_verified=email_verified,
avatar_url=self.FAKE_AVATAR,
name="OAuth User",
)
def _authorize(self, client) -> str:
"""Call /authorize and return the fresh state token."""
resp = client.get("/api/v1/auth/oauth/google/authorize")
assert resp.status_code == 200
return resp.json()["state"]
def _callback(self, client, state: str, userinfo: OAuthUserInfo):
"""POST /callback with mocked provider exchange_code + get_userinfo."""
with (
patch.object(
GoogleOAuthProvider,
"exchange_code",
new=AsyncMock(return_value={"access_token": "google-access-tok"}),
),
patch.object(
GoogleOAuthProvider,
"get_userinfo",
new=AsyncMock(return_value=userinfo),
),
):
return client.post(
"/api/v1/auth/oauth/google/callback",
json={"code": "auth-code", "state": state},
)
def _decode_sub(self, access_token: str) -> str:
return jwt.decode(
access_token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)["sub"]
# -- authorize --
def test_authorize_returns_url_and_state(self, client, monkeypatch) -> None:
self._patch_google(monkeypatch)
resp = client.get("/api/v1/auth/oauth/google/authorize")
assert resp.status_code == 200
data = resp.json()
assert "url" in data and "state" in data
assert "accounts.google.com" in data["url"]
assert len(data["state"]) > 0
def test_authorize_unconfigured_returns_503(self, client, monkeypatch) -> None:
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_ID", "")
monkeypatch.setattr(settings, "GOOGLE_AUTH_CLIENT_SECRET", "")
resp = client.get("/api/v1/auth/oauth/google/authorize")
assert resp.status_code == 503
# -- callback --
def test_callback_state_mismatch_returns_401(self, client, monkeypatch) -> None:
self._patch_google(monkeypatch)
resp = client.post(
"/api/v1/auth/oauth/google/callback",
json={"code": "code", "state": "not-a-real-state"},
)
assert resp.status_code == 401
def test_callback_creates_new_user(self, client, monkeypatch) -> None:
"""First-time Google login creates a new user and returns valid tokens."""
self._patch_google(monkeypatch)
state = self._authorize(client)
resp = self._callback(client, state, self._userinfo())
assert resp.status_code == 200
data = resp.json()
assert "access_token" in data and "refresh_token" in data
payload = jwt.decode(
data["access_token"], settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
assert payload["email"] == self.FAKE_EMAIL
def test_callback_existing_oauth_link_logs_in(self, client, monkeypatch) -> None:
"""Second Google login with the same account re-uses the existing user."""
self._patch_google(monkeypatch)
userinfo = self._userinfo()
# First login — creates user + oauth_accounts row
resp1 = self._callback(client, self._authorize(client), userinfo)
assert resp1.status_code == 200
sub1 = self._decode_sub(resp1.json()["access_token"])
# Second login — finds existing oauth_accounts row → same user
resp2 = self._callback(client, self._authorize(client), userinfo)
assert resp2.status_code == 200
sub2 = self._decode_sub(resp2.json()["access_token"])
assert sub1 == sub2
def test_callback_email_match_links_account(self, client, monkeypatch) -> None:
"""Verified Google email matching an existing password user links the accounts."""
email = "link-target@example.com"
reg_resp = client.post(
"/api/v1/auth/register",
json={"email": email, "password": "TestPass123!"},
)
assert reg_resp.status_code == 201
orig_sub = self._decode_sub(reg_resp.json()["access_token"])
self._patch_google(monkeypatch)
state = self._authorize(client)
resp = self._callback(client, state, self._userinfo(email=email, email_verified=True))
assert resp.status_code == 200
oauth_sub = self._decode_sub(resp.json()["access_token"])
# OAuth login must resolve to the same user as the original registration
assert orig_sub == oauth_sub
def test_callback_unverified_email_conflict_returns_409(self, client, monkeypatch) -> None:
"""Unverified Google email matching an existing account returns 409, not 500."""
email = "conflict@example.com"
reg_resp = client.post(
"/api/v1/auth/register",
json={"email": email, "password": "TestPass123!"},
)
assert reg_resp.status_code == 201
self._patch_google(monkeypatch)
state = self._authorize(client)
resp = self._callback(client, state, self._userinfo(email=email, email_verified=False))
assert resp.status_code == 409

View File

@@ -1,243 +0,0 @@
"""Tests for backup routes: upload, download, history, delete.
Exercises the backup lifecycle through the FastAPI TestClient against the
in-memory SQLite test database and moto-mocked S3 bucket.
"""
from __future__ import annotations
import hashlib
from tests.conftest import auth_header, TEST_USER_IDS
# ── Helpers ───────────────────────────────────────────────────────────
_BLOB = b"encrypted-backup-blob-opaque-bytes"
_CHECKSUM = hashlib.sha256(_BLOB).hexdigest()
_VERSION = 1
_TIMESTAMP = 1700000000000 # arbitrary ms timestamp
def _backup_headers(tier: str = "power", **overrides) -> dict[str, str]:
"""Return auth + backup metadata headers."""
headers = auth_header(tier)
headers["X-Backup-Version"] = str(overrides.get("version", _VERSION))
headers["X-Backup-Timestamp"] = str(overrides.get("timestamp", _TIMESTAMP))
headers["X-Backup-Checksum"] = overrides.get("checksum", _CHECKSUM)
headers["Content-Type"] = "application/octet-stream"
return headers
def _upload(client, tier="power", **overrides) -> "Response": # noqa: F821
"""Upload a backup blob and return the response."""
return client.put(
"/api/v1/backup",
content=overrides.pop("blob", _BLOB),
headers=_backup_headers(tier, **overrides),
)
# ── TestUploadBackup ──────────────────────────────────────────────────
class TestUploadBackup:
"""PUT /api/v1/backup"""
def test_upload_success(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power")
assert resp.status_code == 200
assert resp.json() == {"ok": True}
def test_upload_creates_history_entry(self, client, s3_bucket) -> None:
_upload(client, tier="power")
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 1
assert history[0]["version"] == _VERSION
assert history[0]["timestamp"] == _TIMESTAMP
assert history[0]["checksum"] == _CHECKSUM
def test_upload_bad_checksum(self, client, s3_bucket) -> None:
resp = _upload(client, tier="power", checksum="0" * 64)
assert resp.status_code == 400
def test_upload_free_tier_blocked(self, client, s3_bucket) -> None:
"""Free tier has backup_gb=0 → should return 402."""
resp = _upload(client, tier="free")
assert resp.status_code == 402
def test_upload_pro_tier_allowed(self, client, s3_bucket) -> None:
"""Pro tier has backup_gb=5 → small blob succeeds."""
resp = _upload(client, tier="pro")
assert resp.status_code == 200
# ── TestDownloadBackup ────────────────────────────────────────────────
class TestDownloadBackup:
"""GET /api/v1/backup"""
def test_download_latest(self, client, s3_bucket) -> None:
_upload(client, tier="power")
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == _BLOB
assert resp.headers["X-Checksum"] == _CHECKSUM
assert resp.headers["X-Backup-Version"] == str(_VERSION)
def test_download_no_backup_returns_404(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 404
def test_download_if_modified_since_returns_304(self, client, s3_bucket) -> None:
"""When If-Modified-Since is after the backup timestamp → 304."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2099 00:00:00 GMT",
},
)
assert resp.status_code == 304
def test_download_if_modified_since_returns_200(self, client, s3_bucket) -> None:
"""When If-Modified-Since is before the backup timestamp → serve blob."""
_upload(client, tier="power", timestamp=1700000000000)
resp = client.get(
"/api/v1/backup",
headers={
**auth_header("power"),
"If-Modified-Since": "Thu, 01 Jan 2000 00:00:00 GMT",
},
)
assert resp.status_code == 200
assert resp.content == _BLOB
def test_download_multiple_returns_latest(self, client, s3_bucket) -> None:
"""When multiple backups exist, GET returns the one with the highest timestamp."""
_upload(client, tier="power", timestamp=1000)
blob2 = b"second-encrypted-backup"
checksum2 = hashlib.sha256(blob2).hexdigest()
_upload(client, tier="power", timestamp=2000, blob=blob2, checksum=checksum2)
resp = client.get("/api/v1/backup", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.content == blob2
# ── TestBackupHistory ─────────────────────────────────────────────────
class TestBackupHistory:
"""GET /api/v1/backup/history"""
def test_history_empty(self, client, s3_bucket) -> None:
resp = client.get("/api/v1/backup/history", headers=auth_header("power"))
assert resp.status_code == 200
assert resp.json() == []
def test_history_returns_entries(self, client, s3_bucket) -> None:
_upload(client, tier="power", timestamp=1000)
_upload(client, tier="power", timestamp=2000)
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert len(history) == 2
# Ordered by timestamp descending
assert history[0]["timestamp"] == 2000
assert history[1]["timestamp"] == 1000
def test_history_isolated_per_user(self, client, s3_bucket) -> None:
"""One user's backups should not appear in another user's history."""
_upload(client, tier="power")
resp = client.get("/api/v1/backup/history", headers=auth_header("team"))
assert resp.json() == []
# ── TestDeleteBackup ──────────────────────────────────────────────────
class TestDeleteBackup:
"""DELETE /api/v1/backup/{backup_id}"""
def _get_backup_id(self, client, tier="power") -> str:
"""Upload a backup and return its DB id from history."""
_upload(client, tier=tier)
client.get(
"/api/v1/backup/history", headers=auth_header(tier)
).json()
# History returns BackupMetadata schema which doesn't have `id`.
# We need to look it up via a different means.
# Since there's only 1 backup, find via history length.
# Actually the schema doesn't return id — let's verify via re-download.
# We'll use a workaround: upload, then list history to confirm it exists,
# then try to delete — but we need the id...
# Let's check if history includes an id field.
# The schema is: version, timestamp, checksum, chunk_count — no id.
# We'll need to query the DB directly or use a known ID.
# For testing, we'll search history then use the DB.
return None # pragma: no cover — overridden below
def test_delete_success(self, client, s3_bucket, db_session) -> None:
_upload(client, tier="power")
# Discover the backup_id via direct DB query
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("power")
)
assert resp.status_code == 200
assert resp.json() == {"ok": True}
# History should now be empty
history = client.get(
"/api/v1/backup/history", headers=auth_header("power")
).json()
assert history == []
def test_delete_nonexistent(self, client, s3_bucket) -> None:
resp = client.delete(
"/api/v1/backup/no-such-id", headers=auth_header("power")
)
assert resp.status_code == 404
def test_delete_other_users_backup(self, client, s3_bucket, db_session) -> None:
"""Cannot delete another user's backup (ownership check returns 404)."""
_upload(client, tier="power")
import asyncio
from sqlalchemy import select
from app.models import BackupMetadata
async def _get_id():
result = await db_session.execute(
select(BackupMetadata.id).where(
BackupMetadata.user_id == TEST_USER_IDS["power"]
)
)
return result.scalar_one()
backup_id = asyncio.get_event_loop().run_until_complete(_get_id())
# team user tries to delete power user's backup → 404
resp = client.delete(
f"/api/v1/backup/{backup_id}", headers=auth_header("team")
)
assert resp.status_code == 404

163
tests/test_brief_agent.py Normal file
View File

@@ -0,0 +1,163 @@
"""Tests for Phase 3: brief agent WS frame + REST fallback.
Coverage:
- run_home_brief streams non-empty text (mocked _run_single_agent_stream)
- run_project_brief with bogus UUID → WS returns stream_end with error, no crash
- _build_read_tools uses read-only subset only (no mutating tools)
- POST /chat/brief home mode returns {response: "..."}
- POST /chat/brief project mode with invalid UUID → 422
"""
from __future__ import annotations
import uuid
from collections.abc import AsyncGenerator
from typing import Any
from unittest.mock import AsyncMock, patch
import pytest
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
_USER_ID = TEST_USER_IDS["pro"]
_EMPTY_CONTEXT: dict[str, Any] = {"core_memory": {}}
async def _fake_token_stream(*_args, **_kwargs) -> AsyncGenerator[tuple[str, Any], None]:
"""Fake _run_single_agent_stream that yields two token events."""
yield ("token", "Hello")
yield ("token", " world")
# ---------------------------------------------------------------------------
# Unit: run_home_brief streams non-empty text
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_home_brief_streams_text():
with patch(
"app.core.brief_agent._run_single_agent_stream",
side_effect=_fake_token_stream,
):
from app.core.brief_agent import run_home_brief
chunks: list[str] = []
async for event_type, data in run_home_brief(_USER_ID, _EMPTY_CONTEXT):
if event_type == "token":
chunks.append(str(data))
assert "".join(chunks) == "Hello world"
# ---------------------------------------------------------------------------
# Unit: run_project_brief streams text with valid UUID
# ---------------------------------------------------------------------------
@pytest.mark.asyncio
async def test_run_project_brief_streams_text():
project_id = str(uuid.uuid4())
with patch(
"app.core.brief_agent._run_single_agent_stream",
side_effect=_fake_token_stream,
):
from app.core.brief_agent import run_project_brief
chunks: list[str] = []
async for event_type, data in run_project_brief(_USER_ID, project_id, _EMPTY_CONTEXT):
if event_type == "token":
chunks.append(str(data))
assert "".join(chunks) == "Hello world"
# ---------------------------------------------------------------------------
# Unit: _build_read_tools uses read-only subset (no write tools)
# ---------------------------------------------------------------------------
def test_build_read_tools_read_only_subset():
from app.agents.note_agent import NOTE_READ_TOOLS
from app.agents.project_agent import PROJECT_READ_TOOLS
from app.agents.task_agent import TASK_READ_TOOLS
from app.agents.timeline_agent import TIMELINE_READ_TOOLS
from app.core.brief_agent import _build_read_tools
tools = _build_read_tools(_USER_ID, None)
tool_names = {getattr(t, "name", None) or getattr(t, "__name__", str(t)) for t in tools}
# Read-only exports must be present.
for read_list in (TASK_READ_TOOLS, PROJECT_READ_TOOLS, TIMELINE_READ_TOOLS, NOTE_READ_TOOLS):
for t in read_list:
name = getattr(t, "name", None) or getattr(t, "__name__", str(t))
assert name in tool_names, f"Read tool {name!r} missing from _build_read_tools"
# No mutating tools (e.g. create_task, update_task, delete_task).
mutating = {"create_task", "update_task", "delete_task", "create_project",
"update_project", "delete_project", "create_note", "update_note",
"delete_note", "memory_add", "memory_update", "memory_delete"}
overlap = tool_names & mutating
assert not overlap, f"Mutating tools in brief read-only subset: {overlap}"
# ---------------------------------------------------------------------------
# Integration: POST /chat/brief — home mode
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def _override_db(db_session):
from app.db import get_session
from app.main import app
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
@pytest.mark.asyncio
async def test_rest_brief_home_returns_response(client):
async def _fake_home_brief(user_id, context):
yield ("token", "Today looks light.")
with (
patch("app.api.routes.chat.run_home_brief", side_effect=_fake_home_brief),
patch(
"app.api.routes.chat.MemoryMiddleware.enrich_context",
new=AsyncMock(return_value={}),
),
):
res = client.post(
"/api/v1/chat/brief",
json={"mode": "home"},
headers=auth_header("pro"),
)
assert res.status_code == 200
data = res.json()
assert data["response"] == "Today looks light."
@pytest.mark.asyncio
async def test_rest_brief_project_invalid_uuid_returns_422(client):
res = client.post(
"/api/v1/chat/brief",
json={"mode": "project", "project_id": "not-a-uuid"},
headers=auth_header("pro"),
)
assert res.status_code == 422
@pytest.mark.asyncio
async def test_rest_brief_project_missing_uuid_returns_422(client):
res = client.post(
"/api/v1/chat/brief",
json={"mode": "project"},
headers=auth_header("pro"),
)
assert res.status_code == 422

View File

@@ -1,184 +0,0 @@
"""Unit tests for Step 1 file classification (_classify_file).
These tests call the real LLM so they require OPENAI_API_KEY / LLM env vars.
Run with: pytest tests/test_classify_file.py -v
To run a quick manual check against a real file without the full UI:
python -m tests.test_classify_file <path/to/file.txt> [project_name...]
"""
from __future__ import annotations
import asyncio
import sys
import pytest
from app.core.agent_runner import _classify_file
# ── Fixtures ──────────────────────────────────────────────────────────────
PROJECTS_SAMPLE = [
{
"id": "aaaa-0001-0000-0000-000000000001",
"name": "ARPA Sicilia POC",
"status": "active",
"aiSummary": "Proof of concept for AI features targeting ARPA Sicilia agency.",
},
{
"id": "bbbb-0002-0000-0000-000000000002",
"name": "SNAM AI Meeting Prep",
"status": "active",
"aiSummary": "AI-assisted preparation of meeting materials for SNAM.",
},
{
"id": "cccc-0003-0000-0000-000000000003",
"name": "SFERA+ Wave 2",
"status": "active",
"aiSummary": "Second wave of the SFERA+ whitelist project.",
},
]
ARPA_EMAIL = """\
to: roberto.musso@hpe.com; luca.tondin@hpecds.com
isImportance: normal
hasAttachment: True
---
## Body
Buongiorno,
In riferimento alla riunione di ieri sul POC ARPA Sicilia, vi invio il riassunto
dei deliverable concordati:
- Preparare demo entro il 30 marzo
- Condividere documentazione tecnica con il team ARPA
- Fissare call di follow-up la prossima settimana
Cordiali saluti
Roberto Marchetti
"""
SNAM_EMAIL = """\
to: roberto.musso@hpe.com
isImportance: high
hasAttachment: False
---
## Body
Ciao,
ti invio l'agenda per la riunione SNAM di domani.
Per favore conferma la tua presenza.
"""
UNRELATED_EMAIL = """\
to: roberto.musso@hpe.com
isImportance: normal
---
## Body
Benvenuto nel programma HPE Employee Learning Series.
Completa la formazione richiesta entro la fine del trimestre.
"""
# ── Tests ─────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_classify_arpa_matches_existing():
project_id, domains, new_name = await _classify_file(
file_path="arpa_email.txt",
file_content=ARPA_EMAIL,
projects=PROJECTS_SAMPLE,
config_data_types=["tasks", "notes", "timelines"],
)
assert project_id == "aaaa-0001-0000-0000-000000000001", (
f"Expected ARPA project, got project_id={project_id!r} new_name={new_name!r}"
)
assert new_name is None
@pytest.mark.asyncio
async def test_classify_snam_matches_existing():
project_id, domains, new_name = await _classify_file(
file_path="snam_email.txt",
file_content=SNAM_EMAIL,
projects=PROJECTS_SAMPLE,
config_data_types=["tasks", "notes"],
)
assert project_id == "bbbb-0002-0000-0000-000000000002", (
f"Expected SNAM project, got project_id={project_id!r} new_name={new_name!r}"
)
@pytest.mark.asyncio
async def test_classify_unrelated_returns_new():
project_id, domains, new_name = await _classify_file(
file_path="learning_email.txt",
file_content=UNRELATED_EMAIL,
projects=PROJECTS_SAMPLE,
config_data_types=["tasks", "notes"],
)
assert project_id == "new"
assert new_name is not None # LLM should suggest a name
@pytest.mark.asyncio
async def test_classify_empty_file_returns_new():
project_id, domains, new_name = await _classify_file(
file_path="empty.txt",
file_content=" ",
projects=PROJECTS_SAMPLE,
config_data_types=["tasks"],
)
assert project_id == "new"
@pytest.mark.asyncio
async def test_classify_no_projects_returns_new():
project_id, domains, new_name = await _classify_file(
file_path="arpa_email.txt",
file_content=ARPA_EMAIL,
projects=[],
config_data_types=["tasks", "notes"],
)
assert project_id == "new"
assert new_name is not None
# ── CLI quick-test runner ─────────────────────────────────────────────────
async def _cli_test(file_path: str, project_names: list[str]) -> None:
"""Run Step 1 classification against a real file from the CLI."""
import json
from pathlib import Path
content = Path(file_path).read_text(encoding="utf-8", errors="replace")
projects = [
{"id": f"test-id-{i:04d}", "name": name, "status": "active", "aiSummary": ""}
for i, name in enumerate(project_names)
]
print(f"\nClassifying: {file_path}")
print(f"Projects in context: {[p['name'] for p in projects]}\n")
project_id, domains, new_name = await _classify_file(
file_path=file_path,
file_content=content,
projects=projects,
config_data_types=["tasks", "notes", "timelines"],
)
result = {
"project_id": project_id,
"matched_name": next((p["name"] for p in projects if p["id"] == project_id), None),
"new_project_name": new_name,
"domains": domains,
}
print(json.dumps(result, indent=2, ensure_ascii=False))
if __name__ == "__main__":
if len(sys.argv) < 2:
print("Usage: python -m tests.test_classify_file <file_path> [project_name ...]")
sys.exit(1)
asyncio.run(_cli_test(sys.argv[1], sys.argv[2:]))

View File

@@ -10,8 +10,11 @@ import pytest
from langchain_core.messages import AIMessage, ToolMessage from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import ( from app.core.deep_agent import (
_build_system_prompt,
_datetime_context_injection,
_infer_floating_domain, _infer_floating_domain,
_normalize_tagged_list_lines, _normalize_tagged_list_lines,
_request_context_block,
run_floating, run_floating,
run_floating_stream, run_floating_stream,
run_home, run_home,
@@ -63,7 +66,7 @@ class _FakeLLM:
async def test_run_home_uses_mocked_tool_result(): async def test_run_home_uses_mocked_tool_result():
fake_llm = _FakeLLM() fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch( with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()] "app.core.deep_agent._all_tools", return_value=[_FakeTool()]
): ):
out = await run_home("user-1", "list my tasks", {}) out = await run_home("user-1", "list my tasks", {})
@@ -76,7 +79,7 @@ async def test_run_home_uses_mocked_tool_result():
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result(): async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
fake_llm = _FakeLLM() fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch( with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()] "app.core.deep_agent._all_tools", return_value=[_FakeTool()]
): ):
events = [] events = []
@@ -91,8 +94,12 @@ async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_res
"floating_domain", "floating_domain",
{"type": "timeline", "id": "tl-1", "section": None}, {"type": "timeline", "id": "tl-1", "section": None},
) )
assert ("token", "stream-") in events # _run_single_agent_stream uses ainvoke (not astream); the final token is
assert ("token", "ok") in events # the second LLM response which echoes the tool result.
token_events = [e for e in events if e[0] == "token"]
assert token_events, "Expected at least one token event"
combined = "".join(str(e[1]) for e in token_events)
assert "Mock Task" in combined
@pytest.mark.asyncio @pytest.mark.asyncio
@@ -103,7 +110,7 @@ async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}' content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
) )
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()): with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain( domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X", "Quali sono i miei task per il progetto X",
{ {
@@ -165,7 +172,7 @@ async def test_run_floating_strips_xml_like_tags_from_final_text():
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>" "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
) )
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch( with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent "app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
): ):
text, _domain = await run_floating( text, _domain = await run_floating(
@@ -187,7 +194,7 @@ async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
yield "token", "Hai 1 task:\\n" yield "token", "Hai 1 task:\\n"
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>" yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch( with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream "app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
): ):
events = [] events = []
@@ -233,7 +240,7 @@ async def test_run_floating_stream_falls_back_to_final_response_content_when_ast
if False: if False:
yield None yield None
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch( with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()] "app.core.deep_agent._all_tools", return_value=[_FakeTool()]
): ):
events = [] events = []
@@ -255,7 +262,7 @@ async def test_run_floating_returns_fallback_when_sanitization_would_empty_text(
async def _fake_run_single_agent(**_kwargs): async def _fake_run_single_agent(**_kwargs):
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>" return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch( with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent "app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
): ):
text, _domain = await run_floating( text, _domain = await run_floating(
@@ -274,7 +281,7 @@ async def test_run_floating_stream_returns_fallback_when_sanitization_would_empt
async def _fake_stream(**_kwargs): async def _fake_stream(**_kwargs):
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>" yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch( with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream "app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
): ):
events = [] events = []
@@ -286,3 +293,213 @@ async def test_run_floating_stream_returns_fallback_when_sanitization_would_empt
events.append(event) events.append(event)
assert ("token", "No results found.") in events assert ("token", "No results found.") in events
# ── _datetime_context_injection ────────────────────────────────────────────────
def _fp(tz: str, now_iso: str) -> dict:
return {"timezone": tz, "now_iso": now_iso, "date_format": "dd/MM/yyyy", "time_format": "24h"}
def _parse_ms(block: str, key: str) -> tuple[int, int]:
"""Extract [start, end] from a 'key [start, end]' line in the DATE CONTEXT block."""
import re
m = re.search(rf"^{key}\s+\[(\d+),\s*(\d+)\]", block, re.MULTILINE)
assert m, f"Key '{key}' not found in block:\n{block}"
return int(m.group(1)), int(m.group(2))
def test_datetime_context_injection_europe_rome_late_evening():
"""22:16 CEST on 2026-04-26 — 'tomorrow' must be 2026-04-27 00:00→23:59:59.999 CEST."""
from zoneinfo import ZoneInfo
from datetime import datetime, timezone
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z")})
assert "DATE CONTEXT" in block
assert "Europe/Rome" in block
tz = ZoneInfo("Europe/Rome")
today_start = int(datetime(2026, 4, 26, tzinfo=tz).timestamp() * 1000)
today_end = int(datetime(2026, 4, 27, tzinfo=tz).timestamp() * 1000) - 1
tomorrow_start = today_end + 1
tomorrow_end = int(datetime(2026, 4, 28, tzinfo=tz).timestamp() * 1000) - 1
t_s, t_e = _parse_ms(block, "today")
assert t_s == today_start
assert t_e == today_end
tm_s, tm_e = _parse_ms(block, "tomorrow")
assert tm_s == tomorrow_start
assert tm_e == tomorrow_end
# Sanity: window is exactly 86 400 000 ms (1 day, CEST has no DST jump on this date)
assert today_end - today_start + 1 == 86_400_000
assert tomorrow_end - tomorrow_start + 1 == 86_400_000
def test_datetime_context_injection_utc():
"""UTC timezone: boundaries are clean UTC midnights."""
from datetime import datetime, timezone
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-01-15T10:00:00Z")})
t_s, t_e = _parse_ms(block, "today")
expected_start = int(datetime(2026, 1, 15, tzinfo=timezone.utc).timestamp() * 1000)
assert t_s == expected_start
assert t_e == expected_start + 86_400_000 - 1
def test_datetime_context_injection_dst_spring_forward():
"""Europe/Rome DST spring-forward 2026-03-29: that day is 23h, not 24h."""
from zoneinfo import ZoneInfo
from datetime import datetime
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-03-29T08:00:00Z")})
tz = ZoneInfo("Europe/Rome")
day_start = int(datetime(2026, 3, 29, tzinfo=tz).timestamp() * 1000)
day_end = int(datetime(2026, 3, 30, tzinfo=tz).timestamp() * 1000) - 1
t_s, t_e = _parse_ms(block, "today")
assert t_s == day_start
assert t_e == day_end
assert t_e - t_s + 1 == 23 * 3_600_000 # 23-hour day
def test_datetime_context_injection_dst_fall_back():
"""Europe/Rome DST fall-back 2026-10-25: that day is 25h."""
from zoneinfo import ZoneInfo
from datetime import datetime
block = _datetime_context_injection({"format_prefs": _fp("Europe/Rome", "2026-10-25T08:00:00Z")})
tz = ZoneInfo("Europe/Rome")
day_start = int(datetime(2026, 10, 25, tzinfo=tz).timestamp() * 1000)
day_end = int(datetime(2026, 10, 26, tzinfo=tz).timestamp() * 1000) - 1
t_s, t_e = _parse_ms(block, "today")
assert t_s == day_start
assert t_e == day_end
assert t_e - t_s + 1 == 25 * 3_600_000 # 25-hour day
def test_datetime_context_injection_year_boundary():
"""Dec 31 → Jan 1: last_year, this_year, next_month cross year boundary correctly."""
from zoneinfo import ZoneInfo
from datetime import datetime
block = _datetime_context_injection({"format_prefs": _fp("UTC", "2026-12-31T23:00:00Z")})
tz = ZoneInfo("UTC")
yr_s, yr_e = _parse_ms(block, "this_year")
assert yr_s == int(datetime(2026, 1, 1, tzinfo=tz).timestamp() * 1000)
assert yr_e == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000) - 1
ly_s, ly_e = _parse_ms(block, "last_year")
assert ly_s == int(datetime(2025, 1, 1, tzinfo=tz).timestamp() * 1000)
assert ly_e == yr_s - 1
nm_s, _ = _parse_ms(block, "next_month")
assert nm_s == int(datetime(2027, 1, 1, tzinfo=tz).timestamp() * 1000)
def test_datetime_context_injection_missing_format_prefs():
assert _datetime_context_injection({}) == ""
assert _datetime_context_injection({"format_prefs": None}) == ""
assert _datetime_context_injection({"format_prefs": "bad"}) == ""
# ── _request_context_block ─────────────────────────────────────────────────────
def test_request_context_block_scope_and_project():
ctx = {"scope": {"type": "task", "id": "t-1"}, "resolved_project_id": "proj-uuid"}
block = _request_context_block(ctx)
assert "scope" in block
assert "resolved_project_id: proj-uuid" in block
def test_request_context_block_empty():
assert _request_context_block({}) == ""
assert _request_context_block({"scope": None}) == ""
# ── _build_system_prompt ───────────────────────────────────────────────────────
def test_build_system_prompt_substitutes_all_slots(monkeypatch):
"""All five slots must appear in the compiled output; no raw placeholder remains."""
# Patch get_prompt_or_fallback to return None prompt_obj so we use fallback .format() path
import app.core.deep_agent as da
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
ctx = {
"format_prefs": _fp("Europe/Rome", "2026-04-26T20:16:02.155Z"),
"core_memory": {"language": "it"},
"relational_memory": ["Alice — client"],
"proactive_hints": ["User prefers morning meetings"],
"scope": {"type": "task"},
"resolved_project_id": "proj-1",
}
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, ctx)
# No unresolved placeholders
assert "{date_context}" not in text
assert "{language_instruction}" not in text
assert "{relational_memory}" not in text
assert "{proactive_hints}" not in text
assert "{request_context}" not in text
# Content was injected
assert "DATE CONTEXT" in text
assert "Italian" in text
assert "Alice" in text
assert "morning meetings" in text
assert "proj-1" in text
def test_build_system_prompt_empty_format_prefs(monkeypatch):
"""Missing format_prefs must not raise — date_context slot renders empty string."""
import app.core.deep_agent as da
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda name, fallback: (fallback, None))
from app.core.deep_agent import _HOME_SYSTEM_PROMPT
text, _ = _build_system_prompt("home_system", _HOME_SYSTEM_PROMPT, {})
# Prompt renders without error; date section is empty but structure holds
assert "# Date filtering" in text
assert "{date_context}" not in text
def test_human_message_is_bare_message(monkeypatch):
"""After the refactor HumanMessage content must equal the raw user message exactly."""
import app.core.deep_agent as da
from langchain_core.messages import HumanMessage as LCHumanMessage
captured: list[list] = []
class _CaptureLLM:
def bind_tools(self, _):
return self
async def ainvoke(self, messages):
captured.append(list(messages))
return AIMessage(content="risposta")
monkeypatch.setattr(da, "get_prompt_or_fallback", lambda n, f: (f, None))
monkeypatch.setattr(da, "get_agent_llm", lambda _: _CaptureLLM())
monkeypatch.setattr(da, "_all_tools_for_user", lambda *_: [])
monkeypatch.setattr(da, "get_langfuse", lambda: None)
monkeypatch.setattr(da, "set_tool_result_collector", lambda _: None)
monkeypatch.setattr(da, "clear_tool_result_collector", lambda: None)
import asyncio
async def _run():
chunks = []
ctx = {"format_prefs": _fp("UTC", "2026-04-27T10:00:00Z")}
async for ev in da.run_home_stream("u1", "Cosa devo fare domani?", ctx):
chunks.append(ev)
asyncio.get_event_loop().run_until_complete(_run())
assert captured, "LLM was never called"
messages = captured[0]
human = next(m for m in messages if isinstance(m, LCHumanMessage))
assert human.content == "Cosa devo fare domani?"
assert "Context:" not in human.content

View File

@@ -18,13 +18,12 @@ from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pytest_asyncio
from app.core.device_manager import DeviceConnection, DeviceConnectionManager from app.core.device_manager import DeviceConnectionManager
from app.db import get_session from app.db import get_session
from app.main import app from app.main import app
from app.models import AgentRunLog from app.models import AgentRunLog
from tests.conftest import TEST_USER_IDS, auth_header, make_jwt from tests.conftest import TEST_USER_IDS, make_jwt
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Helpers # Helpers
@@ -157,40 +156,6 @@ async def test_manager_unregister_cancels_pending_calls(manager, mock_ws):
assert fut.cancelled() assert fut.cancelled()
@pytest.mark.asyncio
async def test_manager_agent_data_queue(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
q = manager.get_agent_data_queue("user1", "run-xyz")
# Put a frame and get it back.
frame = {"type": "agent_data", "run_id": "run-xyz", "files": []}
await q.put(frame)
assert await q.get() == frame
@pytest.mark.asyncio
async def test_manager_agent_data_queue_creates_once(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
q1 = manager.get_agent_data_queue("user1", "run-1")
q2 = manager.get_agent_data_queue("user1", "run-1")
assert q1 is q2
@pytest.mark.asyncio
async def test_manager_agent_data_queue_raises_when_offline(manager):
with pytest.raises(RuntimeError, match="not connected"):
manager.get_agent_data_queue("ghost", "run-1")
@pytest.mark.asyncio
async def test_manager_cleanup_agent_data_queue(manager, mock_ws):
manager.register("user1", "dev-A", mock_ws)
manager.get_agent_data_queue("user1", "run-1")
manager.cleanup_agent_data_queue("user1", "run-1")
# After cleanup a new queue is created (not the same object).
q_new = manager.get_agent_data_queue("user1", "run-1")
assert q_new is not None
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Integration tests — /api/v1/ws/device endpoint # Integration tests — /api/v1/ws/device endpoint
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -236,7 +201,6 @@ def test_ws_device_invalid_first_frame_closes(client):
def test_ws_device_tool_result_dispatched(client): def test_ws_device_tool_result_dispatched(client):
"""tool_result frame is routed to the DeviceConnectionManager.""" """tool_result frame is routed to the DeviceConnectionManager."""
token = make_jwt(tier="free") token = make_jwt(tier="free")
user_id = TEST_USER_IDS["free"]
from app.core.device_manager import device_manager as dm from app.core.device_manager import device_manager as dm
@@ -267,43 +231,6 @@ def test_ws_device_tool_result_dispatched(client):
assert any(c["call_id"] == "call-123" for c in captured) assert any(c["call_id"] == "call-123" for c in captured)
def test_ws_device_agent_data_enqueued(client):
"""agent_data frame is placed in the per-run queue by the message loop."""
from app.core.device_manager import device_manager as dm
token = make_jwt(tier="free")
user_id = TEST_USER_IDS["free"]
# Capture the queue object the message loop accesses.
captured_queue: list[asyncio.Queue] = []
original_get_queue = dm.get_agent_data_queue
def _spy_get_queue(uid, run_id):
q = original_get_queue(uid, run_id)
if not captured_queue:
captured_queue.append(q)
return q
with patch.object(dm, "get_agent_data_queue", side_effect=_spy_get_queue):
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(_device_hello("dev-001"))
ws.send_text(
json.dumps(
{
"type": "agent_data",
"run_id": "run-XYZ",
"files": [{"path": "/tmp/file.txt", "content": "hello"}],
}
)
)
ws.close()
# The queue should have received exactly one frame.
assert captured_queue, "queue was never accessed"
assert not captured_queue[0].empty()
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session): def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
"""On disconnect, _mark_runs_disconnected is called with the correct user_id.""" """On disconnect, _mark_runs_disconnected is called with the correct user_id."""
from app.api.routes import device_ws as _dws from app.api.routes import device_ws as _dws

View File

@@ -0,0 +1,139 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.agents.folder_agent import (
read_project_folder_file,
search_project_folder_file,
)
pytestmark = pytest.mark.asyncio
async def test_happy_path():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "file body", "kind": "text", "totalSize": 9}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "docs/x.md"})
assert "file body" in out
assert "kind=text" in out
async def test_traversal_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "../../etc/passwd"})
assert out == "Access denied"
async def test_absolute_path_rejected():
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "C:\\Windows\\foo"})
assert out == "Access denied"
async def test_missing_file():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "", "kind": "missing", "totalSize": 0}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "ghost.md"})
assert "not found" in out.lower()
async def test_pagination_signals_more_available():
# Electron returned the first slice, totalSize larger than slice length.
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "first chunk", "kind": "text", "totalSize": 1000}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "big.txt",
"offset": 0,
"length": 11,
})
assert "first chunk" in out
assert "More content available" in out
assert "offset=11" in out
async def test_pdf_extracted_then_sliced(monkeypatch):
from app.agents import folder_agent
monkeypatch.setattr(folder_agent, "_extract_pdf_text", lambda b: "ABC " * 100)
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "JVBERi0xLg==", "kind": "pdf", "totalSize": 12}),
):
out = await read_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "doc.pdf",
"offset": 0,
"length": 8,
})
assert "kind=pdf" in out
assert "ABC ABC " in out
assert "More content available" in out
async def test_image_returns_placeholder():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "iVBORw0K", "kind": "image", "totalSize": 1024}),
):
out = await read_project_folder_file.ainvoke({"project_id": "p1", "relative_path": "logo.png"})
assert "image" in out.lower()
async def test_search_finds_match_with_context():
body = "alpha\nbeta\nthe needle is here\ngamma\ndelta"
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": body, "kind": "text", "totalSize": len(body)}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "log.txt",
"query": "needle",
"context_lines": 1,
})
assert "needle" in out
assert "matches=1" in out
# Context lines included
assert "beta" in out
assert "gamma" in out
async def test_search_no_match():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "nothing here", "kind": "text", "totalSize": 12}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "x.txt",
"query": "zzz",
})
assert "No matches" in out
async def test_search_rejects_traversal():
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "../etc/passwd",
"query": "root",
})
assert out == "Access denied"
async def test_search_image_rejected():
with patch(
"app.agents.folder_agent.execute_on_client",
new=AsyncMock(return_value={"content": "b64data", "kind": "image", "totalSize": 100}),
):
out = await search_project_folder_file.ainvoke({
"project_id": "p1",
"relative_path": "logo.png",
"query": "anything",
})
assert "Cannot search" in out

View File

@@ -0,0 +1,83 @@
"""Folder indexer LLM helpers."""
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.folder_indexer import summarize_text, summarize_image, IndexResult
pytestmark = pytest.mark.asyncio
async def test_summarize_text_returns_summary_and_tokens():
mock_resp = AsyncMock()
mock_resp.content = "Kickoff notes covering scope and deadlines."
mock_resp.usage_metadata = {"input_tokens": 320, "output_tokens": 18, "total_tokens": 338}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="hello world", ext=".md", name="kickoff.md")
assert isinstance(result, IndexResult)
assert result.summary == "Kickoff notes covering scope and deadlines."
assert result.tokens_used == 338
async def test_summarize_text_truncates_summary_at_500_chars():
mock_resp = AsyncMock()
mock_resp.content = "x" * 1000
mock_resp.usage_metadata = {"total_tokens": 100}
with patch("app.core.folder_indexer._llm_text", new=AsyncMock(return_value=mock_resp)):
result = await summarize_text(content="x", ext=".md", name="x.md")
assert len(result.summary) <= 500
async def test_summarize_image_uses_vision_content_blocks():
mock_resp = AsyncMock()
mock_resp.content = "Final logo on white background."
mock_resp.usage_metadata = {"total_tokens": 500}
captured = {}
async def fake_llm_vision(messages):
captured["messages"] = messages
return mock_resp
with patch("app.core.folder_indexer._llm_vision", new=fake_llm_vision):
result = await summarize_image(image_b64="iVBORw0KG", mime="image/png")
assert "Final logo" in result.summary
assert result.tokens_used == 500
# last message contains an image content block
last = captured["messages"][-1]
assert any(
isinstance(p, dict) and p.get("type") == "image_url"
for p in (last.content if isinstance(last.content, list) else [])
)
async def test_summarize_pdf_extracts_then_summarizes(monkeypatch):
# pypdf.PdfReader returns text from pages
from app.core import folder_indexer
class FakePage:
def extract_text(self): return "PDF page content with project info."
class FakeReader:
pages = [FakePage(), FakePage()]
monkeypatch.setattr(folder_indexer, "PdfReader", lambda buf: FakeReader())
mock_resp = AsyncMock(); mock_resp.content = "Project info doc."; mock_resp.usage_metadata = {"total_tokens": 50}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_pdf(pdf_b64="SGVsbG8=", name="doc.pdf")
assert "Project info" in result.summary
assert result.tokens_used == 50
async def test_summarize_docx_extracts_then_summarizes(monkeypatch):
from app.core import folder_indexer
class FakePara:
def __init__(self, t): self.text = t
class FakeDoc:
paragraphs = [FakePara("Heading"), FakePara("Body paragraph one.")]
monkeypatch.setattr(folder_indexer, "DocxDocument", lambda buf: FakeDoc())
mock_resp = AsyncMock(); mock_resp.content = "Heading and body."; mock_resp.usage_metadata = {"total_tokens": 30}
async def fake_llm(messages): return mock_resp
with patch("app.core.folder_indexer._llm_text", new=fake_llm):
result = await folder_indexer.summarize_docx(docx_b64="UEsDBBQ=", name="doc.docx")
assert result.summary == "Heading and body."

View File

@@ -0,0 +1,94 @@
"""Folder quota helpers."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from sqlalchemy import select
from app.billing.quota import (
check_folder_quota,
add_token_usage,
QuotaExceeded,
)
from app.models import MonthlyTokenUsage
pytestmark = pytest.mark.asyncio
async def test_check_folder_quota_free_rejects_above_file_cap(db, test_user_free):
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=500, db=db
)
assert exc.value.reason == "max_files"
async def test_check_folder_quota_free_passes_under_cap(db, test_user_free):
# No raise
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=50, db=db
)
async def test_check_folder_quota_rejects_when_monthly_exhausted(db, test_user_free):
ym = datetime.now(timezone.utc).strftime("%Y-%m")
db.add(MonthlyTokenUsage(
user_id=test_user_free.id, year_month=ym, feature="folder_index", tokens_used=100_000
))
await db.commit()
with pytest.raises(QuotaExceeded) as exc:
await check_folder_quota(
user_id=test_user_free.id, tier="free", estimated_files=10, db=db
)
assert exc.value.reason == "monthly_tokens"
async def test_check_folder_quota_power_unlimited(db, test_user_power):
await check_folder_quota(
user_id=test_user_power.id, tier="power", estimated_files=999_999, db=db
)
async def test_add_token_usage_atomic_increment(db, test_user_free):
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=1500, db=db)
await add_token_usage(user_id=test_user_free.id, feature="folder_index", tokens=2500, db=db)
ym = datetime.now(timezone.utc).strftime("%Y-%m")
row = (await db.execute(
select(MonthlyTokenUsage).where(
MonthlyTokenUsage.user_id == test_user_free.id,
MonthlyTokenUsage.year_month == ym,
MonthlyTokenUsage.feature == "folder_index",
)
)).scalar_one()
assert row.tokens_used == 4000
async def test_add_token_usage_returns_exhausted_when_over_cap(db, test_user_free):
result = await add_token_usage(
user_id=test_user_free.id, feature="folder_index", tokens=150_000, db=db, cap=100_000
)
assert result.exhausted is True
assert result.tokens_used == 150_000
def test_quota_check_endpoint_rejects(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 500},
headers=auth_headers_free,
)
assert res.status_code == 402
body = res.json()
assert body["detail"]["reason"] == "max_files"
def test_quota_check_endpoint_passes(client, auth_headers_free):
res = client.post(
"/api/v1/billing/quota/check",
json={"feature": "folder_index", "estimated_files": 50},
headers=auth_headers_free,
)
assert res.status_code == 200
assert res.json() == {"ok": True}

View File

@@ -40,11 +40,9 @@ Coverage:
from __future__ import annotations from __future__ import annotations
import asyncio
import json import json
import uuid
from datetime import datetime, timezone from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, Mock, PropertyMock, patch from unittest.mock import AsyncMock, MagicMock, PropertyMock, patch
import pytest import pytest
@@ -330,7 +328,7 @@ def _make_gmail_message(
class TestGmailClientFetchMessages: class TestGmailClientFetchMessages:
"""GmailClient.fetch_messages tests with mocked Google API.""" """GmailClient.fetch_messages tests with mocked Google API."""
def _make_client(self) -> "GmailClient": def _make_client(self):
from app.integrations.gmail import GmailClient from app.integrations.gmail import GmailClient
return GmailClient(_TOKEN_DICT) return GmailClient(_TOKEN_DICT)
@@ -511,7 +509,7 @@ def _make_graph_teams_message(
class TestMSGraphClientFetchEmails: class TestMSGraphClientFetchEmails:
"""MSGraphClient.fetch_emails tests with mocked httpx.""" """MSGraphClient.fetch_emails tests with mocked httpx."""
def _make_client(self) -> "MSGraphClient": def _make_client(self):
from app.integrations.ms_graph import MSGraphClient from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(_MS_TOKEN_DICT) return MSGraphClient(_MS_TOKEN_DICT)
@@ -610,7 +608,7 @@ class TestMSGraphClientFetchEmails:
class TestMSGraphClientFetchMessages: class TestMSGraphClientFetchMessages:
"""MSGraphClient.fetch_messages (Teams) tests.""" """MSGraphClient.fetch_messages (Teams) tests."""
def _make_client(self) -> "MSGraphClient": def _make_client(self):
from app.integrations.ms_graph import MSGraphClient from app.integrations.ms_graph import MSGraphClient
return MSGraphClient(_MS_TOKEN_DICT) return MSGraphClient(_MS_TOKEN_DICT)

View File

@@ -12,16 +12,17 @@ Unit tests (no LLM)
4.6e Session not found → done=True, agent_config=None 4.6e Session not found → done=True, agent_config=None
4.6f Nudge uses AGENT_CONFIG_START/END markers (not old PROMPT_TEMPLATE) 4.6f Nudge uses AGENT_CONFIG_START/END markers (not old PROMPT_TEMPLATE)
Eval tests (real LLM + Langfuse scoring) Eval test (real LLM + Langfuse scoring)
----------------------------------------- ----------------------------------------
Cases are defined in tests/fixtures/journey_v2/cases.yaml. 4.1 Journey start explores directory → first reply contains a question
Email HTML files live in tests/fixtures/journey_v2/data/.
Use --journey-dir to point at a custom folder (same structure required). Cases 4.24.5 (multi-turn conversations producing a full AgentConfig) are
non-deterministic and tested manually — results tracked in Langfuse.
Run: Run:
pytest tests/test_journey_v2.py -v pytest tests/test_journey_v2.py -v
pytest tests/test_journey_v2.py -v -k "4_6" # unit only pytest tests/test_journey_v2.py -v -k "4_6" # unit only
pytest tests/test_journey_v2.py -v -k "eval" # LLM evals only pytest tests/test_journey_v2.py -v -k "eval" # single LLM eval
pytest tests/test_journey_v2.py -v --journey-dir /p # custom fixtures pytest tests/test_journey_v2.py -v --journey-dir /p # custom fixtures
""" """
@@ -170,57 +171,6 @@ def _evaluate_case(case: dict, reply: dict) -> tuple[float, str]:
has_q = "?" in reply.get("message", "") has_q = "?" in reply.get("message", "")
return (1.0 if has_q else 0.0), f"first_reply_has_question={has_q}" return (1.0 if has_q else 0.0), f"first_reply_has_question={has_q}"
if case.get("expect_done") and not reply.get("done"):
return 0.0, "expected done=True but journey did not complete"
agent_config_raw = reply.get("agent_config")
if case.get("expect_valid_config"):
if not agent_config_raw:
return 0.0, "agent_config is None"
try:
parsed = AgentConfig.model_validate_json(agent_config_raw)
valid = len(parsed.content_types) > 0
return (1.0 if valid else 0.0), f"content_types={len(parsed.content_types)}"
except Exception as exc:
return 0.0, f"parse error: {exc}"
if case.get("expect_content_type_id"):
expected_id = case["expect_content_type_id"]
if not agent_config_raw:
return 0.0, "agent_config is None"
try:
parsed = AgentConfig.model_validate_json(agent_config_raw)
ids = [ct.id for ct in parsed.content_types]
found = expected_id in ids
return (1.0 if found else 0.0), f"content_type_ids={ids}, expected={expected_id}"
except Exception as exc:
return 0.0, f"parse error: {exc}"
if case.get("expect_extraction_contains"):
keyword = case["expect_extraction_contains"].lower()
if not agent_config_raw:
return 0.0, "agent_config is None"
try:
parsed = AgentConfig.model_validate_json(agent_config_raw)
if not parsed.content_types:
return 0.0, "no content_types in config"
prompt = parsed.content_types[0].extraction_prompt.lower()
found = keyword in prompt
return (1.0 if found else 0.0), f"keyword='{keyword}' in extraction_prompt={found}"
except Exception as exc:
return 0.0, f"parse error: {exc}"
if case.get("expect_global_rules"):
if not agent_config_raw:
return 0.0, "agent_config is None"
try:
parsed = AgentConfig.model_validate_json(agent_config_raw)
has_rules = len(parsed.global_rules) > 0
return (1.0 if has_rules else 0.0), f"global_rules={parsed.global_rules}"
except Exception as exc:
return 0.0, f"parse error: {exc}"
return 1.0, "no specific assertion" return 1.0, "no specific assertion"

View File

@@ -0,0 +1,69 @@
from __future__ import annotations
from unittest.mock import AsyncMock, patch
import pytest
from app.core.deep_agent import format_folder_manifest, MANIFEST_TOKEN_BUDGET
pytestmark = pytest.mark.asyncio
def test_format_folder_manifest_basic():
manifest = {
"folderPath": "D:\\Acme",
"lastScannedAt": "2h ago",
"files": [
{"relPath": "briefs/kickoff.md", "kind": "text", "summary": "Kickoff notes; scope and deadlines."},
{"relPath": "logos/logo-v3.png", "kind": "image", "summary": "Final logo on white."},
],
}
out = format_folder_manifest(manifest)
assert "<linked_folder>" in out
assert "/briefs/kickoff.md" in out or "briefs/kickoff.md" in out
assert "[text]" in out
assert "[image]" in out
def test_format_folder_manifest_truncates_past_budget():
files = [
{"relPath": f"f{i}.md", "kind": "text", "summary": "x" * 100, "mtimeMs": i}
for i in range(2000)
]
out = format_folder_manifest({"folderPath": "p", "lastScannedAt": "now", "files": files})
assert "more files omitted" in out
# Rough token check
assert len(out) // 4 < MANIFEST_TOKEN_BUDGET + 200
def test_format_folder_manifest_null_returns_empty():
assert format_folder_manifest(None) == ""
assert format_folder_manifest({"files": []}) == ""
async def test_brief_multi_project_manifest_top_5_per_project():
fake_response = [
{
"projectId": "p1", "projectName": "Acme", "folderPath": "/a",
"lastScannedAt": "now",
"files": [
{"relPath": f"f{i}.md", "kind": "text", "summary": "s", "mtimeMs": i}
for i in range(10)
],
},
{
"projectId": "p2", "projectName": "Beta", "folderPath": "/b",
"lastScannedAt": "now",
"files": [{"relPath": "x.md", "kind": "text", "summary": "s", "mtimeMs": 1}],
},
]
with patch(
"app.core.deep_agent.execute_on_client",
new=AsyncMock(return_value={"projects": fake_response}),
):
from app.core.deep_agent import build_brief_multi_project_manifest
out = await build_brief_multi_project_manifest()
# Project 1 has 10 files, only top 5 by mtimeMs should appear
assert out.count("[p1]") <= 5
# Project 2 has 1 file, must appear
assert "[p2]" in out or "Beta" in out

405
tests/test_memory_audit.py Normal file
View File

@@ -0,0 +1,405 @@
"""Tests for Phase 7 — weekly audit_memory job.
Coverage:
1. audit_memory never raises even if inner work fails.
2. _scan_associative_contradictions skips when < 2 decryptable facts.
3. _scan_associative_contradictions calls LLM and deletes flagged rows.
4. _scan_associative_contradictions is a no-op when LLM fails.
5. _scan_associative_contradictions is a no-op when LLM returns non-list.
6. _canonicalize_relation_labels skips when no relation rows.
7. _canonicalize_relation_labels rewrites variant labels to canonical form.
8. _canonicalize_relation_labels is a no-op when LLM fails.
9. _canonicalize_relation_labels is a no-op when remap is empty.
10. Both helpers work correctly when Langfuse is unavailable (lf=None).
11. get_prompt_or_fallback called with correct Langfuse prompt names.
"""
from __future__ import annotations
import json
import uuid
from contextlib import contextmanager, ExitStack
from datetime import datetime, timezone
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_maintenance import (
_canonicalize_relation_labels,
_scan_associative_contradictions,
audit_memory,
)
from app.db import get_session
from app.main import app
from app.models import MemoryAssociative, MemoryRelation, User
from tests.conftest import TEST_USER_IDS
PRO_USER_ID = TEST_USER_IDS["pro"]
_FERNET_KEY = Fernet.generate_key().decode()
_FERNET = Fernet(_FERNET_KEY.encode())
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ── Helpers ───────────────────────────────────────────────────────────────────
@pytest_asyncio.fixture
async def pro_user(db_session):
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
def _enc(text: str) -> str:
return _FERNET.encrypt(text.encode()).decode()
def _assoc_row(user_id: str, text: str) -> MemoryAssociative:
return MemoryAssociative(
id=str(uuid.uuid4()),
user_id=user_id,
content_encrypted=_enc(text),
updated_at=datetime.now(timezone.utc),
)
def _relation_row(user_id: str, subject: str, predicate: str, obj: str) -> MemoryRelation:
return MemoryRelation(
id=str(uuid.uuid4()),
user_id=user_id,
subject_label=subject,
subject_type="person",
predicate=predicate,
object_label=obj,
object_type="company",
confidence=0.8,
)
def _llm_response(content: str) -> MagicMock:
msg = MagicMock()
msg.content = content
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
return msg
def _mock_llm(content: str) -> MagicMock:
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=_llm_response(content))
return llm
@contextmanager
def _patch_audit(llm_mock, lf=None, prompt_text: str = "fallback {facts}"):
"""Context manager that patches all external deps for audit helpers."""
with ExitStack() as stack:
stack.enter_context(
patch("app.core.llm.get_agent_llm", return_value=llm_mock)
)
stack.enter_context(
patch("app.core.llm.model_for_agent", return_value="memory-auditor")
)
stack.enter_context(
patch("app.core.memory_maintenance.get_langfuse", return_value=lf)
)
stack.enter_context(
patch(
"app.core.memory_maintenance.get_prompt_or_fallback",
return_value=(prompt_text, None),
)
)
stack.enter_context(
patch(
"app.core.memory_maintenance.compile_prompt",
side_effect=lambda tmpl, obj, **kw: tmpl.format(**kw) if "{" in tmpl else tmpl,
)
)
yield
# ── Test 1: audit_memory never raises ────────────────────────────────────────
@pytest.mark.asyncio
async def test_audit_memory_never_raises_on_missing_user(db_session):
"""audit_memory with a non-existent user_id must not raise."""
await audit_memory(db_session, str(uuid.uuid4()))
@pytest.mark.asyncio
async def test_audit_memory_never_raises_on_llm_failure(db_session, pro_user):
"""audit_memory must swallow inner exceptions."""
llm = MagicMock()
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
with (
patch("app.core.llm.get_agent_llm", return_value=llm),
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
patch(
"app.core.memory_maintenance.get_prompt_or_fallback",
return_value=("p {facts}", None),
),
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
):
await audit_memory(db_session, PRO_USER_ID)
# ── Test 2: _scan skips when < 2 facts ───────────────────────────────────────
@pytest.mark.asyncio
async def test_scan_contradictions_skips_with_one_fact(db_session, pro_user):
row = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
db_session.add(row)
await db_session.commit()
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
with _patch_audit(llm):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
llm.ainvoke.assert_not_called()
# ── Test 3: _scan deletes flagged contradiction ───────────────────────────────
@pytest.mark.asyncio
async def test_scan_contradictions_deletes_flagged_row(db_session, pro_user):
keep = _assoc_row(PRO_USER_ID, "Prefers morning meetings")
drop = _assoc_row(PRO_USER_ID, "Never schedules before noon")
db_session.add(keep)
db_session.add(drop)
await db_session.commit()
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts morning pref"}])
llm = _mock_llm(deletion_payload)
with _patch_audit(llm, prompt_text="p {facts}"):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
)
remaining = result.scalars().all()
remaining_ids = {r.id for r in remaining}
assert keep.id in remaining_ids
assert drop.id not in remaining_ids
# ── Test 4: _scan is no-op on LLM failure ────────────────────────────────────
@pytest.mark.asyncio
async def test_scan_contradictions_noop_on_llm_failure(db_session, pro_user):
for text in ("Fact A", "Fact B"):
db_session.add(_assoc_row(PRO_USER_ID, text))
await db_session.commit()
llm = MagicMock()
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
with _patch_audit(llm, prompt_text="p {facts}"):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
)
assert len(result.scalars().all()) == 2
# ── Test 5: _scan is no-op when LLM returns non-list ─────────────────────────
@pytest.mark.asyncio
async def test_scan_contradictions_noop_on_non_list_response(db_session, pro_user):
for text in ("Fact A", "Fact B"):
db_session.add(_assoc_row(PRO_USER_ID, text))
await db_session.commit()
llm = _mock_llm('"unexpected string"')
with _patch_audit(llm, prompt_text="p {facts}"):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
)
assert len(result.scalars().all()) == 2
# ── Test 6: _canonicalize skips when no relations ────────────────────────────
@pytest.mark.asyncio
async def test_canonicalize_skips_when_no_relations(db_session, pro_user):
llm = MagicMock()
llm.ainvoke = AsyncMock(return_value=_llm_response("[]"))
with _patch_audit(llm, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
llm.ainvoke.assert_not_called()
# ── Test 7: _canonicalize rewrites variant labels ────────────────────────────
@pytest.mark.asyncio
async def test_canonicalize_rewrites_variant_labels(db_session, pro_user):
row_a = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
row_b = _relation_row(PRO_USER_ID, "Giulia R.", "reports_to", "Marco")
row_c = _relation_row(PRO_USER_ID, "Marco", "manages", "Giulia")
db_session.add(row_a)
db_session.add(row_b)
db_session.add(row_c)
await db_session.commit()
groups = json.dumps([
{"canonical": "Giulia", "variants": ["giulia", "Giulia R."]}
])
llm = _mock_llm(groups)
with _patch_audit(llm, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
await db_session.refresh(row_a)
await db_session.refresh(row_b)
await db_session.refresh(row_c)
assert row_a.subject_label == "Giulia"
assert row_b.subject_label == "Giulia"
assert row_c.object_label == "Giulia"
assert row_c.subject_label == "Marco"
# ── Test 8: _canonicalize is no-op on LLM failure ────────────────────────────
@pytest.mark.asyncio
async def test_canonicalize_noop_on_llm_failure(db_session, pro_user):
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
db_session.add(row)
await db_session.commit()
llm = MagicMock()
llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
with _patch_audit(llm, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
await db_session.refresh(row)
assert row.subject_label == "giulia"
# ── Test 9: _canonicalize is no-op when remap is empty ───────────────────────
@pytest.mark.asyncio
async def test_canonicalize_noop_when_remap_empty(db_session, pro_user):
row = _relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme")
db_session.add(row)
await db_session.commit()
llm = _mock_llm("[]")
with _patch_audit(llm, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
await db_session.refresh(row)
assert row.subject_label == "Giulia"
# ── Test 10: both helpers work without Langfuse ───────────────────────────────
@pytest.mark.asyncio
async def test_scan_works_without_langfuse(db_session, pro_user):
keep = _assoc_row(PRO_USER_ID, "Prefers dark mode")
drop = _assoc_row(PRO_USER_ID, "Prefers light mode")
db_session.add(keep)
db_session.add(drop)
await db_session.commit()
deletion_payload = json.dumps([{"delete": drop.id, "reason": "contradicts dark mode"}])
llm = _mock_llm(deletion_payload)
with _patch_audit(llm, lf=None, prompt_text="p {facts}"):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
result = await db_session.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == PRO_USER_ID)
)
remaining_ids = {r.id for r in result.scalars().all()}
assert keep.id in remaining_ids
assert drop.id not in remaining_ids
@pytest.mark.asyncio
async def test_canonicalize_works_without_langfuse(db_session, pro_user):
row = _relation_row(PRO_USER_ID, "giulia", "works_at", "Acme")
db_session.add(row)
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Giulia"))
await db_session.commit()
groups = json.dumps([{"canonical": "Giulia", "variants": ["giulia"]}])
llm = _mock_llm(groups)
with _patch_audit(llm, lf=None, prompt_text="p {labels}"):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
await db_session.refresh(row)
assert row.subject_label == "Giulia"
# ── Test 11: correct Langfuse prompt names used ───────────────────────────────
@pytest.mark.asyncio
async def test_scan_uses_correct_langfuse_prompt_name(db_session, pro_user):
for text in ("Fact A", "Fact B"):
db_session.add(_assoc_row(PRO_USER_ID, text))
await db_session.commit()
llm = _mock_llm("[]")
mock_get_prompt = MagicMock(return_value=("p {facts}", None))
with (
patch("app.core.llm.get_agent_llm", return_value=llm),
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
):
await _scan_associative_contradictions(db_session, PRO_USER_ID, _FERNET)
mock_get_prompt.assert_called_once()
assert mock_get_prompt.call_args[0][0] == "memory_audit_contradictions"
@pytest.mark.asyncio
async def test_canonicalize_uses_correct_langfuse_prompt_name(db_session, pro_user):
db_session.add(_relation_row(PRO_USER_ID, "Giulia", "works_at", "Acme"))
db_session.add(_relation_row(PRO_USER_ID, "Marco", "manages", "Acme"))
await db_session.commit()
llm = _mock_llm("[]")
mock_get_prompt = MagicMock(return_value=("p {labels}", None))
with (
patch("app.core.llm.get_agent_llm", return_value=llm),
patch("app.core.llm.model_for_agent", return_value="memory-auditor"),
patch("app.core.memory_maintenance.get_langfuse", return_value=None),
patch("app.core.memory_maintenance.get_prompt_or_fallback", mock_get_prompt),
patch("app.core.memory_maintenance.compile_prompt", return_value="compiled"),
):
await _canonicalize_relation_labels(db_session, PRO_USER_ID)
mock_get_prompt.assert_called_once()
assert mock_get_prompt.call_args[0][0] == "memory_audit_canonicalize"

View File

@@ -0,0 +1,345 @@
"""Tests for Phase 2 — Mem0-style Extract/Update pipeline.
Coverage:
2.1 extract_candidates returns valid ExtractionResult with mocked LLM.
2.2 decide_action — all 4 branches (ADD/UPDATE/DELETE/NOOP + empty existing).
2.3 run_extraction end-to-end with mocked LLM writes expected rows.
2.4 _dispatch_extraction — Pro user triggers realtime task; Free enqueues row.
"""
from __future__ import annotations
import json
import uuid
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_extraction import (
ExtractionResult,
MemoryCandidate,
decide_action,
extract_candidates,
run_extraction,
)
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.main import app
from app.models import ExtractionQueue, MemoryCore, User
from tests.conftest import TEST_USER_IDS
PRO_USER_ID = TEST_USER_IDS["pro"]
FREE_USER_ID = TEST_USER_IDS["free"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ── Helpers ───────────────────────────────────────────────────────────────────
@pytest_asyncio.fixture
async def pro_user(db_session):
"""Update the seeded pro user to have an encryption_key."""
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
@pytest_asyncio.fixture
async def free_user(db_session):
"""Update the seeded free user to have an encryption_key."""
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
def _make_llm_response(content: str) -> MagicMock:
msg = MagicMock()
msg.content = content
msg.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
return msg
# ── TASK 2.1 — extract_candidates ────────────────────────────────────────────
@pytest.mark.asyncio
async def test_extract_candidates_returns_valid_result():
payload = {
"candidates": [
{
"type": "fact",
"content": "User's CFO is Giulia",
"target_tier": "core",
"subject": None,
"predicate": None,
"object": None,
"confidence": 0.85,
}
]
}
mock_response = _make_llm_response(json.dumps(payload))
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = (
"system prompt {last_turn} {core_memory} {recent_episodes}",
None,
)
llm_instance = MagicMock()
llm_instance.bind.return_value = llm_instance
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
result = await extract_candidates(
last_turn="User: My CFO is Giulia\nAssistant: Noted.",
core_memory={},
recent_episodes=[],
)
assert isinstance(result, ExtractionResult)
assert len(result.candidates) == 1
assert result.candidates[0].type == "fact"
assert "Giulia" in result.candidates[0].content
assert result.candidates[0].confidence == 0.85
@pytest.mark.asyncio
async def test_extract_candidates_returns_empty_on_llm_failure():
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("prompt {last_turn} {core_memory} {recent_episodes}", None)
llm_instance = MagicMock()
llm_instance.bind.return_value = llm_instance
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
mock_get_llm.return_value = llm_instance
result = await extract_candidates("turn", {}, [])
assert isinstance(result, ExtractionResult)
assert result.candidates == []
# ── TASK 2.2 — decide_action ─────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_decide_action_add_when_no_existing():
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
action = await decide_action(candidate, existing=[])
assert action == "ADD"
@pytest.mark.asyncio
async def test_decide_action_noop():
candidate = MemoryCandidate(type="fact", content="CFO is Giulia", target_tier="core")
mock_response = _make_llm_response("NOOP")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["CFO is Giulia"])
assert action == "NOOP"
@pytest.mark.asyncio
async def test_decide_action_update():
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
mock_response = _make_llm_response("UPDATE")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["CFO is Giulia"])
assert action == "UPDATE"
@pytest.mark.asyncio
async def test_decide_action_delete():
candidate = MemoryCandidate(type="fact", content="No longer have a CFO", target_tier="core")
mock_response = _make_llm_response("DELETE")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(return_value=mock_response)
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["CFO is Giulia"])
assert action == "DELETE"
@pytest.mark.asyncio
async def test_decide_action_defaults_add_on_llm_failure():
candidate = MemoryCandidate(type="fact", content="CFO is Marco", target_tier="core")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch("app.core.memory_extraction.get_prompt_or_fallback") as mock_prompt,
):
mock_prompt.return_value = ("p {candidate} {existing_memories}", None)
llm_instance = MagicMock()
llm_instance.ainvoke = AsyncMock(side_effect=RuntimeError("LLM down"))
mock_get_llm.return_value = llm_instance
action = await decide_action(candidate, existing=["old memory"])
assert action == "ADD"
# ── TASK 2.3 — run_extraction end-to-end ─────────────────────────────────────
@pytest.mark.asyncio
async def test_run_extraction_writes_core_candidate(db_session, pro_user):
"""'My CFO is Giulia' → fact candidate → core row written."""
fact_payload = {
"candidates": [
{
"type": "fact",
"content": "User prefers morning meetings",
"target_tier": "core",
"confidence": 0.8,
}
]
}
def _mock_llm_response(content: str):
msg = MagicMock()
msg.content = content
msg.usage_metadata = {}
return msg
call_count = 0
async def _ainvoke_side_effect(messages):
nonlocal call_count
call_count += 1
if call_count == 1:
# extract_candidates call
return _mock_llm_response(json.dumps(fact_payload))
# decide_action — no existing → short-circuits to ADD without LLM
return _mock_llm_response("ADD")
with (
patch("app.core.memory_extraction.get_agent_llm") as mock_get_llm,
patch("app.core.memory_extraction.get_langfuse", return_value=None),
patch(
"app.core.memory_extraction.get_prompt_or_fallback",
side_effect=lambda name, fb: (
("p {last_turn} {core_memory} {recent_episodes}", None)
if name == "memory_extraction"
else ("p {candidate} {existing_memories}", None)
),
),
):
llm_instance = MagicMock()
llm_instance.bind.return_value = llm_instance
llm_instance.ainvoke = AsyncMock(side_effect=_ainvoke_side_effect)
mock_get_llm.return_value = llm_instance
await run_extraction(
db=db_session,
user_id=PRO_USER_ID,
last_user_msg="My CFO is Giulia",
last_assistant_msg="Noted, I will remember that.",
session_id="test-session",
)
# core row should exist
result = await db_session.execute(
select(MemoryCore).where(MemoryCore.user_id == PRO_USER_ID)
)
rows = result.scalars().all()
assert len(rows) >= 1
fernet = Fernet(_FERNET_KEY.encode())
values = [fernet.decrypt(r.value_encrypted.encode()).decode() for r in rows]
assert any("morning meetings" in v for v in values)
# ── TASK 2.4 — dispatch ───────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_dispatch_realtime_for_pro(db_session, pro_user):
"""Pro user: asyncio.create_task called (not queue row)."""
middleware = MemoryMiddleware(db_session)
with (
patch("app.core.memory_middleware.asyncio.create_task") as mock_task,
patch("app.billing.tier_manager.tier_manager.check_feature", return_value=True),
):
await middleware._dispatch_extraction(
user_id=PRO_USER_ID,
episode_id=str(uuid.uuid4()),
last_user_msg="hello",
last_assistant_msg="hi",
session_id=None,
)
mock_task.assert_called_once()
@pytest.mark.asyncio
async def test_dispatch_queue_for_free(db_session, free_user):
"""Free user: ExtractionQueue row inserted."""
middleware = MemoryMiddleware(db_session)
ep_id = str(uuid.uuid4())
with patch("app.billing.tier_manager.tier_manager.check_feature", return_value=False):
await middleware._dispatch_extraction(
user_id=FREE_USER_ID,
episode_id=ep_id,
last_user_msg="hello",
last_assistant_msg="hi",
session_id=None,
)
result = await db_session.execute(
select(ExtractionQueue).where(ExtractionQueue.user_id == FREE_USER_ID)
)
rows = result.scalars().all()
assert len(rows) == 1
assert rows[0].episode_id == ep_id

View File

@@ -12,14 +12,15 @@ from __future__ import annotations
import json import json
import uuid import uuid
from unittest.mock import patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from sqlalchemy import select from sqlalchemy import select
from app.core.memory_middleware import MemoryMiddleware, _PROACTIVE_CONFIDENCE_THRESHOLD from app.core.embeddings import embed_text
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session from app.db import get_session
from app.main import app from app.main import app
from app.models import ( from app.models import (
@@ -341,3 +342,33 @@ def test_home_request_calls_memory_middleware(client):
stored_session_id, stored_message = store_calls[0][1], store_calls[0][2] stored_session_id, stored_message = store_calls[0][1], store_calls[0][2]
assert stored_session_id == session_id assert stored_session_id == session_id
assert stored_message == "Show tasks" assert stored_message == "Show tasks"
# ── embed_text ─────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_embed_text_returns_1536_floats():
"""embed_text returns a 1536-dim float list when OpenAI responds successfully."""
fake_embedding = [0.1] * 1536
mock_response = MagicMock()
mock_response.data = [MagicMock(embedding=fake_embedding)]
mock_client = MagicMock()
mock_client.embeddings.create = AsyncMock(return_value=mock_response)
with patch("app.core.embeddings.AsyncOpenAI", return_value=mock_client):
result = await embed_text("test text")
assert result is not None
assert len(result) == 1536
assert all(isinstance(x, float) for x in result)
@pytest.mark.asyncio
async def test_embed_text_returns_none_on_failure():
"""embed_text returns None when OpenAI raises; must not propagate the exception."""
with patch("app.core.embeddings.AsyncOpenAI", side_effect=Exception("no key")):
result = await embed_text("test text")
assert result is None

View File

@@ -7,10 +7,9 @@ column is stored as JSON in tests (SQLite-compatible).
from __future__ import annotations from __future__ import annotations
import uuid import uuid
from datetime import datetime, timezone from datetime import datetime
import pytest import pytest
import pytest_asyncio
from cryptography.fernet import Fernet from cryptography.fernet import Fernet
from sqlalchemy import select from sqlalchemy import select

View File

@@ -0,0 +1,153 @@
"""Tests for Phase 5 — proactive hints surfacing.
Coverage:
1. _proactive_hints_injection returns correct section for seeded hints
2. _proactive_hints_injection returns empty string when no hints
3. enrich_context includes proactive_hints key from MemoryProactive row
4. System prompt includes proactive line when row exists + confidence >= threshold
5. TierManager.check_feature returns True for power/team, False for free/pro
"""
from __future__ import annotations
import uuid
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.billing.tier_manager import tier_manager
from app.core.deep_agent import _proactive_hints_injection
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.main import app
from app.models import MemoryProactive, User
from tests.conftest import TEST_USER_IDS
USER_ID = TEST_USER_IDS["power"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
# ── Fixtures ──────────────────────────────────────────────────────────────────
@pytest_asyncio.fixture
async def user_with_key(db_session):
result = await db_session.execute(select(User).where(User.id == USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
def _enc(plaintext: str) -> str:
return Fernet(_FERNET_KEY.encode()).encrypt(plaintext.encode()).decode()
# ── _proactive_hints_injection unit tests ─────────────────────────────────────
def test_proactive_hints_injection_with_hints():
context = {"proactive_hints": ["Works late on Thursdays", "Prefers bullet points"]}
result = _proactive_hints_injection(context)
assert "I noticed" in result
assert "Works late on Thursdays" in result
assert "Prefers bullet points" in result
def test_proactive_hints_injection_empty():
assert _proactive_hints_injection({}) == ""
assert _proactive_hints_injection({"proactive_hints": []}) == ""
assert _proactive_hints_injection({"proactive_hints": None}) == ""
def test_proactive_hints_injection_truncates_long_hints():
hints = ["x" * 200] * 10
result = _proactive_hints_injection({"proactive_hints": hints})
assert len(result) <= 600
assert result.endswith("...")
# ── enrich_context includes proactive hints ───────────────────────────────────
@pytest.mark.asyncio
async def test_enrich_context_returns_proactive_hints(db_session, user_with_key):
pattern = "Always checks tasks before meetings"
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc(pattern),
confidence=0.8,
source="inferred",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "test message")
assert "proactive_hints" in ctx
assert pattern in ctx["proactive_hints"]
@pytest.mark.asyncio
async def test_enrich_context_excludes_low_confidence_proactive(db_session, user_with_key):
pattern = "Low confidence pattern"
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc(pattern),
confidence=0.1,
source="inferred",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "test message")
hints = ctx.get("proactive_hints", [])
assert pattern not in hints
# ── proactive hints appear in system prompt string ───────────────────────────
@pytest.mark.asyncio
async def test_proactive_hints_in_system_prompt_string(db_session, user_with_key):
pattern = "Frequently requests end-of-day summaries"
db_session.add(MemoryProactive(
id=str(uuid.uuid4()),
user_id=USER_ID,
pattern_encrypted=_enc(pattern),
confidence=0.75,
source="inferred",
))
await db_session.commit()
middleware = MemoryMiddleware(db_session)
ctx = await middleware.enrich_context(USER_ID, "summarize my day")
system_prompt_suffix = _proactive_hints_injection(ctx)
assert pattern in system_prompt_suffix
# ── Tier gate ─────────────────────────────────────────────────────────────────
@pytest.mark.parametrize("tier,expected", [
("free", False),
("pro", False),
("power", True),
("team", True),
])
def test_proactive_mining_tier_gate(tier, expected):
assert tier_manager.check_feature(tier, "proactive_mining") == expected

View File

@@ -0,0 +1,220 @@
"""Tests for Phase 3 — relational tier (Mem0g-light).
Coverage:
1. upsert_relation inserts a row and query_relations returns it
2. upsert_relation updates existing row on duplicate (subject/predicate/object)
3. tier gating: Free user gets empty list from query_relations + enrich_context
4. enrich_context includes relational_memory key for Pro user
5. decay_relations decays confidence and prunes rows below threshold
"""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import patch
import pytest
import pytest_asyncio
from cryptography.fernet import Fernet
from sqlalchemy import select
from app.core.memory_maintenance import decay_relations
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.main import app
from app.models import MemoryRelation, User
from tests.conftest import TEST_USER_IDS
PRO_USER_ID = TEST_USER_IDS["pro"]
FREE_USER_ID = TEST_USER_IDS["free"]
_FERNET_KEY = Fernet.generate_key().decode()
# ── DB override ───────────────────────────────────────────────────────────────
@pytest.fixture(autouse=True)
def _override_db(db_session):
async def _gen():
yield db_session
app.dependency_overrides[get_session] = _gen
yield
app.dependency_overrides.pop(get_session, None)
@pytest_asyncio.fixture
async def pro_user_with_key(db_session):
"""Set encryption_key on the pro test user so Fernet works."""
result = await db_session.execute(select(User).where(User.id == PRO_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
@pytest_asyncio.fixture
async def free_user_with_key(db_session):
"""Set encryption_key on the free test user."""
result = await db_session.execute(select(User).where(User.id == FREE_USER_ID))
user = result.scalar_one()
user.encryption_key = _FERNET_KEY
await db_session.commit()
return user
# ── Tests ─────────────────────────────────────────────────────────────────────
@pytest.mark.asyncio
async def test_upsert_relation_inserts_and_queries(db_session, pro_user_with_key):
"""upsert_relation inserts a row; query_relations returns it."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
PRO_USER_ID,
subject="Giulia",
subject_type="person",
predicate="works_at",
object_="Acme Corp",
object_type="company",
confidence=0.9,
)
rows = await mm.query_relations(PRO_USER_ID, subject="Giulia")
assert len(rows) == 1
assert rows[0].subject_label == "Giulia"
assert rows[0].predicate == "works_at"
assert rows[0].object_label == "Acme Corp"
assert abs(rows[0].confidence - 0.9) < 0.001
@pytest.mark.asyncio
async def test_upsert_relation_updates_on_duplicate(db_session, pro_user_with_key):
"""Second upsert on same triple updates confidence and last_confirmed_at."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
PRO_USER_ID,
subject="Marco",
subject_type="person",
predicate="stakeholder_of",
object_="Project Nexus",
object_type="project",
confidence=0.7,
)
await mm.upsert_relation(
PRO_USER_ID,
subject="Marco",
subject_type="person",
predicate="stakeholder_of",
object_="Project Nexus",
object_type="project",
confidence=0.95,
)
rows = await mm.query_relations(PRO_USER_ID, subject="Marco")
# Only one row despite two upserts
assert len(rows) == 1
assert abs(rows[0].confidence - 0.95) < 0.001
assert rows[0].last_confirmed_at is not None
@pytest.mark.asyncio
async def test_free_tier_relation_skipped(db_session, free_user_with_key):
"""Free user: upsert_relation is silently skipped (no row created)."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
FREE_USER_ID,
subject="Alice",
subject_type="person",
predicate="reports_to",
object_="Bob",
object_type="person",
confidence=0.8,
)
rows = await mm.query_relations(FREE_USER_ID, subject="Alice")
assert rows == []
@pytest.mark.asyncio
async def test_enrich_context_includes_relational_memory(db_session, pro_user_with_key):
"""enrich_context includes relational_memory key for Pro user."""
mm = MemoryMiddleware(db_session)
await mm.upsert_relation(
PRO_USER_ID,
subject="Elena",
subject_type="person",
predicate="cfo_of",
object_="StartupXYZ",
object_type="company",
confidence=0.85,
)
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
ctx = await mm.enrich_context(PRO_USER_ID, "who is Elena?")
assert "relational_memory" in ctx
assert any("Elena" in r for r in ctx["relational_memory"])
@pytest.mark.asyncio
async def test_enrich_context_relational_empty_for_free(db_session, free_user_with_key):
"""Free user: relational_memory is empty list in enrich_context."""
mm = MemoryMiddleware(db_session)
with patch("app.core.memory_middleware.MemoryMiddleware._load_associative", return_value=[]):
ctx = await mm.enrich_context(FREE_USER_ID, "test message")
assert ctx.get("relational_memory") == []
@pytest.mark.asyncio
async def test_decay_relations_reduces_confidence(db_session, pro_user_with_key):
"""decay_relations reduces confidence on stale rows."""
old_date = datetime.now(timezone.utc) - timedelta(days=35)
row = MemoryRelation(
id=str(uuid.uuid4()),
user_id=PRO_USER_ID,
subject_label="OldContact",
subject_type="person",
predicate="knows",
object_label="SomeProject",
object_type="project",
confidence=0.8,
last_confirmed_at=old_date,
)
db_session.add(row)
await db_session.commit()
await decay_relations(db_session, PRO_USER_ID)
result = await db_session.execute(
select(MemoryRelation).where(MemoryRelation.subject_label == "OldContact")
)
updated = result.scalar_one_or_none()
assert updated is not None
assert updated.confidence < 0.8
@pytest.mark.asyncio
async def test_decay_relations_prunes_low_confidence(db_session, pro_user_with_key):
"""decay_relations deletes rows whose confidence drops below 0.2 threshold."""
# Start at 0.21 with 60-day-old last_confirmed_at → two decay periods → 0.21 * 0.95^2 ≈ 0.19 → pruned
old_date = datetime.now(timezone.utc) - timedelta(days=65)
row = MemoryRelation(
id=str(uuid.uuid4()),
user_id=PRO_USER_ID,
subject_label="ExpiredContact",
subject_type="person",
predicate="used_to_work_with",
object_label="OldCorp",
object_type="company",
confidence=0.21,
last_confirmed_at=old_date,
)
db_session.add(row)
await db_session.commit()
await decay_relations(db_session, PRO_USER_ID)
result = await db_session.execute(
select(MemoryRelation).where(MemoryRelation.subject_label == "ExpiredContact")
)
pruned = result.scalar_one_or_none()
assert pruned is None

View File

@@ -1,400 +0,0 @@
"""Tests for Step 10+12: Plugin Marketplace (DB-backed).
Covers:
- PluginRegistry: catalog management, filtering, sorting, install counts (PostgreSQL)
- ReviewQueue: pending queue, review decisions, manifest security checklist
- RevenueShare: install event recording, earnings aggregation (PostgreSQL)
- Route integration: tier gate, list/get/install/uninstall via TestClient
"""
from __future__ import annotations
import uuid
import pytest
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import PluginRegistry
from app.marketplace.plugin_review import ReviewQueue, validate_manifest
from app.marketplace.revenue_share import RevenueShare
from app.models import Plugin, PluginReview as PluginReviewModel, RevenueEvent
from app.schemas import PluginManifest
from tests.conftest import TEST_USER_IDS, auth_header
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _fresh_manifest(
plugin_id: str | None = None,
category: str = "productivity",
price_cents: int = 0,
permissions: list[str] | None = None,
) -> PluginManifest:
pid = plugin_id or f"plugin-{uuid.uuid4().hex[:8]}"
return PluginManifest(
id=pid,
name=f"Plugin {pid}",
description=f"Description for {pid}",
version="1.0.0",
author="test-author",
permissions=permissions or ["read:tasks"],
category=category,
price_cents=price_cents,
)
# ---------------------------------------------------------------------------
# PluginRegistry (DB-backed)
# ---------------------------------------------------------------------------
class TestPluginRegistry:
"""Each test uses the conftest db_session fixture with a fresh in-memory DB."""
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.mark.asyncio
async def test_seed_plugins_are_listed(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session)
assert result.total == 3
assert all(p.id.startswith("plugin-") for p in result.plugins)
@pytest.mark.asyncio
async def test_list_approved_only(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "plugins/key.zip")
result = await reg.list_plugins(db_session)
ids = [p.id for p in result.plugins]
assert manifest.id not in ids # still pending
@pytest.mark.asyncio
async def test_list_filter_by_category(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, category="communication")
assert result.total == 1
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_list_filter_by_query(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
result = await reg.list_plugins(db_session, query="time")
assert result.total == 1
assert result.plugins[0].id == "plugin-time-tracker"
@pytest.mark.asyncio
async def test_list_sort_by_installs(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-slack-notify")
await reg.record_install(db_session, "plugin-slack-notify")
result = await reg.list_plugins(db_session, sort="installs")
assert result.plugins[0].id == "plugin-slack-notify"
@pytest.mark.asyncio
async def test_get_plugin_found(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["manifest"].id == "plugin-github-sync"
assert "install_count" in entry
@pytest.mark.asyncio
async def test_get_plugin_not_found(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
entry = await reg.get_plugin(db_session, "no-such-plugin")
assert entry is None
@pytest.mark.asyncio
async def test_submit_sets_pending(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
plugin_id = await reg.submit_plugin(db_session, manifest, "key.zip")
assert plugin_id == manifest.id
result = await db_session.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one()
assert row.status == "pending_review"
@pytest.mark.asyncio
async def test_approve_makes_visible(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.approve_plugin(db_session, manifest.id)
result = await reg.list_plugins(db_session)
assert manifest.id in [p.id for p in result.plugins]
@pytest.mark.asyncio
async def test_reject_stores_reason(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await reg.reject_plugin(db_session, manifest.id, reason="Unsafe permissions")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
assert row.rejection_reason == "Unsafe permissions"
listed = await reg.list_plugins(db_session)
assert manifest.id not in [p.id for p in listed.plugins]
@pytest.mark.asyncio
async def test_approve_unknown_raises_key_error(
self, reg: PluginRegistry, db_session: AsyncSession
) -> None:
with pytest.raises(KeyError):
await reg.approve_plugin(db_session, "ghost-plugin")
@pytest.mark.asyncio
async def test_record_install_increments_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_decrements_count(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_install(db_session, "plugin-github-sync")
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_record_uninstall_floors_at_zero(
self, reg: PluginRegistry, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await reg.record_uninstall(db_session, "plugin-github-sync")
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 0
# ---------------------------------------------------------------------------
# ReviewQueue (DB-backed)
# ---------------------------------------------------------------------------
class TestReviewQueue:
@pytest.fixture
def reg(self) -> PluginRegistry:
return PluginRegistry()
@pytest.fixture
def queue(self) -> ReviewQueue:
return ReviewQueue()
@pytest.mark.asyncio
async def test_get_pending_returns_submitted_plugins(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
pending = await queue.get_pending(db_session)
assert any(p["plugin_id"] == manifest.id for p in pending)
@pytest.mark.asyncio
async def test_submit_review_approved(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(db_session, manifest.id, TEST_USER_IDS["power"], "approved", "Looks good")
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "approved"
# Check review row was persisted
review_result = await db_session.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == manifest.id)
)
review = review_result.scalar_one()
assert review.decision == "approved"
@pytest.mark.asyncio
async def test_submit_review_rejected(
self, reg: PluginRegistry, queue: ReviewQueue, db_session: AsyncSession
) -> None:
manifest = _fresh_manifest()
await reg.submit_plugin(db_session, manifest, "key.zip")
await queue.submit_review(
db_session, manifest.id, TEST_USER_IDS["power"], "rejected", "Bad permissions"
)
result = await db_session.execute(select(Plugin).where(Plugin.id == manifest.id))
row = result.scalar_one()
assert row.status == "rejected"
def test_validate_manifest_ok(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "write:notes"])
validate_manifest(manifest) # should not raise
def test_validate_manifest_unknown_permission(self) -> None:
manifest = _fresh_manifest(permissions=["read:tasks", "read:secrets"])
with pytest.raises(ValueError, match="Unknown permission"):
validate_manifest(manifest)
def test_validate_manifest_invalid_id_format(self) -> None:
manifest = _fresh_manifest(plugin_id="Plugin_ID_Invalid")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
def test_validate_manifest_id_with_uppercase(self) -> None:
manifest = _fresh_manifest(plugin_id="UpperCase")
with pytest.raises(ValueError, match="Invalid plugin id format"):
validate_manifest(manifest)
# ---------------------------------------------------------------------------
# RevenueShare (DB-backed)
# ---------------------------------------------------------------------------
class TestRevenueShare:
@pytest.fixture
def rs(self) -> RevenueShare:
return RevenueShare()
@pytest.mark.asyncio
async def test_record_install_free_plugin(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-github-sync")
)
event = result.scalar_one()
assert event.developer_share_cents == 0
@pytest.mark.asyncio
async def test_record_install_paid_plugin_no_stripe(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(
db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499
)
result = await db_session.execute(
select(RevenueEvent).where(RevenueEvent.plugin_id == "plugin-slack-notify")
)
event = result.scalar_one()
assert event.amount_cents == 499
assert event.developer_share_cents == int(499 * 0.70)
@pytest.mark.asyncio
async def test_record_install_increments_registry_count(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
reg = PluginRegistry()
await rs.record_install(db_session, "plugin-github-sync", TEST_USER_IDS["power"], amount_cents=0)
entry = await reg.get_plugin(db_session, "plugin-github-sync")
assert entry is not None
assert entry["install_count"] == 1
@pytest.mark.asyncio
async def test_get_earnings_empty(
self, rs: RevenueShare, db_session: AsyncSession
) -> None:
result = await rs.get_earnings(db_session, "unknown-dev")
assert result["total_installs"] == 0
assert result["total_revenue_cents"] == 0
assert result["developer_share_cents"] == 0
@pytest.mark.asyncio
async def test_get_earnings_aggregates(
self, rs: RevenueShare, db_session: AsyncSession, seed_plugins: list[Plugin]
) -> None:
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["power"], amount_cents=499)
await rs.record_install(db_session, "plugin-slack-notify", TEST_USER_IDS["pro"], amount_cents=499)
result = await rs.get_earnings(db_session, "Adiuva")
assert result["total_installs"] == 2
assert result["total_revenue_cents"] == 998
assert result["developer_share_cents"] == int(499 * 0.70) * 2
# ---------------------------------------------------------------------------
# Route integration tests
# ---------------------------------------------------------------------------
class TestPluginRoutes:
def test_list_plugins_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("free"))
assert resp.status_code == 403
def test_list_plugins_pro_tier_blocked(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("pro"))
assert resp.status_code == 403
def test_list_plugins_power_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("power"))
assert resp.status_code == 200
data = resp.json()
assert "plugins" in data
assert data["total"] == 3
def test_list_plugins_team_tier_ok(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins", headers=auth_header("team"))
assert resp.status_code == 200
def test_get_plugin_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/plugin-github-sync", headers=auth_header())
assert resp.status_code == 200
data = resp.json()
assert data["plugin"]["id"] == "plugin-github-sync"
assert "install_count" in data
def test_get_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.get("/api/v1/plugins/no-such-plugin", headers=auth_header())
assert resp.status_code == 404
def test_install_plugin_free(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header(),
)
assert resp.status_code == 200
data = resp.json()
assert data["ok"] is True
assert "download_url" in data
def test_install_plugin_not_found(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/ghost/install",
json={"plugin_id": "ghost"},
headers=auth_header(),
)
assert resp.status_code == 404
def test_uninstall_plugin_ok(self, client, seed_plugins) -> None:
resp = client.delete(
"/api/v1/plugins/plugin-github-sync/install",
headers=auth_header(),
)
assert resp.status_code == 200
assert resp.json()["ok"] is True
def test_install_requires_power_tier(self, client, seed_plugins) -> None:
resp = client.post(
"/api/v1/plugins/plugin-github-sync/install",
json={"plugin_id": "plugin-github-sync"},
headers=auth_header("free"),
)
assert resp.status_code == 403

View File

@@ -12,7 +12,6 @@ from __future__ import annotations
import re import re
from pathlib import Path from pathlib import Path
import pytest
import yaml import yaml
from app.core.preprocessors import detect_content_type, preprocess from app.core.preprocessors import detect_content_type, preprocess

View File

@@ -45,9 +45,6 @@ def test_v2_frame_types_still_exist():
"tool_result", "tool_result",
"final", "final",
"ping", "ping",
"agent_run",
"agent_data",
"agent_complete",
"device_hello", "device_hello",
] ]
for name in v2_types: for name in v2_types:

Some files were not shown because too many files have changed in this diff Show More