31 Commits

Author SHA1 Message Date
Roberto Musso
e668e3fd20 update setting page 2026-04-15 11:43:56 +02:00
Roberto Musso
7ccdad431f feat(i18n): inject user language into AI agent system prompts
- Add _language_instruction() to deep_agent.py, reads language from core memory
- Append language directive to all 4 run_* functions (task/project/checkpoint/note)
- Minor fixes: alembic env, route imports, test cleanup
2026-04-12 00:35:23 +02:00
Roberto Musso
4073863dc6 feat: add onboarding wizard backend - migration, schema, memory routes 2026-04-11 23:38:53 +02:00
Roberto Musso
a85f8fde29 feat(langfuse): propagate user_id and session_id to all traces
- Add hash_user_id() to SHA-256 hash user IDs before sending to Langfuse
- Add langfuse_context() helper wrapping propagate_attributes()
- deep_agent: extract session_id from _debug context, wrap all agent
  runs and classifier with langfuse_context(user_id, session_id)
- agent_runner: add session_id param, pass run_id as session for batch
- agent_setup: wrap journey LLM calls with langfuse_context
- Remove redundant metadata dicts (now handled by propagate_attributes)
2026-04-10 22:44:05 +02:00
Roberto Musso
90500a3462 fix: return 409 when unverified OAuth email conflicts with existing account
Before: branch 3 of oauth_callback attempted to INSERT a user with a
duplicate email → DB constraint violation → 500.

After: if email_verified=False and the email already exists, raise 409
with a message directing the user to sign in with their password.

Also adds test_callback_unverified_email_conflict_returns_409.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 13:46:15 +02:00
Roberto Musso
c1a8ac7669 test: add TestOAuth suite for Google OAuth routes
6 tests covering the authorize and callback endpoints:
- authorize returns URL + state, 503 when unconfigured
- callback: state mismatch → 401, new user creation, existing OAuth
  link re-login (same user sub), email-match auto-linking to password user

Provider methods (exchange_code, get_userinfo) are mocked via AsyncMock
so tests run without hitting Google APIs.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 13:42:11 +02:00
Roberto Musso
c510cbaae5 feat: add OAuth web-callback route and update OAUTH_REDIRECT_URI default
GET /auth/oauth/{provider}/web-callback receives the Google redirect and
bounces immediately to adiuvai://oauth/callback deep link. Google Cloud
Console only accepts http/https redirect URIs — adiuvai:// is not valid.
Default OAUTH_REDIRECT_URI now points to localhost:8000 for dev; override
with the API domain env var in production.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 13:03:05 +02:00
Roberto Musso
ce139bbac3 feat: add OAuth DB schema — oauth_accounts table, nullable password_hash, avatar_url on User
Step 1 of Google login integration: Alembic migration for oauth_accounts +
avatar_url on users, OAuthAccount model with User relationship, UserProfile
schema extended with avatar_url, get_current_user updated to include avatar_url.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-10 09:20:52 +02:00
Roberto Musso
3cf067faea feat: enhance agent configuration and model management with per-agent overrides 2026-04-10 08:45:14 +02:00
Roberto Musso
7253f6fe72 testing journey agent creation 2026-04-09 00:40:16 +02:00
Roberto Musso
41db3a7089 update env variables 2026-04-08 23:52:52 +02:00
Roberto Musso
cc94194fd1 update app name 2026-04-08 23:27:34 +02:00
Roberto Musso
96c91e386d remove deprecated docs 2026-04-08 23:23:14 +02:00
Roberto Musso
c0aef71141 refactor(tests): remove non-deterministic journey eval cases 4.2–4.5
Keep only 4.1 (first reply contains question) as automated eval.
Multi-turn cases (4.2–4.5) are non-deterministic and tested manually
with results tracked in Langfuse.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 09:41:43 +02:00
Roberto Musso
467abc8d42 Merge branch 'develop' into feature/batch-agent-v2 2026-04-08 00:48:23 +02:00
Roberto Musso
5753f8def9 refactor: remove storage, backup, plugin/marketplace features
- Delete app/storage/ (blob_store, vector_store, encryption)
- Delete app/marketplace/ (plugin_registry, plugin_review, revenue_share)
- Delete routes: backup.py, plugins.py, storage.py, vectors.py
- Relocate embed endpoint to POST /chat/embed
- Rewrite migration 001 (remove storage/plugin tables)
- Delete migration 002 (seed_plugins)
- Remove S3/Pinecone/Qdrant env vars from settings
- Remove storage/backup quotas from tier_manager
- Remove MinIO and Qdrant from docker-compose
- Delete tests: test_backup, test_plugins, test_storage
- Update README.md and clean .env.example
2026-04-08 00:47:37 +02:00
Roberto Musso
e672b58b6f fix(langfuse): remove invalid user_id/session_id kwargs from start_as_current_observation
Langfuse V3 does not accept user_id/session_id on observation-level calls.
Moved to metadata dict in agent_runner, deep_agent, and agent_setup.

refactor(tests): fixture-based pattern for agent_runner_v2 eval tests

- cases.yaml + data/ fixtures under tests/fixtures/agent_runner_v2/
- pytest_generate_tests parametrizes test_eval_runner from YAML
- _resolve_projects() handles symbolic names and inline dicts
- _evaluate_case() centralizes all assertion logic
- --runner-dir CLI option for custom fixture folders

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:45:15 +02:00
Roberto Musso
d8add7e8cb feat(local-agent-v2): step 4 — journey produces structured AgentConfig JSON
Replace freeform prompt_template output with validated AgentConfig JSON:
- agent_setup.py: new system prompt (journey_system_v2), AGENT_CONFIG_START/END
  markers, _extract_agent_config() with Pydantic validation, updated handlers
  returning agent_config key; import AgentConfig from schemas
- tests/test_journey_v2.py: 6 unit tests + 5 parametrized LLM eval cases
  following test_agent_runner_v2.py pattern; _run_journey uses
  set_client_executor/clear_client_executor mirroring device_ws
- tests/fixtures/journey_v2/: cases.yaml + email_action.html + email_info.html
- tests/conftest.py: add --journey-dir CLI option; remove S3/plugin fixtures
  (cleanup from microservices migration, already present in working tree)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-08 00:23:58 +02:00
Roberto Musso
c6c4578f9a fix(tests): migrate eval tests to Langfuse V3 API
lf.trace() and lf.score(trace_id=...) are V2 API removed in V3.

V3 pattern:
  lf.start_as_current_observation(name=...) as context manager → obs
  obs.score(name=..., value=...)
  contextlib.nullcontext() when lf is None so structure stays the same

Updated tests 2.1–2.7 in test_agent_runner_v2.py accordingly.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 23:04:24 +02:00
Roberto Musso
3aa0b36a6c fix(langfuse): use compile() instead of .format() for prompt variable injection
Langfuse uses {{variable}} syntax in its prompt management UI, while the
hardcoded fallbacks use {variable} (Python str.format). The previous code
always called .format() which silently failed/errored when a real Langfuse
prompt was fetched.

- langfuse_client.py: add compile_prompt(template, prompt_obj, **vars)
  → uses prompt_obj.compile(**vars) when Langfuse is available
  → falls back to template.format(**vars) when using the hardcoded fallback
- agent_runner.py: replace .format() with compile_prompt() for
  unified_processing (V2 local) and batch_cloud_processing (cloud agent)
- agent_setup.py: replace .format() with compile_prompt() for journey_system

deep_agent.py prompts have no variables, so no change needed there.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 16:49:26 +02:00
Roberto Musso
fa231a3642 feat(local-agent-v2): step 2+3 — unified runner + AgentConfig schema
Step 3 (prerequisite):
- app/schemas.py: add ContentTypeConfig + AgentConfig Pydantic models
- app/models.py: add agent_config (JSON, nullable) to LocalAgentConfig
- alembic migration a3b9c0d1e2f3: ADD COLUMN agent_config

Step 2 (runner refactor):
- Remove _classify_file() and _BATCH_FILE_CLASSIFIER_PROMPT (LLM classification step)
- Add Phase A: detect_content_type + preprocess (zero LLM, per file)
- Add _UNIFIED_PROCESSING_PROMPT (hot-swappable via Langfuse "unified_processing")
- Add helper functions: _format_projects, _format_metadata, _get_extraction_rules,
  _get_no_match_behavior
- Single LLM call per file with tools (classify + extract + create)
- Fix items_created: count create_* tool calls via _tool_calls_out param
- test_agent_runner_v2.py: 10 cases (2.1-2.10) with Langfuse eval scoring

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 15:00:32 +02:00
Roberto Musso
d91c98f86d chore(tests): remove Langfuse from all preprocessor tests
I test del preprocessor sono deterministici — nessun LLM coinvolto,
nessuno score da tracciare.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:26:33 +02:00
Roberto Musso
c0619f5c4d fix(tests): move pytest_addoption after __future__ import in conftest
SyntaxError: from __future__ imports must occur at the beginning of the file.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:21:50 +02:00
Roberto Musso
da282229ff refactor(tests): remove redundant filename field
file: serve sia come path da leggere che come nome passato a detect_content_type.
Non c'è motivo di averli separati.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 14:13:14 +02:00
Roberto Musso
7fa6ad5760 feat(tests): add --preprocess-dir CLI option to pytest
- conftest.py: registra --preprocess-dir via pytest_addoption
- test_preprocessors.py: usa pytest_generate_tests per leggere i casi
  a collection time con accesso a config; _content e _fixtures_dir
  accettano path dinamico

Usage: pytest tests/test_preprocessors.py --preprocess-dir /my/folder

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 13:59:32 +02:00
Roberto Musso
dcd14220ca refactor(tests): simplify YAML fixture schema and test runner
YAML: rimosse op/description/score_name/assertions block — ora detect/process
come chiave diretta, assertions piatte sullo stesso livello del caso.

Runner: eliminato _run_assertions engine, assertions inline in test_preprocess.
Riduzione da ~170 a ~75 righe totali tra YAML + test.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 11:30:38 +02:00
Roberto Musso
3cc32569d9 chore(tests): remove Langfuse scoring from preprocess tests
Scoring is only meaningful for LLM-backed steps. Preprocess tests are
deterministic Python, so scores add no value. Kept only for detect tests.

- test_preprocess: drop _lf_score call, simplify _run_assertions return type
- cases.yaml: remove score_name from all op=preprocess entries

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 11:21:42 +02:00
Roberto Musso
bf445ac2ce refactor(tests): YAML-driven fixtures for preprocessor tests
- cases.yaml: 10 test cases con schema dichiarativo (op, assertions)
- data/: 7 file reali (email_action.html, email_thread.html, email_single.html,
  email_heavy.html, generic_page.html, notes.txt, fallback.txt)
- test_preprocessors.py: parametrize da YAML via test_detect / test_preprocess;
  assertion engine generico (no_html_tags, min_length, compression_ratio,
  metadata_keys, contains, not_contains, content_type)
- requirements.txt: add PyYAML

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 10:44:41 +02:00
Roberto Musso
a2d6d689e4 feat: add preprocessor system (Step 1 — Local Agent V2)
- app/core/preprocessors/__init__.py: detect_content_type + preprocess dispatcher
- app/core/preprocessors/base.py: PreprocessResult dataclass
- app/core/preprocessors/email_html.py: BeautifulSoup HTML stripping, metadata extraction, thread splitting
- requirements.txt: add beautifulsoup4 and lxml
- tests/test_preprocessors.py: 10 tests with Langfuse scoring (preprocess.* scores)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 10:19:02 +02:00
Roberto Musso
aa8bcbf0d8 Refactor system prompt variables for clarity and consistency across agent setup and runner modules 2026-04-07 00:23:41 +02:00
Roberto Musso
1ce1d492b0 Add Langfuse observability: traces, prompt management, prompt-to-generation linking
- New app/core/langfuse_client.py: lazy singleton client, get_prompt_or_fallback()
  helper (returns raw template + prompt obj for linking), extract_usage() for token
  counts. No-ops when LANGFUSE_* env vars are not set.
- deep_agent.py: home-agent and floating-agent runs wrapped in spans; each ainvoke
  wrapped in a generation with model/input/output/usage; prompts fetched from
  Langfuse (adiuva-home-agent, adiuva-floating-agent, adiuva-floating-classifier)
  with hardcoded fallback.
- agent_runner.py: step1-classifier and step2-processor LLM calls traced; batch
  agent _run_agent_with_tools spans + generations; cloud-processor included.
  Prompts: adiuva-step1-classifier, adiuva-step2-processor, adiuva-cloud-processor.
- agent_setup.py: journey-setup span + generation per ainvoke; prompt_obj stored
  on JourneySession and reused across turns. Prompt: journey_system.
- settings.py: LANGFUSE_SECRET_KEY, LANGFUSE_PUBLIC_KEY, LANGFUSE_HOST added.
- .env.example: Langfuse section with EU/US/self-hosted host comments.
- requirements.txt: langfuse>=2.0.0.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-04-07 00:19:20 +02:00
134 changed files with 3869 additions and 10688 deletions

View File

@@ -2,55 +2,69 @@
ENV=dev
# ── Database ──────────────────────────────────────────────────────────────────
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai
# ── Redis ─────────────────────────────────────────────────────────────────────
REDIS_URL=redis://localhost:6379/0
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
# Public key for optional local JWT verification (Traefik ForwardAuth handles
# this in production — services trust X-User-* headers from Traefik).
# Generate keypair:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# openssl rsa -in private.pem -pubout -out public.pem
# Paste PEM content with literal \n for newlines.
JWT_PUBLIC_KEY=
# ── Auth ──────────────────────────────────────────────────────────────────────
JWT_SECRET=replace-with-a-long-random-secret
JWT_ALGORITHM=HS256
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
# ── LLM ───────────────────────────────────────────────────────────────────────
# LiteLLM model identifiers — change to swap providers without code changes.
# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3
#
# API keys — only the key(s) matching your chosen provider(s) are required.
# The correct key is picked automatically from the model prefix (e.g.
# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY).
OPENAI_API_KEY=
ANTHROPIC_API_KEY=
GOOGLE_API_KEY=
LLM_MODEL=gpt-4o
CEREBRAS_API_KEY=
# Default model used by any agent that does not have a specific override below.
LLM_MODEL=gpt-5-mini
LLM_EMBED_MODEL=text-embedding-3-small
# GitHub Copilot — leave empty to use the LiteLLM default token directory.
# In Docker, point this to a named-volume path so tokens survive restarts.
# GITHUB_COPILOT_TOKEN_DIR=
# ── Per-agent model overrides ─────────────────────────────────────────────────
# Leave a value empty to fall back to LLM_MODEL.
# Each agent resolves its API key from the model prefix automatically.
#
# Intent classifier — routes user messages to the right domain agent.
# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here.
LLM_MODEL_CLASSIFIER=
# Home-agent — handles chat from the home screen (all tools available).
LLM_MODEL_HOME_AGENT=
# Floating-agent — handles contextual chat triggered from a task/project/note.
LLM_MODEL_FLOATING_AGENT=
# Unified-processor — processes local directory files (local agent runner).
LLM_MODEL_UNIFIED_PROCESSOR=
# Cloud-processor — fetches and processes data from cloud connectors.
LLM_MODEL_CLOUD_PROCESSOR=
# Setup-agent — guided journey to build an AgentConfig via WebSocket chat.
LLM_MODEL_SETUP_AGENT=
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=
# ── AWS / S3 ──────────────────────────────────────────────────────────────────
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=
AWS_ACCESS_KEY_ID=
AWS_SECRET_ACCESS_KEY=
# For MinIO (homelab): S3_ENDPOINT_URL=http://minio:9000
# ── Vector Store ──────────────────────────────────────────────────────────────
# Pinecone is used when PINECONE_API_KEY is set; otherwise falls back to Qdrant.
PINECONE_API_KEY=
PINECONE_INDEX=adiuva
QDRANT_URL=
QDRANT_API_KEY=
# For local Qdrant (homelab): QDRANT_URL=http://qdrant:6333
# ── Langfuse (leave empty to disable observability) ───────────────────────────
LANGFUSE_SECRET_KEY=
LANGFUSE_PUBLIC_KEY=
# LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default)
# LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US
# LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted
# ── CORS ──────────────────────────────────────────────────────────────────────
# Comma-separated list parsed by Settings (override default if needed)
# CORS_ORIGINS=["app://.","http://localhost:3000"]
# ── Langfuse (observability) ─────────────────────────────────────────────────
LANGFUSE_SECRET_KEY=sk-lf-...
LANGFUSE_PUBLIC_KEY=pk-lf-...
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL

View File

@@ -48,23 +48,23 @@ jobs:
key: ${{ secrets.SSH_KEY }}
script: |
set -e
DEPLOY_DIR="/opt/adiuva-api"
DEPLOY_DIR="/opt/adiuvai-api"
REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git"
TAG="${{ gitea.ref_name }}"
# ── Pull latest code ──
cd /tmp && rm -rf adiuva-api-deploy
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuva-api-deploy
cd /tmp && rm -rf adiuvai-api-deploy
git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy
# ── Sync source (preserve .env) ──
cp -rf /tmp/adiuva-api-deploy/app/ \
/tmp/adiuva-api-deploy/alembic/ \
/tmp/adiuva-api-deploy/alembic.ini \
/tmp/adiuva-api-deploy/Dockerfile \
/tmp/adiuva-api-deploy/docker-compose.yml \
/tmp/adiuva-api-deploy/requirements.txt \
cp -rf /tmp/adiuvai-api-deploy/app/ \
/tmp/adiuvai-api-deploy/alembic/ \
/tmp/adiuvai-api-deploy/alembic.ini \
/tmp/adiuvai-api-deploy/Dockerfile \
/tmp/adiuvai-api-deploy/docker-compose.yml \
/tmp/adiuvai-api-deploy/requirements.txt \
"$DEPLOY_DIR/"
rm -rf /tmp/adiuva-api-deploy
rm -rf /tmp/adiuvai-api-deploy
# ── Verify .env ──
if [ ! -f "$DEPLOY_DIR/.env" ]; then

View File

@@ -58,7 +58,7 @@ jobs:
- uses: actions/checkout@v4
- name: Build image
run: docker build -t adiuva-api:ci .
run: docker build -t adiuvai-api:ci .
- name: Verify gunicorn installed
run: docker run --rm adiuva-api:ci gunicorn --version
run: docker run --rm adiuvai-api:ci gunicorn --version

4
.gitignore vendored
View File

@@ -13,9 +13,6 @@ env/
# Environment variables
.env
# Cryptographic keys
*.pem
# IDE
.vscode/
.idea/
@@ -24,6 +21,7 @@ env/
.pytest_cache/
htmlcov/
.coverage
tests/fixtures/private*/
# Docker
*.log

793
README.md
View File

@@ -1,793 +0,0 @@
# Adiuva Cloud API
**AI-powered project management backend with E2E encrypted cloud storage, LLM orchestration, and a plugin marketplace.**
Built with FastAPI · Python 3.12 · PostgreSQL · LangChain · Stripe · AWS S3
---
## Table of Contents
- [Overview](#overview)
- [Architecture](#architecture)
- [Key Features](#key-features)
- [Tech Stack](#tech-stack)
- [Getting Started](#getting-started)
- [Docker Deployment](#docker-deployment)
- [Environment Variables](#environment-variables)
- [API Reference](#api-reference)
- [Data Model](#data-model)
- [AI Agent System](#ai-agent-system)
- [Orchestration & Execution Plans](#orchestration--execution-plans)
- [Middleware](#middleware)
- [Storage Layer](#storage-layer)
- [Billing & Tiers](#billing--tiers)
- [Plugin Marketplace](#plugin-marketplace)
- [Testing](#testing)
- [Project Structure](#project-structure)
- [License](#license)
---
## Overview
Adiuva Cloud API is the FastAPI backend that powers the **Adiuva Electron desktop app**. It provides LLM-powered chat orchestration, end-to-end encrypted cloud storage, a vector search engine, an encrypted backup system, a plugin marketplace with revenue sharing, and Stripe-based subscription billing across four tiers.
### Design Principles
1. **Never persist user data in plaintext** — the database stores only auth, billing, storage metadata, and marketplace data. All user content is E2E encrypted by the client before reaching the server.
2. **Never expose prompts** — system prompts stay server-side; responses are sanitized to strip any leaked prompt fragments.
3. **Never decrypt user blobs** — the backend performs only checksum verification; no decryption keys ever reach the server.
4. **Stateless request handling** — all context comes from the client and JWT; no server-side session state.
5. **Tier gates enforced server-side** — the server always reads the current tier from the database, never trusting client-reported values.
---
## Architecture
```
┌──────────────┐ ┌────────────────────────────────────────────────────────┐
│ Electron │ │ FastAPI (Uvicorn / Gunicorn) │
│ Desktop App │────▶│ │
│ (Client) │◀────│ Middleware: RateLimit → Sanitizer → CORS → Router │
└──────────────┘ │ │
│ ┌──────────────────┐ ┌────────────────────────────┐ │
│ │ Auth Routes │ │ Chat Routes │ │
│ │ Billing Routes │ │ ↓ │ │
│ │ Storage Routes │ │ Orchestrator (GPT-4o-mini)│ │
│ │ Backup Routes │ │ ↓ classify intent │ │
│ │ Plugin Routes │ │ Agent Registry │ │
│ │ Vector Routes │ │ ↓ │ │
│ │ Plans Routes │ │ TaskAgent | ProjectAgent │ │
│ └──────────────────┘ │ NoteAgent | CheckptAgent │ │
│ │ (GPT-4o + LangChain) │ │
│ └────────────────────────────┘ │
└────────────────────────────────────────────────────────┘
│ │ │
┌────────▼───┐ ┌───────▼───────┐ ┌──▼─────────────┐
│ PostgreSQL │ │ AWS S3 │ │ Pinecone / │
│ (Auth, │ │ (E2E blobs, │ │ Qdrant │
│ Billing, │ │ backups) │ │ (Vectors) │
│ Metadata) │ └───────────────┘ └────────────────┘
└────────────┘
┌────────▼───┐
│ Stripe │
│ (Billing, │
│ Connect) │
└────────────┘
```
---
## Key Features
1. **LLM-powered orchestration** — GPT-4o-mini classifies user intent and routes to the appropriate domain agent.
2. **4 specialized AI agents** — Tasks (8 tools), Projects (6 tools), Timelines (4 tools), Notes (5 tools), all powered by GPT-4o via LangChain.
3. **Execution plans & playbooks** — Server-side prompt template registry; clients receive only opaque template IDs, never raw prompts.
4. **E2E encrypted cloud storage** — The backend never decrypts user data; SHA-256 checksum verification uses constant-time comparison to prevent timing attacks.
5. **Cloud vector store** — Pinecone or Qdrant with user-isolated namespaces and encrypted blob payloads.
6. **Encrypted backup system** — Tiered storage limits with `If-Modified-Since` support for efficient syncing.
7. **Plugin marketplace** — Catalog, admin review/approval workflow, security checklist, and 70/30 revenue sharing via Stripe Connect.
8. **Stripe billing** — Four-tier subscription model (Free / Pro / Power / Team) with checkout sessions and full webhook lifecycle handling.
9. **JWT authentication** — Access + refresh tokens with bcrypt password hashing, SHA-256 token hashing, and automatic rotation.
10. **Prompt IP protection** — Sanitizer middleware strips system prompts, reasoning markers, tool schemas, and agent routing metadata from all chat responses.
11. **Tier-based rate limiting** — Sliding-window per-user limiter scaling from 20 to 200 requests/min by subscription tier.
12. **Zero-trust data model** — User content is never stored in plaintext; the database holds only authentication, billing, and metadata records.
13. **WebSocket streaming** — Real-time chat with 30-second heartbeat keep-alive and chunked text delivery.
14. **Alembic migrations** — Versioned schema management with seed data for the plugin marketplace.
15. **Comprehensive test suite** — In-memory SQLite + moto S3 mocks, per-tier test fixtures, and full API coverage without external dependencies.
---
## Tech Stack
| Package | Version | Purpose |
|---|---|---|
| `fastapi` | ≥ 0.115.0 | Web framework |
| `uvicorn[standard]` | ≥ 0.34.0 | ASGI development server |
| `gunicorn` | ≥ 22.0.0 | Production process manager |
| `langchain` | ≥ 0.3.0 | LLM orchestration framework |
| `langchain-openai` | ≥ 0.3.0 | OpenAI LLM provider integration |
| `litellm` | ≥ 1.50.0 | Universal LLM gateway (100+ providers) |
| `pydantic` | ≥ 2.10.0 | Data validation and serialization |
| `pydantic-settings` | ≥ 2.7.0 | Environment-based configuration |
| `python-jose[cryptography]` | ≥ 3.3.0 | JWT encoding and decoding |
| `stripe` | ≥ 11.0.0 | Billing and payment integration |
| `boto3` | ≥ 1.35.0 | AWS S3 client |
| `slowapi` | ≥ 0.1.9 | Rate limiting utilities |
| `sqlalchemy` | ≥ 2.0.0 | Async ORM and query builder |
| `asyncpg` | ≥ 0.30.0 | PostgreSQL async driver |
| `alembic` | ≥ 1.14.0 | Database migration management |
| `bcrypt` | ≥ 4.2.0 | Password hashing |
| `python-dotenv` | ≥ 1.0.0 | `.env` file loading |
| `httpx` | ≥ 0.28.0 | Async HTTP client (used in tests) |
| `websockets` | ≥ 14.0 | WebSocket protocol support |
| `psycopg2-binary` | ≥ 2.9.0 | Synchronous PostgreSQL driver (Alembic) |
| `pinecone` | ≥ 5.0.0 | Pinecone vector store client |
| `qdrant-client` | ≥ 1.7.0 | Qdrant vector store client |
| `pytest` | ≥ 8.0.0 | Test framework |
| `pytest-asyncio` | ≥ 0.24.0 | Async test support |
| `aiosqlite` | ≥ 0.20.0 | In-memory SQLite for tests |
| `moto[s3]` | ≥ 5.0.0 | AWS S3 mock for tests |
| `ruff` | ≥ 0.8.0 | Linter and formatter |
---
## Getting Started
### Prerequisites
- Python 3.12+
- PostgreSQL 16+
- An OpenAI API key (for LLM features)
- Stripe API keys (optional — billing stubs gracefully when unconfigured)
- AWS credentials (optional — needed for S3 storage in production)
### Installation
```bash
# Clone the repository
git clone <repo-url> && cd adiuva-api
# Create a virtual environment
python -m venv .venv && source .venv/bin/activate
# Install dependencies
pip install -r requirements.txt
# Configure environment
cp .env.example .env
# Edit .env with your DATABASE_URL, OPENAI_API_KEY, etc.
```
### Database Setup
```bash
# Start PostgreSQL (or use the Docker Compose database)
docker compose up db -d
# Run migrations
alembic upgrade head
```
### Run the Development Server
```bash
uvicorn app.main:app --reload --host 0.0.0.0 --port 8000
```
Interactive API docs are available at [http://localhost:8000/docs](http://localhost:8000/docs) in development mode (`ENV=dev`). The `/docs` endpoint is disabled in production.
---
## Docker Deployment
### Quick Start
```bash
docker compose up --build
```
This starts two services:
- **app** — FastAPI server on port `8000`
- **db** — PostgreSQL 16 (Alpine) on port `5432` with a persistent volume and health checks
The compose file also includes optional services for fully local deployments:
- **minio** — S3-compatible object storage on ports `9000` (API) and `9001` (console)
- **qdrant** — Vector search engine on ports `6333` (HTTP) and `6334` (gRPC)
### Dockerfile Details
The Dockerfile uses a multi-stage build:
1. **Builder stage** — Installs Python dependencies into a virtual environment.
2. **Runtime stage** — Copies only the venv, app source, and Alembic migrations. Runs as a non-root user (`appuser`).
3. **Production server** — Gunicorn with 4 Uvicorn workers, 120-second timeout, listening on port 8000.
```bash
# Production command (run by the container)
gunicorn app.main:app -k uvicorn.workers.UvicornWorker -w 4 --timeout 120 -b 0.0.0.0:8000
```
---
## Homelab / Self-Hosted Deployment
You can run the entire stack locally on a homelab with **no cloud dependencies except the LLM provider**. The compose file includes MinIO (S3 replacement) and Qdrant (vector store) out of the box.
### 1. Start all services
```bash
docker compose up -d
```
This starts PostgreSQL, MinIO, and Qdrant alongside the app.
### 2. Create the MinIO bucket
Open the MinIO console at [http://localhost:9001](http://localhost:9001) (login: `minioadmin` / `minioadmin`) and create a bucket named `adiuva`, or use the CLI:
```bash
docker compose exec minio mc alias set local http://localhost:9000 minioadmin minioadmin
docker compose exec minio mc mb local/adiuva
```
### 3. Configure your `.env`
```bash
# Database (uses the compose PostgreSQL)
DATABASE_URL=postgresql+asyncpg://postgres:postgres@db:5432/adiuva
# S3 → MinIO
S3_BUCKET=adiuva
S3_REGION=us-east-1
S3_ENDPOINT_URL=http://minio:9000
AWS_ACCESS_KEY_ID=minioadmin
AWS_SECRET_ACCESS_KEY=minioadmin
# Vector store → local Qdrant (leave PINECONE_API_KEY empty)
QDRANT_URL=http://qdrant:6333
QDRANT_API_KEY=
PINECONE_API_KEY=
# Billing — leave empty to stub (no Stripe needed)
STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=
# LLM — the only external service
OPENAI_API_KEY=sk-...
LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# Auth
JWT_SECRET=your-secret-here
ENV=dev
```
### 4. Run migrations
```bash
docker compose exec app alembic upgrade head
```
### What runs where
| Service | Runs on | Port | Notes |
|---|---|---|---|
| FastAPI app | Docker | 8000 | API server |
| PostgreSQL | Docker | 5432 | Auth, billing, metadata |
| MinIO | Docker | 9000 / 9001 | S3-compatible blob & backup storage |
| Qdrant | Docker | 6333 / 6334 | Vector search (replaces Pinecone) |
| Stripe | — | — | Stubbed when keys are empty |
| OpenAI / LLM | Cloud | — | Only external dependency |
> **Want fully offline AI too?** Set `LLM_MODEL=ollama/llama3` and `LLM_ROUTER_MODEL=ollama/llama3`, then add an Ollama container or point at a local Ollama instance. See the [LLM provider switching](#switching-llm-providers) section.
---
## Environment Variables
All variables are loaded from a `.env` file via Pydantic Settings. Source: `app/config/settings.py`
| Variable | Type | Default | Description |
|---|---|---|---|
| `DATABASE_URL` | `str` | `postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva` | Async SQLAlchemy connection string |
| `JWT_SECRET` | `str` | `change-me-in-production` | HMAC secret for JWT signing |
| `JWT_ALGORITHM` | `str` | `HS256` | JWT signing algorithm |
| `JWT_ACCESS_TOKEN_EXPIRE_MINUTES` | `int` | `30` | Access token time-to-live |
| `JWT_REFRESH_TOKEN_EXPIRE_DAYS` | `int` | `30` | Refresh token time-to-live |
| `STRIPE_SECRET_KEY` | `str` | `""` | Stripe API key (empty = stub mode) |
| `STRIPE_WEBHOOK_SECRET` | `str` | `""` | Stripe webhook signature secret |
| `S3_BUCKET` | `str` | `""` | S3 bucket for encrypted blobs and backups |
| `S3_REGION` | `str` | `us-east-1` | AWS region |
| `S3_ENDPOINT_URL` | `str` | `""` | Custom S3 endpoint (e.g. `http://minio:9000` for MinIO). Leave empty for AWS. |
| `AWS_ACCESS_KEY_ID` | `str` | `""` | AWS credentials |
| `AWS_SECRET_ACCESS_KEY` | `str` | `""` | AWS credentials |
| `PINECONE_API_KEY` | `str` | `""` | Pinecone API key (if set, Pinecone is used for vectors) |
| `PINECONE_INDEX` | `str` | `adiuva` | Pinecone index name |
| `QDRANT_URL` | `str` | `""` | Qdrant URL (used when Pinecone is not configured) |
| `QDRANT_API_KEY` | `str` | `""` | Qdrant API key |
| `OPENAI_API_KEY` | `str` | `""` | OpenAI key for LLM agent calls |
| `LLM_MODEL` | `str` | `gpt-4o` | LiteLLM model identifier for agents (e.g. `anthropic/claude-3.5-sonnet`, `gemini/gemini-pro`, `ollama/llama3`) |
| `LLM_ROUTER_MODEL` | `str` | `gpt-4o-mini` | Lighter model used for intent classification / routing |
| `CORS_ORIGINS` | `list[str]` | `["app://.", "http://localhost:3000", "http://localhost:5173"]` | Allowed CORS origins |
| `ENV` | `Literal` | `dev` | `dev` or `prod` — controls `/docs` visibility and SQL echo |
---
## API Reference
All routes are prefixed with `/api/v1`. **27 endpoints** total (25 REST + 1 WebSocket + 1 health check).
### Health
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/health` | No | Returns `{"status": "ok", "version": "0.1.0"}` |
### Auth
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/auth/register` | No | Create account with bcrypt-hashed password, returns `AuthTokens` |
| `POST` | `/api/v1/auth/login` | No | Validate credentials, returns `AuthTokens` |
| `POST` | `/api/v1/auth/refresh` | No | Rotate refresh token, returns new `AuthTokens` |
| `GET` | `/api/v1/auth/me` | JWT | Returns `UserProfile` for the authenticated user |
### Chat
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/chat` | JWT | Route message through the orchestrator; returns `ChatResponse` or `ExecutionPlan` depending on execution mode |
| `WS` | `/api/v1/chat/stream` | JWT (query param `?token=`) | Streaming chat — first frame is a `ChatRequest`, server yields text chunks, final frame is `{"done": true, "response": "...", "actions": [...]}`. 30-second heartbeat ping. |
### Plans
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plans/playbook` | JWT | List all cached execution plan playbooks |
| `GET` | `/api/v1/plans/playbook/{plan_id}` | JWT | Retrieve a specific playbook by ID |
### Storage (Cloud Records)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/records` | JWT | Upload an E2E encrypted record (verifies checksum, enforces storage quota) |
| `GET` | `/api/v1/storage/records` | JWT | List record metadata with pagination (`?table`, `?page`, `?limit`); no blob bytes returned |
| `GET` | `/api/v1/storage/records/{id}` | JWT | Download encrypted blob with `X-Checksum` response header |
| `PUT` | `/api/v1/storage/records/{id}` | JWT | Replace an existing blob (verifies checksum, enforces quota) |
| `DELETE` | `/api/v1/storage/records/{id}` | JWT | Delete a record and its S3 blob |
### Vectors (Cloud Vector Store)
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/storage/vectors/upsert` | JWT | Verify checksums and upsert encrypted vectors |
| `POST` | `/api/v1/storage/vectors/search` | JWT | Search user-scoped vector namespace |
| `DELETE` | `/api/v1/storage/vectors` | JWT | Delete vectors by ID list |
### Backup
| Method | Path | Auth | Description |
|---|---|---|---|
| `PUT` | `/api/v1/backup` | JWT | Upload encrypted backup blob with custom headers (`X-Backup-Version`, `X-Backup-Timestamp`, `X-Backup-Checksum`). Tier quota enforced. |
| `GET` | `/api/v1/backup` | JWT | Download latest backup blob. Supports `If-Modified-Since`. |
| `GET` | `/api/v1/backup/history` | JWT | List backup metadata (no blob content) |
| `DELETE` | `/api/v1/backup/{backup_id}` | JWT | Delete a specific backup |
### Plugins (Marketplace)
| Method | Path | Auth | Description |
|---|---|---|---|
| `GET` | `/api/v1/plugins` | JWT (Power+) | Browse the marketplace (`?category`, `?q`, `?page`, `?sort=rating\|installs\|newest`) |
| `GET` | `/api/v1/plugins/{id}` | JWT (Power+) | Plugin detail with install count and ratings |
| `POST` | `/api/v1/plugins/{id}/install` | JWT (Power+) | Install plugin; triggers Stripe Connect revenue split for paid plugins |
| `DELETE` | `/api/v1/plugins/{id}/install` | JWT | Uninstall plugin |
### Billing
| Method | Path | Auth | Description |
|---|---|---|---|
| `POST` | `/api/v1/billing/checkout` | JWT | Create a Stripe checkout session, returns `{"checkout_url": "..."}` |
| `POST` | `/api/v1/billing/webhook` | Stripe signature | Handle Stripe events: `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, `invoice.payment_failed` |
| `GET` | `/api/v1/billing/subscription` | JWT | Get current subscription information |
| `DELETE` | `/api/v1/billing/subscription` | JWT | Cancel subscription and revert to free tier |
---
## Data Model
9 tables managed by Alembic migrations. Source: `app/models.py`
### Tables
| Table | Primary Key | Key Columns | Purpose |
|---|---|---|---|
| `users` | `id` (UUID) | `email` (unique), `password_hash`, `tier`, `stripe_customer_id`, timestamps | User accounts |
| `refresh_tokens` | `id` (UUID) | `user_id` (FK), `token_hash` (SHA-256, unique), `expires_at` | Hashed refresh tokens for rotation |
| `subscriptions` | `id` (UUID) | `user_id` (FK, unique), `stripe_subscription_id`, `tier`, `status`, `current_period_end` | Stripe subscription records |
| `storage_records` | `id` (UUID) | `user_id` (FK), `table_name`, `s3_key`, `checksum`, `size_bytes`, timestamps | S3 blob metadata (no plaintext content) |
| `backup_metadata` | `id` (UUID) | `user_id` (FK), `s3_key`, `version`, `timestamp`, `checksum`, `size_bytes` | Backup manifests |
| `plugins` | `id` (String) | `name`, `description`, `version`, `author_id` (FK), `category`, `price_cents`, `permissions` (JSON), `status`, `s3_package_key`, `install_count`, `avg_rating` | Marketplace plugin catalog |
| `plugin_installations` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), unique constraint on (`plugin_id`, `user_id`) | Per-user install tracking |
| `plugin_reviews` | `id` (UUID) | `plugin_id` (FK), `reviewer_id` (FK), `decision`, `notes`, `reviewed_at` | Admin review decisions |
| `revenue_events` | `id` (UUID) | `plugin_id` (FK), `user_id` (FK), `amount_cents`, `developer_share_cents`, `stripe_transfer_id` | 70/30 revenue split ledger |
### Enum Types
| Enum | Values |
|---|---|
| `billing_tier` | `free`, `pro`, `power`, `team` |
| `plugin_status` | `pending_review`, `approved`, `rejected` |
| `review_decision` | `approved`, `rejected` |
### Migrations
| Version | Description |
|---|---|
| `001_initial_schema` | Creates all 9 tables with indexes and foreign key constraints |
| `002_seed_plugins` | Seeds 3 approved plugins: GitHub Sync (free), Slack Notifier (€4.99), Time Tracker (€9.99) |
---
## AI Agent System
The agent system uses a registry pattern with LangChain tool-calling agents powered by GPT-4o. Source: `app/agents/`, `app/core/agent_registry.py`
### Architecture
- **`BaseAgent`** — Abstract base with `user_id`, `shared_memory`, and `vector_store_context`.
- **`ChatAgent(BaseAgent)`** — Abstract `handle(query, context)` and `get_tools()` methods, plus a shared `_tool_loop(llm, messages, tools, max_iter=5)` for iterative tool calling.
- **`AgentRegistry`** — Singleton registry with `@register` decorator, `get(name)`, `list_agents()`, and `call_agent(name, query, context)`.
### Registered Agents
| Agent | Registry Name | Tools | Description |
|---|---|---|---|
| **TaskAgent** | `task_agent` | 8 | Full task and comment CRUD. Status: `todo` / `in_progress` / `done`. Priority: `high` / `medium` / `low`. Tools: `list_tasks`, `create_task`, `update_task`, `delete_task`, `list_tasks_due_today`, `list_task_comments`, `add_task_comment`, `delete_task_comment` |
| **ProjectAgent** | `project_agent` | 6 | Project lifecycle management. Status: `active` / `archived`. Prefers archiving over deletion. Tools: `list_projects`, `list_all_projects`, `get_project`, `create_project`, `update_project`, `delete_project` |
| **TimelineAgent** | `timeline_agent` | 4 | Project milestones. Requires `project_id` for creation. Supports AI-suggestion and approval workflows. Tools: `list_timelines`, `create_timeline`, `update_timeline`, `delete_timeline` |
| **NoteAgent** | `note_agent` | 5 | Markdown note management. Optionally linked to projects. Tools: `list_notes`, `get_note`, `create_note`, `update_note`, `delete_note` |
All agents use the model configured by `LLM_MODEL` (default: GPT-4o) with `temperature=0` via LiteLLM. Tools return JSON action descriptors that the Electron client interprets and applies locally.
### Switching LLM Providers
The backend uses **LiteLLM** as a universal LLM gateway. All agents and the orchestrator instantiate models through a centralized factory in `app/core/llm.py`. To switch providers, change environment variables — no code changes required:
```bash
# OpenAI (default)
LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# Anthropic
LLM_MODEL=anthropic/claude-3.5-sonnet
LLM_ROUTER_MODEL=anthropic/claude-3-haiku
# Google Gemini
LLM_MODEL=gemini/gemini-pro
LLM_ROUTER_MODEL=gemini/gemini-flash
# Local Ollama
LLM_MODEL=ollama/llama3
LLM_ROUTER_MODEL=ollama/llama3
# AWS Bedrock
LLM_MODEL=bedrock/anthropic.claude-v2
LLM_ROUTER_MODEL=bedrock/anthropic.claude-instant-v1
```
See the [LiteLLM provider docs](https://docs.litellm.ai/docs/providers) for the full list of 100+ supported providers and model naming conventions.
---
## Orchestration & Execution Plans
Source: `app/core/orchestrator.py`, `app/core/execution_plan.py`
### Orchestrator
1. **`classify_intent(message, context, registry)`** — Uses the router model (`LLM_ROUTER_MODEL`, default: GPT-4o-mini) to determine which agent should handle a message. Falls back to `task_agent` when classification is ambiguous.
2. **`route_single(agent_name, message, context)`** — Routes to a single agent and returns a `ChatResponse`.
3. **`route_pipeline(agent_names, message, context)`** — Executes agents sequentially; each receives `previous_results` from earlier agents. A final LLM synthesis step merges all results.
4. **`orchestrate(request)`** — Main entry point. In `direct` mode, returns a `ChatResponse`. In `plan` mode, returns an `ExecutionPlan`.
5. **`orchestrate_stream(request)`** — Streaming variant that yields 50-character text chunks with a final JSON frame.
### Execution Plans
- **`PromptTemplateRegistry`** — Maps template IDs to server-side prompt text. Clients only ever see opaque IDs, never raw prompts.
- **`ExecutionPlanBuilder`** — Fluent builder API: `add_step()`, `add_llm_step(template_id, vars)`, `add_data_step(action, data_from_step)`. Validates step references on `build()`.
- **`PlanCache`** — LRU cache (maxsize 1000) for storing plans as reusable playbooks.
### Built-in Templates (6)
`tpl_task_agent_default`, `tpl_timeline_agent_default`, `tpl_project_agent_default`, `tpl_note_agent_default`, `tpl_task_extract_from_project`, `tpl_note_weekly_summary`
### Built-in Playbooks (2)
| Playbook | Description |
|---|---|
| `create_tasks_from_project` | LLM extracts actionable tasks from project context, then creates task records |
| `generate_weekly_note` | LLM generates a weekly summary, then creates a note record |
---
## Middleware
Middleware executes in this order on each request: **TierRateLimit → Sanitizer → CORS → Router**
### JWT Authentication
Source: `app/api/middleware/auth.py`
- FastAPI dependency `get_current_user` validates the `Bearer` JWT and extracts `user_id` and `email`.
- **Live tier lookup** — The current tier is fetched from the `subscriptions` table on every request (not cached in the JWT), so upgrades and downgrades take immediate effect.
- Falls back to `free` when no subscription row exists.
- Raises `401 Unauthorized` on invalid or expired tokens.
- **Exempt paths:** `/api/v1/auth/register`, `/api/v1/auth/login`, `/api/v1/billing/webhook`
### Tier-Based Rate Limiter
Source: `app/api/middleware/rate_limit.py`
- `TierRateLimitMiddleware` — Sliding-window in-process rate limiter (no Redis dependency).
- Per-user 60-second window sized by subscription tier:
| Tier | Requests / Minute |
|---|---|
| Free | 20 |
| Pro | 60 |
| Power | 120 |
| Team | 200 |
- Returns `429 Too Many Requests` with a `Retry-After` header when the limit is exceeded.
- **Exempt paths:** register, login, webhook, health
### Response Sanitizer
Source: `app/api/middleware/sanitizer.py`
- Runs only on `/api/v1/chat` endpoints.
- Scans JSON response bodies and replaces leaked prompt IP fragments with `[REDACTED]`.
- Detects: system prompt openers, agent routing metadata, LangChain tool schemas, internal reasoning markers (`<thinking>`, `[INST]`), and known prompt fingerprints.
- Logs sanitization events as `WARNING`.
- Binary responses (storage, backup) are never touched.
---
## Storage Layer
### Blob Store
Source: `app/storage/blob_store.py`
- S3-backed storage for E2E encrypted blobs.
- Object keys follow the pattern: `{user_id}/{table}/{record_id}`
- Server-side SSE-S3 encryption at rest (additional layer on top of client-side E2E encryption).
- Methods: `upload()`, `download()`, `delete()` (idempotent), `list_keys()`
- The backend **never inspects or decrypts blob content**.
### Vector Store
Source: `app/storage/vector_store.py`
- Runtime-configurable: **Pinecone** (when `PINECONE_API_KEY` is set) or **Qdrant** (fallback).
- User isolation: Pinecone uses `namespace=user_id`; Qdrant filters by `user_id` payload field.
- 32-dimensional SHA-256-derived float vectors (deterministic, not semantically meaningful on encrypted data — a documented trade-off for privacy).
- Encrypted blobs are stored as base64 in metadata/payload for verbatim retrieval.
- Methods: `upsert()`, `search()`, `delete()`
### Encryption Utilities
Source: `app/storage/encryption.py`
- `verify_checksum(blob, checksum)` — SHA-256 hash comparison using `hmac.compare_digest` (constant-time to prevent timing attacks).
- `reject_if_tampered(blob, checksum)` — Raises HTTP 400 on checksum mismatch.
- **No decryption key ever reaches the backend.**
---
## Billing & Tiers
Source: `app/billing/stripe_service.py`, `app/billing/tier_manager.py`
### Feature Matrix
| Feature | Free | Pro | Power | Team |
|---|---|---|---|---|
| AI Agents | 3 | Unlimited | Unlimited | Unlimited |
| Batch Active | 2 | 10 | Unlimited | Unlimited |
| Cloud Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| Backup Storage | 0 GB | 5 GB | 25 GB | Unlimited |
| LLM Providers | 1 | Unlimited | Unlimited | Unlimited |
| Batch Builder | — | — | ✓ | ✓ |
| Plugin Marketplace | — | — | ✓ | ✓ |
| SSO | — | — | — | ✓ |
| Rate Limit | 20 req/min | 60 req/min | 120 req/min | 200 req/min |
### Stripe Integration
- **Checkout** — `create_checkout_session(user_id, tier)` creates a Stripe Checkout session. Returns a stub URL when Stripe is not configured.
- **Webhooks** — Handles `checkout.session.completed`, `customer.subscription.updated`, `customer.subscription.deleted`, and `invoice.payment_failed`.
- **Subscription management** — `get_subscription()` returns the current subscription record; `cancel_subscription()` cancels via the Stripe API and reverts the user to the free tier.
- **Price IDs:** `price_pro_monthly`, `price_power_monthly`, `price_team_monthly`
### Tier Manager
- `get_tier(user_id)` — Returns the user's current billing tier.
- `check_feature(tier, feature)` — Boolean feature gate check.
- `require_feature(tier, feature)` — Raises HTTP 403 if the feature is not available.
- `enforce_quota(user_id, tier)` / `enforce_backup_quota(user_id, tier)` — Raises HTTP 402 if storage limits are exceeded.
---
## Plugin Marketplace
Source: `app/marketplace/`
### Plugin Registry
- PostgreSQL-backed catalog of submitted and approved plugins.
- `list_plugins(db, category, query, page, sort)` — Paginated listing (page size: 20) with optional filtering by category, text search, and sorting by `rating`, `installs`, or `newest`.
- `get_plugin(db, plugin_id)` — Full manifest with install count and ratings.
- `submit_plugin(db, manifest, s3_key)` — Submits a plugin with `pending_review` status.
- `approve_plugin()` / `reject_plugin(reason)` — Admin workflow for plugin approval.
- `record_install()` / `record_uninstall()` — Tracks per-user installations and updates install counts.
### Review Queue
- Automated security checklist before human review:
- Plugin ID must match `^[a-z0-9-]+$`
- Permissions must be from the allowed set only
- No binary blobs in the manifest
- **Allowed permissions:** `read:tasks`, `write:tasks`, `read:projects`, `write:projects`, `read:notes`, `write:notes`, `read:timelines`, `write:timelines`, `read:calendar`, `write:calendar`
- `get_pending(db)` — Lists plugins awaiting review.
- `submit_review(db, plugin_id, reviewer_id, decision, notes)` — Records the review decision.
### Revenue Sharing
- **70% developer / 30% platform** split on all paid plugin sales.
- `record_install(db, plugin_id, user_id, amount_cents)` — Records the revenue event and triggers a Stripe Connect transfer for the developer share.
- `get_earnings(db, developer_id, period)` — Aggregated earnings report for plugin developers.
- Gracefully stubs transfers when Stripe is not configured.
### Seed Plugins
| Plugin | Category | Price |
|---|---|---|
| GitHub Sync | Productivity | Free |
| Slack Notifier | Communication | €4.99 |
| Time Tracker | Productivity | €9.99 |
---
## Testing
### Running Tests
```bash
# Run all tests
pytest
# Run a specific test file
pytest tests/test_auth.py
# Run with verbose output
pytest -v
```
### Test Infrastructure
- **Database:** Async SQLite in-memory via `aiosqlite` + `StaticPool` — fast, no PostgreSQL needed.
- **S3 mock:** `moto[s3]` with a fixture that patches `BlobStore` settings.
- **Auth helpers:** `make_jwt(tier)` and `auth_header(tier)` generate per-tier test tokens.
- **Seed data:** Auto-creates one `User` + `Subscription` per tier (free/pro/power/team) before each test.
- **Plugin seeds:** Fixture adds 3 approved plugins for marketplace tests.
- **FK enforcement:** SQLite `PRAGMA foreign_keys=ON`.
- **No external dependencies** — all tests run fully offline.
### Test Coverage
| File | Coverage |
|---|---|
| `test_auth.py` | Register, login, token access, refresh, expiration |
| `test_orchestrator.py` | Intent classification, single agent routing, pipeline, plan mode |
| `test_agents.py` | Each agent with mocked LLM: registration, tools, handle method |
| `test_storage.py` | Create, list, download, update, delete records; checksum rejection; quota enforcement |
| `test_backup.py` | Upload, download, history, delete; tier-based storage limits |
| `test_plugins.py` | List, install, uninstall, revenue events, tier gate enforcement |
| `test_agent_registry.py` | Registry singleton, registration, lookup, listing |
| `test_execution_plan.py` | Plan builder, template registry, plan cache |
| `test_middleware.py` | Rate limiting by tier, sanitizer prompt leak detection |
---
## Project Structure
```
adiuva-api/
├── alembic.ini # Alembic configuration
├── BACKEND_PLAN.md # Architecture & design decisions
├── docker-compose.yml # Docker Compose (app + PostgreSQL)
├── Dockerfile # Multi-stage production build
├── requirements.txt # Python dependencies
├── alembic/ # Database migrations
│ ├── env.py # Alembic environment config
│ ├── script.py.mako # Migration template
│ └── versions/
│ ├── 001_initial_schema.py # Tables, indexes, FKs
│ └── 002_seed_plugins.py # Seed marketplace plugins
├── app/ # Application source
│ ├── main.py # FastAPI app factory, middleware, routes
│ ├── db.py # Async SQLAlchemy engine & session
│ ├── models.py # SQLAlchemy ORM models (9 tables)
│ ├── schemas.py # Pydantic request/response schemas
│ │
│ ├── config/
│ │ └── settings.py # Pydantic Settings (env vars)
│ │
│ ├── agents/ # LLM-powered domain agents
│ │ ├── task_agent.py # Task & comment CRUD (8 tools)
│ │ ├── project_agent.py # Project lifecycle (6 tools)
│ │ ├── timeline_agent.py # Milestones (4 tools)
│ │ └── note_agent.py # Markdown notes (5 tools)
│ │
│ ├── core/ # Orchestration engine
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
│ │ ├── llm.py # LiteLLM factory (get_llm)
│ │ ├── orchestrator.py # Intent classification & routing
│ │ └── execution_plan.py # Plan builder, templates, cache
│ │
│ ├── api/ # HTTP layer
│ │ ├── deps.py # Shared FastAPI dependencies
│ │ ├── middleware/
│ │ │ ├── auth.py # JWT validation, live tier lookup
│ │ │ ├── rate_limit.py # Sliding-window tier rate limiter
│ │ │ └── sanitizer.py # Prompt IP leak protection
│ │ └── routes/
│ │ ├── auth.py # Register, login, refresh, me
│ │ ├── chat.py # Chat + WebSocket streaming
│ │ ├── plans.py # Execution plan playbooks
│ │ ├── storage.py # E2E encrypted record CRUD
│ │ ├── vectors.py # Vector upsert, search, delete
│ │ ├── backup.py # Encrypted backup management
│ │ ├── plugins.py # Marketplace browse & install
│ │ └── billing.py # Stripe checkout & webhooks
│ │
│ ├── storage/ # Storage backends
│ │ ├── blob_store.py # S3 blob storage
│ │ ├── vector_store.py # Pinecone / Qdrant vector store
│ │ └── encryption.py # Checksum verification utilities
│ │
│ ├── billing/ # Subscription management
│ │ ├── stripe_service.py # Stripe API integration
│ │ └── tier_manager.py # Feature matrix & quota enforcement
│ │
│ └── marketplace/ # Plugin ecosystem
│ ├── plugin_registry.py # Catalog CRUD & search
│ ├── plugin_review.py # Security checklist & review queue
│ └── revenue_share.py # 70/30 split & Stripe Connect
└── tests/ # Test suite
├── conftest.py # Fixtures: DB, S3, auth, seeds
├── test_auth.py
├── test_orchestrator.py
├── test_agents.py
├── test_storage.py
├── test_backup.py
├── test_plugins.py
├── test_agent_registry.py
├── test_execution_plan.py
└── test_middleware.py
```
---
## License
*To be determined.*

View File

@@ -16,7 +16,7 @@ import re
from logging.config import fileConfig
from alembic import context
from sqlalchemy import engine_from_config, pool
from sqlalchemy import pool
from sqlalchemy.ext.asyncio import create_async_engine
# Alembic Config object (gives access to alembic.ini values).

View File

@@ -1,5 +1,4 @@
"""Initial schema: users, refresh_tokens, subscriptions, storage_records,
backup_metadata, plugins, plugin_installations, plugin_reviews, revenue_events.
"""Initial schema: users, refresh_tokens, subscriptions.
Revision ID: 001
Revises:
@@ -28,18 +27,6 @@ def upgrade() -> None:
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE plugin_status AS ENUM ('pending_review', 'approved', 'rejected');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE review_decision AS ENUM ('approved', 'rejected');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
# ── users ─────────────────────────────────────────────────────────────
op.create_table(
@@ -88,122 +75,10 @@ def upgrade() -> None:
op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"])
op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"])
# ── storage_records ───────────────────────────────────────────────────
op.create_table(
"storage_records",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("table_name", sa.String(100), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_storage_records_user_id", "storage_records", ["user_id"])
# ── backup_metadata ───────────────────────────────────────────────────
op.create_table(
"backup_metadata",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("s3_key", sa.String(500), nullable=False),
sa.Column("version", sa.Integer, nullable=False),
sa.Column("timestamp", sa.BigInteger, nullable=False),
sa.Column("checksum", sa.String(64), nullable=False),
sa.Column("size_bytes", sa.Integer, nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_backup_metadata_user_id", "backup_metadata", ["user_id"])
# ── plugins ───────────────────────────────────────────────────────────
op.create_table(
"plugins",
sa.Column("id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("description", sa.Text, nullable=False, server_default=""),
sa.Column("version", sa.String(50), nullable=False, server_default="1.0.0"),
sa.Column("author_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("author_name", sa.String(255), nullable=False, server_default=""),
sa.Column("category", sa.String(100), nullable=False, server_default=""),
sa.Column("price_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("permissions", sa.Text, nullable=False, server_default="[]"),
sa.Column("status", postgresql.ENUM("pending_review", "approved", "rejected", name="plugin_status", create_type=False), nullable=False, server_default="pending_review"),
sa.Column("s3_package_key", sa.String(500), nullable=True),
sa.Column("install_count", sa.Integer, nullable=False, server_default="0"),
sa.Column("avg_rating", sa.Float, nullable=False, server_default="0.0"),
sa.Column("rejection_reason", sa.Text, nullable=True),
sa.Column("submitted_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["author_id"], ["users.id"], ondelete="SET NULL"),
)
# ── plugin_installations ──────────────────────────────────────────────
op.create_table(
"plugin_installations",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("installed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),
)
op.create_index("ix_plugin_installations_plugin_id", "plugin_installations", ["plugin_id"])
op.create_index("ix_plugin_installations_user_id", "plugin_installations", ["user_id"])
# ── plugin_reviews ────────────────────────────────────────────────────
op.create_table(
"plugin_reviews",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("reviewer_id", postgresql.UUID(as_uuid=False), nullable=True),
sa.Column("decision", postgresql.ENUM("approved", "rejected", name="review_decision", create_type=False), nullable=False),
sa.Column("notes", sa.Text, nullable=True),
sa.Column("reviewed_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["reviewer_id"], ["users.id"], ondelete="SET NULL"),
)
op.create_index("ix_plugin_reviews_plugin_id", "plugin_reviews", ["plugin_id"])
# ── revenue_events ────────────────────────────────────────────────────
op.create_table(
"revenue_events",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("plugin_id", sa.String(255), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("amount_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("developer_share_cents", sa.Integer, nullable=False, server_default="0"),
sa.Column("stripe_transfer_id", sa.String(255), nullable=True),
sa.Column("paid_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["plugin_id"], ["plugins.id"], ondelete="CASCADE"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_revenue_events_plugin_id", "revenue_events", ["plugin_id"])
op.create_index("ix_revenue_events_user_id", "revenue_events", ["user_id"])
def downgrade() -> None:
op.drop_table("revenue_events")
op.drop_table("plugin_reviews")
op.drop_table("plugin_installations")
op.drop_table("plugins")
op.drop_table("backup_metadata")
op.drop_table("storage_records")
op.drop_table("subscriptions")
op.drop_table("refresh_tokens")
op.drop_table("users")
op.execute("DROP TYPE IF EXISTS review_decision")
op.execute("DROP TYPE IF EXISTS plugin_status")
op.execute("DROP TYPE IF EXISTS billing_tier")

View File

@@ -1,92 +0,0 @@
"""Seed approved plugins: GitHub Sync, Slack Notifier, Time Tracker.
Revision ID: 002
Revises: 001
Create Date: 2026-03-03
"""
from __future__ import annotations
import json
from datetime import datetime, timezone
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
_SEED_PLUGINS = [
{
"id": "plugin-github-sync",
"name": "GitHub Sync",
"description": "Sync tasks with GitHub Issues and pull requests.",
"version": "1.0.0",
"author_name": "Adiuva",
"category": "productivity",
"price_cents": 0,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-github-sync/1.0.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-slack-notify",
"name": "Slack Notifier",
"description": "Post task and timeline updates to Slack channels.",
"version": "1.2.0",
"author_name": "Adiuva",
"category": "communication",
"price_cents": 499,
"permissions": json.dumps(["read:tasks", "read:timelines"]),
"status": "approved",
"s3_package_key": "plugins/plugin-slack-notify/1.2.0/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
{
"id": "plugin-time-tracker",
"name": "Time Tracker",
"description": "Track time spent on tasks with automatic reporting.",
"version": "0.9.1",
"author_name": "Third Party",
"category": "productivity",
"price_cents": 999,
"permissions": json.dumps(["read:tasks", "write:tasks"]),
"status": "approved",
"s3_package_key": "plugins/plugin-time-tracker/0.9.1/package.zip",
"install_count": 0,
"avg_rating": 0.0,
},
]
def upgrade() -> None:
plugins = sa.table(
"plugins",
sa.column("id", sa.String),
sa.column("name", sa.String),
sa.column("description", sa.Text),
sa.column("version", sa.String),
sa.column("author_name", sa.String),
sa.column("category", sa.String),
sa.column("price_cents", sa.Integer),
sa.column("permissions", sa.Text),
sa.column("status", sa.Enum("pending_review", "approved", "rejected", name="plugin_status")),
sa.column("s3_package_key", sa.String),
sa.column("install_count", sa.Integer),
sa.column("avg_rating", sa.Float),
)
op.bulk_insert(plugins, _SEED_PLUGINS)
def downgrade() -> None:
op.execute(
"DELETE FROM plugins WHERE id IN ("
"'plugin-github-sync', 'plugin-slack-notify', 'plugin-time-tracker'"
")"
)

View File

@@ -14,7 +14,7 @@ from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "003"
down_revision: Union[str, None] = "002"
down_revision: Union[str, None] = "001"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None

View File

@@ -0,0 +1,107 @@
"""Restore agent config tables and add agent_config column.
9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both
ORM models are still active. This migration recreates them with agent_config
added to local_agent_configs.
Revision ID: a3b9c0d1e2f3
Revises: 9a1f2d0b6c7e
Create Date: 2026-04-07 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "a3b9c0d1e2f3"
down_revision: Union[str, None] = "9a1f2d0b6c7e"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Recreate enum types (idempotent — they may already exist from migration 003)
op.execute("""
DO $$ BEGIN
CREATE TYPE agent_type AS ENUM ('local', 'cloud');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
op.execute("""
DO $$ BEGIN
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
""")
bind = op.get_bind()
inspector = sa.inspect(bind)
existing = set(inspector.get_table_names())
# ── local_agent_configs (with agent_config column) ────────────────────
if "local_agent_configs" not in existing:
op.create_table(
"local_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("device_id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("agent_config", sa.JSON, nullable=True),
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
# ── cloud_agent_configs ───────────────────────────────────────────────
if "cloud_agent_configs" not in existing:
op.create_table(
"cloud_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column(
"provider",
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
nullable=False,
),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
sa.Column("filter_config", sa.JSON, nullable=True),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
def downgrade() -> None:
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
op.drop_table("cloud_agent_configs")
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
op.drop_table("local_agent_configs")

View File

@@ -0,0 +1,56 @@
"""Add oauth_accounts table, nullable password_hash, avatar_url to users.
Revision ID: b4c0d1e2f3a4
Revises: a3b9c0d1e2f3
Create Date: 2026-04-10 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision: str = "b4c0d1e2f3a4"
down_revision: Union[str, None] = "a3b9c0d1e2f3"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ── users: make password_hash nullable (social users have no password) ──
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True)
# ── users: add avatar_url ─────────────────────────────────────────────
op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True))
# ── oauth_accounts ────────────────────────────────────────────────────
op.create_table(
"oauth_accounts",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("provider", sa.String(50), nullable=False),
sa.Column("provider_user_id", sa.String(255), nullable=False),
sa.Column("provider_email", sa.String(255), nullable=True),
sa.Column(
"created_at",
sa.DateTime(timezone=True),
nullable=False,
server_default=sa.text("now()"),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"),
)
op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"])
def downgrade() -> None:
op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts")
op.drop_table("oauth_accounts")
op.drop_column("users", "avatar_url")
op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False)

View File

@@ -0,0 +1,31 @@
"""Add onboarding_completed_at column to users table.
Revision ID: c5d1e2f3a4b5
Revises: b4c0d1e2f3a4
Create Date: 2026-04-11 00:00:00.000000
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c5d1e2f3a4b5"
down_revision: Union[str, None] = "b4c0d1e2f3a4"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column(
"users",
sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True),
)
def downgrade() -> None:
op.drop_column("users", "onboarding_completed_at")

View File

@@ -0,0 +1,34 @@
"""avatar_url_varchar_to_text
Revision ID: e04100e88ace
Revises: c5d1e2f3a4b5
Create Date: 2026-04-13 09:13:06.733674
"""
from __future__ import annotations
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision: str = 'e04100e88ace'
down_revision: Union[str, None] = 'c5d1e2f3a4b5'
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.alter_column('users', 'avatar_url',
existing_type=sa.VARCHAR(length=2048),
type_=sa.Text(),
existing_nullable=True)
def downgrade() -> None:
op.alter_column('users', 'avatar_url',
existing_type=sa.Text(),
type_=sa.VARCHAR(length=2048),
existing_nullable=True)

View File

@@ -7,12 +7,31 @@ handles actual disk I/O and responds with ``tool_result`` frames.
from __future__ import annotations
import os
import re
from pathlib import Path
from typing import Any
from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
# Max characters returned by read_file_content in journey (exploration) tools.
# The journey only needs to understand file structure, not full content.
_JOURNEY_READ_MAX_CHARS: int = 4000
def _resolve_path(path: str, base: str) -> str:
"""Resolve *path* against *base* when *path* is relative.
The LLM often passes ``"."`` meaning "the configured directory".
Without this, Electron resolves ``"."`` relative to its own CWD instead
of the user's chosen directory.
"""
if os.path.isabs(path):
return path
return str(Path(base) / path)
@tool
async def list_directory(path: str) -> str:
@@ -83,3 +102,93 @@ FILESYSTEM_TOOLS: list[Any] = [
read_file_content,
get_file_metadata,
]
def make_directory_tools(base_directory: str) -> list[Any]:
"""Return filesystem tools that resolve relative paths against *base_directory*.
Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target
directory upfront (e.g., journey setup sessions). Relative paths like ``"."``
from the LLM are resolved to the correct absolute path before being sent to
the Electron client, preventing it from falling back to its own CWD.
"""
def _compact_for_journey(raw: str) -> str:
"""Strip HTML noise and truncate for journey exploration.
The journey LLM only needs to understand file structure (headers,
first paragraphs). Full CSS/style blocks are pure noise that eat
up context window budget.
"""
text = re.sub(r"<style[^>]*>.*?</style>", "", raw, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<script[^>]*>.*?</script>", "", text, flags=re.DOTALL | re.IGNORECASE)
text = re.sub(r"<!--.*?-->", "", text, flags=re.DOTALL)
if len(text) > _JOURNEY_READ_MAX_CHARS:
text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]"
return text
@tool
async def list_directory(path: str) -> str: # noqa: F811
"""List files and folders in a local directory on the user's device.
Returns a formatted listing of entries with name, type (file/directory),
and full path.
"""
resolved = _resolve_path(path, base_directory)
result = await execute_on_client(
action="list_directory",
data={"path": resolved},
)
entries: list[dict[str, Any]] = result.get("entries", [])
if not entries:
return f"Directory '{resolved}' is empty or does not exist."
lines: list[str] = []
for entry in entries:
entry_type = entry.get("type", "unknown")
entry_name = entry.get("name", "")
entry_path = entry.get("path", "")
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines)
@tool
async def read_file_content(path: str) -> str: # noqa: F811
"""Read the text content of a local file on the user's device.
Returns the file content as a string. Large files may be truncated
by the Electron client.
"""
resolved = _resolve_path(path, base_directory)
result = await execute_on_client(
action="read_file_content",
data={"path": resolved},
)
content: str = result.get("content", "")
if not content:
return f"File '{resolved}' is empty or could not be read."
return _compact_for_journey(content)
@tool
async def get_file_metadata(path: str) -> str: # noqa: F811
"""Get metadata for a local file: size, creation date, modification date, extension.
Returns a formatted summary of the file's metadata.
"""
resolved = _resolve_path(path, base_directory)
result = await execute_on_client(
action="get_file_metadata",
data={"path": resolved},
)
size = result.get("size", "unknown")
created = result.get("createdAt", "unknown")
modified = result.get("modifiedAt", "unknown")
extension = result.get("extension", "unknown")
name = result.get("name", resolved)
return (
f"File: {name}\n"
f" Extension: {extension}\n"
f" Size: {size} bytes\n"
f" Created: {created}\n"
f" Modified: {modified}"
)
return [list_directory, read_file_content, get_file_metadata]

View File

@@ -18,21 +18,6 @@ _UUID_RE = re.compile(
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool
async def list_notes(project_id: str = "") -> str:

View File

@@ -8,22 +8,6 @@ from langchain_core.tools import tool
from app.core.ws_context import execute_on_client
PROJECT_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool
async def list_projects(

View File

@@ -18,23 +18,6 @@ _UUID_RE = re.compile(
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ────────────────────────────────────────────────────────

View File

@@ -17,20 +17,6 @@ _UUID_RE = re.compile(
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - For update_timeline, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all timelines across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_timelines(project_id: str = "") -> str:

View File

@@ -65,16 +65,39 @@ async def get_current_user(
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname from user row.
# Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row.
user_result = await db.execute(
select(User.name, User.surname).where(User.id == user_id)
select(
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
User.password_hash,
).where(User.id == user_id)
)
user_row = user_result.one_or_none()
# Convert onboarding_completed_at to epoch ms (int) or None.
onboarding_ms: int | None = None
if user_row and user_row.onboarding_completed_at is not None:
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
# Load decrypted core memory.
from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415
memory_dict: dict[str, str] = {}
try:
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(user_id)
memory_dict = {b["label"]: b["value"] for b in blocks}
except Exception:
pass # Non-critical — return empty memory on failure
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
avatar_url=user_row.avatar_url if user_row else None,
has_password=bool(user_row.password_hash) if user_row else False,
tier=tier,
onboarding_completed_at=onboarding_ms,
memory=memory_dict,
) # type: ignore[arg-type]

View File

@@ -8,8 +8,7 @@ that could reveal server-side prompt IP:
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints
Binary responses (storage blobs, backup data) are never touched — the
middleware only activates for paths under /api/v1/chat.
The middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified.

View File

@@ -1,11 +1,11 @@
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template.
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
The journey is driven entirely through WebSocket frames (no REST endpoints).
The device WS handler dispatches ``journey_start`` and ``journey_message``
frames to the functions exported here.
Journey flow:
1. FE sends ``journey_start`` frame with basic agent config (directory,
1. FE sends ``journey_start`` frame with basic agent info (directory,
data_types, schedule).
2. Server creates an in-memory session, sets up a WS executor so the
setup LLM can use file-system tools, does a first directory scrape,
@@ -13,10 +13,11 @@ Journey flow:
3. FE sends ``journey_message`` frames for each user reply.
4. Server appends the user message, calls the LLM (which may read files
via tools), and sends back a ``journey_reply``.
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
6. Server parses the block, sends ``journey_reply`` with ``done=True``
and the template. FE stores it locally.
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
6. Server parses and validates the JSON with Pydantic, sends
``journey_reply`` with ``done=True`` and the serialised config.
FE stores it locally.
"""
from __future__ import annotations
@@ -30,8 +31,10 @@ from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.core.llm import get_llm
from app.agents.filesystem_agent import make_directory_tools
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent
from app.schemas import AgentConfig
logger = logging.getLogger(__name__)
@@ -39,9 +42,9 @@ logger = logging.getLogger(__name__)
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
# Sentinel strings used to delimit the LLM-produced prompt_template.
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
_CONFIG_START = "AGENT_CONFIG_START"
_CONFIG_END = "AGENT_CONFIG_END"
# Minimum turns before we consider nudging the LLM to wrap up.
_MIN_TURNS_BEFORE_NUDGE: int = 3
@@ -62,6 +65,7 @@ class JourneySession:
data_types: list[str]
history: list[dict[str, Any]] = field(default_factory=list)
system_prompt: str = ""
langfuse_prompt: Any = None
created_at: float = field(default_factory=time.monotonic)
def is_expired(self) -> bool:
@@ -83,61 +87,76 @@ def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
return s
# ── System prompt builder ─────────────────────────────────────────────────
# ── System prompt ─────────────────────────────────────────────────────────
_SYSTEM_PROMPT_TEMPLATE = """\
_JOURNEY_SYSTEM_PROMPT = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand exactly what data the user wants to extract from their
local directory and produce a detailed prompt_template that a separate AI will use
as its instruction set.
The extraction agent already has this base behaviour built in:
- Reads each file using file-system tools.
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
- Sets isAiSuggested=1 on every new record.
- Only extracts data explicitly present in the files — it never invents information.
The user's custom prompt is appended AFTER this base behaviour, so focus on
what to look for and how to map it — not on the general extraction mechanics.
Your job is to understand what files the user has in their directory and produce a
structured AgentConfig JSON that the extraction agent will use as its instruction set.
You have access to file-system tools to explore the user's directory:
- list_directory: to see folder structure
- read_file_content: to peek at file contents
- get_file_metadata: to check file info
- list_directory: see folder structure and file names
- read_file_content: peek at a file's content
- get_file_metadata: check file size, extension, dates
The user's configured directory is: {directory}
Target data types: {data_types}
IMPORTANT — project assignment is handled automatically by the main agent runner
before the custom prompt is ever used. You MUST NOT ask the user about projects,
projectId, or how to link records to projects. Never include projectId logic or
project creation instructions in the generated prompt_template.
## Your process
Start by exploring the directory to understand its structure. Then ask concise,
focused questions one at a time. Cover these topics (not necessarily in this order):
1. The type and format of the source content (confirmed by your exploration).
2. How fields should be mapped (e.g. filename → task title).
3. Priority or status rules (e.g. "urgent" keyword → high priority).
4. Any special handling, date extraction, or exclusions.
### Step 1 — Explore the directory
Use list_directory and read_file_content to understand what types of files are present
(HTML emails, plain-text documents, CSVs, etc.).
Once you reach 90% confidence, output the final prompt_template between these exact
markers on their own lines:
### Step 2 — Identify content types
For each distinct file type found, decide:
- A short id (e.g. "email_html", "plain_text", "csv")
- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else
- A human-readable label and optional detection_hint
{template_start}
<the complete extraction prompt here>
{template_end}
### Step 3 — Ask focused questions (one at a time)
Cover these topics based on what you discovered:
1. How to map content to entity types (task / note / timeline entry)
2. Field mapping rules (e.g. email Subject → task title, filename → note title)
3. Priority or status rules (e.g. "urgent" in subject → high priority)
4. Date extraction (e.g. "by Friday" → dueDate)
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
The prompt_template must be a self-contained instruction for an AI that reads files
and must perform CRUD operations using tools to create records. It should specify:
- What entity types to create (tasks, notes, timelines) — never projects.
- How to map file content to record fields (camelCase: title, status, priority,
dueDate, content, etc.) — never include projectId.
- That isAiSuggested must be set to 1 on every new record.
- Concrete examples of mappings based on what you discovered in the directory.
### Step 4 — Produce the AgentConfig JSON
Once you are ≥ 90% confident, output the final config between these exact markers
(each on its own line):
{config_start}
{{
"content_types": [
{{
"id": "email_html",
"label": "Email HTML",
"detection_hint": "HTML file with From/To/Subject headers",
"preprocessing": "email_html",
"extraction_prompt": "Detailed extraction instructions for this content type..."
}}
],
"global_rules": [
"If the file cannot be matched to any project, do not create any entity."
],
"data_types": {data_types_json}
}}
{config_end}
## Rules for the extraction_prompt field
- Describe when to create a task vs note vs timeline entry (be specific and concrete)
- Include field mapping rules based on what you found in the directory
- Include priority/status/date rules if applicable
- Do NOT include projectId logic — the runner handles project assignment automatically
- Do NOT mention isAiSuggested — the runner always sets it to 1
## Constraints
- Never ask about projects, projectId, or how to link records to projects
- Never include projectId or project creation logic in the generated config
- Keep asking questions until ≥ 90% confident, then output the JSON immediately
{existing_section}\
Keep asking clarifying questions until you are at least 90% confident you have
enough information to generate an accurate prompt_template. Once you reach that
confidence level, stop asking and produce the final template immediately.
Begin by exploring the directory, then ask your first question.\
"""
@@ -145,33 +164,53 @@ Begin by exploring the directory, then ask your first question.\
def _build_system_prompt(
directory: str,
data_types: list[str],
existing_template: str | None = None,
) -> str:
existing_config: str | None = None,
) -> tuple[str, Any]:
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
existing_section = (
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
f"---\n{existing_template}\n---\n"
if existing_template
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
f"```json\n{existing_config}\n```\n"
if existing_config
else ""
)
return _SYSTEM_PROMPT_TEMPLATE.format(
template, prompt_obj = get_prompt_or_fallback(
"journey_system", _JOURNEY_SYSTEM_PROMPT
)
compiled = compile_prompt(
template,
prompt_obj,
directory=directory,
data_types=", ".join(data_types),
template_start=_TEMPLATE_START,
template_end=_TEMPLATE_END,
data_types_json=json.dumps(data_types),
config_start=_CONFIG_START,
config_end=_CONFIG_END,
existing_section=existing_section,
)
return compiled, prompt_obj
# ── Template extraction ───────────────────────────────────────────────────
# ── AgentConfig extraction ────────────────────────────────────────────────
def _extract_template(text: str) -> str | None:
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
def _extract_agent_config(text: str) -> str | None:
"""Return validated AgentConfig JSON string from between markers, or None.
Parses the JSON with Pydantic to ensure it conforms to the schema before
returning. Returns None if markers are absent or JSON is invalid.
"""
if _CONFIG_START not in text or _CONFIG_END not in text:
return None
start_idx = text.index(_CONFIG_START) + len(_CONFIG_START)
end_idx = text.index(_CONFIG_END)
raw = text[start_idx:end_idx].strip()
if not raw:
return None
try:
parsed = AgentConfig.model_validate_json(raw)
return parsed.model_dump_json()
except Exception as exc:
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
return None
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
end_idx = text.index(_TEMPLATE_END)
return text[start_idx:end_idx].strip() or None
# ── LLM call with tool support ───────────────────────────────────────────
@@ -199,12 +238,17 @@ async def _call_llm_with_tools(
system_prompt: str,
history: list[dict[str, Any]],
tools: list[Any],
*,
user_id: str = "",
session_id: str = "",
langfuse_prompt: Any = None,
) -> str:
"""Build LangChain messages from history and invoke the LLM with tools.
Handles tool-calling loops: if the LLM calls tools, execute them and
continue until a final text response is produced.
"""
lf = get_langfuse()
messages: list[Any] = [SystemMessage(content=system_prompt)]
for turn in history:
if turn["role"] == "user":
@@ -212,16 +256,59 @@ async def _call_llm_with_tools(
else:
messages.append(AIMessage(content=turn["content"]))
llm = get_llm(model=None, temperature=0.4)
llm = get_agent_llm("setup", temperature=0.4)
llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(_MAX_TOOL_STEPS):
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
_lf_ctx.__enter__()
_span_ctx = (
lf.start_as_current_observation(
as_type="span",
name="journey-setup",
input=history[-1]["content"] if history else "",
)
if lf else None
)
_span = _span_ctx.__enter__() if _span_ctx else None
try:
for step in range(_MAX_TOOL_STEPS):
_gen_ctx = (
lf.start_as_current_observation(
as_type="generation",
name="journey-setup-llm",
model=model_for_agent("setup"),
prompt=langfuse_prompt,
input=messages,
)
if lf else None
)
_gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx:
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
_gen_ctx.__exit__(None, None, None)
resp_text = _as_text(response.content)
# Guard against empty responses (e.g. model returned finish_reason
# 'error' which LiteLLM maps to 'stop' with empty content).
if not response.tool_calls and not resp_text.strip():
logger.warning(
"agent_setup: journey LLM returned empty response at step %d — retrying",
step,
)
# Drop the empty AIMessage so we don't pollute history, and retry.
continue
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
if _span:
_span.update(output=resp_text)
return resp_text
for call in response.tool_calls:
call_name = str(call.get("name", ""))
@@ -247,7 +334,19 @@ async def _call_llm_with_tools(
# Fallback: exceeded max steps.
final = await llm.ainvoke(messages)
return _as_text(final.content)
final_text = _as_text(final.content)
if _span:
_span.update(output=final_text)
return final_text or (
"Sorry, I had trouble processing the files. "
"Could you try again? If the issue persists, the files might be too large for me to analyse."
)
finally:
if _span_ctx:
_span_ctx.__exit__(None, None, None)
_lf_ctx.__exit__(None, None, None)
if lf:
lf.flush()
# ── Journey handlers (called from device_ws.py) ──────────────────────────
@@ -265,12 +364,12 @@ async def handle_journey_start(
agent_type = frame.get("agent_type", "local")
directory = frame.get("directory", "")
data_types = frame.get("data_types", [])
existing_template = frame.get("existing_template")
existing_config = frame.get("existing_config")
# Use the session_id provided by the FE so the reply matches the
# listener key; fall back to a generated one if absent.
session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt = _build_system_prompt(directory, data_types, existing_template)
system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config)
session = JourneySession(
session_id=session_id,
@@ -279,19 +378,21 @@ async def handle_journey_start(
directory=directory,
data_types=data_types,
system_prompt=system_prompt,
langfuse_prompt=langfuse_prompt,
)
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
# ws_context executor (already set by the WS handler before calling us).
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
# require at least one user/input message to be present.
# Seed with an initial user message — some providers require at least one
# user/input message to be present.
seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
]
ai_reply = await _call_llm_with_tools(
system_prompt=system_prompt,
history=seed_history,
tools=list(FILESYSTEM_TOOLS),
tools=make_directory_tools(directory),
user_id=user_id,
session_id=session_id,
langfuse_prompt=langfuse_prompt,
)
session.history.extend(seed_history)
@@ -305,14 +406,14 @@ async def handle_journey_start(
directory,
)
# Check if the LLM produced the template on the first turn (unlikely but possible).
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
# Check if the LLM produced the config on the first turn (unlikely but possible).
agent_config = _extract_agent_config(ai_reply)
done = agent_config is not None
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
or "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
@@ -322,7 +423,7 @@ async def handle_journey_start(
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
"agent_config": agent_config,
}
@@ -345,53 +446,59 @@ async def handle_journey_message(
"session_id": session_id,
"message": "Journey session not found or expired. Please start a new setup.",
"done": True,
"prompt_template": None,
"agent_config": None,
}
# Append user turn.
session.history.append({"role": "user", "content": message})
# Call the LLM with tools.
session_tools = make_directory_tools(session.directory)
ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
tools=session_tools,
user_id=session.user_id,
session_id=session_id,
langfuse_prompt=session.langfuse_prompt,
)
session.history.append({"role": "assistant", "content": ai_reply})
# Check if the LLM produced the final template.
prompt_template = _extract_template(ai_reply)
done = prompt_template is not None
# Check if the LLM produced the final config.
agent_config = _extract_agent_config(ai_reply)
done = agent_config is not None
# If the LLM didn't produce a template, nudge it once it has asked enough
# questions (>= _MIN_TURNS_BEFORE_NUDGE) or hits the hard safety cap.
# If the LLM didn't produce a config, nudge it once it hits the hard safety cap.
if not done:
turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS:
nudge_content = (
"[System: You have enough information. Please generate the final "
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})
nudge_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt,
history=session.history,
tools=list(FILESYSTEM_TOOLS),
tools=session_tools,
user_id=session.user_id,
session_id=session_id,
langfuse_prompt=session.langfuse_prompt,
)
session.history.append({"role": "assistant", "content": nudge_reply})
prompt_template = _extract_template(nudge_reply)
if prompt_template is not None:
agent_config = _extract_agent_config(nudge_reply)
if agent_config is not None:
done = True
ai_reply = nudge_reply
display_message = ai_reply
if done:
display_message = (
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
if _TEMPLATE_START in ai_reply
ai_reply[: ai_reply.index(_CONFIG_START)].strip()
if _CONFIG_START in ai_reply
else "Here is your agent configuration. You can save it or continue refining."
)
_sessions.pop(session_id, None)
@@ -402,5 +509,5 @@ async def handle_journey_message(
"session_id": session_id,
"message": display_message,
"done": done,
"prompt_template": prompt_template,
"agent_config": agent_config,
}

View File

@@ -12,8 +12,11 @@ in backend agent-config tables.
from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import datetime, timedelta, timezone
from datetime import datetime, timezone
logger = logging.getLogger(__name__)
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
@@ -177,6 +180,11 @@ async def trigger_agent_run(
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
last_run_dt = (
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
if body.last_run_at
else None
)
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
@@ -184,10 +192,12 @@ async def trigger_agent_run(
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt,
prompt_template=body.custom_agent_prompt or "",
agent_config=body.agent_config,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
last_run_at=last_run_dt,
)
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.

View File

@@ -1,34 +1,68 @@
"""Auth routes: register, login, refresh, me.
"""Auth routes: register, login, refresh, me, OAuth social login, onboarding.
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
SHA-256 hashes so plaintext never reaches the DB.
OAuth (Google):
GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state
POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens
"""
from __future__ import annotations
import hashlib
import json
import time
import urllib.parse
import uuid
from datetime import datetime, timedelta, timezone
from typing import Literal
import bcrypt
from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import RedirectResponse
from jose import jwt
from pydantic import BaseModel
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair
from app.config.settings import settings
from app.core.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware
from app.db import get_session
from app.models import RefreshToken, User
from app.models import OAuthAccount, RefreshToken, User
from app.schemas import AuthTokens, UserProfile
router = APIRouter(prefix="/auth", tags=["auth"])
# ── OAuth provider registry ───────────────────────────────────────────
def _get_google_provider() -> GoogleOAuthProvider:
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
raise HTTPException(
status.HTTP_503_SERVICE_UNAVAILABLE,
"Google login is not configured on this server",
)
return GoogleOAuthProvider(
client_id=settings.GOOGLE_AUTH_CLIENT_ID,
client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET,
redirect_uri=settings.OAUTH_REDIRECT_URI,
)
_PROVIDERS = {"google": _get_google_provider}
# In-memory state store: state → (code_verifier, expires_at_epoch_s)
# Production note: replace with Redis for multi-process deployments.
_pending_states: dict[str, tuple[str, float]] = {}
_STATE_TTL_SECONDS = 600 # 10 minutes
# ── Internal helpers ─────────────────────────────────────────────────
@@ -231,5 +265,531 @@ async def update_profile(
email=user.email,
name=user.name,
surname=user.surname,
avatar_url=user.avatar_url,
tier=current_user.tier,
)
# ── OAuth helpers ─────────────────────────────────────────────────────
async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]:
"""Create a refresh token row and return (plain_token, AuthTokens)."""
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return plain_token, AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
# ── OAuth request/response schemas ───────────────────────────────────
class _OAuthAuthorizeResponse(BaseModel):
url: str
state: str
class _OAuthCallbackRequest(BaseModel):
code: str
state: str
# ── OAuth routes ──────────────────────────────────────────────────────
@router.get(
"/oauth/{provider}/web-callback",
summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link",
include_in_schema=False,
)
async def oauth_web_callback(
provider: Literal["google"],
code: str,
state: str,
) -> RedirectResponse:
"""Google redirects here after user consent.
This endpoint immediately redirects to the Electron deep-link URI so the
desktop app receives the authorization code. It is intentionally simple —
no state validation here (the Electron app + backend callback do that).
Registered in Google Cloud Console as:
http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev)
https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod)
"""
params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider})
deep_link = f"adiuvai://oauth/callback?{params}"
return RedirectResponse(url=deep_link, status_code=302)
@router.get(
"/oauth/{provider}/authorize",
response_model=_OAuthAuthorizeResponse,
summary="Start OAuth flow — returns the provider consent-screen URL",
)
async def oauth_authorize(
provider: Literal["google"],
) -> _OAuthAuthorizeResponse:
"""Generate a PKCE state + code_challenge and return the authorization URL.
The client opens this URL in the system browser. After the user grants
consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback)
with ``code`` and ``state`` query params. The client then calls
``POST /auth/oauth/{provider}/callback`` with those values.
"""
provider_factory = _PROVIDERS.get(provider)
if provider_factory is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
oauth_provider = provider_factory()
state = str(uuid.uuid4())
code_verifier, code_challenge = generate_pkce_pair()
# Purge expired states to prevent unbounded growth.
now = time.time()
expired = [s for s, (_, exp) in _pending_states.items() if exp < now]
for s in expired:
del _pending_states[s]
_pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS)
url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge)
return _OAuthAuthorizeResponse(url=url, state=state)
@router.post(
"/oauth/{provider}/callback",
response_model=AuthTokens,
summary="Complete OAuth flow — exchange code and issue JWT tokens",
)
async def oauth_callback(
provider: Literal["google"],
body: _OAuthCallbackRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate state, exchange the authorization code, and sign in (or register) the user.
Resolution order:
1. ``oauth_accounts`` row match → existing user, log in.
2. Email match + ``email_verified=True`` → link OAuth account to existing user.
3. No match → create new user (password_hash=None, avatar from provider).
"""
provider_factory = _PROVIDERS.get(provider)
if provider_factory is None:
raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}")
# Validate state (CSRF protection).
now = time.time()
entry = _pending_states.pop(body.state, None)
if entry is None or entry[1] < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
code_verifier, _ = entry
oauth_provider = provider_factory()
# Exchange code for tokens.
try:
token_data = await oauth_provider.exchange_code(
code=body.code,
code_verifier=code_verifier,
redirect_uri=settings.OAUTH_REDIRECT_URI,
)
except Exception:
raise HTTPException(
status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code"
)
access_token_google = token_data.get("access_token")
if not access_token_google:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response")
# Fetch user identity.
try:
userinfo = await oauth_provider.get_userinfo(access_token_google)
except Exception:
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider")
# ── Resolution order ──────────────────────────────────────────────
# 1. Existing OAuth link?
oauth_result = await db.execute(
select(OAuthAccount).where(
OAuthAccount.provider == provider,
OAuthAccount.provider_user_id == userinfo.provider_user_id,
)
)
oauth_account = oauth_result.scalar_one_or_none()
if oauth_account is not None:
user_result = await db.execute(select(User).where(User.id == oauth_account.user_id))
user = user_result.scalar_one()
# Backfill avatar if the user doesn't have one yet.
if user.avatar_url is None and userinfo.avatar_url:
user.avatar_url = userinfo.avatar_url
await db.commit()
plain_token, tokens = await _issue_refresh_token(user, db)
await db.commit()
return tokens
# 2. Email match with a verified Google email → link accounts.
if userinfo.email_verified:
email_result = await db.execute(select(User).where(User.email == userinfo.email))
existing_user = email_result.scalar_one_or_none()
if existing_user is not None:
new_link = OAuthAccount(
user_id=existing_user.id,
provider=provider,
provider_user_id=userinfo.provider_user_id,
provider_email=userinfo.email,
)
db.add(new_link)
if existing_user.avatar_url is None and userinfo.avatar_url:
existing_user.avatar_url = userinfo.avatar_url
plain_token, tokens = await _issue_refresh_token(existing_user, db)
await db.commit()
return tokens
# Guard: if the email is already taken but we couldn't auto-link (e.g.
# email_verified=False), refuse with 409 instead of hitting a DB constraint.
if not userinfo.email_verified:
conflict = await db.execute(select(User).where(User.email == userinfo.email))
if conflict.scalar_one_or_none() is not None:
raise HTTPException(
status.HTTP_409_CONFLICT,
"An account with this email already exists. "
"Please sign in with your password.",
)
# 3. New user — social-only account (no password).
new_user = User(
id=str(uuid.uuid4()),
email=userinfo.email,
name=userinfo.name,
password_hash=None,
avatar_url=userinfo.avatar_url,
tier="free",
encryption_key=Fernet.generate_key().decode(),
)
db.add(new_user)
await db.flush() # populate new_user.id
new_oauth = OAuthAccount(
user_id=new_user.id,
provider=provider,
provider_user_id=userinfo.provider_user_id,
provider_email=userinfo.email,
)
db.add(new_oauth)
plain_token, tokens = await _issue_refresh_token(new_user, db)
await db.commit()
return tokens
# ── Onboarding helpers ────────────────────────────────────────────────
async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile:
"""Re-fetch and return a full UserProfile (reuses get_current_user logic)."""
# We can't call the FastAPI dependency directly, but we can replicate
# the core logic inline. Instead, we just re-query the same way.
from app.models import Subscription # noqa: PLC0415
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
user_result = await db.execute(
select(
User.name, User.surname, User.avatar_url, User.onboarding_completed_at,
User.password_hash,
).where(User.id == user_id)
)
user_row = user_result.one_or_none()
onboarding_ms: int | None = None
if user_row and user_row.onboarding_completed_at is not None:
onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000)
memory_dict: dict[str, str] = {}
try:
mw = MemoryMiddleware(db)
blocks = await mw.list_core_blocks(user_id)
memory_dict = {b["label"]: b["value"] for b in blocks}
except Exception:
pass
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
avatar_url=user_row.avatar_url if user_row else None,
has_password=bool(user_row.password_hash) if user_row else False,
tier=tier,
onboarding_completed_at=onboarding_ms,
memory=memory_dict,
)
# ── Onboarding routes ────────────────────────────────────────────────
class _UpdateMemoryRequest(BaseModel):
memory: dict[str, str] = Field(default_factory=dict)
mark_onboarded: bool = False
@router.put("/me/memory", response_model=UserProfile)
async def update_memory(
body: _UpdateMemoryRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update core memory key/value pairs and optionally mark onboarding complete."""
mw = MemoryMiddleware(db)
for key, value in body.memory.items():
await mw.update_core(current_user.id, key, value)
if body.mark_onboarded:
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
user.onboarding_completed_at = datetime.now(timezone.utc)
await db.commit()
return await _build_profile(current_user.id, current_user.email, db)
@router.post("/me/onboarding/reset")
async def reset_onboarding(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
):
"""Reset onboarding so the wizard runs again on next login."""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
user.onboarding_completed_at = None
await db.commit()
return {"status": "reset"}
class _NormalizeRequest(BaseModel):
inputs: dict[str, str]
class _NormalizeResponse(BaseModel):
normalized: dict[str, str]
@router.post("/onboarding/normalize", response_model=_NormalizeResponse)
async def normalize_onboarding(
body: _NormalizeRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _NormalizeResponse:
"""One-shot LLM normalization for free-text onboarding answers."""
if not body.inputs:
return _NormalizeResponse(normalized={})
try:
llm = get_llm(model="gpt-4o-mini", temperature=0)
prompt = (
"You normalize user onboarding answers into clean, ≤3-word canonical labels.\n"
"Return a JSON object with the same keys and normalized values.\n"
"Examples: 'i build websites''Web Developer', 'tech-ish stuff''Technology'\n"
f"Input: {json.dumps(body.inputs)}"
)
response = await llm.ainvoke(
[
{"role": "system", "content": "You normalize user inputs. Return JSON only."},
{"role": "user", "content": prompt},
],
)
normalized = json.loads(response.content)
return _NormalizeResponse(normalized=normalized)
except Exception:
# LLM failure must never block onboarding — return inputs unchanged
return _NormalizeResponse(normalized=body.inputs)
# ── Password management ───────────────────────────────────────────────
class _ChangePasswordRequest(BaseModel):
current_password: str = Field(min_length=1)
new_password: str = Field(min_length=8)
@router.put("/me/password", status_code=status.HTTP_200_OK)
async def change_password(
body: _ChangePasswordRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Change the authenticated user's password.
Requires the current password for verification.
Returns 400 for social-only users (no password set).
"""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
if user.password_hash is None:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"This account uses social login and has no password to change",
)
if not _verify_password(body.current_password, user.password_hash):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect")
user.password_hash = _hash_password(body.new_password)
await db.commit()
return {"ok": True}
# ── OAuth account management ─────────────────────────────────────────
@router.get("/me/oauth-accounts", response_model=list[dict])
async def list_oauth_accounts(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[dict]:
"""List all OAuth providers linked to the authenticated user."""
result = await db.execute(
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
)
accounts = result.scalars().all()
return [
{
"provider": a.provider,
"provider_email": a.provider_email,
"created_at": int(a.created_at.timestamp() * 1000),
}
for a in accounts
]
@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK)
async def unlink_oauth_account(
provider: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unlink an OAuth provider from the authenticated user.
Refuses if the user has no password and this is their only login method.
"""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
oauth_result = await db.execute(
select(OAuthAccount).where(
OAuthAccount.user_id == current_user.id,
OAuthAccount.provider == provider,
)
)
account = oauth_result.scalar_one_or_none()
if account is None:
raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found")
# Safety: don't let users lock themselves out.
all_oauth = await db.execute(
select(OAuthAccount).where(OAuthAccount.user_id == current_user.id)
)
oauth_count = len(all_oauth.scalars().all())
if user.password_hash is None and oauth_count <= 1:
raise HTTPException(
status.HTTP_400_BAD_REQUEST,
"Cannot unlink the only login method. Set a password first.",
)
await db.delete(account)
await db.commit()
return {"ok": True}
# ── Avatar update ─────────────────────────────────────────────────────
class _UpdateAvatarRequest(BaseModel):
avatar_url: str = Field(min_length=1)
@router.put("/me/avatar", response_model=UserProfile)
async def update_avatar(
body: _UpdateAvatarRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update the authenticated user's avatar URL.
Accepts {"avatar_url": "https://..."} — the client uploads the image
to its own storage and passes the resulting URL here.
"""
if not body.avatar_url.startswith(("https://", "http://", "data:image/")):
raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL")
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
user.avatar_url = body.avatar_url
await db.commit()
return await _build_profile(current_user.id, current_user.email, db)
# ── Account deletion ─────────────────────────────────────────────────
@router.delete("/me", status_code=status.HTTP_200_OK)
async def delete_account(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Permanently delete the authenticated user's account.
Cascades: refresh tokens, OAuth accounts, subscription, and all memory
rows are deleted via SQLAlchemy relationship cascades. Stripe subscription
is cancelled if active.
"""
# Cancel Stripe subscription if present.
try:
from app.billing.stripe_service import stripe_service # noqa: PLC0415
await stripe_service.cancel_subscription(current_user.id, db)
except HTTPException:
pass # No subscription — that's fine
# Delete all memory rows (core, associative, episodic, proactive).
try:
from app.models import ( # noqa: PLC0415
MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive,
)
for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive):
await db.execute(
model.__table__.delete().where(model.user_id == current_user.id)
)
except Exception:
pass # Non-critical — cascade on User will handle most
# Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription.
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
await db.delete(user)
await db.commit()
return {"ok": True}

View File

@@ -1,171 +0,0 @@
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
PostgreSQL ``backup_metadata`` table.
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
treating "history" as a ``{backup_id}`` path parameter.
"""
from __future__ import annotations
import uuid
from email.utils import parsedate_to_datetime
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import BackupMetadata as BackupMetadataModel
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/backup", tags=["backup"])
_blob_store = BlobStore()
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total backup bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
BackupMetadataModel.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_backup_quota(
user: UserProfile, size_bytes: int, db: AsyncSession
) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
current = await _current_backup_bytes(user.id, db)
tier_manager.enforce_backup_quota(
user.tier, current_bytes=current, additional_bytes=size_bytes
)
@router.put("")
async def upload_backup(
request: Request,
x_backup_version: int = Header(..., alias="X-Backup-Version"),
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Upload an E2E-encrypted backup blob.
Metadata is passed via custom headers; the raw body is the encrypted blob.
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
await _check_backup_quota(current_user, len(blob), db)
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
)
row = BackupMetadataModel(
id=str(uuid.uuid4()),
user_id=current_user.id,
s3_key=s3_key,
version=x_backup_version,
timestamp=x_backup_timestamp,
checksum=x_backup_checksum,
size_bytes=len(blob),
)
db.add(row)
await db.commit()
return {"ok": True}
@router.get("/history", response_model=list[BackupMetadata])
async def backup_history(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[BackupMetadata]:
"""Return backup metadata records for the authenticated user (no blob bytes)."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
)
rows = result.scalars().all()
return [
BackupMetadata(
version=r.version,
timestamp=r.timestamp,
checksum=r.checksum,
chunk_count=1,
)
for r in rows
]
@router.get("")
async def download_backup(
request: Request,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
.limit(1)
)
latest = result.scalar_one_or_none()
if latest is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
ims_header = request.headers.get("If-Modified-Since")
if ims_header:
try:
ims_dt = parsedate_to_datetime(ims_header)
ims_ms = int(ims_dt.timestamp() * 1000)
if latest.timestamp <= ims_ms:
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
except Exception:
pass # malformed header — ignore and serve the blob
blob = await _blob_store.download(current_user.id, latest.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={
"X-Backup-Version": str(latest.version),
"X-Backup-Timestamp": str(latest.timestamp),
"X-Checksum": latest.checksum,
},
)
@router.delete("/{backup_id}", response_model=dict)
async def delete_backup(
backup_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a specific backup by ID."""
result = await db.execute(
select(BackupMetadataModel).where(
BackupMetadataModel.id == backup_id,
BackupMetadataModel.user_id == current_user.id,
)
)
target = result.scalar_one_or_none()
if target is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
await _blob_store.delete(current_user.id, target.s3_key)
await db.delete(target)
await db.commit()
return {"ok": True}

View File

@@ -83,3 +83,16 @@ async def cancel_subscription(
"""Cancel the active subscription."""
await stripe_service.cancel_subscription(current_user.id, db)
return {"ok": True}
@router.get("/invoices", response_model=list[dict])
async def list_invoices(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[dict[str, Any]]:
"""Return billing history (invoices) from Stripe.
Returns an empty list when Stripe is not configured.
"""
invoices = await stripe_service.list_invoices(current_user.id, db)
return invoices

View File

@@ -1,4 +1,4 @@
"""Chat routes: POST /chat (REST fallback).
"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector).
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
"""
@@ -7,14 +7,30 @@ from __future__ import annotations
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.core.deep_agent import run_home
from app.core.llm import embed
from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"])
# ── Embed helpers ─────────────────────────────────────────────────────────
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
# ── Endpoints ─────────────────────────────────────────────────────────────
@router.post("")
async def chat(
body: ChatRequest,
@@ -27,3 +43,17 @@ async def chat(
context=body.context.model_dump(),
)
return JSONResponse(content={"response": response})
@router.post("/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by Electron (vectordb.ts) for local note search.
"""
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

View File

@@ -1,148 +0,0 @@
"""Plugins routes: browse and install plugins from the marketplace.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
"""
from __future__ import annotations
from typing import Any, Literal
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db import get_session
from app.marketplace.plugin_registry import registry
from app.marketplace.revenue_share import revenue_share
from app.models import PluginInstallation, PluginReview as PluginReviewModel
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
router = APIRouter(prefix="/plugins", tags=["plugins"])
# ── Tier gate ─────────────────────────────────────────────────────────
def _require_plugin_tier(user: UserProfile) -> None:
"""Raise HTTP 403 for users below Power tier."""
if user.tier not in ("power", "team"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Plugin marketplace requires Power tier or above",
)
# ── Local detail schema ────────────────────────────────────────────────
class _PluginDetail(BaseModel):
plugin: PluginManifest
install_count: int
ratings: list[Any]
# ── Routes ────────────────────────────────────────────────────────────
@router.get("", response_model=PluginListResponse)
async def list_plugins(
category: str | None = Query(default=None),
q: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> PluginListResponse:
"""Browse the plugin marketplace. Requires Power tier or above."""
_require_plugin_tier(current_user)
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
@router.get("/{plugin_id}", response_model=_PluginDetail)
async def get_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _PluginDetail:
"""Get full plugin details including install count. Requires Power tier or above."""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Fetch review ratings for this plugin
review_result = await db.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
)
reviews = review_result.scalars().all()
ratings = [
{
"reviewer_id": r.reviewer_id,
"decision": r.decision,
"notes": r.notes,
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
}
for r in reviews
]
return _PluginDetail(
plugin=entry["manifest"],
install_count=entry["install_count"],
ratings=ratings,
)
@router.post("/{plugin_id}/install", response_model=dict)
async def install_plugin(
plugin_id: str,
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
Requires Power tier or above.
"""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Record the installation in plugin_installations
installation = PluginInstallation(
plugin_id=plugin_id,
user_id=current_user.id,
)
db.add(installation)
await db.flush()
await revenue_share.record_install(
db,
plugin_id=plugin_id,
user_id=current_user.id,
amount_cents=entry["manifest"].price_cents,
)
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
return {"ok": True, "download_url": download_url}
@router.delete("/{plugin_id}/install", response_model=dict)
async def uninstall_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unregister a plugin installation."""
result = await db.execute(
select(PluginInstallation).where(
PluginInstallation.plugin_id == plugin_id,
PluginInstallation.user_id == current_user.id,
)
)
installation = result.scalar_one_or_none()
if installation is not None:
await db.delete(installation)
await db.commit()
await registry.record_uninstall(db, plugin_id)
return {"ok": True}

View File

@@ -1,195 +0,0 @@
"""Storage routes: CRUD for E2E-encrypted cloud records.
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
PostgreSQL ``storage_records`` table.
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import StorageRecord
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/storage", tags=["storage"])
_blob_store = BlobStore()
# ── Local response schemas ─────────────────────────────────────────────
class _CreateResponse(BaseModel):
id: str
created_at: int
class _RecordMeta(BaseModel):
id: str
table: str
checksum: str
created_at: int
updated_at: int
# ── Helpers ────────────────────────────────────────────────────────────
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
StorageRecord.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
current = await _current_usage_bytes(user.id, db)
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
async def _get_record_for_user(
record_id: str, user_id: str, db: AsyncSession
) -> StorageRecord:
"""Look up a record and verify ownership. Returns 404 on mismatch
to prevent user enumeration attacks."""
result = await db.execute(
select(StorageRecord).where(
StorageRecord.id == record_id, StorageRecord.user_id == user_id
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
return record
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
async def create_record(
body: StorageRecordCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
await _check_quota(current_user, len(body.blob), db)
record_id = str(uuid.uuid4())
s3_key = await _blob_store.upload(
current_user.id, body.table, record_id, body.blob, body.checksum
)
record = StorageRecord(
id=record_id,
user_id=current_user.id,
table_name=body.table,
s3_key=s3_key,
checksum=body.checksum,
size_bytes=len(body.blob),
)
db.add(record)
await db.commit()
await db.refresh(record)
created_at_ms = int(record.created_at.timestamp() * 1000)
return _CreateResponse(id=record_id, created_at=created_at_ms)
@router.get("/records", response_model=list[_RecordMeta])
async def list_records(
table: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=50, ge=1, le=200),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[_RecordMeta]:
"""List record metadata for the authenticated user. Blob bytes are never returned."""
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
if table is not None:
query = query.where(StorageRecord.table_name == table)
query = query.offset((page - 1) * limit).limit(limit)
result = await db.execute(query)
rows = result.scalars().all()
return [
_RecordMeta(
id=r.id,
table=r.table_name,
checksum=r.checksum,
created_at=int(r.created_at.timestamp() * 1000),
updated_at=int(r.updated_at.timestamp() * 1000),
)
for r in rows
]
@router.get("/records/{record_id}")
async def download_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
record = await _get_record_for_user(record_id, current_user.id, db)
blob = await _blob_store.download(current_user.id, record.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={"X-Checksum": record.checksum},
)
@router.put("/records/{record_id}", response_model=dict)
async def update_record(
record_id: str,
body: StorageRecordUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Replace the blob for an existing record. Verifies checksum before storing."""
record = await _get_record_for_user(record_id, current_user.id, db)
reject_if_tampered(body.blob, body.checksum)
delta = len(body.blob) - record.size_bytes
if delta > 0:
await _check_quota(current_user, delta, db)
s3_key = await _blob_store.upload(
current_user.id, record.table_name, record_id, body.blob, body.checksum
)
record.s3_key = s3_key
record.checksum = body.checksum
record.size_bytes = len(body.blob)
await db.commit()
return {"ok": True}
@router.delete("/records/{record_id}", response_model=dict)
async def delete_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a record and its S3 blob."""
record = await _get_record_for_user(record_id, current_user.id, db)
await _blob_store.delete(current_user.id, record.s3_key)
await db.delete(record)
await db.commit()
return {"ok": True}

View File

@@ -1,79 +0,0 @@
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
from __future__ import annotations
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.core.llm import embed
from app.schemas import (
UserProfile,
VectorSearchRequest,
VectorSearchResponse,
VectorUpsertRequest,
)
from app.storage.encryption import reject_if_tampered
from app.storage.vector_store import VectorStore
router = APIRouter(prefix="/storage", tags=["vectors"])
_vector_store = VectorStore()
class _VectorDeleteRequest(BaseModel):
ids: list[str]
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
@router.post("/vectors/upsert", response_model=dict)
async def upsert_vectors(
body: VectorUpsertRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, int]:
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
for item in body.vectors:
reject_if_tampered(item.blob, item.checksum)
await _vector_store.upsert(current_user.id, body.vectors)
return {"upserted": len(body.vectors)}
@router.post("/vectors/search", response_model=VectorSearchResponse)
async def search_vectors(
body: VectorSearchRequest,
current_user: UserProfile = Depends(get_current_user),
) -> VectorSearchResponse:
"""Search the user-scoped vector namespace with an encrypted query blob."""
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
return VectorSearchResponse(results=results)
@router.delete("/vectors", response_model=dict)
async def delete_vectors(
body: _VectorDeleteRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, bool]:
"""Delete vectors by ID, scoped to the authenticated user."""
await _vector_store.delete(current_user.id, body.ids)
return {"ok": True}
@router.post("/vectors/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
"""
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

1
app/auth/__init__.py Normal file
View File

@@ -0,0 +1 @@
"OAuth provider abstractions and utilities."

135
app/auth/oauth_providers.py Normal file
View File

@@ -0,0 +1,135 @@
"""OAuth 2.0 + PKCE provider abstractions.
Each provider implements a three-step flow designed for a desktop (public) client:
1. get_authorization_url(state, code_challenge) → str
Build the provider's consent-screen URL. State and code_challenge are
generated server-side; the client opens this URL in the system browser.
2. exchange_code(code, code_verifier, redirect_uri) → dict
Exchange the short-lived authorization code for an access token.
The code_verifier proves ownership of the PKCE challenge.
3. get_userinfo(access_token) → OAuthUserInfo
Fetch the canonical user identity from the provider.
Currently supported providers:
- GoogleOAuthProvider (scope: openid email profile)
Adding a new provider:
- Implement the three methods above.
- Register in _PROVIDERS inside routes/auth.py.
"""
from __future__ import annotations
import base64
import hashlib
import os
import urllib.parse
from dataclasses import dataclass
import httpx
# ── Data transfer objects ─────────────────────────────────────────────
@dataclass
class OAuthUserInfo:
"""Normalized user identity returned by any provider."""
provider_user_id: str
email: str
email_verified: bool
avatar_url: str | None
name: str | None
# ── PKCE helpers ──────────────────────────────────────────────────────
def generate_pkce_pair() -> tuple[str, str]:
"""Generate a (code_verifier, code_challenge) pair for PKCE S256.
The code_verifier is a random 32-byte URL-safe base64 string.
The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding).
"""
code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode()
digest = hashlib.sha256(code_verifier.encode()).digest()
code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode()
return code_verifier, code_challenge
# ── Google provider ───────────────────────────────────────────────────
class GoogleOAuthProvider:
"""Google OAuth 2.0 provider (openid email profile scope).
Uses Google's standard authorization endpoint with PKCE S256.
Does NOT use google-auth-oauthlib to keep the flow generic and async.
"""
name = "google"
_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None:
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self, state: str, code_challenge: str) -> str:
"""Build the Google consent-screen URL."""
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
"response_type": "code",
"scope": "openid email profile",
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "select_account",
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
async def exchange_code(
self, code: str, code_verifier: str, redirect_uri: str
) -> dict:
"""Exchange authorization code for an access token."""
async with httpx.AsyncClient() as client:
response = await client.post(
self._TOKEN_URL,
data={
"client_id": self.client_id,
"client_secret": self.client_secret,
"code": code,
"code_verifier": code_verifier,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
},
)
response.raise_for_status()
return response.json()
async def get_userinfo(self, access_token: str) -> OAuthUserInfo:
"""Fetch the authenticated user's identity from Google."""
async with httpx.AsyncClient() as client:
response = await client.get(
self._USERINFO_URL,
headers={"Authorization": f"Bearer {access_token}"},
)
response.raise_for_status()
data = response.json()
return OAuthUserInfo(
provider_user_id=data["sub"],
email=data["email"],
email_verified=data.get("email_verified", False),
avatar_url=data.get("picture"),
name=data.get("name"),
)

View File

@@ -43,8 +43,8 @@ class StripeService:
self,
user_id: str,
tier: str,
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
cancel_url: str = "https://app.adiuva.app/billing/cancel",
success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
cancel_url: str = "https://app.adiuvai.app/billing/cancel",
) -> str:
"""Create a Stripe checkout session and return the URL.
@@ -200,6 +200,45 @@ class StripeService:
sub.status = "canceled"
await db.commit()
async def list_invoices(
self, user_id: str, db: AsyncSession, limit: int = 24
) -> list[dict[str, Any]]:
"""Return recent invoices for the user from Stripe.
Returns an empty list when Stripe is not configured or the user has
no ``stripe_customer_id``.
"""
if not self._configured():
return []
from app.models import User # noqa: PLC0415
result = await db.execute(
select(User.stripe_customer_id).where(User.id == user_id)
)
customer_id = result.scalar_one_or_none()
if not customer_id:
return []
try:
s = self._client()
invoices = s.Invoice.list(customer=customer_id, limit=limit)
return [
{
"id": inv.id,
"amount_due": inv.amount_due,
"amount_paid": inv.amount_paid,
"currency": inv.currency,
"status": inv.status,
"created": inv.created * 1000, # epoch ms
"invoice_url": inv.hosted_invoice_url,
"invoice_pdf": inv.invoice_pdf,
}
for inv in invoices.auto_paging_iter()
]
except Exception:
return []
# ── Private DB helpers ───────────────────────────────────────────────
async def _upsert_subscription(

View File

@@ -22,44 +22,32 @@ FEATURES: dict[str, dict[str, Any]] = {
"agents": 3,
"batch_active": 2,
"batch_runs_per_day": 5,
"cloud_storage_gb": 0,
"backup_gb": 0,
"providers": 1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"pro": {
"agents": -1, # unlimited
"batch_active": 10,
"batch_runs_per_day": 50,
"cloud_storage_gb": 5,
"backup_gb": 5,
"providers": -1,
"batch_builder": False,
"plugin_marketplace": False,
"sso": False,
},
"power": {
"agents": -1,
"batch_active": -1, # unlimited
"batch_runs_per_day": -1, # unlimited
"cloud_storage_gb": 25,
"backup_gb": 25,
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": False,
},
"team": {
"agents": -1,
"batch_active": -1,
"batch_runs_per_day": -1, # unlimited
"cloud_storage_gb": -1, # unlimited
"backup_gb": -1, # unlimited
"providers": -1,
"batch_builder": True,
"plugin_marketplace": True,
"sso": True,
},
}
@@ -125,71 +113,6 @@ class TierManager:
"""Return the requests-per-minute limit for ``tier``."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# ── Storage quota ────────────────────────────────────────────────────
def enforce_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
``tier`` is the caller's current tier (from ``current_user.tier``).
``current_bytes`` is the total bytes already stored (queried by caller).
"""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Cloud storage is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Storage quota exceeded for tier '{tier}'",
)
def enforce_backup_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> None:
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup is not available on the '{tier}' tier",
)
if limit_gb == -1:
return # unlimited
limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Backup quota exceeded for tier '{tier}'",
)
def check_quota(
self,
tier: BillingTier,
current_bytes: int = 0,
additional_bytes: int = 0,
) -> bool:
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0:
return False
if limit_gb == -1:
return True
limit_bytes = limit_gb * 1024 ** 3
return current_bytes + additional_bytes <= limit_bytes
# Module-level singleton shared across the app.
tier_manager = TierManager()

View File

@@ -3,7 +3,7 @@ from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai"
JWT_SECRET: str = "change-me-in-production"
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
@@ -12,17 +12,6 @@ class Settings(BaseSettings):
STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = ""
S3_BUCKET: str = ""
S3_REGION: str = "us-east-1"
S3_ENDPOINT_URL: str = ""
AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = ""
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = ""
@@ -31,6 +20,14 @@ class Settings(BaseSettings):
LLM_MODEL: str = "gpt-4o"
LLM_EMBED_MODEL: str = "text-embedding-3-small"
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey
# GitHub Copilot OAuth token storage directory.
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
@@ -44,18 +41,37 @@ class Settings(BaseSettings):
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
MS_TENANT_ID: str = "common"
# Google Login OAuth credentials — scope: openid email profile.
# Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope).
GOOGLE_AUTH_CLIENT_ID: str = ""
GOOGLE_AUTH_CLIENT_SECRET: str = ""
# The redirect URI registered in Google Cloud Console.
# Google redirects here after consent; this backend route then bounces to
# the adiuvai:// deep link so the Electron app receives the code.
# Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
OAUTH_ENCRYPTION_KEY: str = ""
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
CORS_ORIGINS: list[str] = [
"app://.",
"http://localhost:3000",
"http://localhost:5173",
"http://localhost:4173", # Vite preview (web SPA)
"https://app.adiuvai.com", # Production web portal
]
LANGFUSE_SECRET_KEY: str = ""
LANGFUSE_PUBLIC_KEY: str = ""
LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com"
ENV: Literal["dev", "prod"] = "dev"
model_config = SettingsConfigDict(
env_file=".env", env_file_encoding="utf-8", extra="ignore"
)
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
settings = Settings()

View File

@@ -2,12 +2,12 @@
Drives two agent types:
* **Local directory agent** — two-step execution per file:
Step 1 (Classification) uses code to fetch all projects and asks the LLM
to identify which project the file belongs to and which domains are relevant.
Step 2 (Processing) fetches existing entities for that project/domains via
code and runs an LLM with tools — existing data in context enforces
update-first naturally.
* **Local directory agent** — V2 unified flow per file:
Phase A (Detect + Preprocess, zero LLM): Python detects the content type
and strips markup/noise, producing clean text + metadata.
Phase B (Single LLM call with tools): the LLM identifies the project,
checks for duplicates via list_* tools, and creates/updates records.
``items_created`` is counted from ``create_*`` tool calls.
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
Teams, Outlook) and pushes extracted items to Electron.
@@ -29,7 +29,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import uuid
import os
from datetime import datetime, timedelta, timezone
from typing import Any
@@ -43,7 +43,9 @@ from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.device_manager import DeviceConnectionManager
from app.core.llm import get_llm
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent
from app.core.preprocessors import detect_content_type, preprocess
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
from app.db import async_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
@@ -70,97 +72,52 @@ _MAX_PROCESSING_STEPS: int = 12
_MAX_SCAN_DEPTH: int = 5
# ── Data-type to tool mapping ─────────────────────────────────────────────
# NOTE: "projects" is intentionally excluded — project creation/assignment is
# handled in code by the runner, never delegated to the Step 2 LLM.
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
"tasks": TASK_TOOLS,
"notes": NOTE_TOOLS,
"timelines": TIMELINE_TOOLS,
"timelineEvents": TIMELINE_TOOLS,
"projects": PROJECT_TOOLS,
}
# ── Step 1: Classification prompt ─────────────────────────────────────────
# ── V2: Unified processing prompt (hot-swappable via Langfuse "unified_processing") ──
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
"tasks": (
"Action items, to-dos, deliverables — anything that describes work to be done, "
"assigned to someone, or tracked with a due date or status."
),
"notes": (
"Documentation, meeting notes, summaries, reference material — "
"written content meant to be read and referenced rather than acted on."
),
"timelines": (
"Project milestones, deadlines, scheduled events — "
"specific dates that mark a point in the progress of a project."
),
"projects": (
"High-level project entities — only relevant if the file clearly introduces "
"a new project or updates the scope of an existing one."
),
}
_STEP1_SYSTEM_PROMPT = """\
You are a file classifier for a freelance project management tool.
Your job is to match a file to an existing project and identify which data domains to extract.
## Project matching rules (STRICT — follow in order)
1. Search the file content for any mention of a project name, client name, acronym, or topic
that overlaps with the existing projects listed below.
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
when the file has zero meaningful connection to any listed project.
4. When in doubt, pick the closest match from the list.
## Response format
Respond ONLY with a JSON object — no markdown, no explanation:
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
## Domain definitions (only consider domains in the allowed list)
{domain_definitions}
## Existing projects
{projects_list}
"""
# ── Step 2: Processing prompt ─────────────────────────────────────────────
_PROCESSING_SYSTEM_PROMPT = """\
_UNIFIED_PROCESSING_PROMPT = """\
You are a data extraction assistant for a freelance project management tool.
Your task: extract structured data from the file content and persist it using the available tools.
## Your process (follow this exact order)
## Mandatory process — follow this order for EVERY item you extract
### 1. Identify the project
File: {filename}
{metadata_section}
1. READ the existing records listed below for the relevant domain.
2. SEARCH for a match by title, topic, or semantic similarity.
3. If a match exists → call the update_* tool with the existing record's id.
4. If no match exists → call the create_* tool and set isAiSuggested=1.
Existing projects:
{projects_list}
NEVER call create_* without first checking the existing records.
NEVER duplicate a record that already exists under a different wording.
Match this file to an existing project using the filename and content clues.
If no project matches, {no_match_behavior}.
## Existing records (source of truth)
### 2. Check existing records
Once you identify the project, use list_tasks / list_notes / list_timelines
(filtered by projectId) to see what already exists.
NEVER create a record that already exists under the same or similar title.
{existing_context}
### 3. Extract and create / update
{extraction_rules}
## Context
Project: {project_context}
Domains to extract: {data_types}
{custom_prompt_section}
### Rules
- Set isAiSuggested=1 on every new record.
- Set projectId on every record (use the id from the project list above).
- Update existing records when a match is found by title or topic.
- Do NOT invent data — only extract what is clearly stated in the content.
- Target entity types: {data_types}.
{global_rules}
"""
# ── Cloud processing prompt (kept separate for cloud agent) ───────────────
_CLOUD_PROCESSING_PROMPT = """\
_BATCH_CLOUD_PROCESSING_PROMPT = """\
You are a data extraction and management assistant for a freelance project
management tool.
@@ -268,9 +225,19 @@ async def _run_agent_with_tools(
user_message: str,
tools: list[Any],
max_steps: int,
user_id: str = "",
session_id: str = "",
langfuse_prompt: Any = None,
agent_name: str = "batch-agent",
_tool_calls_out: list[str] | None = None,
) -> str:
"""Run an LLM agent with tool-calling, returning the final text response."""
llm = get_llm()
"""Run an LLM agent with tool-calling, returning the final text response.
If *_tool_calls_out* is provided, the name of every tool called during the
run is appended to it (used by the caller to count ``create_*`` calls).
"""
lf = get_langfuse()
llm = get_agent_llm(agent_name)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
@@ -279,12 +246,45 @@ async def _run_agent_with_tools(
tool_map = {tool_def.name: tool_def for tool_def in tools}
_lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None)
_lf_ctx.__enter__()
_span_ctx = (
lf.start_as_current_observation(
as_type="span",
name=agent_name,
metadata={"user_id": user_id} if user_id else None,
input=user_message,
)
if lf else None
)
_span = _span_ctx.__enter__() if _span_ctx else None
try:
for _ in range(max_steps):
_gen_ctx = (
lf.start_as_current_observation(
as_type="generation",
name=f"{agent_name}-llm",
model=model_for_agent(agent_name),
prompt=langfuse_prompt,
input=messages,
)
if lf else None
)
_gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx:
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
_gen_ctx.__exit__(None, None, None)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
final_text = _as_text(response.content)
if _span:
_span.update(output=final_text)
return final_text
for call in response.tool_calls:
call_id = str(call.get("id", ""))
@@ -296,6 +296,9 @@ async def _run_agent_with_tools(
json.dumps(call_args, ensure_ascii=True)[:800],
)
if _tool_calls_out is not None:
_tool_calls_out.append(call_name)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
@@ -310,7 +313,16 @@ async def _run_agent_with_tools(
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
return _as_text(final.content)
final_text = _as_text(final.content)
if _span:
_span.update(output=final_text)
return final_text
finally:
if _span_ctx:
_span_ctx.__exit__(None, None, None)
_lf_ctx.__exit__(None, None, None)
if lf:
lf.flush()
# ── Tool list builder ─────────────────────────────────────────────────────
@@ -377,7 +389,8 @@ async def _scan_directories(
for file_path in all_files:
try:
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
modified_at = meta.get("modifiedAt")
# FE sends snake_case keys on the wire (toSnakeCase transform)
modified_at = meta.get("modified_at") or meta.get("modifiedAt")
if modified_at is None:
filtered.append(file_path)
continue
@@ -479,83 +492,66 @@ def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
return f"Existing {domain}:\n" + "\n".join(lines)
# ── Step 1: LLM file classifier ───────────────────────────────────────────
# ── V2 helper functions ───────────────────────────────────────────────────
async def _classify_file(
file_path: str,
file_content: str,
projects: list[dict],
config_data_types: list[str],
) -> tuple[str, list[str], str | None]:
"""Call the LLM to classify a file by project and relevant domains.
Returns ``(project_id_or_"new", domains, new_project_name_or_None)``.
- ``project_id`` is an existing project UUID, or ``"new"`` when no match found.
- ``new_project_name`` is only set when ``project_id == "new"``.
Falls back to ``("new", config_data_types, None)`` on any error.
"""
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
if not file_content.strip():
return fallback
valid_project_ids = {p["id"] for p in projects}
def _fmt_project(p: dict) -> str:
def _format_projects(projects: list[dict]) -> str:
"""Format the project list for the unified system prompt."""
if not projects:
return " (no projects yet)"
lines: list[str] = []
for p in projects:
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
summary_part = f"{summary[:100]}" if summary else ""
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
lines.append(
f" - id={p['id']} | name={p.get('name', '')} | "
f"status={p.get('status', '')}{summary_part}"
)
return "\n".join(lines)
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
domain_definitions = "\n".join(
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
for d in config_data_types
if d in _DOMAIN_DESCRIPTIONS
def _format_metadata(metadata: dict) -> str:
"""Format preprocessor metadata as a compact context block."""
if not metadata:
return ""
parts: list[str] = []
for key in ("subject", "from", "to", "date"):
if metadata.get(key):
parts.append(f"{key.capitalize()}: {metadata[key]}")
# any remaining keys
for key, val in metadata.items():
if key not in ("subject", "from", "to", "date") and val:
parts.append(f"{key}: {val}")
return "\n".join(parts)
def _get_extraction_rules(agent_config: dict, content_type: str) -> str:
"""Return the extraction_prompt for *content_type* from *agent_config*.
Falls back to a generic instruction when the type is not configured.
"""
for ct in agent_config.get("content_types", []):
if ct.get("id") == content_type:
prompt = ct.get("extraction_prompt", "").strip()
if prompt:
return prompt
return (
"Extract relevant information as tasks (action items), notes "
"(informational content), or timelines (dated events)."
)
system = _STEP1_SYSTEM_PROMPT.format(
domain_definitions=domain_definitions,
projects_list=projects_list,
)
llm = get_llm()
try:
response = await llm.ainvoke([
SystemMessage(content=system),
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
])
raw = _as_text(response.content).strip()
# Strip markdown fences if the model wraps the JSON.
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
parsed = json.loads(raw.strip())
raw_project_id: str = str(parsed.get("project_id") or "new")
# Reject hallucinated UUIDs — only accept ids that exist in the fetched list.
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
new_project_name: str | None = (
str(parsed["new_project_name"]).strip() or None
if project_id == "new" and parsed.get("new_project_name")
else None
)
domains: list[str] = [
d for d in parsed.get("domains", [])
if d in config_data_types
]
if not domains:
domains = list(config_data_types)
return project_id, domains, new_project_name
except Exception as exc:
logger.warning(
"agent_runner: step1 classification failed for %r: %s", file_path, exc
)
return fallback
def _get_no_match_behavior(agent_config: dict) -> str:
"""Derive the 'no project match' instruction from global_rules."""
rules = agent_config.get("global_rules", [])
for rule in rules:
lower = rule.lower()
if "no project" in lower or "no match" in lower or "skip" in lower:
return rule
return "create a new project with a concise name derived from the file content"
# ── Local agent runner (two-step per file) ────────────────────────────────
# ── Local agent runner (V2 — unified per-file flow) ───────────────────────
async def run_local_agent(
@@ -565,16 +561,17 @@ async def run_local_agent(
device_mgr: DeviceConnectionManager,
run_context: dict | None = None,
) -> None:
"""Execute a local directory agent run using a two-step approach per file.
"""Execute a local directory agent run — V2 unified flow.
Step 1 — Classification (code + 1 LLM call per file, no tools):
Code scans directories and fetches all projects via WS.
For each file, LLM identifies the project and relevant domains.
Phase A — Detect + Preprocess (zero LLM, per file):
Python detects the content type from filename + content patterns and
runs the appropriate handler (e.g. email_html) to produce clean text
and structured metadata.
Step 2 — Processing (code + 1 LLM call per file, with tools):
Code fetches existing entities for the identified project/domains.
LLM receives file content + existing entities in context and uses
tools to update existing records or create new ones.
Phase B — Single LLM call with tools (per file):
One LLM call handles project identification, duplicate checking, and
record creation/update. ``create_*`` tool calls are counted to
produce the accurate ``items_created`` metric.
"""
run_id = run_log.id
agent_id = (run_context or {}).get("agent_id") or config.id
@@ -609,16 +606,11 @@ async def run_local_agent(
errors: list[str] = []
items_processed = 0
items_created = 0
custom_section = (
f"User instructions:\n{config.prompt_template}"
if config.prompt_template
else ""
)
agent_config: dict = config.agent_config or {}
processing_tools = _build_processing_tools(config.data_types)
try:
# ── Code: scan directories ───────────────────────────────────
logger.info("agent_runner: run=%s scanning directories user=%s", run_id, user_id)
file_paths = await _scan_directories(
paths=config.directory_paths,
extensions=config.file_extensions or [],
@@ -634,108 +626,89 @@ async def run_local_agent(
# ── Code: fetch all projects once ────────────────────────────
projects = await _fetch_projects()
projects_block = _format_projects(projects)
# Prompt template + Langfuse version linking (hot-swappable from UI).
unified_template, prompt_obj = get_prompt_or_fallback(
"unified_processing", _UNIFIED_PROCESSING_PROMPT
)
for file_path in file_paths:
try:
# Read file content via code.
# ── Phase A: read + detect + preprocess ─────────────
file_result = await execute_on_client(
action="read_file_content", data={"path": file_path}
)
file_content: str = file_result.get("content", "")
if not file_content:
logger.debug("agent_runner: run=%s skipping empty file %r", run_id, file_path)
raw_content: str = file_result.get("content", "")
if not raw_content.strip():
logger.debug(
"agent_runner: run=%s skipping empty file %r", run_id, file_path
)
continue
items_processed += 1
filename = os.path.basename(file_path)
content_type = detect_content_type(filename, raw_content)
preprocessed = preprocess(content_type, raw_content)
# Step 1 — classify file.
project_id, domains, new_project_name = await _classify_file(
file_path=file_path,
file_content=file_content,
projects=projects,
config_data_types=config.data_types,
)
logger.info(
"agent_runner: run=%s file=%r → project=%s new_name=%r domains=%s",
run_id,
file_path,
project_id,
new_project_name,
domains,
"agent_runner: run=%s file=%r content_type=%s clean_len=%d",
run_id, file_path, content_type, len(preprocessed.clean_text),
)
# Step 2 — resolve project_id via CODE, then fetch entities.
# Project creation is NEVER delegated to the Step 2 LLM.
if project_id == "new":
proj_name = new_project_name or "Untitled Project"
try:
proj_result = await execute_on_client(
action="insert",
table="projects",
data={"name": proj_name, "clientId": None},
# ── Phase B: single LLM call ─────────────────────────
extraction_rules = _get_extraction_rules(agent_config, content_type)
no_match_behavior = _get_no_match_behavior(agent_config)
global_rules_lines = "\n".join(
f"- {r}" for r in agent_config.get("global_rules", [])
)
created = proj_result.get("row", {})
effective_project_id = created.get("id", "standalone")
# Add to local list so subsequent files can match it.
if "id" in created:
projects.append(created)
logger.info(
"agent_runner: run=%s created project %r id=%s",
run_id, proj_name, effective_project_id,
)
except Exception as exc:
logger.warning(
"agent_runner: run=%s failed to create project %r: %s",
run_id, proj_name, exc,
)
effective_project_id = "standalone"
proj_name = "unknown"
project_context = (
f"Project: {proj_name} (id: {effective_project_id}). "
"Always set projectId to this id on every record you create."
)
else:
effective_project_id = project_id
proj = next((p for p in projects if p["id"] == project_id), None)
proj_name = proj.get("name", project_id) if proj else project_id
project_context = (
f"Project: {proj_name} (id: {project_id}). "
"Always set projectId to this id on every record you create."
metadata_section = _format_metadata(preprocessed.metadata)
system_prompt = compile_prompt(
unified_template,
prompt_obj,
filename=filename,
metadata_section=metadata_section,
projects_list=projects_block,
no_match_behavior=no_match_behavior,
extraction_rules=extraction_rules,
global_rules=global_rules_lines,
data_types=", ".join(config.data_types),
)
# "projects" domain is never passed to Step 2 — handled above in code.
domains = [d for d in domains if d != "projects"]
existing_blocks: list[str] = []
for domain in domains:
rows = await _fetch_domain_entities(domain, effective_project_id)
existing_blocks.append(_format_entities_for_context(domain, rows))
existing_context = "\n\n".join(existing_blocks)
system_prompt = _PROCESSING_SYSTEM_PROMPT.format(
existing_context=existing_context,
project_context=project_context,
data_types=", ".join(domains),
custom_prompt_section=custom_section,
user_message = (
f"Process this file and extract relevant information.\n\n"
f"File: {file_path}\n\n"
f"Content:\n{preprocessed.clean_text}"
)
processing_tools = _build_processing_tools(domains)
file_tool_calls: list[str] = []
result_text = await _run_agent_with_tools(
system_prompt=system_prompt,
user_message=(
f"Process this file and extract relevant information.\n\n"
f"File: {file_path}\n\nContent:\n{file_content}"
),
user_message=user_message,
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
user_id=user_id,
session_id=run_id,
langfuse_prompt=prompt_obj,
agent_name="unified-processor",
_tool_calls_out=file_tool_calls,
)
file_created = sum(
1 for name in file_tool_calls if name.startswith("create_")
)
items_created += file_created
# Refresh project list when a project was created so
# subsequent files see it in the prompt context.
if "create_project" in file_tool_calls:
projects = await _fetch_projects()
projects_block = _format_projects(projects)
logger.info(
"agent_runner: run=%s file=%r result=%s",
run_id,
file_path,
result_text[:200],
"agent_runner: run=%s file=%r created=%d result=%s",
run_id, file_path, file_created, result_text[:200],
)
except Exception as exc:
@@ -767,10 +740,11 @@ async def run_local_agent(
errors=errors,
)
logger.info(
"agent_runner: run=%s done status=%s processed=%d errors=%d",
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
run_id,
final_status,
items_processed,
items_created,
len(errors),
)
@@ -928,7 +902,12 @@ async def run_cloud_agent(
continue
items_processed += 1
processing_prompt = _CLOUD_PROCESSING_PROMPT.format(
cloud_template, cloud_prompt_obj = get_prompt_or_fallback(
"batch_cloud_processing", _BATCH_CLOUD_PROCESSING_PROMPT
)
processing_prompt = compile_prompt(
cloud_template,
cloud_prompt_obj,
data_types=", ".join(config.data_types),
project_context="Determine the appropriate project from the message context.",
file_list=f"Message from {config.provider} (id: {msg.id})",
@@ -941,6 +920,10 @@ async def run_cloud_agent(
user_message=f"Process this message content:\n\n{content_text[:8000]}",
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
user_id=user_id,
session_id=run_id,
langfuse_prompt=cloud_prompt_obj,
agent_name="cloud-processor",
)
except Exception as exc:
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")

View File

@@ -16,7 +16,8 @@ from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.llm import get_llm
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent
from app.core.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app.db import async_session
@@ -26,7 +27,35 @@ logger = logging.getLogger(__name__)
FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
_HOME_SINGLE_AGENT_SYSTEM = (
# Mapping of core-memory language values to natural-language names for prompts.
_LANGUAGE_NAMES: dict[str, str] = {
"en": "English", "it": "Italian", "es": "Spanish",
"fr": "French", "de": "German",
"english": "English", "italian": "Italian", "italiano": "Italian",
"spanish": "Spanish", "español": "Spanish",
"french": "French", "français": "French",
"german": "German", "deutsch": "German",
}
def _language_instruction(context: dict[str, Any]) -> str:
"""Return a system-prompt suffix that tells the LLM to respond in the user's language.
Returns an empty string when the language is English or unknown — saves tokens.
"""
core = context.get("core_memory") or {}
raw = (core.get("language") or "").strip().lower()
if not raw:
return ""
lang = _LANGUAGE_NAMES.get(raw, raw.title()) # best-effort capitalisation
if lang.lower() == "english":
return ""
return (
f"\n\nIMPORTANT: Always respond in {lang}. "
f"All your output text must be written in {lang}."
)
_HOME_SYSTEM_PROMPT = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
@@ -39,7 +68,7 @@ _HOME_SINGLE_AGENT_SYSTEM = (
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
)
_FLOATING_SINGLE_AGENT_SYSTEM = (
_FLOATING_SYSTEM_PROMPT = (
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Stay focused on the floating scope in context.scope and answer concisely. "
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
@@ -48,7 +77,7 @@ _FLOATING_SINGLE_AGENT_SYSTEM = (
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
)
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
"Allowed type values: task, timeline, project, node. "
@@ -147,6 +176,15 @@ def _trace_id_from_context(context: dict[str, Any]) -> str | None:
return None
def _session_id_from_context(context: dict[str, Any]) -> str | None:
debug = context.get("_debug")
if isinstance(debug, dict):
session_id = debug.get("session_id")
if isinstance(session_id, str) and session_id:
return session_id
return None
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
sanitized = dict(context)
sanitized.pop("_debug", None)
@@ -535,10 +573,9 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
}
try:
llm = get_llm()
response = await llm.ainvoke(
[
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM),
llm = get_agent_llm("classifier")
classifier_messages = [
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
@@ -546,7 +583,29 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
)
),
]
lf = get_langfuse()
_, classifier_prompt_obj = get_prompt_or_fallback(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
)
# Extract user/session from context for Langfuse attribution
_debug = context.get("_debug") if isinstance(context, dict) else None
_lf_user = (_debug or {}).get("user_id") if isinstance(_debug, dict) else None
_lf_session = (_debug or {}).get("session_id") if isinstance(_debug, dict) else None
with langfuse_context(user_id=_lf_user, session_id=_lf_session):
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="floating-classifier",
model=model_for_agent("classifier"),
prompt=classifier_prompt_obj,
input=classifier_messages,
) as gen:
response = await llm.ainvoke(classifier_messages)
gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
else:
response = await llm.ainvoke(classifier_messages)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
@@ -571,9 +630,13 @@ async def _run_single_agent(
message: str,
context: dict[str, Any],
max_steps: int = 6,
langfuse_prompt: Any = None,
agent_name: str = "agent",
) -> str:
trace_id = _trace_id_from_context(context)
llm = get_llm()
session_id = _session_id_from_context(context)
lf = get_langfuse()
llm = get_agent_llm(agent_name)
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
@@ -591,9 +654,39 @@ async def _run_single_agent(
tool_calls_count = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
_lf_ctx = langfuse_context(user_id=user_id, session_id=session_id)
_lf_ctx.__enter__()
_span_ctx = (
lf.start_as_current_observation(
as_type="span",
name=agent_name,
metadata={"user_id": user_id, "session_id": trace_id},
input=message,
)
if lf else None
)
_span = _span_ctx.__enter__() if _span_ctx else None
try:
for _ in range(max_steps):
_gen_ctx = (
lf.start_as_current_observation(
as_type="generation",
name=f"{agent_name}-llm",
model=model_for_agent(agent_name),
prompt=langfuse_prompt,
input=messages,
)
if lf else None
)
_gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx:
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
_gen_ctx.__exit__(None, None, None)
messages.append(response)
if not response.tool_calls:
@@ -605,6 +698,8 @@ async def _run_single_agent(
tool_calls_count,
len(final_text),
)
if _span:
_span.update(output=final_text)
return final_text
tool_map = {tool_def.name: tool_def for tool_def in tools}
@@ -644,9 +739,16 @@ async def _run_single_agent(
tool_calls_count,
len(final_text),
)
if _span:
_span.update(output=final_text)
return final_text
finally:
clear_tool_result_collector()
if _span_ctx:
_span_ctx.__exit__(None, None, None)
_lf_ctx.__exit__(None, None, None)
if lf:
lf.flush()
async def _run_single_agent_stream(
@@ -656,9 +758,13 @@ async def _run_single_agent_stream(
message: str,
context: dict[str, Any],
max_steps: int = 6,
langfuse_prompt: Any = None,
agent_name: str = "agent",
) -> AsyncGenerator[tuple[str, Any], None]:
trace_id = _trace_id_from_context(context)
llm = get_llm()
session_id = _session_id_from_context(context)
lf = get_langfuse()
llm = get_agent_llm(agent_name)
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
@@ -677,9 +783,40 @@ async def _run_single_agent_stream(
streamed_chars = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
_lf_ctx = langfuse_context(user_id=user_id, session_id=session_id)
_lf_ctx.__enter__()
_span_ctx = (
lf.start_as_current_observation(
as_type="span",
name=f"{agent_name}-stream",
metadata={"user_id": user_id, "session_id": trace_id},
input=message,
)
if lf else None
)
_span = _span_ctx.__enter__() if _span_ctx else None
streamed_text: list[str] = []
try:
for _ in range(max_steps):
_gen_ctx = (
lf.start_as_current_observation(
as_type="generation",
name=f"{agent_name}-llm",
model=model_for_agent(agent_name),
prompt=langfuse_prompt,
input=messages,
)
if lf else None
)
_gen = _gen_ctx.__enter__() if _gen_ctx else None
response: AIMessage = await llm_with_tools.ainvoke(messages)
if _gen_ctx:
_gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
_gen_ctx.__exit__(None, None, None)
messages.append(response)
if not response.tool_calls:
@@ -688,6 +825,7 @@ async def _run_single_agent_stream(
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
streamed_text.append(token)
emitted_any = True
yield "token", token
@@ -696,6 +834,7 @@ async def _run_single_agent_stream(
fallback_text = _as_text(response.content)
if fallback_text:
streamed_chars += len(fallback_text)
streamed_text.append(fallback_text)
yield "token", fallback_text
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
@@ -704,6 +843,8 @@ async def _run_single_agent_stream(
tool_calls_count,
streamed_chars,
)
if _span:
_span.update(output="".join(streamed_text))
return
tool_map = {tool_def.name: tool_def for tool_def in tools}
@@ -738,6 +879,7 @@ async def _run_single_agent_stream(
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
streamed_text.append(token)
yield "token", token
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
@@ -746,17 +888,30 @@ async def _run_single_agent_stream(
tool_calls_count,
streamed_chars,
)
if _span:
_span.update(output="".join(streamed_text))
finally:
clear_tool_result_collector()
if _span_ctx:
_span_ctx.__exit__(None, None, None)
_lf_ctx.__exit__(None, None, None)
if lf:
lf.flush()
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"home_system", _HOME_SYSTEM_PROMPT
)
system_prompt += _language_instruction(context)
response = await _run_single_agent(
user_id=user_id,
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="home-agent",
)
return _normalize_tagged_list_lines(response, message)
@@ -764,11 +919,17 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"floating_system", _FLOATING_SYSTEM_PROMPT
)
system_prompt += _language_instruction(context)
response = await _run_single_agent(
user_id=user_id,
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
)
sanitized = _strip_floating_markup(response)
if not sanitized and response:
@@ -782,12 +943,18 @@ async def run_home_stream(
context: dict[str, Any],
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"home_system", _HOME_SYSTEM_PROMPT
)
system_prompt += _language_instruction(context)
text_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=_HOME_SINGLE_AGENT_SYSTEM,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="home-agent",
):
event_type, data = event
if event_type != "token":
@@ -809,14 +976,20 @@ async def run_floating_stream(
domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain
system_prompt, langfuse_prompt = get_prompt_or_fallback(
"floating_system", _FLOATING_SYSTEM_PROMPT
)
system_prompt += _language_instruction(context)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
raw_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
):
event_type, data = event
if event_type != "token":

190
app/core/langfuse_client.py Normal file
View File

@@ -0,0 +1,190 @@
"""Langfuse observability — singleton client and prompt helpers.
If LANGFUSE_SECRET_KEY / LANGFUSE_PUBLIC_KEY are not set,
all helpers are no-ops so the app works without Langfuse configured.
Usage
-----
Tracing::
from app.core.langfuse_client import get_langfuse
lf = get_langfuse()
if lf:
with lf.start_as_current_observation(as_type="span", name="my-agent") as span:
span.update(input=user_message)
# ... do work ...
span.update(output=result)
lf.flush()
Prompt management::
from app.core.langfuse_client import get_prompt_or_fallback
text, prompt_obj = get_prompt_or_fallback("home_system", FALLBACK_PROMPT)
# Use text as the system prompt; pass prompt_obj to generations for linking.
Linking a prompt to a generation::
with lf.start_as_current_observation(
as_type="generation",
name="llm-call",
model="gpt-4o",
prompt=prompt_obj, # links generation → prompt version in the UI
input=messages,
) as gen:
response = await llm.ainvoke(messages)
gen.update(output=response.content, usage=_usage(response))
"""
from __future__ import annotations
import hashlib
import logging
from contextlib import contextmanager
from typing import Any, Generator
logger = logging.getLogger(__name__)
_client: Any = None
_initialized: bool = False
def get_langfuse() -> Any | None:
"""Return the Langfuse singleton, or ``None`` when not configured."""
global _client, _initialized
if _initialized:
return _client
_initialized = True
from app.config.settings import settings # local import to avoid circular deps
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
logger.debug("langfuse: not configured — observability disabled")
return None
try:
from langfuse import Langfuse
_client = Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_BASE_URL,
)
logger.info("langfuse: client initialized host=%s", settings.LANGFUSE_BASE_URL)
except Exception as exc:
logger.warning("langfuse: failed to initialize: %s", exc)
_client = None
return _client
def get_prompt_or_fallback(name: str, fallback: str) -> tuple[str, Any]:
"""Fetch a text prompt from Langfuse; fall back to ``fallback`` on any error.
Returns ``(raw_template, prompt_obj_or_None)``.
* ``raw_template`` — the uncompiled template string. Do NOT call ``.format()``
on it directly; use :func:`compile_prompt` instead so the correct variable
syntax is applied (``{{var}}`` for Langfuse, ``{var}`` for the fallback).
* ``prompt_obj`` — the Langfuse prompt object, or ``None`` when Langfuse is
unavailable / the fetch failed. Pass this to generation observations so
Langfuse links the generation to the exact prompt version in the UI.
"""
lf = get_langfuse()
if lf is None:
return fallback, None
try:
prompt = lf.get_prompt(name, label="production", fallback=fallback)
# For text-type prompts .prompt holds the raw template string.
raw = prompt.prompt if hasattr(prompt, "prompt") and isinstance(prompt.prompt, str) else fallback
return raw, prompt
except Exception as exc:
logger.warning("langfuse: get_prompt %r failed: %s — using fallback", name, exc)
return fallback, None
def compile_prompt(template: str, prompt_obj: Any, **variables: Any) -> str:
"""Compile *template* with *variables*, choosing the right syntax.
* When *prompt_obj* is a real Langfuse prompt object, calls
``prompt_obj.compile(**variables)`` which handles ``{{variable}}``
substitution as defined in the Langfuse UI.
* When *prompt_obj* is ``None`` (Langfuse unavailable or fetch failed),
falls back to ``template.format(**variables)`` which handles the
``{variable}`` syntax used in the hardcoded fallback strings.
This keeps callers oblivious to which syntax is in use.
"""
if prompt_obj is not None:
try:
compiled = prompt_obj.compile(**variables)
# compile() returns a string for text prompts.
if isinstance(compiled, str):
return compiled
# Chat prompts return a list of dicts — join text parts.
if isinstance(compiled, list):
return "\n".join(
m.get("content", "") for m in compiled if isinstance(m, dict)
)
except Exception as exc:
logger.warning(
"langfuse: compile failed for prompt %r: %s — falling back to .format()",
getattr(prompt_obj, "name", "?"),
exc,
)
return template.format(**variables)
def extract_usage(response: Any) -> dict[str, int]:
"""Extract token usage from a LangChain AI message into Langfuse format."""
meta = getattr(response, "usage_metadata", None)
if not meta:
return {}
return {
"input": int(meta.get("input_tokens", 0)),
"output": int(meta.get("output_tokens", 0)),
"total": int(meta.get("total_tokens", 0)),
}
def hash_user_id(user_id: str) -> str:
"""Return a SHA-256 hash of *user_id* for use as Langfuse ``user_id``.
This avoids sending raw database UUIDs to external observability services
while still providing a stable, deterministic identifier for per-user
metrics in the Langfuse dashboard.
"""
return hashlib.sha256(user_id.encode()).hexdigest()
@contextmanager
def langfuse_context(
user_id: str | None = None,
session_id: str | None = None,
) -> Generator[None, None, None]:
"""Propagate ``user_id`` (hashed) and ``session_id`` to all Langfuse observations.
No-op when Langfuse is not configured or parameters are empty.
"""
lf = get_langfuse()
if lf is None or (not user_id and not session_id):
yield
return
try:
from langfuse import propagate_attributes
except ImportError:
logger.debug("langfuse: propagate_attributes not available — skipping context")
yield
return
attrs: dict[str, str] = {}
if user_id:
attrs["user_id"] = hash_user_id(user_id)
if session_id:
attrs["session_id"] = session_id
with propagate_attributes(**attrs):
yield

View File

@@ -19,6 +19,7 @@ from __future__ import annotations
import os
import warnings
from collections.abc import Callable
from openai import AsyncOpenAI
import litellm
@@ -95,6 +96,35 @@ def get_llm(
)
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"setup": lambda: settings.LLM_MODEL_SETUP_AGENT or settings.LLM_MODEL,
}
def model_for_agent(agent_name: str) -> str:
"""Return the resolved model string for *agent_name* (for Langfuse tracking)."""
return _AGENT_MODEL_SETTINGS.get(agent_name, lambda: settings.LLM_MODEL)()
def get_agent_llm(
agent_name: str,
*,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
"""Return an LLM configured for *agent_name*, respecting per-agent overrides.
Falls back to ``settings.LLM_MODEL`` for unknown agent names or when the
per-agent override is left empty in ``.env``.
"""
model = model_for_agent(agent_name)
return get_llm(model=model, temperature=temperature)
async def embed(text: str) -> list[float]:
"""Return an embedding vector for *text*.

View File

@@ -0,0 +1,104 @@
"""Preprocessor registry: detect content type and dispatch to handlers.
Public API
----------
detect_content_type(filename, raw_content) -> str
Heuristic detection based on file extension and content patterns.
preprocess(content_type, raw_content) -> PreprocessResult
Dispatch to the appropriate handler.
"""
from __future__ import annotations
import re
from app.core.preprocessors.base import PreprocessResult
# ── Heuristics ────────────────────────────────────────────────────────
# Patterns that strongly suggest an email HTML file
_EMAIL_SIGNALS = re.compile(
r"(Subject:|From:|To:|Date:|Sent:|MIME-Version:|Content-Type:\s*text/html)",
re.IGNORECASE,
)
# Patterns that suggest a generic HTML page (not an email)
_GENERIC_HTML_SIGNALS = re.compile(
r"<(nav|main|header|footer|article|section)\b",
re.IGNORECASE,
)
def detect_content_type(filename: str, raw_content: str) -> str:
"""Return a content-type string for the given file.
Supported types: ``"email_html"``, ``"generic_html"``,
``"plain_text"``, ``"unknown"``.
"""
ext = filename.rsplit(".", 1)[-1].lower() if "." in filename else ""
if ext == "txt":
return "plain_text"
if ext in ("html", "htm", "eml", "mhtml", "mht"):
# Prefer email detection over generic HTML
if _EMAIL_SIGNALS.search(raw_content[:4096]):
return "email_html"
if _GENERIC_HTML_SIGNALS.search(raw_content[:4096]) or "<html" in raw_content[:200].lower():
return "generic_html"
# .html without clear signals — check for any email header
if re.search(r"^(From|To|Subject|Date):", raw_content[:2048], re.MULTILINE | re.IGNORECASE):
return "email_html"
return "generic_html"
# Plain text files with email headers
if ext in ("", "txt") or not ext:
if _EMAIL_SIGNALS.search(raw_content[:4096]):
return "email_html"
# Detect binary content
try:
raw_content.encode("utf-8")
except (UnicodeEncodeError, AttributeError):
return "unknown"
# Non-text bytes heuristic: high ratio of non-printable chars
sample = raw_content[:512]
non_printable = sum(1 for c in sample if ord(c) < 32 and c not in "\r\n\t")
if len(sample) > 0 and non_printable / len(sample) > 0.1:
return "unknown"
return "unknown"
# ── Generic fallback handler ──────────────────────────────────────────
def _preprocess_generic(raw_content: str, content_type: str) -> PreprocessResult:
"""Strip HTML tags if present, return text as-is."""
try:
from bs4 import BeautifulSoup
text = BeautifulSoup(raw_content, "html.parser").get_text(separator="\n")
except ImportError:
# No BeautifulSoup — strip tags with a simple regex
text = re.sub(r"<[^>]+>", "", raw_content)
text = re.sub(r"\n{3,}", "\n\n", text).strip()
return PreprocessResult(content_type=content_type, clean_text=text, metadata={})
# ── Dispatch ──────────────────────────────────────────────────────────
def preprocess(content_type: str, raw_content: str) -> PreprocessResult:
"""Dispatch *raw_content* to the handler registered for *content_type*.
Falls back to the generic handler for unknown types.
"""
if content_type == "email_html":
from app.core.preprocessors.email_html import preprocess_email_html
return preprocess_email_html(raw_content)
return _preprocess_generic(raw_content, content_type)
__all__ = ["detect_content_type", "preprocess", "PreprocessResult"]

View File

@@ -0,0 +1,25 @@
"""Base types for the preprocessor system."""
from __future__ import annotations
from dataclasses import dataclass, field
@dataclass
class PreprocessResult:
"""Output of a preprocessor handler.
Attributes
----------
content_type:
The detected content type (e.g. ``"email_html"``, ``"plain_text"``).
clean_text:
Human-readable text stripped of markup/binary noise.
metadata:
Dict of extracted metadata (keys vary by handler).
Common keys: ``subject``, ``from``, ``to``, ``date``, ``filename``.
"""
content_type: str
clean_text: str
metadata: dict = field(default_factory=dict)

View File

@@ -0,0 +1,111 @@
"""Preprocessor for email HTML files.
Handles:
- HTML stripping via BeautifulSoup
- Metadata extraction (Subject, From, To, Date)
- Thread splitting — isolates the latest reply
"""
from __future__ import annotations
import re
from typing import TYPE_CHECKING
from app.core.preprocessors.base import PreprocessResult
if TYPE_CHECKING:
pass
# ── Thread split markers ──────────────────────────────────────────────
# Matches patterns like:
# "On Mon, Apr 7, 2026 at 10:00 AM, Alice <alice@co.com> wrote:"
# "-----Original Message-----"
# "> " (plain-text quote prefix)
_THREAD_PATTERNS = [
re.compile(r"^On\s+.+wrote\s*:", re.IGNORECASE | re.MULTILINE),
re.compile(r"^-{3,}\s*(original message|forwarded message)\s*-{3,}", re.IGNORECASE | re.MULTILINE),
re.compile(r"^>{1,}\s+\S", re.MULTILINE),
re.compile(r"^From:\s+.+\nSent:\s+", re.IGNORECASE | re.MULTILINE),
]
# ── Metadata patterns (applied on raw HTML / plain fallback) ──────────
_META_PATTERNS: dict[str, list[re.Pattern]] = {
"subject": [
re.compile(r"<title>(.+?)</title>", re.IGNORECASE | re.DOTALL),
re.compile(r"Subject:\s*(.+)", re.IGNORECASE),
],
"from": [
re.compile(r'<meta[^>]+name=["\']?from["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
re.compile(r"From:\s*(.+)", re.IGNORECASE),
],
"to": [
re.compile(r'<meta[^>]+name=["\']?to["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
re.compile(r"To:\s*(.+)", re.IGNORECASE),
],
"date": [
re.compile(r'<meta[^>]+name=["\']?date["\']?[^>]+content=["\']([^"\']+)["\']', re.IGNORECASE),
re.compile(r"Date:\s*(.+)", re.IGNORECASE),
re.compile(r"Sent:\s*(.+)", re.IGNORECASE),
],
}
def _extract_metadata(raw_html: str, text: str) -> dict:
"""Extract Subject/From/To/Date from raw HTML or plain text."""
metadata: dict[str, str] = {}
for field, patterns in _META_PATTERNS.items():
for pat in patterns:
m = pat.search(raw_html) or pat.search(text)
if m:
metadata[field] = m.group(1).strip()
break
return metadata
def _split_thread(text: str) -> str:
"""Return only the latest message in a threaded email."""
earliest_pos: int | None = None
for pat in _THREAD_PATTERNS:
m = pat.search(text)
if m and (earliest_pos is None or m.start() < earliest_pos):
earliest_pos = m.start()
if earliest_pos is not None and earliest_pos > 0:
return text[:earliest_pos].strip()
return text.strip()
def preprocess_email_html(raw_content: str) -> PreprocessResult:
"""Strip HTML, extract metadata, split thread from an email HTML file."""
try:
from bs4 import BeautifulSoup # lazy import — optional dep
except ImportError as exc:
raise ImportError(
"beautifulsoup4 is required for email_html preprocessing. "
"Install it with: pip install beautifulsoup4"
) from exc
# Parse with lxml if available, fall back to html.parser
try:
soup = BeautifulSoup(raw_content, "lxml")
except Exception:
soup = BeautifulSoup(raw_content, "html.parser")
# Remove noise tags
for tag in soup(["style", "script", "head", "noscript"]):
tag.decompose()
clean_text = soup.get_text(separator="\n")
# Collapse excessive blank lines
clean_text = re.sub(r"\n{3,}", "\n\n", clean_text).strip()
metadata = _extract_metadata(raw_content, clean_text)
latest_message = _split_thread(clean_text)
return PreprocessResult(
content_type="email_html",
clean_text=latest_message,
metadata=metadata,
)

View File

@@ -25,7 +25,7 @@ from __future__ import annotations
import logging
import re
from datetime import datetime, timedelta, timezone
from datetime import datetime, timezone
from typing import Any
import httpx

View File

@@ -30,7 +30,7 @@ async def lifespan(app: FastAPI):
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Cloud API",
title="AdiuvAI Cloud API",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
@@ -50,14 +50,10 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
from app.api.routes import agents, auth, billing, chat, device_ws
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1")
app.include_router(plugins.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")

View File

@@ -1,7 +0,0 @@
"""Plugin marketplace package.
Three service classes introduced in Step 10:
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
- ``ReviewQueue`` — approval workflow + security checklist
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
"""

View File

@@ -1,212 +0,0 @@
"""Plugin catalog registry backed by PostgreSQL.
Maintains the authoritative list of plugins, their review status, and
aggregate install counts. All data is persisted in the ``plugins`` table.
Module-level singleton::
from app.marketplace.plugin_registry import registry
"""
from __future__ import annotations
import json
from typing import Any, Literal
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Plugin
from app.schemas import PluginListResponse, PluginManifest
_PAGE_SIZE = 20
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
try:
permissions = json.loads(p.permissions) if p.permissions else []
except (json.JSONDecodeError, TypeError):
permissions = []
return PluginManifest(
id=p.id,
name=p.name,
description=p.description,
version=p.version,
author=p.author_name,
permissions=permissions,
category=p.category,
price_cents=p.price_cents,
)
class PluginRegistry:
"""PostgreSQL-backed plugin catalog.
All methods accept an ``AsyncSession`` parameter so the calling route
controls the session lifecycle.
"""
# ── Queries ──────────────────────────────────────────────────────
async def list_plugins(
self,
db: AsyncSession,
category: str | None = None,
query: str | None = None,
page: int = 1,
sort: Literal["rating", "installs", "newest"] = "newest",
) -> PluginListResponse:
"""Return a page of approved plugins, optionally filtered and sorted."""
base = select(Plugin).where(Plugin.status == "approved")
if category:
base = base.where(Plugin.category == category)
if query:
pattern = f"%{query}%"
base = base.where(
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
)
# Count
count_q = select(func.count()).select_from(base.subquery())
total = (await db.execute(count_q)).scalar_one()
# Sort
if sort == "installs":
base = base.order_by(Plugin.install_count.desc())
elif sort == "rating":
base = base.order_by(Plugin.avg_rating.desc())
else: # newest
base = base.order_by(Plugin.created_at.desc())
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
rows = (await db.execute(base)).scalars().all()
return PluginListResponse(
plugins=[_plugin_to_manifest(r) for r in rows],
total=total,
page=page,
)
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
p = result.scalar_one_or_none()
if p is None:
return None
return {
"manifest": _plugin_to_manifest(p),
"status": p.status,
"install_count": p.install_count,
"avg_rating": p.avg_rating,
}
# ── Mutations ────────────────────────────────────────────────────
async def submit_plugin(
self,
db: AsyncSession,
manifest: PluginManifest,
package_s3_key: str,
) -> str:
"""Add *manifest* to the catalog with ``status='pending_review'``.
Returns the plugin_id. If a plugin with the same id already exists
it is overwritten (re-submission after rejection).
"""
plugin_id = manifest.id
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = existing.scalar_one_or_none()
if row is not None:
row.name = manifest.name
row.description = manifest.description
row.version = manifest.version
row.author_name = manifest.author
row.category = manifest.category
row.price_cents = manifest.price_cents
row.permissions = json.dumps(manifest.permissions)
row.status = "pending_review"
row.s3_package_key = package_s3_key
row.rejection_reason = None
else:
row = Plugin(
id=plugin_id,
name=manifest.name,
description=manifest.description,
version=manifest.version,
author_name=manifest.author,
category=manifest.category,
price_cents=manifest.price_cents,
permissions=json.dumps(manifest.permissions),
status="pending_review",
s3_package_key=package_s3_key,
install_count=0,
avg_rating=0.0,
)
db.add(row)
await db.commit()
return plugin_id
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
"""Set *plugin_id* status to ``'approved'``.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "approved"
row.rejection_reason = None
await db.commit()
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "rejected"
row.rejection_reason = reason
await db.commit()
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
"""Increment the install count for *plugin_id* (no-op if not found)."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = row.install_count + 1
await db.commit()
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
"""Decrement the install count for *plugin_id*, floored at 0."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = max(0, row.install_count - 1)
await db.commit()
# ── Internal helpers used by ReviewQueue ─────────────────────────
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review'."""
result = await db.execute(
select(Plugin).where(Plugin.status == "pending_review")
)
rows = result.scalars().all()
return [
{
"manifest": _plugin_to_manifest(r),
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
}
for r in rows
]
# Module-level singleton
registry = PluginRegistry()

View File

@@ -1,125 +0,0 @@
"""Plugin review workflow backed by PostgreSQL.
Manages the approval queue for newly submitted plugins and enforces a
security checklist before any plugin is made visible in the marketplace.
Module-level singleton::
from app.marketplace.plugin_review import review_queue
"""
from __future__ import annotations
import re
from typing import Any, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import registry
from app.models import PluginReview as PluginReviewModel
from app.schemas import PluginManifest
# ── Security policy ───────────────────────────────────────────────────
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
{
"read:tasks",
"write:tasks",
"read:projects",
"write:projects",
"read:notes",
"write:notes",
"read:timelines",
"write:timelines",
"read:calendar",
"write:calendar",
}
)
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
def validate_manifest(manifest: PluginManifest) -> None:
"""Enforce the plugin security checklist.
Raises:
``ValueError`` on the first violation found. Callers should catch
this and return HTTP 422 / reject the submission.
Checks:
1. Plugin id matches ``^[a-z0-9-]+$``
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
3. No manifest field contains raw binary data
"""
if not _PLUGIN_ID_RE.match(manifest.id):
raise ValueError(
f"Invalid plugin id format: '{manifest.id}'. "
"Only lowercase letters, digits, and hyphens are allowed."
)
for perm in manifest.permissions:
if perm not in ALLOWED_PERMISSIONS:
raise ValueError(
f"Unknown permission: '{perm}'. "
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
)
for field_name, value in manifest.model_dump().items():
if isinstance(value, (bytes, bytearray)):
raise ValueError(
f"Binary content is not allowed in manifest field '{field_name}'."
)
class ReviewQueue:
"""Approval queue for pending plugin submissions.
Delegates status changes to the shared ``PluginRegistry`` singleton.
Review records are persisted in the ``plugin_reviews`` table.
"""
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all plugins currently awaiting review.
Each item is ``{plugin_id, manifest, submitted_at}``.
"""
entries = await registry.get_pending_entries(db)
return [
{
"plugin_id": e["manifest"].id,
"manifest": e["manifest"],
"submitted_at": e["submitted_at"],
}
for e in entries
]
async def submit_review(
self,
db: AsyncSession,
plugin_id: str,
reviewer_id: str,
decision: Literal["approved", "rejected"],
notes: str = "",
) -> None:
"""Record a review decision and update the plugin's status.
Raises:
``KeyError`` if *plugin_id* is not found in the registry.
"""
if decision == "approved":
await registry.approve_plugin(db, plugin_id)
else:
await registry.reject_plugin(db, plugin_id, reason=notes)
review = PluginReviewModel(
plugin_id=plugin_id,
reviewer_id=reviewer_id,
decision=decision,
notes=notes,
)
db.add(review)
await db.commit()
# Module-level singleton
review_queue = ReviewQueue()

View File

@@ -1,233 +0,0 @@
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
Records every plugin installation as a revenue event and facilitates
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
in the ``revenue_events`` table.
Module-level singleton::
from app.marketplace.revenue_share import revenue_share
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from sqlalchemy import extract, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
from app.models import Plugin, RevenueEvent
logger = logging.getLogger(__name__)
# ── Revenue split constants ───────────────────────────────────────────
DEVELOPER_SHARE: float = 0.70
PLATFORM_SHARE: float = 0.30
class RevenueShare:
"""Records installation revenue events and coordinates developer payouts.
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
is not configured, consistent with the rest of the billing layer.
"""
# ── Helpers ──────────────────────────────────────────────────────
@staticmethod
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
@staticmethod
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Core operations ──────────────────────────────────────────────
async def record_install(
self,
db: AsyncSession,
plugin_id: str,
user_id: str,
amount_cents: int,
) -> None:
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
For free plugins (``amount_cents == 0``) no payment is initiated but
the event is still recorded for analytics.
For paid plugins the developer receives 70 % via a Stripe Connect
destination charge. If Stripe is not configured or the charge fails
the installation still succeeds (the event is recorded and the install
count is incremented) — a warning is logged for monitoring.
"""
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
stripe_transfer_id: str | None = None
if amount_cents > 0 and self._stripe_configured():
# Look up the plugin's author Stripe account from the DB
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = result.scalar_one_or_none()
developer_stripe_account: str | None = None
if plugin_row and plugin_row.author_id:
# Future: look up user.stripe_connect_account_id
developer_stripe_account = None # no real account yet
if developer_stripe_account:
try:
s = self._stripe()
transfer = s.Transfer.create(
amount=developer_share_cents,
currency="eur",
destination=developer_stripe_account,
description=f"Revenue share for plugin {plugin_id}",
metadata={"plugin_id": plugin_id, "user_id": user_id},
)
stripe_transfer_id = transfer["id"]
except Exception as exc:
logger.warning(
"Stripe Connect transfer failed for plugin %s: %s",
plugin_id,
exc,
)
else:
logger.debug(
"No Stripe account on file for plugin %s developer; "
"skipping transfer.",
plugin_id,
)
event = RevenueEvent(
plugin_id=plugin_id,
user_id=user_id,
amount_cents=amount_cents,
developer_share_cents=developer_share_cents,
stripe_transfer_id=stripe_transfer_id,
)
db.add(event)
await db.commit()
await registry.record_install(db, plugin_id)
async def get_earnings(
self,
db: AsyncSession,
developer_id: str,
period: str | None = None,
) -> dict[str, Any]:
"""Return aggregated earnings for *developer_id*.
``period`` is an optional ``YYYY-MM`` string to restrict the window.
Returns::
{
"developer_id": str,
"period": str | None,
"total_installs": int,
"total_revenue_cents": int,
"developer_share_cents": int,
}
"""
# Find plugin ids belonging to this developer (by author_name match)
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
plugin_result = await db.execute(plugin_q)
developer_plugin_ids = [row[0] for row in plugin_result.all()]
if not developer_plugin_ids:
return {
"developer_id": developer_id,
"period": period,
"total_installs": 0,
"total_revenue_cents": 0,
"developer_share_cents": 0,
}
query = select(
func.count().label("total_installs"),
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
if period:
# Filter by YYYY-MM: extract year and month from created_at
try:
year, month = period.split("-")
query = query.where(
extract("year", RevenueEvent.created_at) == int(year),
extract("month", RevenueEvent.created_at) == int(month),
)
except ValueError:
pass # invalid period format — return all
result = await db.execute(query)
row = result.one()
return {
"developer_id": developer_id,
"period": period,
"total_installs": row.total_installs,
"total_revenue_cents": row.total_revenue,
"developer_share_cents": row.dev_share,
}
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
Marks processed events with ``paid_at`` timestamp.
Stubs gracefully when Stripe is not configured.
"""
try:
year, month = period.split("-")
year_int, month_int = int(year), int(month)
except ValueError:
logger.warning("Invalid period format: %s", period)
return
result = await db.execute(
select(RevenueEvent).where(
RevenueEvent.plugin_id == plugin_id,
RevenueEvent.paid_at.is_(None),
extract("year", RevenueEvent.created_at) == year_int,
extract("month", RevenueEvent.created_at) == month_int,
)
)
unpaid = list(result.scalars().all())
total_dev_share = sum(e.developer_share_cents for e in unpaid)
if total_dev_share <= 0 or not unpaid:
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
return
if self._stripe_configured():
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = plugin_result.scalar_one_or_none()
developer_stripe_account: str | None = None # Future: fetch from DB
if plugin_row and developer_stripe_account:
try:
s = self._stripe()
s.Transfer.create(
amount=total_dev_share,
currency="eur",
destination=developer_stripe_account,
description=f"Payout for plugin {plugin_id} period {period}",
)
except Exception as exc:
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
return
paid_ts = datetime.now(timezone.utc)
for event in unpaid:
event.paid_at = paid_ts
await db.commit()
# Module-level singleton
revenue_share = RevenueShare()

View File

@@ -1,19 +1,15 @@
"""SQLAlchemy ORM models for all persistent tables.
Only auth, billing, storage metadata, and marketplace data live here.
User content (notes, tasks, etc.) is NEVER persisted server-side —
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
Only auth, billing, agent config, and memory data live here.
User content (notes, tasks, etc.) lives exclusively on the client.
Table inventory:
users — account credentials + tier
refresh_tokens — hashed refresh token store
subscriptions — Stripe subscription records
storage_records — S3 blob metadata (no plaintext)
backup_metadata — encrypted backup manifests
plugins — marketplace plugin catalog
plugin_installations — per-user install records
plugin_reviews — admin review decisions
revenue_events — Stripe Connect 70/30 split ledger
local_agent_configs — per-device batch agent configs
cloud_agent_configs — OAuth-backed cloud agent configs
agent_run_logs — execution history for all agents
memory_core — per-user persistent key/value preferences (encrypted)
memory_associative — per-user semantic memory with embeddings (encrypted)
memory_episodic — per-user session summaries (encrypted)
@@ -26,7 +22,6 @@ import uuid
from datetime import datetime, timezone
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Enum,
@@ -36,7 +31,6 @@ from sqlalchemy import (
JSON,
String,
Text,
UniqueConstraint,
Uuid,
func,
)
@@ -58,8 +52,6 @@ def _now() -> datetime:
# ── Enum types ────────────────────────────────────────────────────────────
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
@@ -77,7 +69,8 @@ class User(Base):
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
password_hash: Mapped[str | None] = mapped_column(String(255), nullable=True)
avatar_url: Mapped[str | None] = mapped_column(Text, nullable=True)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
@@ -86,6 +79,9 @@ class User(Base):
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
onboarding_completed_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, default=None
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
@@ -96,6 +92,9 @@ class User(Base):
subscription: Mapped[Subscription | None] = relationship(
back_populates="user", uselist=False, cascade="all, delete-orphan"
)
oauth_accounts: Mapped[list[OAuthAccount]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
class RefreshToken(Base):
@@ -116,6 +115,25 @@ class RefreshToken(Base):
user: Mapped[User] = relationship(back_populates="refresh_tokens")
class OAuthAccount(Base):
__tablename__ = "oauth_accounts"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
provider: Mapped[str] = mapped_column(String(50), nullable=False)
provider_user_id: Mapped[str] = mapped_column(String(255), nullable=False)
provider_email: Mapped[str | None] = mapped_column(String(255), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="oauth_accounts")
class Subscription(Base):
__tablename__ = "subscriptions"
@@ -137,151 +155,6 @@ class Subscription(Base):
user: Mapped[User] = relationship(back_populates="subscription")
class StorageRecord(Base):
__tablename__ = "storage_records"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class BackupMetadata(Base):
__tablename__ = "backup_metadata"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class Plugin(Base):
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
# nullable until developer account system is built
author_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
submitted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
installations: Mapped[list[PluginInstallation]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
reviews: Mapped[list[PluginReview]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
revenue_events: Mapped[list[RevenueEvent]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
class PluginInstallation(Base):
__tablename__ = "plugin_installations"
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="installations")
class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
reviewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
class RevenueEvent(Base):
__tablename__ = "revenue_events"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
class LocalAgentConfig(Base):
__tablename__ = "local_agent_configs"
@@ -296,6 +169,7 @@ class LocalAgentConfig(Base):
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)

View File

@@ -30,6 +30,16 @@ class UserProfile(BaseModel):
name: str | None = None
surname: str | None = None
tier: BillingTier
avatar_url: str | None = None
has_password: bool = True
onboarding_completed_at: int | None = None # epoch ms, null = not onboarded
memory: dict[str, str] = Field(default_factory=dict) # decrypted core memory k/v
class OAuthAccountInfo(BaseModel):
provider: str
provider_email: str | None = None
created_at: int # epoch ms
# ── Chat ─────────────────────────────────────────────────────────────
@@ -50,88 +60,6 @@ class ChatResponse(BaseModel):
response: str
# ── Backup ───────────────────────────────────────────────────────────
class BackupMetadata(BaseModel):
version: int
timestamp: int
checksum: str
chunk_count: int
# ── Cloud Storage (E2E encrypted blobs) ──────────────────────────────
class StorageRecord(BaseModel):
id: str
user_id: str
table: str
blob: bytes
checksum: str
created_at: int
updated_at: int
class StorageRecordCreate(BaseModel):
table: str
blob: bytes
checksum: str
class StorageRecordUpdate(BaseModel):
blob: bytes
checksum: str
# ── Cloud Vector Store (E2E encrypted vectors) ────────────────────────
class VectorItem(BaseModel):
id: str
blob: bytes # encrypted vector + metadata — backend never decrypts
checksum: str
class VectorUpsertRequest(BaseModel):
vectors: list[VectorItem]
class VectorSearchRequest(BaseModel):
query_blob: bytes # encrypted query — backend never decrypts
top_k: int = 10
class VectorSearchResult(BaseModel):
id: str
score: float
blob: bytes
class VectorSearchResponse(BaseModel):
results: list[VectorSearchResult]
# ── Plugin Marketplace ────────────────────────────────────────────────
class PluginManifest(BaseModel):
id: str
name: str
description: str
version: str
author: str
permissions: list[str]
category: str
price_cents: int = 0
class PluginListResponse(BaseModel):
plugins: list[PluginManifest]
total: int
page: int
class PluginInstallRequest(BaseModel):
plugin_id: str
# ── WebSocket Frame Protocol ──────────────────────────────────────────
class WsFrameType(str, Enum):
@@ -273,6 +201,27 @@ class WsFloatingDomain(BaseModel):
domain: WsDomain
# ── Agent Config V2 ───────────────────────────────────────────────────
class ContentTypeConfig(BaseModel):
"""Per-type extraction config produced by the journey chatbot."""
id: str
label: str = ""
detection_hint: str = ""
preprocessing: str = "generic" # handler name: "email_html", "plain_text", ...
extraction_prompt: str
class AgentConfig(BaseModel):
"""Structured agent configuration (replaces freeform prompt_template)."""
content_types: list[ContentTypeConfig] = []
global_rules: list[str] = []
data_types: list[str] = []
# ── Agent Catalog ─────────────────────────────────────────────────────
class AgentCatalogItem(BaseModel):
@@ -297,10 +246,11 @@ class AgentTriggerRequest(BaseModel):
device_id: str = Field(default="")
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
what_to_extract: list[str] = Field(min_length=1)
actions_by_type: dict[str, list[str]] | None = None
batch_interval: str = Field(min_length=1)
custom_agent_prompt: str = Field(min_length=1)
custom_agent_prompt: str | None = None
agent_config: dict | None = None
active_agents: int = Field(ge=0, default=0)
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
# ── Agent Run Log ─────────────────────────────────────────────────────

View File

@@ -1 +0,0 @@
"""Cloud storage layer — E2E encrypted blobs and vectors."""

View File

@@ -1,106 +0,0 @@
"""S3-backed store for E2E-encrypted blobs.
Keys are structured as ``{user_id}/{table}/{record_id}``.
The backend never inspects blob content — it stores and retrieves opaque bytes.
"""
from __future__ import annotations
from typing import Any
import boto3
from app.config.settings import settings
class BlobStore:
"""Thin wrapper around boto3 S3.
All blobs must be E2E encrypted by the client before upload.
The backend adds SSE-S3 as an extra layer of at-rest encryption
but cannot decrypt the inner client-side payload.
"""
def _client(self) -> Any:
kwargs: dict[str, Any] = {
"region_name": settings.S3_REGION,
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
}
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
return boto3.client("s3", **kwargs)
@staticmethod
def _key(user_id: str, table: str, record_id: str) -> str:
return f"{user_id}/{table}/{record_id}"
async def upload(
self,
user_id: str,
table: str,
record_id: str,
blob: bytes,
checksum: str,
) -> str:
"""Store *blob* in S3 and return the S3 key.
Args:
user_id: Owner of the blob (used as key prefix).
table: Logical table name (e.g. ``"tasks"``).
record_id: Record UUID.
blob: Raw bytes (pre-encrypted by client).
checksum: SHA-256 hex digest supplied by the client; stored as
object metadata for download-time verification.
Returns:
The S3 key under which the blob was stored.
"""
key = self._key(user_id, table, record_id)
self._client().put_object(
Bucket=settings.S3_BUCKET,
Key=key,
Body=blob,
ServerSideEncryption="AES256", # SSE-S3 at rest
Metadata={"checksum": checksum},
)
return key
async def download(self, user_id: str, s3_key: str) -> bytes:
"""Retrieve the blob stored at *s3_key*.
*user_id* is retained in the signature so higher-level code can
enforce ownership without re-parsing the key.
Raises:
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
object does not exist.
"""
response = self._client().get_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
return response["Body"].read()
async def delete(self, user_id: str, s3_key: str) -> None:
"""Delete the object at *s3_key*.
S3 ``delete_object`` is idempotent — it succeeds even if the key does
not exist.
"""
self._client().delete_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
async def list_keys(self, user_id: str, table: str) -> list[str]:
"""Return all S3 keys for a given user + table combination.
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
"""
prefix = f"{user_id}/{table}/"
response = self._client().list_objects_v2(
Bucket=settings.S3_BUCKET,
Prefix=prefix,
)
return [obj["Key"] for obj in response.get("Contents", [])]

View File

@@ -1,32 +0,0 @@
"""Integrity verification only — the backend NEVER decrypts user data."""
from __future__ import annotations
import hashlib
import hmac
from fastapi import HTTPException
def verify_checksum(blob: bytes, checksum: str) -> bool:
"""Return ``True`` if SHA-256(blob) matches *checksum*.
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
timing-based side-channel attacks.
"""
computed = hashlib.sha256(blob).hexdigest()
return hmac.compare_digest(computed, checksum)
def reject_if_tampered(blob: bytes, checksum: str) -> None:
"""Raise ``HTTP 400`` if the blob does not match its checksum.
Call this before storing or forwarding any client-provided blob.
The backend never holds decryption keys — this check only verifies
that the opaque bytes arrived intact.
"""
if not verify_checksum(blob, checksum):
raise HTTPException(
status_code=400,
detail="Checksum mismatch: blob integrity check failed",
)

View File

@@ -1,205 +0,0 @@
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
Vectors are pre-encrypted blobs from the client. The backend stores them
alongside a deterministic 32-dim float representation derived from the blob's
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
is a known trade-off documented in the backend plan.
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
``user_id`` payload field on a shared collection.
"""
from __future__ import annotations
import base64
import hashlib
from typing import Any
from pinecone import Pinecone
from qdrant_client import QdrantClient
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
from app.config.settings import settings
from app.schemas import VectorItem, VectorSearchResult
_QDRANT_COLLECTION = "adiuva_vectors"
def _blob_to_vector(blob: bytes) -> list[float]:
"""Derive a 32-dim float vector from *blob* for storage purposes only.
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
normalises each byte to the range [-1.0, 1.0]. This vector carries no
semantic meaning on encrypted data.
"""
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
class VectorStore:
"""Thin wrapper around Pinecone or Qdrant.
The backend to use is selected at runtime:
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
"""
def _use_pinecone(self) -> bool:
return bool(settings.PINECONE_API_KEY)
# ── Pinecone helpers ──────────────────────────────────────────────
def _pinecone_index(self) -> Any:
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
return pc.Index(settings.PINECONE_INDEX)
# ── Qdrant helpers ────────────────────────────────────────────────
def _qdrant_client(self) -> Any:
return QdrantClient(
url=settings.QDRANT_URL,
api_key=settings.QDRANT_API_KEY or None,
)
# ── Public API ────────────────────────────────────────────────────
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
"""Store encrypted vectors in the backend.
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
so it can be returned verbatim during search.
Args:
user_id: Used as Pinecone namespace or Qdrant payload field.
vectors: List of encrypted vector items from the client.
"""
if self._use_pinecone():
await self._pinecone_upsert(user_id, vectors)
else:
await self._qdrant_upsert(user_id, vectors)
async def search(
self,
user_id: str,
query_blob: bytes,
top_k: int,
) -> list[VectorSearchResult]:
"""Query the vector store and return encrypted result blobs.
The query vector is derived from *query_blob* using the same
deterministic mapping as upsert.
Args:
user_id: Scopes the search to this user's namespace.
query_blob: Encrypted query from the client.
top_k: Maximum number of results to return.
Returns:
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
"""
if self._use_pinecone():
return await self._pinecone_search(user_id, query_blob, top_k)
return await self._qdrant_search(user_id, query_blob, top_k)
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
"""Remove vectors by ID, scoped to *user_id*.
Args:
user_id: Namespace / payload filter to prevent cross-user deletion.
vector_ids: List of vector IDs to remove.
"""
if self._use_pinecone():
await self._pinecone_delete(user_id, vector_ids)
else:
await self._qdrant_delete(user_id, vector_ids)
# ── Pinecone implementation ───────────────────────────────────────
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
index = self._pinecone_index()
records = [
{
"id": v.id,
"values": _blob_to_vector(v.blob),
"metadata": {
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
}
for v in vectors
]
index.upsert(vectors=records, namespace=user_id)
async def _pinecone_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
index = self._pinecone_index()
query_vector = _blob_to_vector(query_blob)
response = index.query(
vector=query_vector,
top_k=top_k,
namespace=user_id,
include_metadata=True,
)
results: list[VectorSearchResult] = []
for match in response.get("matches", []):
blob_bytes = base64.b64decode(match["metadata"]["blob"])
results.append(
VectorSearchResult(
id=match["id"],
score=match["score"],
blob=blob_bytes,
)
)
return results
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
index = self._pinecone_index()
index.delete(ids=vector_ids, namespace=user_id)
# ── Qdrant implementation ─────────────────────────────────────────
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
client = self._qdrant_client()
points = [
PointStruct(
id=v.id,
vector=_blob_to_vector(v.blob),
payload={
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
)
for v in vectors
]
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
async def _qdrant_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
client = self._qdrant_client()
query_vector = _blob_to_vector(query_blob)
hits = client.search(
collection_name=_QDRANT_COLLECTION,
query_vector=query_vector,
query_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=top_k,
)
return [
VectorSearchResult(
id=str(hit.id),
score=hit.score,
blob=base64.b64decode(hit.payload["blob"]),
)
for hit in hits
]
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
client = self._qdrant_client()
client.delete(
collection_name=_QDRANT_COLLECTION,
points_selector=PointIdsList(points=vector_ids),
)

View File

@@ -7,7 +7,7 @@ services:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuvai
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
volumes:
- copilot_tokens:/root/.config/litellm/github_copilot
@@ -21,7 +21,7 @@ services:
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: adiuva
POSTGRES_DB: adiuvai
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
@@ -36,37 +36,6 @@ services:
# image: redis:7-alpine
# restart: unless-stopped
# ── Local S3-compatible storage (MinIO) ──
minio:
image: minio/minio:latest
command: server /data --console-address ":9001"
ports:
- "9000:9000"
- "9001:9001"
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin
volumes:
- minio_data:/data
healthcheck:
test: ["CMD", "mc", "ready", "local"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
# ── Local vector store (Qdrant) ──
qdrant:
image: qdrant/qdrant:latest
ports:
- "6333:6333"
- "6334:6334"
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
volumes:
postgres_data:
minio_data:
qdrant_data:
copilot_tokens:

View File

@@ -1,941 +0,0 @@
# Adiuva — Architettura Microservizi (MVP)
## Panoramica
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
```
┌──────────────┐
│ Cloudflare │
│ (DNS + CDN) │
└──────┬───────┘
│ HTTPS / WSS
┌──────▼───────┐
│ Traefik │
│ API Gateway │
│ (routing, │
│ TLS, rate │
│ limiting) │
└──────┬───────┘
┌──────────┬───────────┼───────────┐
│ │ │ │
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
│ Auth │ │ Chat │ │ Agent │ │Billing │
│ Service │ │Service│ │ Service │ │Service │
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
│ │ │ │
┌─────▼──────────▼──────────▼───────────▼────┐
│ Infrastruttura │
│ PostgreSQL │ Redis │ Qdrant │
└─────────────────────────────────────────────┘
```
---
## 1. Suddivisione dei Servizi
### 1.1 Auth Service (`auth-service`)
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
| Endpoint originale | Metodo |
|---|---|
| `/api/v1/auth/register` | POST |
| `/api/v1/auth/login` | POST |
| `/api/v1/auth/refresh` | POST |
| `/api/v1/auth/me` | GET / PUT |
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
**Modifica chiave — JWT con RS256**:
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
- L'Auth Service firma i JWT con la **chiave privata**.
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
```python
# auth-service/app/auth/jwt.py
from cryptography.hazmat.primitives.asymmetric import rsa
from jose import jwt
PRIVATE_KEY = ... # Da env/secret
PUBLIC_KEY = ... # Derivata o da env
def create_access_token(user_id: str, tier: str) -> str:
return jwt.encode(
{"sub": user_id, "tier": tier, "exp": ...},
PRIVATE_KEY,
algorithm="RS256",
)
```
```python
# shared/auth.py (usato da tutti gli altri servizi)
from jose import jwt
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
def verify_token(token: str) -> dict:
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
```
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
---
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
| Endpoint | Tipo |
|---|---|
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
| `/api/v1/chat` | POST (REST fallback) |
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
**Scaling**: 2N repliche. Sticky cookies per le WS + Redis per cross-instance.
---
### 1.3 Agent Service (`agent-service`) ⭐ Batch
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
| Endpoint | Tipo |
|---|---|
| `/api/v1/agents/catalog` | GET |
| `/api/v1/agents/can-create` | POST |
| `/api/v1/agents/trigger` | POST |
| `/api/v1/agents/journey/start` | POST (o WS relay) |
| `/api/v1/agents/journey/message` | POST (o WS relay) |
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
```
┌──────────────┐ ┌──────────────┐ ┌──────────┐
│ Agent Service│ │ Redis │ │ Chat │
│ (batch run) │ │ │ │ Service │
│ │ │ │ │ (ha WS) │
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
│ from │ │ │ │ al │
│ device │ │ │ │ device│
│ │ │ │ │ via WS│
│ │ SUBSCRIBE │ │ PUBLISH │ │
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
│ risultato │ │ │ │ reply │
└──────────────┘ └──────────────┘ └──────────┘
```
**Scaling**: 1N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
---
### 1.4 Billing Service (`billing-service`)
**Responsabilità**: Stripe checkout, webhook, subscription management.
| Endpoint originale | Metodo |
|---|---|
| `/api/v1/billing/checkout` | POST |
| `/api/v1/billing/webhook` | POST |
| `/api/v1/billing/subscription` | GET / DELETE |
**Database**: Tabelle `subscriptions` (schema `billing`).
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
**Scaling**: 1 replica sufficiente. Basso traffico.
---
### 1.5 Servizi esclusi dall'MVP
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
| Servizio | Responsabilità | Note |
|---|---|---|
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
---
## 2. Tier Check — Dove e Come
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
### Strategia: Tier nel JWT
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
```json
{
"sub": "user_123",
"tier": "pro",
"exp": 1742515200,
"iat": 1742511600
}
```
Ogni servizio:
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
2. Legge `payload["tier"]`**zero chiamate extra**
3. Applica le sue regole di enforcement localmente
```python
# shared/auth.py — dependency FastAPI condivisa
from fastapi import Depends, HTTPException, Request
from jose import jwt
PUBLIC_KEY = ...
class CurrentUser:
def __init__(self, user_id: str, tier: str):
self.user_id = user_id
self.tier = tier
async def get_current_user(request: Request) -> CurrentUser:
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
def require_tier(*allowed_tiers: str):
"""Dependency che blocca se il tier non è tra quelli ammessi."""
async def check(user: CurrentUser = Depends(get_current_user)):
if user.tier not in allowed_tiers:
raise HTTPException(403, "Tier insufficient")
return user
return check
```
### Cosa succede quando il tier cambia (upgrade/downgrade)?
```
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
│ │ │ Service │ (Redis pub/sub) │ Service │
└──────────┘ └──────────┘ └────┬─────┘
UPDATE users
SET tier = 'power'
Al prossimo /refresh
il JWT conterrà tier='power'
```
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 1530 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
### Dove si applica in ciascun servizio
| Servizio | Enforcement |
|---|---|
| **Auth Service** | Nessuno (è lui che scrive il tier) |
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
### Rate-limit distribuito via Redis
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
```python
# shared/middleware/rate_limit.py
import redis.asyncio as aioredis
class DistributedRateLimiter:
def __init__(self, redis: aioredis.Redis):
self._redis = redis
async def check(self, user_id: str, tier: str, service: str) -> bool:
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
max_req = limits.get(tier, 20)
key = f"rate:{service}:{user_id}"
pipe = self._redis.pipeline()
pipe.incr(key)
pipe.expire(key, 60)
count, _ = await pipe.execute()
return count <= max_req
```
---
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
`DeviceConnectionManager` è un **singleton in-memory**:
```python
class DeviceConnectionManager:
def __init__(self):
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
```
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
### La soluzione: Redis Pub/Sub + Registry
```
┌──────────────────────────────────────────────────────────────┐
│ Redis │
│ │
│ Hash: ws:connections │
│ user_123 → instance_A │
│ user_456 → instance_B │
│ │
│ Pub/Sub channels: │
│ tool_call:{user_id} → tool call payloads │
│ tool_result:{call_id} → tool result payloads │
│ stream:{user_id} → text_chunk streaming │
└──────────────────────────────────────────────────────────────┘
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
┌───────────────────────┐ ┌───────────────────────┐
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
│ tool_call:user_123│ │ → user_123 è su A │
│ │ │ │
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
│ da Redis channel │ │ tool_call:user_123 │
│ │ │ {id, action, ...} │
│ 3. Invia al device │ │ │
│ via WS │ │ 4. SUBSCRIBE │
│ │ │ tool_result:{id} │
│ 4. Device risponde │ │ │
│ tool_result │──────────│► 5. Riceve risultato │
│ │ │ │
│ 5. PUBLISH │ │ │
│ tool_result:{id} │ │ │
└───────────────────────┘ └───────────────────────┘
```
### Implementazione: `RedisDeviceManager`
```python
# chat-service/app/core/device_manager.py
import asyncio
import json
import os
import redis.asyncio as aioredis
from dataclasses import dataclass, field
from fastapi import WebSocket
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
@dataclass
class LocalConnection:
ws: WebSocket
device_id: str
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
class RedisDeviceManager:
"""Device manager backed by Redis for cross-instance communication."""
def __init__(self, redis_url: str = "redis://redis:6379"):
self._redis = aioredis.from_url(redis_url)
self._pubsub = self._redis.pubsub()
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
async def start(self):
"""Avvia il listener Redis per tool_call in arrivo."""
asyncio.create_task(self._listen_tool_calls())
# ── Registrazione ──
async def register(self, user_id: str, device_id: str, ws: WebSocket):
# Registra localmente
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
# Registra in Redis quale istanza ha la connessione
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
# Sottoscrivi ai tool_call per questo utente
await self._pubsub.subscribe(f"tool_call:{user_id}")
async def unregister(self, user_id: str):
conn = self._local.pop(user_id, None)
if conn:
for fut in conn.pending_calls.values():
if not fut.done():
fut.cancel()
await self._redis.hdel("ws:connections", user_id)
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
# ── Presenza ──
async def is_online(self, user_id: str) -> bool:
return await self._redis.hexists("ws:connections", user_id)
# ── Tool-call round-trip (cross-instance) ──
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
"""
Invia un tool_call al device dell'utente.
Funziona sia che la WS sia locale che su un'altra istanza.
"""
call_id = payload["id"]
# Caso 1: connessione locale → invio diretto
if user_id in self._local:
conn = self._local[user_id]
loop = asyncio.get_event_loop()
fut: asyncio.Future[dict] = loop.create_future()
conn.pending_calls[call_id] = fut
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
return await asyncio.wait_for(fut, timeout=30.0)
# Caso 2: connessione remota → Redis pub/sub
loop = asyncio.get_event_loop()
fut = loop.create_future()
self._remote_futures[call_id] = fut
# Sottoscrivi al canale di risposta
result_channel = f"tool_result:{call_id}"
await self._pubsub.subscribe(result_channel)
# Pubblica il tool_call
await self._redis.publish(
f"tool_call:{user_id}",
json.dumps(payload),
)
try:
return await asyncio.wait_for(fut, timeout=30.0)
finally:
self._remote_futures.pop(call_id, None)
await self._pubsub.unsubscribe(result_channel)
# ── Risoluzione tool_result (da WS locale) ──
def resolve_local(self, user_id: str, call_id: str, result: dict):
conn = self._local.get(user_id)
if conn:
fut = conn.pending_calls.pop(call_id, None)
if fut and not fut.done():
fut.set_result(result)
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
"""Chiamato quando il device locale invia un tool_result."""
self.resolve_local(user_id, call_id, result)
# Pubblica anche su Redis per l'istanza remota che aspetta
await self._redis.publish(
f"tool_result:{call_id}",
json.dumps(result),
)
# ── Listener Redis ──
async def _listen_tool_calls(self):
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
async for message in self._pubsub.listen():
if message["type"] != "message":
continue
channel = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
data = json.loads(message["data"])
if channel.startswith("tool_call:"):
# Un'altra istanza vuole che inviamo un tool_call al nostro device
user_id = channel.split(":", 1)[1]
conn = self._local.get(user_id)
if conn:
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
elif channel.startswith("tool_result:"):
# Risposta a un tool_call che abbiamo inviato tramite Redis
call_id = channel.split(":", 1)[1]
fut = self._remote_futures.pop(call_id, None)
if fut and not fut.done():
fut.set_result(data)
# ── Stream cross-instance ──
async def publish_stream_chunk(self, user_id: str, chunk: dict):
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
```
---
## 4. Struttura Directory Proposta (MVP)
```
adiuva-api/
├── docker-compose.yml # Orchestrazione completa
├── docker-compose.dev.yml # Override per sviluppo locale
├── shared/ # Codice condiviso (montato come volume)
│ ├── auth.py # JWT verification (chiave pubblica)
│ ├── schemas.py # Pydantic schemas condivisi
│ ├── middleware/
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
│ │ └── sanitizer.py
│ └── models/
│ └── base.py # SQLAlchemy base condivisa
├── auth-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # users, refresh_tokens
│ ├── routes/
│ │ └── auth.py
│ └── services/
│ ├── jwt_service.py # RS256 signing
│ └── user_service.py
├── chat-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # memory_*
│ ├── routes/
│ │ ├── device_ws.py # WS connection owner
│ │ └── chat.py # REST fallback
│ ├── core/
│ │ ├── device_manager.py # RedisDeviceManager
│ │ ├── deep_agent.py # Home + floating chat
│ │ ├── memory_middleware.py
│ │ ├── ws_context.py
│ │ ├── output_formatter.py
│ │ └── llm.py
│ └── agents/ # Tool definitions (used by deep_agent)
│ ├── task_agent.py
│ ├── project_agent.py
│ ├── note_agent.py
│ └── timeline_agent.py
├── agent-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
│ ├── routes/
│ │ ├── agents.py # catalog, can-create, trigger
│ │ └── agent_setup.py # journey start/message
│ ├── core/
│ │ ├── agent_runner.py # Batch classify → process
│ │ ├── agent_registry.py
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
│ │ └── llm.py
│ └── agents/
│ ├── task_agent.py # Tool definitions (batch context)
│ ├── project_agent.py
│ ├── note_agent.py
│ ├── timeline_agent.py
│ └── filesystem_agent.py
├── billing-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # subscriptions
│ ├── routes/
│ │ └── billing.py
│ └── services/
│ ├── stripe_service.py
│ └── tier_manager.py
└── infra/
├── traefik/
│ └── traefik.yml
├── keys/
│ ├── jwt_private.pem # Solo auth-service
│ └── jwt_public.pem # Tutti i servizi
└── alembic/ # Migrazioni condivise o per-servizio
```
---
## 5. Docker Compose — Configurazione MVP
```yaml
# docker-compose.yml
services:
# ══════════════════════════════════════════════════════════
# API Gateway
# ══════════════════════════════════════════════════════════
traefik:
image: traefik:v3.2
command:
- "--api.insecure=true"
- "--providers.docker=true"
- "--providers.docker.exposedbydefault=false"
- "--entrypoints.web.address=:80"
- "--entrypoints.websecure.address=:443"
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
ports:
- "80:80"
- "443:443"
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro
- ./infra/certs:/certs:ro
restart: unless-stopped
# ══════════════════════════════════════════════════════════
# Auth Service (2 repliche)
# ══════════════════════════════════════════════════════════
auth-service:
build: ./auth-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
SERVICE_NAME: auth
secrets:
- jwt_private_key
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
- "traefik.http.services.auth.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Chat Service — Real-time WS + Chat (scalabile)
# ══════════════════════════════════════════════════════════
chat-service:
build: ./chat-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: chat
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
# REST chat endpoint
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
- "traefik.http.services.chat.loadbalancer.server.port=8000"
# WebSocket route con sticky session
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
- "traefik.http.routers.ws.service=chat-ws"
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Agent Service — Batch processing (scalabile indipendentemente)
# ══════════════════════════════════════════════════════════
agent-service:
build: ./agent-service
deploy:
replicas: 2
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: agent
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
- "traefik.http.services.agents.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Billing Service (1 replica)
# ══════════════════════════════════════════════════════════
billing-service:
build: ./billing-service
deploy:
replicas: 1
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: billing
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
- "traefik.http.services.billing.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════
# Infrastruttura
# ══════════════════════════════════════════════════════════
db:
image: pgvector/pgvector:pg16
environment:
POSTGRES_USER: postgres
POSTGRES_PASSWORD: postgres
POSTGRES_DB: adiuva
volumes:
- postgres_data:/var/lib/postgresql/data
healthcheck:
test: ["CMD-SHELL", "pg_isready -U postgres"]
interval: 5s
timeout: 5s
retries: 5
restart: unless-stopped
redis:
image: redis:7-alpine
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
volumes:
- redis_data:/data
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 5s
timeout: 3s
retries: 5
restart: unless-stopped
qdrant:
image: qdrant/qdrant:latest
volumes:
- qdrant_data:/qdrant/storage
restart: unless-stopped
secrets:
jwt_private_key:
file: ./infra/keys/jwt_private.pem
jwt_public_key:
file: ./infra/keys/jwt_public.pem
volumes:
postgres_data:
redis_data:
qdrant_data:
```
---
## 6. Configurazione Cloudflare + VPS
### 6.1 DNS
```
api.tuodominio.com → A record → IP del VPS
→ Proxy: ON (orange cloud)
```
### 6.2 Cloudflare Settings
| Setting | Valore | Motivo |
|---------|--------|--------|
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
| Under Attack Mode | Off (attivare se necessario) | |
### 6.3 TLS sul VPS
Due opzioni:
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
```yaml
# traefik.yml — con Cloudflare Origin Certificate
entryPoints:
websecure:
address: ":443"
tls:
certificates:
- certFile: /certs/origin.pem
keyFile: /certs/origin-key.pem
```
### 6.4 Rete VPS
```bash
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
# https://www.cloudflare.com/ips/
ufw default deny incoming
ufw allow from 173.245.48.0/20 to any port 443
ufw allow from 103.21.244.0/22 to any port 443
# ... (tutti gli IP range di Cloudflare)
ufw allow ssh
ufw enable
```
---
## 7. Comunicazione Inter-Servizio
### 7.1 Redis Pub/Sub — Event Bus
```
┌──────────┐ tier_changed:user_123 ┌──────────┐
│ Billing │ ────────────────────────► │ Auth │
│ Service │ │ Service │
└──────────┘ └──────────┘
┌──────────┐ tool_call:user_123 ┌──────────┐
│ Agent │ ────────────────────────► │ Chat │
│ Service │ │ Service │
│ (batch) │ ◄────────────────────────│ (ha WS) │
└──────────┘ tool_result:{call_id} └──────────┘
```
### 7.2 Health Checks e Service Discovery
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
- **Redis pub/sub** (tool-call cross-instance, tier events)
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
- **PostgreSQL** (dati persistenti condivisi)
---
## 8. Piano di Migrazione Incrementale (MVP)
### Fase 1 — Preparazione (nel monolite attuale)
1. Aggiungere Redis al `docker-compose.yml` attuale
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
4. Estrarre `shared/` con auth verification, schemas, middleware
### Fase 2 — Auth Service (primo split)
1. Estrarre `auth.py` routes + models in `auth-service/`
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
4. Il monolite continua a servire tutto il resto
### Fase 3 — Billing Service
1. Estrarre billing routes, Stripe service, tier manager
2. Configurare Redis pub/sub per `tier_changed` events
3. Routare via Traefik
### Fase 4 — Split Chat + Agent (il più delicato)
1. Il monolite residuo contiene WS + chat + agents
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
### Fase 5 — Scaling test
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
3. Monitoring (Prometheus + Grafana) per ogni servizio
---
## 9. Monitoraggio e Logging
```yaml
# Aggiungere al docker-compose.yml
prometheus:
image: prom/prometheus:latest
volumes:
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
restart: unless-stopped
grafana:
image: grafana/grafana:latest
ports:
- "3000:3000"
volumes:
- grafana_data:/var/lib/grafana
restart: unless-stopped
loki:
image: grafana/loki:latest
restart: unless-stopped
```
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
---
## 10. Sizing VPS Minimo Consigliato (MVP)
| Componente | CPU | RAM | Note |
|---|---|---|---|
| Traefik | 0.25 | 128MB | |
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
| Billing Service | 0.25 | 128MB | |
| PostgreSQL | 1.0 | 1GB | |
| Redis | 0.25 | 256MB | |
| Qdrant | 0.5 | 512MB | |
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
---
## Riepilogo Architettura MVP
| Servizio | Repliche | Proprietario di |
|---|---|---|
| **Traefik** | 1 | Routing, TLS, sticky sessions |
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
| **Chat Service** | 2N | WebSocket, home/floating chat, streaming |
| **Agent Service** | 2N | Batch processing, directory scan, agent setup |
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
| Decisione | Scelta | Motivazione |
|---|---|---|
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
| TLS | Cloudflare Origin Certificate | Zero maintenance |
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
| Storage / Plugin | Post-MVP | Non critici per il lancio |

View File

@@ -32,6 +32,8 @@ google-auth-oauthlib>=1.2.0
google-auth-httplib2>=0.2.0
msal>=1.28.0
cryptography>=42.0.0
redis>=5.0.0
langfuse>=3.0.0
langfuse>=2.0.0
beautifulsoup4>=4.12.0
lxml>=5.0.0
PyYAML>=6.0.0
ruff>=0.8.0

View File

@@ -1,19 +0,0 @@
# ── Auth Service ──────────────────────────────────────────────────────────────
# This file contains env vars specific to the Auth Service.
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
# or from docker-compose environment.
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
# Generate keypair:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# openssl rsa -in private.pem -pubout -out public.pem
#
# Paste PEM content with literal \n for newlines:
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
JWT_PRIVATE_KEY=
# PUBLIC KEY — used to VERIFY JWTs.
JWT_PUBLIC_KEY=

View File

@@ -1,36 +0,0 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
# Install shared + service deps in one layer
COPY services/auth/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Copy shared module (available to all services)
COPY shared/ shared/
# Copy service source
COPY services/auth/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "30"]

View File

@@ -1,16 +0,0 @@
# Auth Service
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
## Tables owned
- `users`
- `refresh_tokens`
- `subscriptions` (read; Billing Service writes)
## Endpoints
- `POST /auth/register`
- `POST /auth/login`
- `POST /auth/refresh`
- `GET /auth/me`
- `PUT /auth/me`
- `GET /auth/verify` (ForwardAuth for Traefik)

View File

@@ -1,34 +0,0 @@
"""Auth Service — local configuration.
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
These are NOT in shared/config.py to prevent other services from accessing them.
"""
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class AuthSettings(BaseSettings):
# RS256 private key (PEM format). Used to SIGN JWTs.
# Only the Auth Service has this. Generate with:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# Then set the env var (newlines as \n):
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
JWT_PRIVATE_KEY: str = ""
# RS256 public key (PEM format). Used to VERIFY JWTs.
# Derived from the private key:
# openssl rsa -in private.pem -pubout -out public.pem
JWT_PUBLIC_KEY: str = ""
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
@classmethod
def _expand_pem_newlines(cls, v: str) -> str:
if isinstance(v, str) and r"\n" in v:
return v.replace(r"\n", "\n")
return v
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
auth_settings = AuthSettings()

View File

@@ -1,69 +0,0 @@
"""Auth dependencies — JWT validation for the Auth Service.
This is the canonical get_current_user used by protected endpoints
within the Auth Service itself (/me, /me PUT).
"""
from __future__ import annotations
from fastapi import Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from jose import JWTError, jwt
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.config import settings
from shared.db import get_session
from shared.models import Subscription, User
from shared.schemas import UserProfile
from app.config import auth_settings
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
async def get_current_user(
token: str = Depends(oauth2_scheme),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user.
The JWT is used for identity and expiry. Tier is fetched live from the
subscriptions table so upgrades/downgrades take effect immediately.
"""
credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Could not validate credentials",
headers={"WWW-Authenticate": "Bearer"},
)
try:
payload = jwt.decode(
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
raise credentials_exc
except JWTError:
raise credentials_exc
# Live tier lookup
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname
user_result = await db.execute(
select(User.name, User.surname).where(User.id == user_id)
)
user_row = user_result.one_or_none()
return UserProfile(
id=user_id,
email=email,
name=user_row.name if user_row else None,
surname=user_row.surname if user_row else None,
tier=tier,
) # type: ignore[arg-type]

View File

@@ -1,62 +0,0 @@
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
Standalone FastAPI service extracted from the adiuva-api monolith.
Owns: users, refresh_tokens, subscriptions (read).
"""
import sys
from contextlib import asynccontextmanager
from pathlib import Path
# Ensure the repo root is on sys.path so "shared" is importable.
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
# In local dev, we need to add the repo root (two levels up from this file).
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from shared.config import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
from shared.db import engine
await engine.dispose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Auth Service",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from app.routes import router
from app.verify import router as verify_router
app.include_router(router, prefix="/api/v1")
app.include_router(verify_router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "auth", "version": app.version}
return app
app = create_app()

View File

@@ -1,249 +0,0 @@
"""Auth routes: register, login, refresh, me.
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
"""
from __future__ import annotations
import hashlib
import time
import uuid
from datetime import datetime, timedelta, timezone
import bcrypt
from cryptography.fernet import Fernet
from fastapi import APIRouter, Depends, HTTPException, status
from jose import jwt
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.config import settings
from shared.db import get_session
from shared.models import RefreshToken, Subscription, User
from shared.schemas import AuthTokens, UserProfile
from app.config import auth_settings
from app.deps import get_current_user
router = APIRouter(prefix="/auth", tags=["auth"])
# ── Internal helpers ─────────────────────────────────────────────────
def _hash_password(password: str) -> str:
return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
def _verify_password(password: str, hashed: str) -> bool:
return bcrypt.checkpw(password.encode(), hashed.encode())
def _hash_token(plain_token: str) -> str:
"""SHA-256 of the plain refresh token string."""
return hashlib.sha256(plain_token.encode()).hexdigest()
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"""Return (RS256-signed JWT, expires_at_ms)."""
now = int(time.time())
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
payload = {
"sub": user_id,
"email": email,
"tier": tier,
"exp": exp,
"iat": now,
}
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
return token, exp * 1000 # ms for client
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
"""Fetch authoritative tier from subscriptions table."""
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
return result.scalar_one_or_none() or default_tier
# ── Request bodies ────────────────────────────────────────────────────
class _RegisterRequest(BaseModel):
email: str
password: str
name: str | None = None
surname: str | None = None
class _LoginRequest(BaseModel):
email: str
password: str
class _RefreshRequest(BaseModel):
refresh_token: str
class _UpdateProfileRequest(BaseModel):
name: str | None = None
surname: str | None = None
# ── Routes ────────────────────────────────────────────────────────────
@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED)
async def register(
body: _RegisterRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Create a new account and return JWT tokens."""
existing = await db.execute(select(User).where(User.email == body.email))
if existing.scalar_one_or_none() is not None:
raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered")
user = User(
id=str(uuid.uuid4()),
email=body.email,
name=body.name,
surname=body.surname,
password_hash=_hash_password(body.password),
tier="free",
encryption_key=Fernet.generate_key().decode(),
)
db.add(user)
await db.flush()
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/login", response_model=AuthTokens)
async def login(
body: _LoginRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Validate credentials and return JWT tokens."""
result = await db.execute(select(User).where(User.email == body.email))
user = result.scalar_one_or_none()
if user is None or not _verify_password(body.password, user.password_hash):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
# Fetch live tier for the JWT claim
tier = await _get_live_tier(db, user.id)
plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
)
rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=expires_at,
)
db.add(rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.post("/refresh", response_model=AuthTokens)
async def refresh(
body: _RefreshRequest,
db: AsyncSession = Depends(get_session),
) -> AuthTokens:
"""Rotate a refresh token and return a new token pair."""
token_hash = _hash_token(body.refresh_token)
result = await db.execute(
select(RefreshToken).where(RefreshToken.token_hash == token_hash)
)
rt = result.scalar_one_or_none()
now = datetime.now(timezone.utc)
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
await db.delete(rt)
user_result = await db.execute(select(User).where(User.id == rt.user_id))
user = user_result.scalar_one_or_none()
if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
# Fetch live tier for the new JWT
tier = await _get_live_tier(db, user.id)
plain_token = str(uuid.uuid4())
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken(
user_id=user.id,
token_hash=_hash_token(plain_token),
expires_at=new_expires,
)
db.add(new_rt)
await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
return AuthTokens(
access_token=access_token,
refresh_token=plain_token,
expires_at=expires_at_ms,
)
@router.get("/me", response_model=UserProfile)
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
"""Return the profile for the authenticated user."""
return current_user
@router.put("/me", response_model=UserProfile)
async def update_profile(
body: _UpdateProfileRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> UserProfile:
"""Update the authenticated user's name and surname."""
result = await db.execute(select(User).where(User.id == current_user.id))
user = result.scalar_one()
if body.name is not None:
user.name = body.name
if body.surname is not None:
user.surname = body.surname
await db.commit()
await db.refresh(user)
return UserProfile(
id=user.id,
email=user.email,
name=user.name,
surname=user.surname,
tier=current_user.tier,
)

View File

@@ -1,66 +0,0 @@
"""ForwardAuth verification endpoint for Traefik.
Traefik calls GET /api/v1/auth/verify on every request to a protected
service. This endpoint validates the JWT from the Authorization header
and returns identity headers that Traefik injects into downstream requests.
Downstream services NEVER validate JWTs themselves — they trust the
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
"""
from __future__ import annotations
from fastapi import APIRouter, Request, Response
from fastapi import status as http_status
from jose import JWTError, jwt
from sqlalchemy import select
from shared.config import settings
from shared.db import async_session
from shared.models import Subscription
from app.config import auth_settings
router = APIRouter(tags=["auth"])
@router.get("/auth/verify")
async def verify(request: Request) -> Response:
"""Validate JWT and return identity headers for Traefik ForwardAuth.
Returns 200 with X-User-* headers on success, 401 on failure.
Traefik copies response headers to the downstream request.
"""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
token = auth_header[7:] # strip "Bearer "
try:
payload = jwt.decode(
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
except JWTError:
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
# Live tier lookup from subscriptions table
async with async_session() as db:
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
return Response(
status_code=http_status.HTTP_200_OK,
headers={
"X-User-Id": user_id,
"X-User-Email": email,
"X-User-Tier": tier,
},
)

View File

@@ -1,11 +0,0 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
bcrypt>=4.2.0
cryptography>=42.0.0
python-dotenv>=1.0.0

View File

@@ -1,23 +0,0 @@
# Batch Agent Service
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
## Tables owned
- `local_agent_configs`
- `cloud_agent_configs`
- `agent_run_logs`
## Endpoints
- `GET /agents/catalog`
- `POST /agents/can-create`
- `POST /agents/trigger`
- `GET /agents/{id}/history`
## Redis channels
- Subscribe: `batch:request:{user_id}`
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
- BRPOP: `tool:result:{call_id}` (30s timeout)
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
## TODO
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.

View File

@@ -1,15 +0,0 @@
# Billing Service
Owns: Stripe integration, tier management, subscription CRUD.
## Tables owned (write)
- `subscriptions`
## Endpoints
- `POST /billing/checkout`
- `POST /billing/webhook` (Stripe, no JWT auth)
- `GET /billing/subscription`
- `DELETE /billing/subscription`
## Redis channels
- Publish: `tier:changed:{user_id}` on tier change

View File

@@ -1,36 +0,0 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/chat/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/chat/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Chat service is CPU-bound (LLM calls) — use multiple workers
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "120"]

View File

@@ -1,21 +0,0 @@
# Chat Service
Owns: deep_agent (home + floating chat), memory middleware, domain agents
(task, note, project, timeline), LLM orchestration.
## Tables owned
- `memory_core`
- `memory_associative`
- `memory_episodic`
- `memory_proactive`
## Tables read (cross-service)
- `users` (for encryption_key — memory decryption)
## Endpoints
- `POST /chat` (REST fallback)
## Redis channels
- Subscribe: `chat:request:{user_id}`
- Publish: `ws:out:{user_id}` (stream frames + tool calls)
- BRPOP: `tool:result:{call_id}` (30s timeout)

View File

@@ -1 +0,0 @@
"""Chat Service domain agents."""

View File

@@ -1,142 +0,0 @@
"""Note agent — Markdown note management (list, get, create, update, delete).
Adapted for Chat Service: import from app.ws_context and app.llm.
"""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.llm import embed
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
NOTE_SYSTEM_PROMPT = (
"You are a note-taking assistant. You help users create, retrieve, update,\n"
"and delete Markdown notes in their workspace.\n\n"
"Rules:\n"
" - content is always Markdown; preserve formatting when updating\n"
" - project_id is optional; link a note to a project when mentioned\n"
" - When updating, call get_note first if you need to read existing content\n"
" before appending or replacing sections\n"
" - list_notes without project_id returns all notes; scope with project_id\n"
" when the user is working within a specific project\n"
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
" - Do not fabricate note content — reflect what the user provides or what\n"
" is already in the note (retrieved via get_note)."
)
@tool
async def list_notes(project_id: str = "") -> str:
"""List notes, optionally scoped to a project by project_id."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="notes",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No notes found."
lines = [f"- {r['title']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} note(s):\n" + "\n".join(lines)
@tool
async def get_note(note_id: str) -> str:
"""Fetch a single note by its UUID to read its full Markdown content."""
result = await execute_on_client(action="get", table="notes", data={"id": note_id})
row = result.get("row")
if not row:
return f"Note {note_id} not found."
return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}"
@tool
async def create_note(
title: str,
content: str,
project_id: str = "",
) -> str:
"""Create a new note.
title: note heading (required)
content: Markdown body text (required)
project_id: optional UUID linking this note to a project
"""
result = await execute_on_client(
action="insert",
table="notes",
data={
"title": title,
"content": content,
"projectId": project_id or None,
},
)
row = result["row"]
# Index the note content in the vector store.
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": row["id"], "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note created: '{row['title']}' (id: {row['id']})."
@tool
async def update_note(
note_id: str,
title: str = "",
content: str = "",
) -> str:
"""Update an existing note. Only pass fields that should change.
note_id: UUID of the note (required)
If you need to preserve existing content, call get_note first.
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if content:
updates["content"] = content
result = await execute_on_client(
action="update",
table="notes",
data={"id": note_id, "updates": updates},
)
row = result["row"]
# Re-index if content changed.
if content:
vector = await embed(content)
await execute_on_client(
action="vector_upsert",
data={"id": note_id, "projectId": row.get("projectId"), "content": content},
vector=vector,
)
return f"Note updated: '{row['title']}' (id: {row['id']})."
@tool
async def delete_note(note_id: str) -> str:
"""Delete a note permanently by its UUID."""
await execute_on_client(action="delete", table="notes", data={"id": note_id})
return f"Note {note_id} deleted."
NOTE_TOOLS: list[Any] = [
list_notes,
get_note,
create_note,
update_note,
delete_note,
]

View File

@@ -1,146 +0,0 @@
"""Project agent — full lifecycle management (list, get, create, update, archive, delete).
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
"""
from __future__ import annotations
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
PROJECT_SYSTEM_PROMPT = (
"You are a project management assistant. You help users create, find,\n"
"update, and archive projects in their workspace.\n\n"
"Rules:\n"
" - status must be one of: active, archived\n"
" - client_id is optional; link to a client only when explicitly mentioned\n"
" - ai_summary is populated only when the user asks for a project summary;\n"
" derive it from context data — do not fabricate content\n"
" - Use list_projects for scoped queries; list_all_projects only when the\n"
" user wants a complete cross-client view including archived projects\n"
" - get_project requires a project UUID; resolve the ID first by calling\n"
" list_projects if you only have a project name\n"
" - Prefer archiving (update_project status=archived) over deletion;\n"
" only call delete_project when the user explicitly confirms deletion."
)
@tool
async def list_projects(
client_id: str = "",
include_archived: int = 0,
) -> str:
"""List projects, optionally filtered by client_id.
include_archived: 1 to include archived projects, 0 for active only (default).
"""
result = await execute_on_client(
action="select",
table="projects",
filters={
"clientId": client_id or None,
"includeArchived": bool(include_archived),
},
)
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} project(s):\n" + "\n".join(lines)
@tool
async def list_all_projects() -> str:
"""List every project regardless of client or status.
Use only when the user wants a complete cross-client overview.
"""
result = await execute_on_client(action="select", table="projects")
rows = result.get("rows", [])
if not rows:
return "No projects found."
lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows]
return f"All projects ({len(rows)}):\n" + "\n".join(lines)
@tool
async def get_project(project_id: str) -> str:
"""Fetch a single project by its UUID."""
result = await execute_on_client(action="get", table="projects", data={"id": project_id})
row = result.get("row")
if not row:
return f"Project {project_id} not found."
return (
f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, "
f"clientId: {row.get('clientId', 'none')})"
)
@tool
async def create_project(
name: str,
client_id: str = "",
) -> str:
"""Create a new project.
name: human-readable project name (required)
client_id: optional UUID of the owning client
"""
result = await execute_on_client(
action="insert",
table="projects",
data={"name": name, "clientId": client_id or None},
)
row = result["row"]
return f"Project created: '{row['name']}' (id: {row['id']})"
@tool
async def update_project(
project_id: str,
name: str = "",
client_id: str = "",
status: str = "",
ai_summary: str = "",
) -> str:
"""Update a project. Only pass fields that should change.
project_id: UUID of the project (required)
status: active | archived
ai_summary: AI-generated summary text (populate only when explicitly requested)
"""
updates: dict[str, Any] = {}
if name:
updates["name"] = name
if client_id:
updates["clientId"] = client_id
if status:
updates["status"] = status
if ai_summary:
updates["aiSummary"] = ai_summary
result = await execute_on_client(
action="update",
table="projects",
data={"id": project_id, "updates": updates},
)
row = result["row"]
return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_project(project_id: str) -> str:
"""Permanently delete a project and orphan its tasks.
IMPORTANT: prefer update_project(status='archived') unless the user
has explicitly confirmed they want permanent deletion.
"""
await execute_on_client(action="delete", table="projects", data={"id": project_id})
return f"Project {project_id} permanently deleted."
PROJECT_TOOLS: list[Any] = [
list_projects,
list_all_projects,
get_project,
create_project,
update_project,
delete_project,
]

View File

@@ -1,240 +0,0 @@
"""Task agent — full CRUD for tasks and task comments.
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
"""
from __future__ import annotations
from datetime import datetime, timezone
import re
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TASK_SYSTEM_PROMPT = (
"You are a task management assistant for a project workspace.\n"
"You create, update, list, and track tasks and their comments.\n\n"
"Rules:\n"
" - status must be one of: todo, in_progress, done\n"
" - priority must be one of: high, medium, low\n"
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
" - project_id is optional; link to a project when the user mentions one\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
" did not explicitly request; 0 otherwise\n"
" - is_ai_suggested: 1 only when proactively proposing a task the user did not explicitly request; 0 otherwise\n"
" - Use list_tasks_due_today for 'what's due today' queries\n"
" - For update_task, use -1 for integer fields you do not want to change\n"
" - Always confirm the action in plain, user-friendly language."
)
# ── Task tools ────────────────────────────────────────────────────────
@tool
async def list_tasks(
project_id: str = "",
status: str = "",
search: str = "",
order_by: str = "",
) -> str:
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
a search string, or an order_by field name (dueDate|priority|createdAt)."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="tasks",
filters={
"projectId": normalized_project_id or None,
"status": status or None,
"search": search or None,
"orderBy": order_by or None,
},
)
rows = result.get("rows", [])
if not rows:
return "No tasks found matching the given filters."
lines = [
f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, id: {r['id']})"
for r in rows
]
return f"Found {len(rows)} task(s):\n" + "\n".join(lines)
@tool
async def create_task(
title: str,
description: str = "",
status: str = "todo",
priority: str = "medium",
assignees: str = "[]",
due_date: int = 0,
project_id: str = "",
is_ai_suggested: int = 0,
) -> str:
"""Create a new task.
title: task title (required)
description: optional details
status: todo | in_progress | done (default: todo)
priority: high | medium | low (default: medium)
assignees: JSON-encoded array of assignee names, e.g. '["Alice"]'
due_date: Unix timestamp in milliseconds; 0 means no due date
project_id: optional UUID of the parent project
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
"""
result = await execute_on_client(
action="insert",
table="tasks",
data={
"title": title,
"description": description or None,
"status": status,
"priority": priority,
"assignee": assignees,
"dueDate": due_date or None,
"projectId": project_id or None,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return (
f"Task created: '{row['title']}' "
f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']})"
)
@tool
async def update_task(
task_id: str,
title: str = "",
description: str = "",
status: str = "",
priority: str = "",
assignees: str = "",
due_date: int = -1,
project_id: str = "",
) -> str:
"""Update fields on an existing task. Only pass fields you want to change.
task_id: the task's UUID (required)
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if description:
updates["description"] = description
if status:
updates["status"] = status
if priority:
updates["priority"] = priority
if assignees:
updates["assignee"] = assignees
if due_date != -1:
updates["dueDate"] = due_date or None
if project_id:
updates["projectId"] = project_id
result = await execute_on_client(
action="update",
table="tasks",
data={"id": task_id, "updates": updates},
)
row = result["row"]
return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']})"
@tool
async def delete_task(task_id: str) -> str:
"""Delete a task permanently by its UUID."""
await execute_on_client(action="delete", table="tasks", data={"id": task_id})
return f"Task {task_id} deleted."
@tool
async def list_tasks_due_today() -> str:
"""List all tasks whose due date falls on today's date."""
now = datetime.now(tz=timezone.utc)
start_ms = int(datetime(now.year, now.month, now.day, tzinfo=timezone.utc).timestamp() * 1000)
end_ms = start_ms + 86_400_000 - 1 # last ms of today
result = await execute_on_client(
action="select",
table="tasks",
filters={"dueDateFrom": start_ms, "dueDateTo": end_ms},
)
rows = result.get("rows", [])
if not rows:
return "No tasks are due today."
lines = [
f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, id: {r['id']})"
for r in rows
]
return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines)
# ── Task comment tools ────────────────────────────────────────────────
@tool
async def list_task_comments(task_id: str) -> str:
"""List all comments on a task by its UUID."""
result = await execute_on_client(
action="select",
table="taskComments",
filters={"taskId": task_id},
)
rows = result.get("rows", [])
if not rows:
return f"No comments found for task {task_id}."
lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows]
return f"Found {len(rows)} comment(s):\n" + "\n".join(lines)
@tool
async def add_task_comment(task_id: str, author: str, content: str) -> str:
"""Add a comment to a task.
task_id: UUID of the task to comment on
author: name or ID of the comment author
content: comment text
"""
result = await execute_on_client(
action="insert",
table="taskComments",
data={"taskId": task_id, "author": author, "content": content},
)
row = result.get("row", {})
row_author = row.get("author", author)
row_task_id = row.get("taskId") or row.get("task_id") or task_id
row_comment_id = row.get("id", "unknown")
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
@tool
async def delete_task_comment(comment_id: str) -> str:
"""Delete a task comment by its UUID."""
await execute_on_client(action="delete", table="taskComments", data={"id": comment_id})
return f"Comment {comment_id} deleted."
# ── Agent ─────────────────────────────────────────────────────────────
TASK_TOOLS: list[Any] = [
list_tasks,
create_task,
update_task,
delete_task,
list_tasks_due_today,
list_task_comments,
add_task_comment,
delete_task_comment,
]

View File

@@ -1,117 +0,0 @@
"""Timeline agent — project milestone management (list, create, update, delete).
Adapted for Chat Service: import from app.ws_context instead of app.core.ws_context.
"""
from __future__ import annotations
import re
from typing import Any
from langchain_core.tools import tool
from app.ws_context import execute_on_client
_UUID_RE = re.compile(
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
)
def _is_uuid(value: str) -> bool:
return bool(_UUID_RE.match(value))
TIMELINE_SYSTEM_PROMPT = (
"You are a project timeline assistant. Timelines are milestone dates that\n"
"track progress on a project — they are not calendar events.\n\n"
"Rules:\n"
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
" - For update_timeline, use -1 for integer fields you do not want to change\n"
" - Listing without a project_id returns all timelines across projects\n"
" - Always echo the title and formatted date in your confirmation."
)
@tool
async def list_timelines(project_id: str = "") -> str:
"""List timelines. Provide project_id to scope to a specific project."""
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
result = await execute_on_client(
action="select",
table="timelines",
filters={"projectId": normalized_project_id or None},
)
rows = result.get("rows", [])
if not rows:
return "No timelines found."
lines = [f"- {r['title']} (date: {r['date']}, id: {r['id']})" for r in rows]
return f"Found {len(rows)} timeline(s):\n" + "\n".join(lines)
@tool
async def create_timeline(
project_id: str,
title: str,
date: int,
is_ai_suggested: int = 0,
) -> str:
"""Create a project timeline (milestone).
project_id: REQUIRED UUID of the parent project
title: descriptive name for the milestone
date: Unix timestamp in milliseconds
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
"""
result = await execute_on_client(
action="insert",
table="timelines",
data={
"projectId": project_id,
"title": title,
"date": date,
"isAiSuggested": is_ai_suggested,
},
)
row = result["row"]
return f"Timeline created: '{row['title']}' (id: {row['id']}, date: {row['date']})"
@tool
async def update_timeline(
timeline_id: str,
title: str = "",
date: int = -1,
) -> str:
"""Update a timeline. Only pass fields that should change.
timeline_id: UUID of the timeline (required)
date: -1 means unchanged; any other value sets the new date (ms timestamp)
"""
updates: dict[str, Any] = {}
if title:
updates["title"] = title
if date != -1:
updates["date"] = date
result = await execute_on_client(
action="update",
table="timelines",
data={"id": timeline_id, "updates": updates},
)
row = result["row"]
return f"Timeline updated: '{row['title']}' (id: {row['id']})"
@tool
async def delete_timeline(timeline_id: str) -> str:
"""Delete a timeline permanently by its UUID."""
await execute_on_client(action="delete", table="timelines", data={"id": timeline_id})
return f"Timeline {timeline_id} deleted."
TIMELINE_TOOLS: list[Any] = [
list_timelines,
create_timeline,
update_timeline,
delete_timeline,
]

View File

@@ -1,883 +0,0 @@
"""Single-agent runners for home and floating chat contexts.
Adapted from app/core/deep_agent.py for the Chat Service.
Import paths changed to use local app modules and shared/.
"""
from __future__ import annotations
import json
import logging
import re
from datetime import date
from collections.abc import AsyncGenerator
from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.llm import get_llm
from app.memory_middleware import MemoryMiddleware
from app.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app import tracing
from shared.db import async_session
logger = logging.getLogger(__name__)
FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
_HOME_SINGLE_AGENT_SYSTEM = (
"You are the home assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
"Return markdown and use tags when relevant: <project>[ids]</project>, <task>[ids]</task>, "
"<note>[ids]</note>, <timeline>[ids]</timeline>, <chart>{json}</chart>. "
"When listing tasks or timelines, each id tag must be on its own line with no prefix/suffix text. "
"Never put titles, priorities, or dates on the same line as <task> or <timeline> tags. "
"For questions about upcoming timelines (e.g. 'prossimi eventi'), include only future items in the current month unless the user asks a different range. "
"For upcoming tasks, after tag lines add a short recommendation based on due date and priority."
)
_FLOATING_SINGLE_AGENT_SYSTEM = (
"You are the floating assistant with direct access to all tools: tasks, projects, notes, timelines, and memory tools. "
"Stay focused on the floating scope in context.scope and answer concisely. "
"Return plain text only. Do not output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed id tag wrappers. "
"Always use tools for factual data retrieval before answering. "
"When the user asks to remember, forget, or update what you know about them, use memory tools. "
"If context.context.resolved_project_id exists, use it as project_id for scoped list calls. "
)
_FLOATING_DOMAIN_CLASSIFIER_SYSTEM = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
"Allowed type values: task, timeline, project, node. "
"Allowed section values: task, timeline, note, or null. "
"Rules: infer from user message intent first; do not blindly trust scope.type. "
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
"If project id is unknown but context.resolved_project_id exists, use it as id. "
"If id is unknown, use null. "
"No markdown, no prose, JSON only."
)
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
def _candidate_tokens(message: str) -> list[str]:
tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower())
return [token for token in tokens if len(token) >= 3]
async def _resolve_project_id_from_message(message: str) -> str | None:
"""Resolve likely project UUID from user message using client project list."""
try:
result = await execute_on_client(action="select", table="projects")
except Exception as exc:
logger.warning("deep_agent: project resolve select failed: %s", exc)
return None
rows = result.get("rows", [])
if not isinstance(rows, list) or not rows:
return None
tokens = _candidate_tokens(message)
scored: list[tuple[int, dict[str, Any]]] = []
for row in rows:
if not isinstance(row, dict):
continue
name = str(row.get("name", "")).lower()
score = sum(1 for token in tokens if token in name)
if score > 0:
scored.append((score, row))
if not scored:
return None
scored.sort(key=lambda item: item[0], reverse=True)
top_score = scored[0][0]
top_rows = [row for score, row in scored if score == top_score]
if len(top_rows) != 1:
return None
project_id = top_rows[0].get("id")
return project_id if isinstance(project_id, str) else None
def _needs_project_resolution(message: str) -> bool:
lowered = message.lower()
return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"])
async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]:
prepared = dict(context)
if _needs_project_resolution(message):
resolved_project_id = await _resolve_project_id_from_message(message)
if resolved_project_id:
prepared["resolved_project_id"] = resolved_project_id
logger.info("deep_agent: resolved_project_id=%s", resolved_project_id)
return prepared
def _all_tools() -> list[Any]:
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
debug = context.get("_debug")
if isinstance(debug, dict):
request_id = debug.get("request_id")
if isinstance(request_id, str) and request_id:
return request_id
return None
def _context_for_model(context: dict[str, Any]) -> dict[str, Any]:
sanitized = dict(context)
sanitized.pop("_debug", None)
return sanitized
_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]</\1>")
_TIMELINE_DMY_RE = re.compile(r"(?P<d>\d{2})/(?P<m>\d{2})/(?P<y>\d{4})")
def _is_upcoming_timeline_query(message: str) -> bool:
lowered = message.lower()
has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered
has_timeline_topic = any(
token in lowered
for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden")
)
return has_upcoming and has_timeline_topic
def _timeline_date_in_current_month_or_future(dmy: str) -> bool:
match = _TIMELINE_DMY_RE.search(dmy)
if not match:
return True
try:
parsed = date(
int(match.group("y")),
int(match.group("m")),
int(match.group("d")),
)
except ValueError:
return True
today = date.today()
return parsed >= today and parsed.year == today.year and parsed.month == today.month
def _normalize_tagged_list_lines(text: str, message: str) -> str:
if not text:
return text
upcoming_timeline_only = _is_upcoming_timeline_query(message)
output_lines: list[str] = []
for line in text.splitlines():
matches = list(_TAG_LINE_RE.finditer(line))
if not matches:
output_lines.append(line)
continue
had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)")
if not had_non_tag_text and len(matches) == 1:
tag_text = matches[0].group(0)
if (
upcoming_timeline_only
and "<timeline>" in tag_text
and not _timeline_date_in_current_month_or_future(line)
):
continue
output_lines.append(tag_text)
continue
for match in matches:
tag_text = match.group(0)
if (
upcoming_timeline_only
and "<timeline>" in tag_text
and not _timeline_date_in_current_month_or_future(line)
):
continue
output_lines.append(tag_text)
return "\n".join(output_lines)
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
_FLOATING_EMPTY_FALLBACK = "No results found."
def _strip_floating_markup_fragment(text: str) -> str:
if not text:
return text
cleaned = _GENERIC_TAG_RE.sub("", text)
return _BRACKETED_ID_RE.sub("", cleaned)
def _strip_floating_markup(text: str) -> str:
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
if not text:
return text
cleaned = _strip_floating_markup_fragment(text)
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
return "\n".join(line for line in lines if line)
def _fallback_from_raw_floating_text(raw_text: str) -> str:
fallback = _strip_floating_markup_fragment(raw_text or "")
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
return fallback or _FLOATING_EMPTY_FALLBACK
class _FloatingStreamSanitizer:
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
def __init__(self) -> None:
self._pending = ""
@staticmethod
def _split_safe_boundary(text: str) -> tuple[str, str]:
boundary = len(text)
last_lt = text.rfind("<")
if last_lt != -1 and ">" not in text[last_lt:]:
boundary = min(boundary, last_lt)
last_lb = text.rfind("[")
if last_lb != -1 and "]" not in text[last_lb:]:
boundary = min(boundary, last_lb)
if boundary == len(text):
return text, ""
return text[:boundary], text[boundary:]
def feed(self, chunk: str) -> str:
combined = f"{self._pending}{chunk}"
safe_text, self._pending = self._split_safe_boundary(combined)
return _strip_floating_markup_fragment(safe_text)
def finalize(self) -> str:
tail = re.sub(r"<[^>\n]*$", "", self._pending)
tail = re.sub(r"\[[^\]\n]*$", "", tail)
self._pending = ""
return _strip_floating_markup_fragment(tail)
def _normalize_memory_label(path_or_label: str) -> str:
value = path_or_label.strip()
if value.startswith("/memories/"):
value = value[len("/memories/"):]
value = value.strip("/")
return value
def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]:
@tool
async def memory_list_blocks() -> str:
"""List all core memory blocks currently stored for the user."""
logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id)
async with async_session() as db:
memory = MemoryMiddleware(db)
blocks = await memory.list_core_blocks(user_id)
if not blocks:
return "No memory blocks found."
lines = [f"- {b['label']}: {b['value']}" for b in blocks]
return "Memory blocks:\n" + "\n".join(lines)
@tool
async def memory_get(path_or_label: str) -> str:
"""Get one memory block by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_get trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
value = await memory.get_core_block(user_id, label)
if value is None:
return f"Memory block '{label}' not found."
return f"Memory block '{label}':\n{value}"
@tool
async def memory_create(path_or_label: str, value: str) -> str:
"""Create or overwrite a memory block value by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_create trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, label, value, trace_id=trace_id)
return f"Memory block '{label}' saved."
@tool
async def memory_append(path_or_label: str, content: str) -> str:
"""Append content to a memory block, creating it if missing."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_append trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.append_core(user_id, label, content)
return f"Memory block '{label}' appended."
@tool
async def memory_replace(path_or_label: str, old_string: str, new_string: str) -> str:
"""Replace one exact string in a memory block."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_replace trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
changed = await memory.replace_core(user_id, label, old_string, new_string)
if not changed:
return f"No replacement made in '{label}' (old string not found)."
return f"Memory block '{label}' updated."
@tool
async def memory_delete(path_or_label: str) -> str:
"""Delete a memory block by label or /memories/<label> path."""
label = _normalize_memory_label(path_or_label)
logger.info("deep_agent: memory_delete trace=%s user=%s label=%s", trace_id or "-", user_id, label)
if not label:
return "Invalid memory label."
async with async_session() as db:
memory = MemoryMiddleware(db)
deleted = await memory.delete_core(user_id, label)
if not deleted:
return f"Memory block '{label}' not found."
return f"Memory block '{label}' deleted."
@tool
async def archival_memory_insert(content: str) -> str:
"""Insert a long-term archival memory entry."""
logger.info("deep_agent: archival_memory_insert trace=%s user=%s", trace_id or "-", user_id)
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.insert_archival(user_id, content, source="assistant")
return "Archival memory saved."
@tool
async def archival_memory_search(query: str, top_k: int = 5) -> str:
"""Search long-term archival memory by semantic fallback (keyword currently)."""
logger.info("deep_agent: archival_memory_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_archival(user_id, query, top_k=top_k)
if not results:
return "No archival memory results found."
lines = [f"- {item}" for item in results]
return "Archival memory results:\n" + "\n".join(lines)
@tool
async def conversation_search(query: str, top_k: int = 5) -> str:
"""Search recall memory from prior episodic conversation summaries."""
logger.info("deep_agent: conversation_search trace=%s user=%s query=%s", trace_id or "-", user_id, query[:80])
async with async_session() as db:
memory = MemoryMiddleware(db)
results = await memory.search_recall(user_id, query, top_k=top_k)
if not results:
return "No recall memory results found."
lines = [f"- {item}" for item in results]
return "Recall memory results:\n" + "\n".join(lines)
return [
memory_list_blocks,
memory_get,
memory_create,
memory_append,
memory_replace,
memory_delete,
archival_memory_insert,
archival_memory_search,
conversation_search,
]
def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timeline"
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
return "task"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "note"
return None
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
type_raw = str(payload.get("type") or "").strip().lower()
domain_type: FloatingDomainType = "task"
if type_raw in {"task", "timeline", "project", "node"}:
domain_type = type_raw
id_value = payload.get("id")
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
if domain_type == "project" and not domain_id:
domain_id = fallback_id
section_raw = payload.get("section")
section: FloatingDomainSection | None = None
if isinstance(section_raw, str):
section_candidate = section_raw.strip().lower()
if section_candidate in {"task", "timeline", "note"}:
section = section_candidate
if domain_type != "project":
section = None
return {
"type": domain_type,
"id": domain_id,
"section": section,
}
def _parse_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", raw, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
section = _detect_domain_section(message)
scope = context.get("scope") if isinstance(context, dict) else None
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
scope_id = scope.get("id")
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
if scope_type in {"task", "tasks"}:
return {"type": "task", "id": scope_id_value, "section": None}
if scope_type in {"project", "projects"}:
project_scope_id = scope_id_value or project_id
return {
"type": "project",
"id": project_scope_id,
"section": section,
}
if scope_type in {"note", "notes"}:
return {
"type": "node",
"id": scope_id_value,
"section": None,
}
if scope_type in {"timeline", "timelines"}:
return {"type": "timeline", "id": scope_id_value, "section": None}
lowered = message.lower()
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
return {
"type": "project",
"id": project_id,
"section": section,
}
if section == "timeline":
return {"type": "timeline", "id": None, "section": None}
if section == "note":
return {"type": "node", "id": None, "section": None}
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(
message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None,
) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
classifier_context = {
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
"resolved_project_id": project_id,
}
try:
classifier_prompt = _get_system_prompt(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_SYSTEM,
)
callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
response = await llm.ainvoke(
[
SystemMessage(content=classifier_prompt),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
)
),
]
)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
logger.info(
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
domain.get("type"),
domain.get("id"),
domain.get("section"),
)
return domain
logger.warning("deep_agent: floating_domain classifier returned non-json output")
except Exception as exc:
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
return _infer_floating_domain_rule_based(message, context)
def _get_system_prompt(langfuse_name: str, fallback: str) -> str:
"""Fetch a managed prompt from Langfuse, falling back to the hardcoded string."""
managed = tracing.get_prompt(langfuse_name, fallback=None)
return managed if managed is not None else fallback
def _build_callbacks(langfuse_handler: Any | None) -> list[Any] | None:
"""Return a callbacks list if a Langfuse handler is available."""
if langfuse_handler is None:
return None
return [langfuse_handler]
async def _run_single_agent(
*,
user_id: str,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
langfuse_handler: Any | None = None,
) -> str:
trace_id = _trace_id_from_context(context)
callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
]
tool_calls_count = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
final_text = _as_text(response.content)
logger.info(
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d",
trace_id or "-",
user_id,
tool_calls_count,
len(final_text),
)
return final_text
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
final_text = _as_text(final.content)
logger.info(
"deep_agent: run_single_agent_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-",
user_id,
tool_calls_count,
len(final_text),
)
return final_text
finally:
clear_tool_result_collector()
async def _run_single_agent_stream(
*,
user_id: str,
system_prompt: str,
message: str,
context: dict[str, Any],
max_steps: int = 6,
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
trace_id = _trace_id_from_context(context)
callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(
content=(
f"User message:\n{message}\n\n"
f"Context:\n{json.dumps({'context': model_context}, ensure_ascii=True)[:3500]}"
)
),
]
tool_calls_count = 0
streamed_chars = 0
collected: list[dict[str, Any]] = []
set_tool_result_collector(collected)
try:
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
emitted_any = False
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
emitted_any = True
yield "token", token
if not emitted_any:
fallback_text = _as_text(response.content)
if fallback_text:
streamed_chars += len(fallback_text)
yield "token", fallback_text
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d",
trace_id or "-",
user_id,
tool_calls_count,
streamed_chars,
)
return
tool_map = {tool_def.name: tool_def for tool_def in tools}
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"deep_agent: AI->Tool tool_call_id=%s tool=%s args=%s",
call_id,
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"deep_agent: Tool->AI tool_call_id=%s tool=%s output=%s",
call_id,
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
async for chunk in llm.astream(messages):
token = _as_text(getattr(chunk, "content", ""))
if token:
streamed_chars += len(token)
yield "token", token
logger.info(
"deep_agent: run_single_agent_stream_end trace=%s user=%s tool_calls=%d response_chars=%d fallback=1",
trace_id or "-",
user_id,
tool_calls_count,
streamed_chars,
)
finally:
clear_tool_result_collector()
async def run_home(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> str:
prepared_context = await _prepare_context(message, context)
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_handler=langfuse_handler,
)
return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_handler=langfuse_handler,
)
sanitized = _strip_floating_markup(response)
if not sanitized and response:
sanitized = _fallback_from_raw_floating_text(response)
return sanitized, domain
async def run_home_stream(
user_id: str,
message: str,
context: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
text_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_handler=langfuse_handler,
):
event_type, data = event
if event_type != "token":
yield event
continue
text_chunks.append(str(data or ""))
normalized = _normalize_tagged_list_lines("".join(text_chunks), message)
if normalized:
yield "token", normalized
async def run_floating_stream(
user_id: str,
message: str,
context: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
yield "floating_domain", domain
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
raw_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_handler=langfuse_handler,
):
event_type, data = event
if event_type != "token":
yield event
continue
raw_chunk = str(data or "")
raw_chunks.append(raw_chunk)
sanitized_chunk = sanitizer.feed(raw_chunk)
if sanitized_chunk:
emitted_sanitized = True
yield "token", sanitized_chunk
tail = sanitizer.finalize()
if tail:
emitted_sanitized = True
yield "token", tail
if not emitted_sanitized and raw_chunks:
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
async def update_core_memory(user_id: str, key: str, value: str) -> None:
"""Compatibility helper kept for callers that expect explicit memory update API."""
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.update_core(user_id, key, value)

View File

@@ -1,72 +0,0 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Adapted from app/core/llm.py for the Chat Service.
Uses shared.config.settings instead of app.config.settings.
"""
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from shared.config import settings
litellm.drop_params = True
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github_copilot/"):
return None
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
callbacks: list | None = None,
) -> ChatOpenAI | ChatLiteLLM:
model = model or settings.LLM_MODEL
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
callbacks=callbacks,
)
async def embed(text: str) -> list[float]:
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

View File

@@ -1,87 +0,0 @@
"""Chat Service — LLM orchestration, domain agents, memory.
Consumes chat requests from Redis, executes deep_agent (home/floating),
streams responses back via Redis pub/sub to WS Gateway.
Owns: memory_core, memory_associative, memory_episodic, memory_proactive tables.
"""
import sys
from contextlib import asynccontextmanager
import logging
from pathlib import Path
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from shared.config import settings
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialise Langfuse tracing (no-op if keys are missing)
from app.tracing import init_langfuse
init_langfuse()
# Start Redis consumer in background
from app.redis_consumer import start_consumer
consumer_task = start_consumer()
yield
consumer_task.cancel()
from app.tracing import shutdown as shutdown_langfuse
shutdown_langfuse()
from shared.db import engine
await engine.dispose()
from shared.redis import redis_client
await redis_client.aclose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Chat Service",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from app.routes import router
app.include_router(router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "chat", "version": app.version}
return app
app = create_app()

View File

@@ -1,295 +0,0 @@
"""Memory Middleware — adapted for Chat Service.
Uses shared.models instead of app.models. Otherwise identical to the
monolith's app/core/memory_middleware.py.
"""
from __future__ import annotations
import logging
import uuid
from typing import Any
from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.models import (
MemoryAssociative,
MemoryCore,
MemoryEpisodic,
MemoryProactive,
User,
)
logger = logging.getLogger(__name__)
_ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware:
def __init__(self, db: AsyncSession) -> None:
self._db = db
async def enrich_context(
self,
user_id: str,
message: str,
trace_id: str | None = None,
session_id: str | None = None,
) -> dict[str, Any]:
fernet = await self._get_fernet(user_id)
if fernet is None:
return {}
core = await self._load_core(user_id, fernet)
associative = await self._load_associative(user_id, message, fernet)
episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet)
logger.info(
"memory: enrich_context trace=%s user=%s core=%d assoc=%d episodic=%d proactive=%d",
trace_id or "-", user_id, len(core), len(associative), len(episodic), len(proactive),
)
return {
"core_memory": core,
"associative_memory": associative,
"episodic_memory": episodic,
"proactive_hints": proactive,
}
async def store_episode(
self, user_id: str, session_id: str, message: str, response: str,
trace_id: str | None = None,
) -> None:
fernet = await self._get_fernet(user_id)
if fernet is None:
return
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
encrypted = _encrypt(fernet, summary)
row = MemoryEpisodic(
id=str(uuid.uuid4()),
user_id=user_id,
summary_encrypted=encrypted,
session_id=session_id,
)
self._db.add(row)
try:
await self._db.commit()
except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, value)
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == key)
)
existing = result.scalar_one_or_none()
if existing is not None:
existing.value_encrypted = encrypted
else:
self._db.add(MemoryCore(
id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted,
))
try:
await self._db.commit()
except Exception as exc:
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
await self._db.rollback()
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id).order_by(MemoryCore.key.asc())
)
out: list[dict[str, str]] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out.append({"label": row.key, "value": plaintext})
return out
async def get_core_block(self, user_id: str, label: str) -> str | None:
fernet = await self._get_fernet(user_id)
if fernet is None:
return None
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
)
row = result.scalar_one_or_none()
if row is None:
return None
return _safe_decrypt(fernet, row.value_encrypted)
async def delete_core(self, user_id: str, label: str) -> bool:
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
)
row = result.scalar_one_or_none()
if row is None:
return False
await self._db.delete(row)
try:
await self._db.commit()
return True
except Exception as exc:
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
await self._db.rollback()
return False
async def append_core(self, user_id: str, label: str, content: str) -> None:
current = await self.get_core_block(user_id, label)
if current is None:
await self.update_core(user_id, label, content)
return
await self.update_core(user_id, label, f"{current}\n{content}")
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
current = await self.get_core_block(user_id, label)
if current is None or old not in current:
return False
await self.update_core(user_id, label, current.replace(old, new, 1))
return True
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
fernet = await self._get_fernet(user_id)
if fernet is None:
return
encrypted = _encrypt(fernet, content)
row = MemoryAssociative(
id=str(uuid.uuid4()), user_id=user_id,
content_encrypted=encrypted, embedding=None,
entity_type=source, entity_id=None,
)
self._db.add(row)
try:
await self._db.commit()
except Exception as exc:
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
await self._db.rollback()
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc()).limit(100)
)
needle = query.strip().lower()
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
return out
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
fernet = await self._get_fernet(user_id)
if fernet is None:
return []
result = await self._db.execute(
select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
.order_by(MemoryEpisodic.created_at.desc()).limit(100)
)
needle = query.strip().lower()
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is None:
continue
if not needle or needle in plaintext.lower():
out.append(plaintext)
if len(out) >= max(top_k, 1):
break
return out
# ── Private ───────────────────────────────────────────────────────
async def _get_fernet(self, user_id: str) -> Fernet | None:
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None or not user.encryption_key:
logger.warning("memory: no encryption_key for user=%s", user_id)
return None
return Fernet(user.encryption_key.encode())
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id)
)
out: dict[str, str] = {}
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None:
out[row.key] = plaintext
return out
async def _load_associative(self, user_id: str, message: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
.order_by(MemoryAssociative.updated_at.desc()).limit(_ASSOCIATIVE_TOP_K)
)
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_episodic(self, user_id: str, fernet: Fernet, session_id: str | None = None) -> list[str]:
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
if session_id:
query = query.where(MemoryEpisodic.session_id == session_id)
result = await self._db.execute(
query.order_by(MemoryEpisodic.created_at.desc()).limit(_EPISODIC_RECENT_N)
)
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute(
select(MemoryProactive).where(
MemoryProactive.user_id == user_id,
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
).order_by(MemoryProactive.confidence.desc())
)
out: list[str] = []
for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
if plaintext is not None:
out.append(plaintext)
return out
def _encrypt(fernet: Fernet, plaintext: str) -> str:
return fernet.encrypt(plaintext.encode()).decode()
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
try:
return fernet.decrypt(ciphertext.encode()).decode()
except (InvalidToken, Exception) as exc:
logger.warning("memory: decrypt failed: %s", exc)
return None

View File

@@ -1,50 +0,0 @@
"""Output formatter for deep-agent stream events — Chat Service copy.
Converts (event_type, data) tuples into WebSocket frame Pydantic models.
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from typing import Any
from shared.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
class StreamFormatter:
"""Convert `(event_type, data)` stream events into websocket frame models."""
def __init__(self, request_id: str) -> None:
self.request_id = request_id
async def format(
self,
event_stream: AsyncGenerator[tuple[str, Any], None],
) -> AsyncGenerator[WsFrame, None]:
started = False
async for event_type, data in event_stream:
if event_type == "floating_domain":
if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain=data,
)
continue
if event_type != "token":
continue
if not started:
yield WsStreamStart(request_id=self.request_id)
started = True
text = str(data or "")
if text:
yield WsStreamText(request_id=self.request_id, chunk=text)
if not started:
yield WsStreamStart(request_id=self.request_id)
yield WsStreamEnd(request_id=self.request_id)

View File

@@ -1,209 +0,0 @@
"""Redis consumer — listens for chat requests and dispatches to deep_agent.
Subscribes to a Redis pattern channel chat:request:* so it receives
requests for ALL users. Each request is processed in a separate asyncio task.
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from shared.db import async_session
from shared.redis import redis_client, ws_out_channel
from app.deep_agent import run_floating_stream, run_home_stream
from app.memory_middleware import MemoryMiddleware
from app.output_formatter import StreamFormatter
from app.ws_context import clear_current_user, set_current_user
from app import tracing
logger = logging.getLogger(__name__)
def start_consumer() -> asyncio.Task:
"""Start the Redis consumer as a background asyncio task."""
return asyncio.create_task(_consumer_loop())
async def _consumer_loop() -> None:
"""Subscribe to chat:request:* and dispatch incoming frames."""
pubsub = redis_client.pubsub()
await pubsub.psubscribe("chat:request:*")
logger.info("redis_consumer: subscribed to chat:request:*")
try:
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message is not None and message["type"] == "pmessage":
frame = json.loads(message["data"])
asyncio.create_task(_dispatch(frame))
else:
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("redis_consumer: shutting down")
finally:
await pubsub.punsubscribe()
await pubsub.aclose()
async def _dispatch(frame: dict) -> None:
"""Route a chat request frame to the appropriate handler."""
frame_type = frame.get("type")
user_id = frame.get("user_id")
if not user_id:
logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type"))
return
if frame_type == "home_request":
await _handle_home_request(user_id, frame)
elif frame_type == "floating_request":
await _handle_floating_request(user_id, frame)
else:
logger.debug("redis_consumer: unknown frame type %r", frame_type)
async def _publish_frame(user_id: str, frame_data: str) -> None:
"""Publish a frame to ws:out:{user_id} for the WS Gateway to forward."""
channel = ws_out_channel(user_id)
await redis_client.publish(channel, frame_data)
async def _handle_home_request(user_id: str, frame: dict) -> None:
"""Process a home_request — enrich with memory, run deep_agent, stream results."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
logger.info(
"redis_consumer: home_request user=%s req=%s msg=%s",
user_id, request_id, message[:200],
)
response_chunks: list[str] = []
with tracing.trace_span(
name="home_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200]},
tags=["home"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
try:
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "home_system")
response_text = "".join(response_chunks)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks),
trace_id=request_id,
)
async def _handle_floating_request(user_id: str, frame: dict) -> None:
"""Process a floating_request — enrich with memory, run deep_agent, stream results."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
logger.info(
"redis_consumer: floating_request user=%s req=%s scope=%s msg=%s",
user_id, request_id, json.dumps(scope)[:200], message[:200],
)
response_chunks: list[str] = []
with tracing.trace_span(
name="floating_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200], "scope": scope},
tags=["floating"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
try:
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "floating_system")
response_text = "".join(response_chunks)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks),
trace_id=request_id,
)

View File

@@ -1,37 +0,0 @@
"""Chat REST route — POST /chat fallback when WS is unavailable."""
from __future__ import annotations
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from shared.schemas import ChatRequest
from app.deep_agent import run_home
from app.ws_context import clear_current_user, set_current_user
router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("")
async def chat(body: ChatRequest, request: Request) -> JSONResponse:
"""REST fallback for home chat.
In the microservices setup, Traefik ForwardAuth has already validated
the JWT and injected X-User-Id / X-User-Email / X-User-Tier headers.
"""
user_id = request.headers.get("X-User-Id", "")
if not user_id:
return JSONResponse(status_code=401, content={"detail": "Missing X-User-Id header"})
set_current_user(user_id)
try:
response = await run_home(
user_id=user_id,
message=body.message,
context=body.context.model_dump(),
)
finally:
clear_current_user()
return JSONResponse(content={"response": response})

View File

@@ -1,264 +0,0 @@
"""Langfuse tracing & prompt management for the Chat Service (v4 SDK).
Provides:
- ``init_langfuse()`` — initialise the singleton client at startup
- ``trace_span()`` — context manager that creates a trace + span
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
- ``flush()`` / ``shutdown()`` — lifecycle management
All functions gracefully degrade to no-ops when Langfuse is not configured,
so the service works identically with or without observability keys.
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import Any
from shared.config import settings
logger = logging.getLogger(__name__)
# ── State ────────────────────────────────────────────────────────────────
_initialised: bool = False
_disabled: bool = False
def _is_configured() -> bool:
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
def init_langfuse() -> None:
"""Initialise the Langfuse singleton. Call once at startup."""
global _initialised, _disabled
if _initialised or _disabled:
return
if not _is_configured():
_disabled = True
logger.info("tracing: Langfuse keys not set — tracing disabled")
return
try:
from langfuse import Langfuse
Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_HOST,
)
_initialised = True
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
except Exception as exc:
_disabled = True
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
def _get_client() -> Any | None:
"""Return the singleton Langfuse client, or *None* if disabled."""
if _disabled:
return None
if not _initialised:
init_langfuse()
if _disabled:
return None
try:
from langfuse import get_client
return get_client()
except Exception:
return None
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
class _NullSpan:
"""Drop-in replacement when Langfuse is disabled."""
def update(self, **_: Any) -> None: ...
def set_trace_io(self, **_: Any) -> None: ...
def score_trace(self, **_: Any) -> None: ...
# ── Trace context manager ───────────────────────────────────────────────
@contextmanager
def trace_span(
*,
name: str,
user_id: str,
session_id: str | None = None,
trace_id: str | None = None,
input: Any = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
):
"""Context manager that creates a Langfuse trace/span.
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
A ``CallbackHandler`` created inside this block auto-inherits the trace
context, so there is no need to pass trace IDs manually.
"""
lf = _get_client()
if lf is None:
yield _NullSpan()
return
try:
from langfuse import Langfuse, propagate_attributes
trace_ctx: dict[str, str] = {}
if trace_id is not None:
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
with lf.start_as_current_observation(
as_type="span",
name=name,
input=input,
metadata=metadata or {},
**({"trace_context": trace_ctx} if trace_ctx else {}),
) as span:
with propagate_attributes(
user_id=user_id,
session_id=session_id,
tags=tags or [],
):
yield span
except Exception as exc:
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
yield _NullSpan()
# ── LangChain callback handler ──────────────────────────────────────────
def get_langfuse_callback() -> Any | None:
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
Must be called inside a ``trace_span()`` block for proper linking.
Returns *None* when Langfuse is disabled.
"""
if _disabled and not _initialised:
return None
try:
from langfuse.langchain import CallbackHandler
return CallbackHandler()
except Exception as exc:
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
return None
# ── Prompt management ────────────────────────────────────────────────────
def get_prompt(
name: str,
*,
version: int | None = None,
label: str | None = None,
fallback: str | None = None,
cache_ttl_seconds: int = 300,
) -> str | None:
"""Fetch a managed prompt from Langfuse by name.
Returns the compiled prompt string, or *fallback* if the prompt is not
found or Langfuse is disabled.
"""
lf = _get_client()
if lf is None:
return fallback
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
return prompt.prompt
except Exception as exc:
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
return fallback
def link_prompt_to_trace(
span: Any,
prompt_name: str,
*,
version: int | None = None,
label: str | None = None,
) -> None:
"""Attach prompt metadata to a span/trace."""
lf = _get_client()
if lf is None or isinstance(span, _NullSpan):
return
try:
kwargs: dict[str, Any] = {"name": prompt_name}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
span.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}})
except Exception as exc:
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
# ── Scoring helper ───────────────────────────────────────────────────────
def score_trace(
trace_id: str,
name: str,
value: float,
*,
comment: str | None = None,
) -> None:
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
lf = _get_client()
if lf is None:
return
try:
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
except Exception as exc:
logger.warning("tracing: score_trace failed: %s", exc)
# ── Shutdown ─────────────────────────────────────────────────────────────
def flush() -> None:
"""Flush pending Langfuse events."""
lf = _get_client()
if lf is not None:
try:
lf.flush()
except Exception as exc:
logger.warning("tracing: flush failed: %s", exc)
def shutdown() -> None:
"""Flush and close the Langfuse client."""
global _initialised, _disabled
lf = _get_client()
if lf is not None:
try:
lf.flush()
lf.shutdown()
except Exception as exc:
logger.warning("tracing: shutdown failed: %s", exc)
_initialised = False
_disabled = False

View File

@@ -1,115 +0,0 @@
"""WebSocket context for Chat Service — Redis-based tool call round-trip.
Replaces the monolith's ws_context.py. Instead of calling Electron directly
via WebSocket, this publishes tool_call frames to Redis (ws:out:{user_id})
and awaits the result via BRPOP on tool:result:{call_id}.
"""
from __future__ import annotations
import json
import logging
from contextvars import ContextVar
from typing import Any
from uuid import uuid4
from shared.redis import redis_client, tool_result_key, ws_out_channel
logger = logging.getLogger(__name__)
_TOOL_CALL_TIMEOUT = 30 # seconds — BRPOP timeout
# Per-request user_id context var (set before agent runs)
_current_user_id: ContextVar[str | None] = ContextVar("_current_user_id", default=None)
# Optional collector for debug
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_current_user(user_id: str) -> None:
_current_user_id.set(user_id)
def clear_current_user() -> None:
_current_user_id.set(None)
def set_tool_result_collector(lst: list[dict]) -> None:
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
_tool_result_collector.set(None)
async def execute_on_client(
action: str,
table: str | None = None,
data: dict[str, Any] | None = None,
filters: dict[str, Any] | None = None,
vector: list[float] | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Send a tool_call to Electron via Redis and await the result.
1. Build tool_call payload
2. Publish to ws:out:{user_id} (WS Gateway forwards to Electron)
3. BRPOP on tool:result:{call_id} (WS Gateway pushes when Electron replies)
4. Return result dict
Raises RuntimeError if no user_id is set or if the call times out.
"""
user_id = _current_user_id.get()
if not user_id:
raise RuntimeError(
"execute_on_client() called without a user_id — "
"set_current_user() must be called first."
)
call_id = str(uuid4())
payload: dict[str, Any] = {
"type": "tool_call",
"id": call_id,
"action": action,
}
if table is not None:
payload["table"] = table
if data is not None:
payload["data"] = data
if filters is not None:
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
if vector is not None:
payload["vector"] = vector
if limit is not None:
payload["limit"] = limit
# Publish tool_call to WS Gateway → Electron
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps(payload))
# Wait for Electron's tool_result
result_key = tool_result_key(call_id)
response = await redis_client.brpop(result_key, timeout=_TOOL_CALL_TIMEOUT)
if response is None:
raise RuntimeError(
f"Tool call {call_id} timed out after {_TOOL_CALL_TIMEOUT}s — "
f"device may be offline or unresponsive."
)
# response is (key, value) tuple
_, raw = response
result = json.loads(raw)
# Collect for debug if requested
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append({
"action": action,
"table": table,
"data": result,
})
return result

View File

@@ -1,17 +0,0 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
redis>=5.0.0
cryptography>=42.0.0
python-dotenv>=1.0.0
langchain-core>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.3.0
litellm>=1.50.0
openai>=1.50.0
httpx>=0.27.0
langfuse>=3.0.0

View File

@@ -1,36 +0,0 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/ws-gateway/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/ws-gateway/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Single worker — each instance handles many WS connections via asyncio
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "1", \
"--timeout", "0"]

View File

@@ -1,17 +0,0 @@
# WS Gateway
Stateless WebSocket proxy. Accepts Electron connections, authenticates JWT,
routes frames to Chat/Batch services via Redis pub/sub.
## No business logic
This service does NOT know what tasks, notes, or agents are.
It only routes JSON frames between Electron and downstream services.
## Scaling
Sticky sessions on `user_id` (Traefik consistent hashing).
## Redis channels used
- Subscribe: `ws:out:{user_id}` (frames to send to client)
- Publish: `chat:request:{user_id}`, `batch:request:{user_id}`
- LPUSH: `tool:result:{call_id}` (from client tool_result frames)
- HSET/HDEL: `ws:devices:{user_id}` (device registry)

View File

@@ -1,173 +0,0 @@
"""WebSocket handler — device connection lifecycle.
Accepts Electron WS connections, authenticates JWT, registers device in Redis,
and runs two concurrent loops:
1. Message loop: receive frames from Electron, route to Redis
2. Outbound loop: subscribe to Redis ws:out:{user_id}, forward to Electron
3. Heartbeat loop: ping every 30s
No business logic lives here — the handler is a JSON frame router.
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from shared.config import settings
from shared.schemas import WsFrameType
from app.redis_bridge import (
publish_batch_request,
publish_chat_request,
push_tool_result,
register_device,
set_gateway_id,
subscribe_outbound,
unregister_device,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["ws-gateway"])
_HEARTBEAT_INTERVAL = 30 # seconds
# Set a unique gateway instance ID on module load
set_gateway_id(str(uuid4()))
@router.websocket("/device")
async def device_ws(websocket: WebSocket) -> None:
"""Persistent WebSocket endpoint for Electron device connections."""
# ── 1. Authenticate via ?token= query parameter ──────────────────
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(
token,
settings.JWT_PUBLIC_KEY,
algorithms=["RS256"],
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008)
return
await websocket.accept()
# ── 2. Await device_hello frame ──────────────────────────────────
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
try:
hello = json.loads(raw)
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("handler: invalid device_hello user=%s: %s", user_id, exc)
await websocket.close(code=1008)
return
# ── 3. Register device in Redis ──────────────────────────────────
await register_device(user_id, device_id)
logger.info("handler: connected user=%s device=%s agents=%s", user_id, device_id, agent_ids)
# Notify downstream services that device is online (for agent trigger)
await publish_batch_request(user_id, {
"type": "device_online",
"user_id": user_id,
"device_id": device_id,
"agent_ids": agent_ids,
})
# ── 4. Subscribe to outbound Redis channel ───────────────────────
pubsub = await subscribe_outbound(user_id)
# ── 5. Run concurrent loops ──────────────────────────────────────
try:
await asyncio.gather(
_inbound_loop(websocket, user_id),
_outbound_loop(websocket, pubsub),
_heartbeat_loop(websocket),
)
except WebSocketDisconnect:
pass
except Exception as exc:
logger.warning("handler: unhandled exception user=%s: %s", user_id, exc)
finally:
await pubsub.unsubscribe()
await pubsub.aclose()
await unregister_device(user_id)
logger.info("handler: disconnected user=%s device=%s", user_id, device_id)
# ── Inbound: Electron → Redis ────────────────────────────────────────
async def _inbound_loop(websocket: WebSocket, user_id: str) -> None:
"""Receive frames from Electron and route to the appropriate Redis channel."""
async for raw in websocket.iter_text():
try:
frame: dict = json.loads(raw)
except json.JSONDecodeError:
logger.warning("handler: invalid JSON from user=%s", user_id)
continue
frame_type = frame.get("type")
# Inject user_id so downstream services know who sent it
frame["user_id"] = user_id
if frame_type == WsFrameType.tool_result:
call_id = frame.get("id")
if call_id:
await push_tool_result(call_id, frame)
else:
logger.warning("handler: tool_result missing id user=%s", user_id)
elif frame_type in (WsFrameType.home_request, WsFrameType.floating_request):
await publish_chat_request(user_id, frame)
elif frame_type in (WsFrameType.journey_start, WsFrameType.journey_message):
await publish_batch_request(user_id, frame)
elif frame_type == "pong":
pass # heartbeat ack
else:
logger.debug("handler: unknown frame type %r user=%s", frame_type, user_id)
# ── Outbound: Redis → Electron ───────────────────────────────────────
async def _outbound_loop(websocket: WebSocket, pubsub) -> None:
"""Subscribe to Redis ws:out:{user_id} and forward frames to Electron."""
while True:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=1.0)
if message is not None and message["type"] == "message":
await websocket.send_text(message["data"])
else:
# Brief sleep to avoid busy-wait when no messages
await asyncio.sleep(0.01)
# ── Heartbeat ────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:
"""Send ping frames every 30s to keep the connection alive."""
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"type": "ping"}))

View File

@@ -1,56 +0,0 @@
"""WS Gateway — stateless WebSocket proxy.
Accepts Electron device connections, authenticates JWT (RS256 public key),
and routes frames between Electron and downstream services via Redis pub/sub.
This service has NO business logic — it only routes JSON frames.
"""
import sys
from contextlib import asynccontextmanager
import logging
from pathlib import Path
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from shared.config import settings
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
from shared.redis import redis_client
await redis_client.aclose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva WS Gateway",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
from app.handler import router
app.include_router(router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "ws-gateway", "version": app.version}
return app
app = create_app()

View File

@@ -1,104 +0,0 @@
"""Redis bridge — device registry + pub/sub routing.
All inter-service communication passes through Redis:
- Device registry: HSET/HDEL ws:devices:{user_id}
- Outbound frames: Subscribe ws:out:{user_id}
- Chat requests: Publish chat:request:{user_id}
- Batch requests: Publish batch:request:{user_id}
- Tool results: LPUSH tool:result:{call_id}
"""
from __future__ import annotations
import json
import logging
from shared.redis import (
batch_request_channel,
chat_request_channel,
device_key,
redis_client,
tool_result_key,
ws_out_channel,
)
logger = logging.getLogger(__name__)
# Instance ID for this gateway replica (set on startup)
_GATEWAY_ID: str = ""
def set_gateway_id(gid: str) -> None:
global _GATEWAY_ID
_GATEWAY_ID = gid
def get_gateway_id() -> str:
return _GATEWAY_ID
# ── Device Registry ──────────────────────────────────────────────────
async def register_device(user_id: str, device_id: str) -> None:
"""Register a connected device in Redis."""
key = device_key(user_id)
await redis_client.hset(key, mapping={
"device_id": device_id,
"gateway_id": _GATEWAY_ID,
})
logger.info("redis_bridge: registered user=%s device=%s gateway=%s", user_id, device_id, _GATEWAY_ID)
async def unregister_device(user_id: str) -> None:
"""Remove device registration from Redis."""
key = device_key(user_id)
await redis_client.delete(key)
logger.info("redis_bridge: unregistered user=%s", user_id)
async def is_device_online(user_id: str) -> bool:
"""Check if a device is registered."""
key = device_key(user_id)
return await redis_client.exists(key) > 0
# ── Frame Routing ────────────────────────────────────────────────────
async def publish_chat_request(user_id: str, frame: dict) -> None:
"""Forward a chat request frame to the Chat Service via Redis."""
channel = chat_request_channel(user_id)
await redis_client.publish(channel, json.dumps(frame))
logger.debug("redis_bridge: published chat_request user=%s", user_id)
async def publish_batch_request(user_id: str, frame: dict) -> None:
"""Forward a batch request frame to the Batch Agent Service via Redis."""
channel = batch_request_channel(user_id)
await redis_client.publish(channel, json.dumps(frame))
logger.debug("redis_bridge: published batch_request user=%s", user_id)
async def push_tool_result(call_id: str, result: dict) -> None:
"""Push a tool_result to the Redis list for the waiting service.
Chat/Batch services do BRPOP on this key with a 30s timeout.
"""
key = tool_result_key(call_id)
await redis_client.lpush(key, json.dumps(result))
# Auto-expire after 60s to prevent stale keys
await redis_client.expire(key, 60)
logger.debug("redis_bridge: pushed tool_result call_id=%s", call_id)
async def subscribe_outbound(user_id: str):
"""Return an async pubsub subscription for frames to send to Electron.
Chat/Batch services publish to ws:out:{user_id} and this gateway
forwards them to the connected WebSocket.
"""
channel = ws_out_channel(user_id)
pubsub = redis_client.pubsub()
await pubsub.subscribe(channel)
return pubsub

View File

@@ -1,8 +0,0 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
redis>=5.0.0
websockets>=14.0

View File

@@ -1,5 +0,0 @@
"""Shared module — imported by all microservices.
Contains DB engine/session, ORM models, Pydantic schemas, config,
and Redis utilities. Changes here affect ALL services.
"""

View File

@@ -1,98 +0,0 @@
"""Shared configuration — Pydantic Settings loaded from environment.
All services import ``settings`` from here. Each service only uses a subset
of the vars, but keeping one Settings class avoids fragmentation.
"""
from pathlib import Path
from typing import Literal
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
# Locate the repo root (adiuva-api/) so we can load its .env as a fallback.
# Works whether cwd is adiuva-api/ (monolith) or adiuva-api/services/xyz/ (microservice).
_this_dir = Path(__file__).resolve().parent # shared/
_repo_root = _this_dir.parent # adiuva-api/
_root_env = _repo_root / ".env"
class Settings(BaseSettings):
# ── Database ─────────────────────────────────────────────────────
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
# ── JWT ────────────────────────────────────────────────────────
# RS256 public key (PEM). Used by any service that needs to verify
# JWTs locally (optional — Traefik ForwardAuth handles this in prod).
# The private key lives ONLY in the Auth Service config.
JWT_PUBLIC_KEY: str = ""
@field_validator("JWT_PUBLIC_KEY", mode="before")
@classmethod
def _expand_pem_newlines(cls, v: str) -> str:
if isinstance(v, str) and r"\n" in v:
return v.replace(r"\n", "\n")
return v
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
# ── Redis ────────────────────────────────────────────────────────
REDIS_URL: str = "redis://localhost:6379/0"
# ── Stripe ───────────────────────────────────────────────────────
STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = ""
# ── S3 ───────────────────────────────────────────────────────────
S3_BUCKET: str = ""
S3_REGION: str = "us-east-1"
S3_ENDPOINT_URL: str = ""
AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = ""
# ── Vector stores ────────────────────────────────────────────────
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
# ── LLM providers ────────────────────────────────────────────────
OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = ""
CEREBRAS_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o"
LLM_EMBED_MODEL: str = "text-embedding-3-small"
GITHUB_COPILOT_TOKEN_DIR: str = ""
# ── OAuth (integrations) ─────────────────────────────────────────
GMAIL_CLIENT_ID: str = ""
GMAIL_CLIENT_SECRET: str = ""
MS_CLIENT_ID: str = ""
MS_CLIENT_SECRET: str = ""
MS_TENANT_ID: str = "common"
OAUTH_ENCRYPTION_KEY: str = ""
# ── Langfuse (observability) ─────────────────────────────────────
LANGFUSE_SECRET_KEY: str = ""
LANGFUSE_PUBLIC_KEY: str = ""
LANGFUSE_HOST: str = "https://cloud.langfuse.com"
# ── CORS ─────────────────────────────────────────────────────────
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
# ── Environment ──────────────────────────────────────────────────
ENV: Literal["dev", "prod"] = "dev"
model_config = SettingsConfigDict(
# Local .env (cwd) takes priority; root .env is fallback.
env_file=(".env", str(_root_env)),
env_file_encoding="utf-8",
extra="ignore",
)
settings = Settings()

View File

@@ -1,32 +0,0 @@
"""Database engine, session factory, and declarative base.
All services use the async SQLAlchemy API via ``get_session()``.
Alembic migrations use the synchronous psycopg2 URL (see alembic/env.py).
"""
from __future__ import annotations
from collections.abc import AsyncGenerator
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
from sqlalchemy.orm import DeclarativeBase
from shared.config import settings
engine = create_async_engine(
settings.DATABASE_URL,
pool_pre_ping=True,
echo=False,
)
async_session = async_sessionmaker(engine, expire_on_commit=False)
class Base(DeclarativeBase):
"""Shared declarative base for all ORM models."""
async def get_session() -> AsyncGenerator[AsyncSession, None]:
"""FastAPI dependency that yields an async DB session per request."""
async with async_session() as session:
yield session

View File

@@ -1,455 +0,0 @@
"""SQLAlchemy ORM models for all persistent tables.
Centralized here so that Alembic migrations and all services share
the same model definitions. Each service only queries the tables it owns.
Ownership:
Auth Service → users, refresh_tokens, subscriptions
Chat Service → memory_core, memory_associative, memory_episodic, memory_proactive
Batch Agent → local_agent_configs, cloud_agent_configs, agent_run_logs
Billing Service → subscriptions (shared write with Auth)
(excluded MVP) → storage_records, backup_metadata, plugins, plugin_*, revenue_events
"""
from __future__ import annotations
import uuid
from datetime import datetime, timezone
from sqlalchemy import (
BigInteger,
Boolean,
DateTime,
Enum,
Float,
ForeignKey,
Integer,
JSON,
String,
Text,
UniqueConstraint,
Uuid,
func,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
from shared.db import Base
# ── Helpers ──────────────────────────────────────────────────────────────
def _uuid() -> str:
return str(uuid.uuid4())
def _now() -> datetime:
return datetime.now(timezone.utc)
# ── Enum types ────────────────────────────────────────────────────────────
TierEnum = Enum("free", "pro", "power", "team", name="billing_tier")
PluginStatusEnum = Enum("pending_review", "approved", "rejected", name="plugin_status")
ReviewDecisionEnum = Enum("approved", "rejected", name="review_decision")
AgentTypeEnum = Enum("local", "cloud", name="agent_type")
AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run_status")
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
# ── Auth models ───────────────────────────────────────────────────────────
class User(Base):
__tablename__ = "users"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
email: Mapped[str] = mapped_column(String(255), unique=True, nullable=False, index=True)
name: Mapped[str | None] = mapped_column(String(100), nullable=True)
surname: Mapped[str | None] = mapped_column(String(100), nullable=True)
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
refresh_tokens: Mapped[list[RefreshToken]] = relationship(
back_populates="user", cascade="all, delete-orphan"
)
subscription: Mapped[Subscription | None] = relationship(
back_populates="user", uselist=False, cascade="all, delete-orphan"
)
class RefreshToken(Base):
__tablename__ = "refresh_tokens"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
token_hash: Mapped[str] = mapped_column(String(64), unique=True, nullable=False, index=True)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="refresh_tokens")
class Subscription(Base):
__tablename__ = "subscriptions"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, unique=True, index=True
)
stripe_subscription_id: Mapped[str | None] = mapped_column(String(255), nullable=True, index=True)
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
status: Mapped[str] = mapped_column(String(50), nullable=False, default="free")
current_period_end: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
user: Mapped[User] = relationship(back_populates="subscription")
# ── Storage models (excluded from MVP, kept for Alembic) ──────────────
class StorageRecord(Base):
__tablename__ = "storage_records"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
table_name: Mapped[str] = mapped_column(String(100), nullable=False)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class BackupMetadata(Base):
__tablename__ = "backup_metadata"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
s3_key: Mapped[str] = mapped_column(String(500), nullable=False)
version: Mapped[int] = mapped_column(Integer, nullable=False)
timestamp: Mapped[int] = mapped_column(BigInteger, nullable=False)
checksum: Mapped[str] = mapped_column(String(64), nullable=False)
size_bytes: Mapped[int] = mapped_column(Integer, nullable=False)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
# ── Plugin models (excluded from MVP, kept for Alembic) ───────────────
class Plugin(Base):
__tablename__ = "plugins"
id: Mapped[str] = mapped_column(String(255), primary_key=True)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
author_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
avg_rating: Mapped[float] = mapped_column(Float, nullable=False, default=0.0)
rejection_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
submitted_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
installations: Mapped[list[PluginInstallation]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
reviews: Mapped[list[PluginReview]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
revenue_events: Mapped[list[RevenueEvent]] = relationship(
back_populates="plugin", cascade="all, delete-orphan"
)
class PluginInstallation(Base):
__tablename__ = "plugin_installations"
__table_args__ = (UniqueConstraint("plugin_id", "user_id", name="uq_plugin_user"),)
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
installed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="installations")
class PluginReview(Base):
__tablename__ = "plugin_reviews"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
reviewer_id: Mapped[str | None] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
)
decision: Mapped[str] = mapped_column(ReviewDecisionEnum, nullable=False)
notes: Mapped[str | None] = mapped_column(Text, nullable=True)
reviewed_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="reviews")
class RevenueEvent(Base):
__tablename__ = "revenue_events"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
plugin_id: Mapped[str] = mapped_column(
String(255), ForeignKey("plugins.id", ondelete="CASCADE"), nullable=False, index=True
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
amount_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
developer_share_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
stripe_transfer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
paid_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
# ── Agent models ──────────────────────────────────────────────────────────
class LocalAgentConfig(Base):
__tablename__ = "local_agent_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
device_id: Mapped[str] = mapped_column(String(255), nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="local_agent",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
cascade="all, delete-orphan",
overlaps="run_logs,cloud_agent",
)
class CloudAgentConfig(Base):
__tablename__ = "cloud_agent_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
provider: Mapped[str] = mapped_column(CloudProviderEnum, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
oauth_token_encrypted: Mapped[str | None] = mapped_column(Text, nullable=True)
filter_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
last_run_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="cloud_agent",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
cascade="all, delete-orphan",
overlaps="run_logs,local_agent",
)
class AgentRunLog(Base):
__tablename__ = "agent_run_logs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
status: Mapped[str] = mapped_column(AgentStatusEnum, nullable=False, default="running")
items_processed: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
items_created: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
errors: Mapped[list | None] = mapped_column(JSON, nullable=True)
started_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
local_agent: Mapped[LocalAgentConfig | None] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,cloud_agent",
)
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,local_agent",
)
# ── Memory models ─────────────────────────────────────────────────────────
class MemoryCore(Base):
"""Per-user persistent key/value preferences, encrypted at rest."""
__tablename__ = "memory_core"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
key: Mapped[str] = mapped_column(String(255), nullable=False)
value_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryAssociative(Base):
"""Per-user semantic memory: encrypted content + pgvector embedding."""
__tablename__ = "memory_associative"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
class MemoryEpisodic(Base):
"""Per-user session summaries, encrypted at rest."""
__tablename__ = "memory_episodic"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
summary_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
session_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)
class MemoryProactive(Base):
"""Per-user inferred behavioral patterns, encrypted at rest."""
__tablename__ = "memory_proactive"
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"),
nullable=False, index=True,
)
pattern_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
confidence: Mapped[float] = mapped_column(Float, nullable=False, default=0.5)
source: Mapped[str] = mapped_column(String(50), nullable=False, default="inferred")
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now()
)

View File

@@ -1,53 +0,0 @@
"""Redis client and pub/sub utilities for inter-service communication.
All services that need Redis import from here.
"""
from __future__ import annotations
import redis.asyncio as aioredis
from shared.config import settings
redis_client: aioredis.Redis = aioredis.from_url(
settings.REDIS_URL,
decode_responses=True,
)
# ── Channel naming conventions ────────────────────────────────────────
# See /memories/repo/microservices-architecture.md for full list.
def ws_out_channel(user_id: str) -> str:
"""Frames to forward to Electron via WS Gateway."""
return f"ws:out:{user_id}"
def chat_request_channel(user_id: str) -> str:
"""Chat requests (home + floating) from WS Gateway → Chat Service."""
return f"chat:request:{user_id}"
def batch_request_channel(user_id: str) -> str:
"""Batch requests (journey + triggers) from WS Gateway → Batch Agent."""
return f"batch:request:{user_id}"
def tool_result_key(call_id: str) -> str:
"""Tool result list: LPUSH by WS Gateway, BRPOP by Chat/Batch."""
return f"tool:result:{call_id}"
def device_key(user_id: str) -> str:
"""Device registry hash."""
return f"ws:devices:{user_id}"
def tier_changed_channel(user_id: str) -> str:
"""Billing tier change notifications."""
return f"tier:changed:{user_id}"
def journey_session_key(user_id: str) -> str:
"""Journey builder session (String + TTL 1800s)."""
return f"journey:{user_id}"

Some files were not shown because too many files have changed in this diff Show More