diff --git a/api/.env.example b/api/.env.example new file mode 100644 index 0000000..2c1990e --- /dev/null +++ b/api/.env.example @@ -0,0 +1,95 @@ +# ── Application ────────────────────────────────────────────────────────────── +ENV=dev + +# ── Database ────────────────────────────────────────────────────────────────── +DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai + +# ── Auth ────────────────────────────────────────────────────────────────────── +JWT_SECRET=replace-with-a-long-random-secret +JWT_ALGORITHM=HS256 +JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30 +JWT_REFRESH_TOKEN_EXPIRE_DAYS=30 + +# ── LLM ─────────────────────────────────────────────────────────────────────── +# LiteLLM model identifiers — change to swap providers without code changes. +# Examples: gpt-4o, anthropic/claude-sonnet-4-20250514, gemini/gemini-pro, ollama/llama3 +# +# API keys — only the key(s) matching your chosen provider(s) are required. +# The correct key is picked automatically from the model prefix (e.g. +# "anthropic/..." → ANTHROPIC_API_KEY, "gemini/..." → GOOGLE_API_KEY). +OPENAI_API_KEY= +ANTHROPIC_API_KEY= +GOOGLE_API_KEY= +CEREBRAS_API_KEY= +GROQ_API_KEY= +DEEPSEEK_API_KEY= + +# Default model used by any agent that does not have a specific override below. +LLM_MODEL=gpt-5-mini +LLM_EMBED_MODEL=text-embedding-3-small + +# GitHub Copilot — leave empty to use the LiteLLM default token directory. +# In Docker, point this to a named-volume path so tokens survive restarts. +# GITHUB_COPILOT_TOKEN_DIR= + +# ── Per-agent model overrides ───────────────────────────────────────────────── +# Leave a value empty to fall back to LLM_MODEL. +# Each agent resolves its API key from the model prefix automatically. +# +# Intent classifier — routes user messages to the right domain agent. +# A small/fast model (e.g. gpt-4o-mini) is usually sufficient here. +LLM_MODEL_CLASSIFIER= + +# Home-agent — handles chat from the home screen (all tools available). +LLM_MODEL_HOME_AGENT= + +# Floating-agent — handles contextual chat triggered from a task/project/note. +LLM_MODEL_FLOATING_AGENT= + +# Unified-processor — processes local directory files (local agent runner). +LLM_MODEL_UNIFIED_PROCESSOR= + +# Cloud-processor — fetches and processes data from cloud connectors. +LLM_MODEL_CLOUD_PROCESSOR= + +# Brief-agent — produces home and project text briefs. +# A small model (e.g. gpt-4o-mini) is sufficient. +# LLM_MODEL_BRIEF_AGENT= + +# Task-brief-agent — per-task deep research (Stage 1 executive assistant). +# Needs tool-use + reasoning; a capable model recommended (e.g. gpt-4o, gemini-2.5-flash). +# LLM_MODEL_TASK_BRIEF_AGENT= + +# Setup-agent — guided journey to build an AgentConfig via WebSocket chat. +LLM_MODEL_SETUP_AGENT= + +# Memory-extractor — Mem0-style extract/decide pipeline (Phase 2). +# Defaults to gpt-4o-mini when empty (fast + cheap, temperature=0). +LLM_MODEL_MEMORY_EXTRACTOR= + +# Memory-miner — proactive pattern mining from episodic history (Phase 5, Power+ only). +# Defaults to gpt-4o-mini when empty. +LLM_MODEL_MEMORY_MINER= + +# Memory-auditor — weekly contradiction scan + relation label canonicalization (Phase 7). +# Defaults to LLM_MODEL when empty (a reasoning-capable model is recommended). +LLM_MODEL_MEMORY_AUDITOR= + +# Scheduler — set to false to disable memory cron jobs (automatically false in tests). +SCHEDULER_ENABLED=true + +# ── Stripe (leave empty to stub billing) ────────────────────────────────────── +STRIPE_SECRET_KEY= +STRIPE_WEBHOOK_SECRET= + + +# ── Langfuse (leave empty to disable observability) ─────────────────────────── +LANGFUSE_SECRET_KEY= +LANGFUSE_PUBLIC_KEY= +# LANGFUSE_BASE_URL=https://cloud.langfuse.com # EU (default) +# LANGFUSE_BASE_URL=https://us.cloud.langfuse.com # US +# LANGFUSE_BASE_URL=http://localhost:3000 # Self-hosted + +# ── CORS ────────────────────────────────────────────────────────────────────── +# Comma-separated list parsed by Settings (override default if needed) +# CORS_ORIGINS=["app://.","http://localhost:3000"] diff --git a/api/.gitea/workflows/deploy.yaml b/api/.gitea/workflows/deploy.yaml new file mode 100644 index 0000000..cc6c5c9 --- /dev/null +++ b/api/.gitea/workflows/deploy.yaml @@ -0,0 +1,93 @@ +name: Test & Deploy API +run-name: ${{ gitea.ref_name }} → Docker LXC + +on: + push: + tags: + - 'v*' + +jobs: + # ── 1. Run tests in an isolated Python container ────────────────── + test: + runs-on: ubuntu-latest + container: + image: python:3.12-slim + + steps: + - name: Install git + run: apt-get update && apt-get install -y --no-install-recommends git + + - name: Checkout Code + run: | + git clone --depth 1 --branch "${GITHUB_REF_NAME}" \ + "http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . || \ + git clone --depth 1 "http://10.0.0.119:3000/${GITHUB_REPOSITORY}.git" . && \ + git checkout "${GITHUB_SHA}" + + - name: Install Dependencies + run: pip install --no-cache-dir -r requirements.txt + + - name: Run Linter + run: ruff check app/ tests/ + + - name: Run Tests + run: pytest tests/ -v --tb=short + + # ── 2. Deploy to Docker LXC via SSH ───────────────────────────────── + deploy: + needs: test + runs-on: ubuntu-latest + if: gitea.event_name == 'push' + + steps: + - name: Deploy via SSH + uses: appleboy/ssh-action@v1.0.0 + with: + host: ${{ secrets.SSH_HOST }} + username: ${{ secrets.SSH_USER }} + key: ${{ secrets.SSH_KEY }} + script: | + set -e + DEPLOY_DIR="/opt/adiuvai-api" + REPO_URL="http://10.0.0.119:3000/${{ gitea.repository }}.git" + TAG="${{ gitea.ref_name }}" + + # ── Pull latest code ── + cd /tmp && rm -rf adiuvai-api-deploy + git clone --depth 1 --branch "${TAG}" "${REPO_URL}" adiuvai-api-deploy + + # ── Sync source (preserve .env) ── + cp -rf /tmp/adiuvai-api-deploy/app/ \ + /tmp/adiuvai-api-deploy/alembic/ \ + /tmp/adiuvai-api-deploy/alembic.ini \ + /tmp/adiuvai-api-deploy/Dockerfile \ + /tmp/adiuvai-api-deploy/docker-compose.yml \ + /tmp/adiuvai-api-deploy/requirements.txt \ + "$DEPLOY_DIR/" + rm -rf /tmp/adiuvai-api-deploy + + # ── Verify .env ── + if [ ! -f "$DEPLOY_DIR/.env" ]; then + echo "❌ $DEPLOY_DIR/.env not found. Create it before deploying." + exit 1 + fi + + # ── Build & restart ── + cd "$DEPLOY_DIR" + docker compose down --remove-orphans || true + docker compose up -d --build + + # ── Migrations ── + docker compose exec -T app alembic upgrade head + + # ── Health check ── + echo "Waiting for app..." + sleep 5 + HTTP_CODE=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/api/v1/health) + if [ "$HTTP_CODE" -eq 200 ]; then + echo "✅ API is healthy (HTTP ${HTTP_CODE})" + else + echo "❌ Health check failed (HTTP ${HTTP_CODE})" + docker compose logs app --tail=50 + exit 1 + fi \ No newline at end of file diff --git a/api/.github/workflows/ci.yml b/api/.github/workflows/ci.yml new file mode 100644 index 0000000..0943da8 --- /dev/null +++ b/api/.github/workflows/ci.yml @@ -0,0 +1,64 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + branches: [main] + +jobs: + lint: + name: Lint + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Install ruff + run: pip install ruff>=0.8.0 + + - name: Ruff check + run: ruff check . + + - name: Ruff format check + run: ruff format --check . + + test: + name: Test + runs-on: ubuntu-latest + needs: lint + steps: + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.12" + + - name: Cache pip + uses: actions/cache@v4 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pip-${{ hashFiles('requirements.txt') }} + restore-keys: ${{ runner.os }}-pip- + + - name: Install dependencies + run: pip install -r requirements.txt + + - name: Run tests + run: pytest -v --tb=short + + docker: + name: Docker Build + runs-on: ubuntu-latest + needs: test + steps: + - uses: actions/checkout@v4 + + - name: Build image + run: docker build -t adiuvai-api:ci . + + - name: Verify gunicorn installed + run: docker run --rm adiuvai-api:ci gunicorn --version diff --git a/api/.gitignore b/api/.gitignore new file mode 100644 index 0000000..7a5d5e6 --- /dev/null +++ b/api/.gitignore @@ -0,0 +1,38 @@ +# Python +__pycache__/ +*.py[cod] +*.egg-info/ +dist/ +build/ + +# Virtual environment +.venv/ +venv/ +env/ + +# Environment variables +.env + +# IDE +.vscode/ +.idea/ + +# Testing / coverage +.pytest_cache/ +htmlcov/ +.coverage +tests/fixtures/private*/ + +# Docker +*.log + +# OS +.DS_Store + +# Smoke scripts (dev-only, not for CI) +scripts/smoke_*.py +Thumbs.db + +# Claude Code +.claude/ +logs/ diff --git a/api/Dockerfile b/api/Dockerfile new file mode 100644 index 0000000..32496db --- /dev/null +++ b/api/Dockerfile @@ -0,0 +1,39 @@ +# ── builder ────────────────────────────────────────────────────────────────── +FROM python:3.12-slim AS builder + +WORKDIR /build + +COPY requirements.txt . +RUN pip install --upgrade pip && \ + pip install --no-cache-dir --prefix=/install -r requirements.txt + +# ── runtime ────────────────────────────────────────────────────────────────── +FROM python:3.12-slim AS runtime + +# Non-root user +RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser + +WORKDIR /app + +# Copy installed packages from builder +COPY --from=builder /install /usr/local + +# Copy application source +COPY app/ app/ + +# Copy Alembic migration files +COPY alembic/ alembic/ +COPY alembic.ini . + +# Ensure appuser owns the working directory +RUN chown -R appuser:appgroup /app + +USER appuser + +EXPOSE 8000 + +CMD ["gunicorn", "app.main:app", \ + "-k", "uvicorn.workers.UvicornWorker", \ + "--bind", "0.0.0.0:8000", \ + "--workers", "4", \ + "--timeout", "120"] diff --git a/api/README.md b/api/README.md new file mode 100644 index 0000000..2565106 --- /dev/null +++ b/api/README.md @@ -0,0 +1,5 @@ +## DEV +Run in DEV with command: +``` +uvicorn app.main:app --host 0.0.0.0 --port 8000 --reload --log-config logging.conf +``` \ No newline at end of file diff --git a/api/alembic.ini b/api/alembic.ini new file mode 100644 index 0000000..1223deb --- /dev/null +++ b/api/alembic.ini @@ -0,0 +1,47 @@ +# Alembic configuration file. +# The async app uses postgresql+asyncpg:// at runtime. +# Alembic CLI uses the sync psycopg2 URL set in env.py (reads from DATABASE_URL env var). + +[alembic] +script_location = alembic +prepend_sys_path = . +version_path_separator = os + +# sqlalchemy.url is overridden in alembic/env.py — leave as placeholder. +sqlalchemy.url = driver://user:pass@localhost/dbname + +[post_write_hooks] + +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/api/alembic/env.py b/api/alembic/env.py new file mode 100644 index 0000000..0480ae2 --- /dev/null +++ b/api/alembic/env.py @@ -0,0 +1,93 @@ +"""Alembic migration environment — async-compatible. + +At runtime the app uses ``postgresql+asyncpg://``. Alembic's CLI is +synchronous, so we derive a *sync* psycopg2 URL from the same DATABASE_URL +env var by replacing the driver prefix. + +Run migrations with: + alembic upgrade head +""" + +from __future__ import annotations + +import asyncio +import os +import re +from logging.config import fileConfig + +from alembic import context +from sqlalchemy import pool +from sqlalchemy.ext.asyncio import create_async_engine + +# Alembic Config object (gives access to alembic.ini values). +config = context.config + +# Set up Python logging from alembic.ini. +if config.config_file_name is not None: + fileConfig(config.config_file_name) + +# Import the Base so that Alembic can detect model changes for --autogenerate. +from app.models import Base # noqa: E402 + +target_metadata = Base.metadata + + +def _sync_url(async_url: str) -> str: + """Convert an asyncpg URL to a psycopg2 URL for Alembic CLI.""" + return re.sub(r"postgresql\+asyncpg", "postgresql+psycopg2", async_url) + + +def _get_url() -> str: + db_url = os.environ.get("DATABASE_URL", "") + if not db_url: + # Fall back to settings if env var not set directly. + from app.config.settings import settings # noqa: PLC0415 + db_url = settings.DATABASE_URL + return _sync_url(db_url) + + +def run_migrations_offline() -> None: + """Emit SQL without a live DB connection.""" + url = _get_url() + context.configure( + url=url, + target_metadata=target_metadata, + literal_binds=True, + dialect_opts={"paramstyle": "named"}, + compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations() + + +def do_run_migrations(connection): # type: ignore[no-untyped-def] + context.configure( + connection=connection, + target_metadata=target_metadata, + compare_type=True, + ) + with context.begin_transaction(): + context.run_migrations() + + +async def run_migrations_online_async() -> None: + """Run migrations against a live DB using the async engine.""" + async_url = os.environ.get("DATABASE_URL", "") + if not async_url: + from app.config.settings import settings # noqa: PLC0415 + async_url = settings.DATABASE_URL + + connectable = create_async_engine(async_url, poolclass=pool.NullPool) + async with connectable.connect() as connection: + await connection.run_sync(do_run_migrations) + await connectable.dispose() + + +def run_migrations_online() -> None: + asyncio.run(run_migrations_online_async()) + + +if context.is_offline_mode(): + run_migrations_offline() +else: + run_migrations_online() diff --git a/api/alembic/script.py.mako b/api/alembic/script.py.mako new file mode 100644 index 0000000..ee746cf --- /dev/null +++ b/api/alembic/script.py.mako @@ -0,0 +1,28 @@ +"""${message} + +Revision ID: ${up_revision} +Revises: ${down_revision | comma,n} +Create Date: ${create_date} + +""" +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa +${imports if imports else ""} + +# revision identifiers, used by Alembic. +revision: str = ${repr(up_revision)} +down_revision: Union[str, None] = ${repr(down_revision)} +branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)} +depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)} + + +def upgrade() -> None: + ${upgrades if upgrades else "pass"} + + +def downgrade() -> None: + ${downgrades if downgrades else "pass"} diff --git a/api/alembic/versions/001_initial_schema.py b/api/alembic/versions/001_initial_schema.py new file mode 100644 index 0000000..ea9895b --- /dev/null +++ b/api/alembic/versions/001_initial_schema.py @@ -0,0 +1,84 @@ +"""Initial schema: users, refresh_tokens, subscriptions. + +Revision ID: 001 +Revises: +Create Date: 2026-03-02 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "001" +down_revision: Union[str, None] = None +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── Enum types — idempotent creation via exception handling ─────────── + op.execute(""" + DO $$ BEGIN + CREATE TYPE billing_tier AS ENUM ('free', 'pro', 'power', 'team'); + EXCEPTION WHEN duplicate_object THEN NULL; + END $$; + """) + + # ── users ───────────────────────────────────────────────────────────── + op.create_table( + "users", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("email", sa.String(255), nullable=False), + sa.Column("password_hash", sa.String(255), nullable=False), + sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"), + sa.Column("stripe_customer_id", sa.String(255), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("email"), + ) + op.create_index("ix_users_email", "users", ["email"]) + + # ── refresh_tokens ──────────────────────────────────────────────────── + op.create_table( + "refresh_tokens", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("token_hash", sa.String(64), nullable=False), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.UniqueConstraint("token_hash"), + ) + op.create_index("ix_refresh_tokens_user_id", "refresh_tokens", ["user_id"]) + op.create_index("ix_refresh_tokens_token_hash", "refresh_tokens", ["token_hash"]) + + # ── subscriptions ───────────────────────────────────────────────────── + op.create_table( + "subscriptions", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("stripe_subscription_id", sa.String(255), nullable=True), + sa.Column("tier", postgresql.ENUM("free", "pro", "power", "team", name="billing_tier", create_type=False), nullable=False, server_default="free"), + sa.Column("status", sa.String(50), nullable=False, server_default="free"), + sa.Column("current_period_end", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.UniqueConstraint("user_id"), + ) + op.create_index("ix_subscriptions_user_id", "subscriptions", ["user_id"]) + op.create_index("ix_subscriptions_stripe_id", "subscriptions", ["stripe_subscription_id"]) + + +def downgrade() -> None: + op.drop_table("subscriptions") + op.drop_table("refresh_tokens") + op.drop_table("users") + + op.execute("DROP TYPE IF EXISTS billing_tier") diff --git a/api/alembic/versions/003_agent_tables.py b/api/alembic/versions/003_agent_tables.py new file mode 100644 index 0000000..455f03b --- /dev/null +++ b/api/alembic/versions/003_agent_tables.py @@ -0,0 +1,127 @@ +"""Add agent config and run log tables: local_agent_configs, cloud_agent_configs, agent_run_logs. + +Revision ID: 003 +Revises: 002 +Create Date: 2026-03-05 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "003" +down_revision: Union[str, None] = "001" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── Enum types — idempotent creation ────────────────────────────────── + op.execute(""" + DO $$ BEGIN + CREATE TYPE agent_type AS ENUM ('local', 'cloud'); + EXCEPTION WHEN duplicate_object THEN NULL; + END $$; + """) + op.execute(""" + DO $$ BEGIN + CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial'); + EXCEPTION WHEN duplicate_object THEN NULL; + END $$; + """) + op.execute(""" + DO $$ BEGIN + CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook'); + EXCEPTION WHEN duplicate_object THEN NULL; + END $$; + """) + + # ── local_agent_configs ─────────────────────────────────────────────── + op.create_table( + "local_agent_configs", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("device_id", sa.String(255), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"), + sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"), + sa.Column("prompt_template", sa.Text, nullable=False, server_default=""), + sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"), + sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"]) + + # ── cloud_agent_configs ─────────────────────────────────────────────── + op.create_table( + "cloud_agent_configs", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column( + "provider", + postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False), + nullable=False, + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"), + sa.Column("prompt_template", sa.Text, nullable=False, server_default=""), + sa.Column("oauth_token_encrypted", sa.Text, nullable=True), + sa.Column("filter_config", sa.JSON, nullable=True), + sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"]) + + # ── agent_run_logs ───────────────────────────────────────────────────── + op.create_table( + "agent_run_logs", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + # Plain string — not a FK because it references either local_agent_configs or + # cloud_agent_configs depending on agent_type. + sa.Column("agent_id", sa.String(255), nullable=False), + sa.Column( + "agent_type", + postgresql.ENUM("local", "cloud", name="agent_type", create_type=False), + nullable=False, + ), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column( + "status", + postgresql.ENUM("running", "success", "error", "partial", name="agent_run_status", create_type=False), + nullable=False, + server_default="running", + ), + sa.Column("items_processed", sa.Integer, nullable=False, server_default="0"), + sa.Column("items_created", sa.Integer, nullable=False, server_default="0"), + sa.Column("errors", sa.JSON, nullable=True), + sa.Column("started_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("completed_at", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_agent_run_logs_user_id", "agent_run_logs", ["user_id"]) + op.create_index("ix_agent_run_logs_agent_id", "agent_run_logs", ["agent_id"]) + + +def downgrade() -> None: + op.drop_table("agent_run_logs") + op.drop_table("cloud_agent_configs") + op.drop_table("local_agent_configs") + + op.execute("DROP TYPE IF EXISTS cloud_provider;") + op.execute("DROP TYPE IF EXISTS agent_run_status;") + op.execute("DROP TYPE IF EXISTS agent_type;") diff --git a/api/alembic/versions/004_add_memory_tables.py b/api/alembic/versions/004_add_memory_tables.py new file mode 100644 index 0000000..ebd2ae1 --- /dev/null +++ b/api/alembic/versions/004_add_memory_tables.py @@ -0,0 +1,144 @@ +"""Add memory tables and user encryption_key column. + +Memory tables: + memory_core — per-user key/value preferences (encrypted) + memory_associative — semantic memory with pgvector embedding (encrypted) + memory_episodic — session summaries (encrypted) + memory_proactive — behavioral patterns (encrypted) + +Also adds encryption_key column to users table. + +Revision ID: 004 +Revises: 003 +Create Date: 2026-03-08 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "004" +down_revision: Union[str, None] = "003" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── Enable pgvector extension (idempotent) ──────────────────────────────── + op.execute("CREATE EXTENSION IF NOT EXISTS vector;") + + # ── Add encryption_key to users ─────────────────────────────────────────── + op.add_column( + "users", + sa.Column("encryption_key", sa.String(64), nullable=True), + ) + + # ── memory_core ─────────────────────────────────────────────────────────── + op.create_table( + "memory_core", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("key", sa.String(255), nullable=False), + sa.Column("value_encrypted", sa.Text, nullable=False), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + op.create_index("ix_memory_core_user_id", "memory_core", ["user_id"]) + + # ── memory_associative ──────────────────────────────────────────────────── + # The embedding column uses pgvector's vector(1536) type. + op.create_table( + "memory_associative", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("content_encrypted", sa.Text, nullable=False), + sa.Column("entity_type", sa.String(100), nullable=True), + sa.Column("entity_id", sa.String(255), nullable=True), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + # Add the pgvector column separately (not supported by generic sa types) + op.execute( + "ALTER TABLE memory_associative ADD COLUMN embedding vector(1536);" + ) + op.create_index("ix_memory_associative_user_id", "memory_associative", ["user_id"]) + # IVFFlat index for approximate nearest-neighbour search + op.execute( + "CREATE INDEX ix_memory_associative_embedding " + "ON memory_associative USING ivfflat (embedding vector_cosine_ops) WITH (lists = 100);" + ) + + # ── memory_episodic ─────────────────────────────────────────────────────── + op.create_table( + "memory_episodic", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("summary_encrypted", sa.Text, nullable=False), + sa.Column("session_id", sa.String(255), nullable=False), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + op.create_index("ix_memory_episodic_user_id", "memory_episodic", ["user_id"]) + op.create_index("ix_memory_episodic_session_id", "memory_episodic", ["session_id"]) + + # ── memory_proactive ────────────────────────────────────────────────────── + op.create_table( + "memory_proactive", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("pattern_encrypted", sa.Text, nullable=False), + sa.Column("confidence", sa.Float, nullable=False, server_default="0.5"), + sa.Column("source", sa.String(50), nullable=False, server_default="inferred"), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + ) + op.create_index("ix_memory_proactive_user_id", "memory_proactive", ["user_id"]) + + +def downgrade() -> None: + op.drop_table("memory_proactive") + op.drop_table("memory_episodic") + op.drop_index("ix_memory_associative_embedding", "memory_associative") + op.drop_table("memory_associative") + op.drop_table("memory_core") + op.drop_column("users", "encryption_key") diff --git a/api/alembic/versions/005_associative_pgvector.py b/api/alembic/versions/005_associative_pgvector.py new file mode 100644 index 0000000..d70f183 --- /dev/null +++ b/api/alembic/versions/005_associative_pgvector.py @@ -0,0 +1,54 @@ +"""Phase 1 — confirm pgvector activation on memory_associative. + +Migration 004 created the embedding column as vector(1536) and added the +IVFFlat index. This migration is the Phase-1 checkpoint: + 1. Ensures the pgvector extension is enabled (idempotent). + 2. Ensures the canonical Phase-1 IVFFlat index exists under the name + memory_associative_embedding_idx (creates it only if absent). + +Revision ID: 005 +Revises: 9a1f2d0b6c7e +Create Date: 2026-04-15 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op + +revision: str = "005" +down_revision: Union[str, None] = "e04100e88ace" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Ensure pgvector extension is enabled (also done in 004, idempotent). + op.execute("CREATE EXTENSION IF NOT EXISTS vector;") + + # Ensure the canonical Phase-1 IVFFlat index exists. + # 004 may have created ix_memory_associative_embedding; this adds the + # Phase-1 name memory_associative_embedding_idx if it is missing. + op.execute( + """ + DO $$ + BEGIN + IF NOT EXISTS ( + SELECT 1 + FROM pg_indexes + WHERE tablename = 'memory_associative' + AND indexname = 'memory_associative_embedding_idx' + ) THEN + CREATE INDEX memory_associative_embedding_idx + ON memory_associative + USING ivfflat (embedding vector_cosine_ops) + WITH (lists = 100); + END IF; + END $$; + """ + ) + + +def downgrade() -> None: + op.execute("DROP INDEX IF EXISTS memory_associative_embedding_idx;") diff --git a/api/alembic/versions/006_memory_relations.py b/api/alembic/versions/006_memory_relations.py new file mode 100644 index 0000000..1d9ce84 --- /dev/null +++ b/api/alembic/versions/006_memory_relations.py @@ -0,0 +1,74 @@ +"""Add memory_relations table (Phase 3 — relational tier). + +Revision ID: 006 +Revises: 1f5975a4f3f4 +Create Date: 2026-04-16 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "006" +down_revision: Union[str, None] = "1f5975a4f3f4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "memory_relations", + sa.Column("id", postgresql.UUID(as_uuid=False), primary_key=True), + sa.Column( + "user_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("users.id", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("subject_label", sa.String(128), nullable=False), + sa.Column("subject_type", sa.String(32), nullable=False), + sa.Column("predicate", sa.String(64), nullable=False), + sa.Column("object_label", sa.String(128), nullable=False), + sa.Column("object_type", sa.String(32), nullable=False), + sa.Column("confidence", sa.Float, nullable=False, server_default="0.7"), + sa.Column( + "source_episode_id", + postgresql.UUID(as_uuid=False), + sa.ForeignKey("memory_episodic.id", ondelete="SET NULL"), + nullable=True, + ), + sa.Column("notes_encrypted", sa.LargeBinary, nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column( + "updated_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.func.now(), + ), + sa.Column("last_confirmed_at", sa.DateTime(timezone=True), nullable=True), + ) + op.create_index( + "memory_relations_user_subject_idx", + "memory_relations", + ["user_id", "subject_label"], + ) + op.create_index( + "memory_relations_user_predicate_idx", + "memory_relations", + ["user_id", "predicate"], + ) + + +def downgrade() -> None: + op.drop_index("memory_relations_user_predicate_idx", "memory_relations") + op.drop_index("memory_relations_user_subject_idx", "memory_relations") + op.drop_table("memory_relations") diff --git a/api/alembic/versions/007_rename_agents_to_scouts.py b/api/alembic/versions/007_rename_agents_to_scouts.py new file mode 100644 index 0000000..e826a46 --- /dev/null +++ b/api/alembic/versions/007_rename_agents_to_scouts.py @@ -0,0 +1,41 @@ +"""Rename agents to scouts. + +Revision ID: 007 +Revises: d6e3f4a5b6c7 +Create Date: 2026-05-15 + +Renames the entire agents subsystem identifiers to scouts. +Pre-1.0 — no data preservation concerns beyond ALTER TABLE rename. +""" + +from typing import Sequence, Union + +from alembic import op + + +revision: str = "007" +down_revision: Union[str, None] = "d6e3f4a5b6c7" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Tables + op.rename_table("local_agent_configs", "local_scout_configs") + op.rename_table("cloud_agent_configs", "cloud_scout_configs") + op.rename_table("agent_run_logs", "scout_run_logs") + + # Columns + op.alter_column("local_scout_configs", "agent_config", new_column_name="scout_config") + op.alter_column("scout_run_logs", "agent_id", new_column_name="scout_id") + op.alter_column("scout_run_logs", "agent_type", new_column_name="scout_type") + + +def downgrade() -> None: + op.alter_column("scout_run_logs", "scout_type", new_column_name="agent_type") + op.alter_column("scout_run_logs", "scout_id", new_column_name="agent_id") + op.alter_column("local_scout_configs", "scout_config", new_column_name="agent_config") + + op.rename_table("scout_run_logs", "agent_run_logs") + op.rename_table("cloud_scout_configs", "cloud_agent_configs") + op.rename_table("local_scout_configs", "local_agent_configs") diff --git a/api/alembic/versions/008_scout_triage_queue.py b/api/alembic/versions/008_scout_triage_queue.py new file mode 100644 index 0000000..a674140 --- /dev/null +++ b/api/alembic/versions/008_scout_triage_queue.py @@ -0,0 +1,59 @@ +"""Scout triage queue + cloud_scout_configs alterations. + +Revision ID: 008 +Revises: 007 +Create Date: 2026-05-16 +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +revision: str = "008" +down_revision: Union[str, None] = "007" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + "scout_triage_queue", + sa.Column("id", sa.Uuid(as_uuid=False), primary_key=True), + sa.Column("user_id", sa.Uuid(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True), + sa.Column("scout_id", sa.Uuid(as_uuid=False), sa.ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False), + sa.Column("source_type", sa.String(50), nullable=False), + sa.Column("source_msg_ref", sa.String(255), nullable=False), + sa.Column("triage_verdict", sa.String(20), nullable=False), + sa.Column("triage_reason", sa.Text, nullable=True), + sa.Column("status", sa.String(20), nullable=False, server_default="queued"), + sa.Column("triaged_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()), + sa.Column("delivered_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("acked_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False), + sa.UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"), + ) + op.create_index("ix_scout_triage_queue_user_status", "scout_triage_queue", ["user_id", "status"]) + op.create_index( + "ix_scout_triage_queue_expires_active", + "scout_triage_queue", + ["expires_at"], + postgresql_where=sa.text("status != 'acked'"), + ) + + op.add_column("cloud_scout_configs", sa.Column("auto_trash_spam", sa.Boolean(), nullable=False, server_default=sa.text("false"))) + op.add_column("cloud_scout_configs", sa.Column("gmail_history_id", sa.String(64), nullable=True)) + op.add_column("cloud_scout_configs", sa.Column("gmail_watch_expires_at", sa.DateTime(timezone=True), nullable=True)) + op.add_column("cloud_scout_configs", sa.Column("device_inactivity_pause_days", sa.Integer(), nullable=False, server_default="14")) + + +def downgrade() -> None: + op.drop_column("cloud_scout_configs", "device_inactivity_pause_days") + op.drop_column("cloud_scout_configs", "gmail_watch_expires_at") + op.drop_column("cloud_scout_configs", "gmail_history_id") + op.drop_column("cloud_scout_configs", "auto_trash_spam") + + op.drop_index("ix_scout_triage_queue_expires_active", table_name="scout_triage_queue") + op.drop_index("ix_scout_triage_queue_user_status", table_name="scout_triage_queue") + op.drop_table("scout_triage_queue") diff --git a/api/alembic/versions/009_cloud_scout_gmail_address.py b/api/alembic/versions/009_cloud_scout_gmail_address.py new file mode 100644 index 0000000..5891f1d --- /dev/null +++ b/api/alembic/versions/009_cloud_scout_gmail_address.py @@ -0,0 +1,25 @@ +"""Add gmail_address to cloud_scout_configs. + +Revision ID: 009 +Revises: 008 +Create Date: 2026-05-16 +""" + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +revision: str = "009" +down_revision: Union[str, None] = "008" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("cloud_scout_configs", sa.Column("gmail_address", sa.String(320), nullable=True)) + + +def downgrade() -> None: + op.drop_column("cloud_scout_configs", "gmail_address") diff --git a/api/alembic/versions/1f5975a4f3f4_add_extraction_queue.py b/api/alembic/versions/1f5975a4f3f4_add_extraction_queue.py new file mode 100644 index 0000000..e7e41ec --- /dev/null +++ b/api/alembic/versions/1f5975a4f3f4_add_extraction_queue.py @@ -0,0 +1,38 @@ +"""add extraction_queue + +Revision ID: 1f5975a4f3f4 +Revises: 005 +Create Date: 2026-04-16 17:26:25.790870 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '1f5975a4f3f4' +down_revision: Union[str, None] = '005' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.create_table( + 'extraction_queue', + sa.Column('id', sa.Uuid(as_uuid=False), nullable=False), + sa.Column('user_id', sa.Uuid(as_uuid=False), nullable=False), + sa.Column('episode_id', sa.Uuid(as_uuid=False), nullable=True), + sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='CASCADE'), + sa.PrimaryKeyConstraint('id'), + ) + op.create_index(op.f('ix_extraction_queue_user_id'), 'extraction_queue', ['user_id'], unique=False) + + +def downgrade() -> None: + op.drop_index(op.f('ix_extraction_queue_user_id'), table_name='extraction_queue') + op.drop_table('extraction_queue') diff --git a/api/alembic/versions/818478c251dc_add_name_and_surname_to_users_table.py b/api/alembic/versions/818478c251dc_add_name_and_surname_to_users_table.py new file mode 100644 index 0000000..164c246 --- /dev/null +++ b/api/alembic/versions/818478c251dc_add_name_and_surname_to_users_table.py @@ -0,0 +1,30 @@ +"""add name and surname to users table + +Revision ID: 818478c251dc +Revises: 004 +Create Date: 2026-03-10 15:10:42.811947 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = '818478c251dc' +down_revision: Union[str, None] = '004' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column('users', sa.Column('name', sa.String(length=100), nullable=True)) + op.add_column('users', sa.Column('surname', sa.String(length=100), nullable=True)) + + +def downgrade() -> None: + op.drop_column('users', 'surname') + op.drop_column('users', 'name') diff --git a/api/alembic/versions/9a1f2d0b6c7e_deprecate_backend_agent_config_tables.py b/api/alembic/versions/9a1f2d0b6c7e_deprecate_backend_agent_config_tables.py new file mode 100644 index 0000000..549c11c --- /dev/null +++ b/api/alembic/versions/9a1f2d0b6c7e_deprecate_backend_agent_config_tables.py @@ -0,0 +1,92 @@ +"""Deprecate backend agent config tables. + +The Electron client is now the source of truth for agent configuration +(directory, extract targets, batch interval, custom prompt). Backend keeps +billing checks and trigger/run logs only. + +Revision ID: 9a1f2d0b6c7e +Revises: 818478c251dc +Create Date: 2026-03-16 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "9a1f2d0b6c7e" +down_revision: Union[str, None] = "818478c251dc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + existing = set(inspector.get_table_names()) + + if "cloud_agent_configs" in existing: + op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs") + op.drop_table("cloud_agent_configs") + + if "local_agent_configs" in existing: + op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs") + op.drop_table("local_agent_configs") + + +def downgrade() -> None: + op.create_table( + "local_agent_configs", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("device_id", sa.String(255), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"), + sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"), + sa.Column("prompt_template", sa.Text, nullable=False, server_default=""), + sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"), + sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"]) + + op.execute( + """ + DO $$ BEGIN + CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook'); + EXCEPTION WHEN duplicate_object THEN NULL; + END $$; + """ + ) + + op.create_table( + "cloud_agent_configs", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column( + "provider", + postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False), + nullable=False, + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"), + sa.Column("prompt_template", sa.Text, nullable=False, server_default=""), + sa.Column("oauth_token_encrypted", sa.Text, nullable=True), + sa.Column("filter_config", sa.JSON, nullable=True), + sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"]) diff --git a/api/alembic/versions/a3b9c0d1e2f3_add_agent_config_to_local_agents.py b/api/alembic/versions/a3b9c0d1e2f3_add_agent_config_to_local_agents.py new file mode 100644 index 0000000..60a9b96 --- /dev/null +++ b/api/alembic/versions/a3b9c0d1e2f3_add_agent_config_to_local_agents.py @@ -0,0 +1,107 @@ +"""Restore agent config tables and add agent_config column. + +9a1f2d0b6c7e dropped local_agent_configs and cloud_agent_configs, but both +ORM models are still active. This migration recreates them with agent_config +added to local_agent_configs. + +Revision ID: a3b9c0d1e2f3 +Revises: 9a1f2d0b6c7e +Create Date: 2026-04-07 00:00:00.000000 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = "a3b9c0d1e2f3" +down_revision: Union[str, None] = "9a1f2d0b6c7e" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Recreate enum types (idempotent — they may already exist from migration 003) + op.execute(""" + DO $$ BEGIN + CREATE TYPE agent_type AS ENUM ('local', 'cloud'); + EXCEPTION WHEN duplicate_object THEN NULL; + END $$; + """) + op.execute(""" + DO $$ BEGIN + CREATE TYPE agent_run_status AS ENUM ('running', 'success', 'error', 'partial'); + EXCEPTION WHEN duplicate_object THEN NULL; + END $$; + """) + op.execute(""" + DO $$ BEGIN + CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook'); + EXCEPTION WHEN duplicate_object THEN NULL; + END $$; + """) + + bind = op.get_bind() + inspector = sa.inspect(bind) + existing = set(inspector.get_table_names()) + + # ── local_agent_configs (with agent_config column) ──────────────────── + if "local_agent_configs" not in existing: + op.create_table( + "local_agent_configs", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("device_id", sa.String(255), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"), + sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"), + sa.Column("prompt_template", sa.Text, nullable=False, server_default=""), + sa.Column("agent_config", sa.JSON, nullable=True), + sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"), + sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"]) + + # ── cloud_agent_configs ─────────────────────────────────────────────── + if "cloud_agent_configs" not in existing: + op.create_table( + "cloud_agent_configs", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column( + "provider", + postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False), + nullable=False, + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"), + sa.Column("prompt_template", sa.Text, nullable=False, server_default=""), + sa.Column("oauth_token_encrypted", sa.Text, nullable=True), + sa.Column("filter_config", sa.JSON, nullable=True), + sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"]) + + +def downgrade() -> None: + op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs") + op.drop_table("cloud_agent_configs") + op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs") + op.drop_table("local_agent_configs") diff --git a/api/alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py b/api/alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py new file mode 100644 index 0000000..8b9b34e --- /dev/null +++ b/api/alembic/versions/b4c0d1e2f3a4_add_oauth_and_avatar.py @@ -0,0 +1,56 @@ +"""Add oauth_accounts table, nullable password_hash, avatar_url to users. + +Revision ID: b4c0d1e2f3a4 +Revises: a3b9c0d1e2f3 +Create Date: 2026-04-10 00:00:00.000000 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + + +# revision identifiers, used by Alembic. +revision: str = "b4c0d1e2f3a4" +down_revision: Union[str, None] = "a3b9c0d1e2f3" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ── users: make password_hash nullable (social users have no password) ── + op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=True) + + # ── users: add avatar_url ───────────────────────────────────────────── + op.add_column("users", sa.Column("avatar_url", sa.String(2048), nullable=True)) + + # ── oauth_accounts ──────────────────────────────────────────────────── + op.create_table( + "oauth_accounts", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("provider", sa.String(50), nullable=False), + sa.Column("provider_user_id", sa.String(255), nullable=False), + sa.Column("provider_email", sa.String(255), nullable=True), + sa.Column( + "created_at", + sa.DateTime(timezone=True), + nullable=False, + server_default=sa.text("now()"), + ), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + sa.UniqueConstraint("provider", "provider_user_id", name="uq_oauth_provider_user"), + ) + op.create_index("ix_oauth_accounts_user_id", "oauth_accounts", ["user_id"]) + + +def downgrade() -> None: + op.drop_index("ix_oauth_accounts_user_id", table_name="oauth_accounts") + op.drop_table("oauth_accounts") + op.drop_column("users", "avatar_url") + op.alter_column("users", "password_hash", existing_type=sa.String(255), nullable=False) diff --git a/api/alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py b/api/alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py new file mode 100644 index 0000000..36d63bd --- /dev/null +++ b/api/alembic/versions/c5d1e2f3a4b5_add_onboarding_completed_at.py @@ -0,0 +1,31 @@ +"""Add onboarding_completed_at column to users table. + +Revision ID: c5d1e2f3a4b5 +Revises: b4c0d1e2f3a4 +Create Date: 2026-04-11 00:00:00.000000 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op + + +# revision identifiers, used by Alembic. +revision: str = "c5d1e2f3a4b5" +down_revision: Union[str, None] = "b4c0d1e2f3a4" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "users", + sa.Column("onboarding_completed_at", sa.DateTime(timezone=True), nullable=True), + ) + + +def downgrade() -> None: + op.drop_column("users", "onboarding_completed_at") diff --git a/api/alembic/versions/d6e3f4a5b6c7_folder_index_tables.py b/api/alembic/versions/d6e3f4a5b6c7_folder_index_tables.py new file mode 100644 index 0000000..c084f72 --- /dev/null +++ b/api/alembic/versions/d6e3f4a5b6c7_folder_index_tables.py @@ -0,0 +1,46 @@ +"""Add token tracking columns for folder integration. + +Revision ID: d6e3f4a5b6c7 +Revises: 006 +Create Date: 2026-05-11 00:00:00.000000 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import UUID + +# revision identifiers, used by Alembic. +revision: str = "d6e3f4a5b6c7" +down_revision: Union[str, None] = "006" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column( + "agent_run_logs", + sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"), + ) + op.create_table( + "monthly_token_usage", + sa.Column("user_id", UUID(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False), + sa.Column("year_month", sa.String(7), nullable=False), + sa.Column("feature", sa.String(64), nullable=False), + sa.Column("tokens_used", sa.Integer(), nullable=False, server_default="0"), + sa.PrimaryKeyConstraint("user_id", "year_month", "feature"), + ) + op.create_index( + "ix_monthly_token_usage_user_month", + "monthly_token_usage", + ["user_id", "year_month"], + ) + + +def downgrade() -> None: + op.drop_index("ix_monthly_token_usage_user_month", table_name="monthly_token_usage") + op.drop_table("monthly_token_usage") + op.drop_column("agent_run_logs", "tokens_used") diff --git a/api/alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py b/api/alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py new file mode 100644 index 0000000..0a1421c --- /dev/null +++ b/api/alembic/versions/e04100e88ace_avatar_url_varchar_to_text.py @@ -0,0 +1,34 @@ +"""avatar_url_varchar_to_text + +Revision ID: e04100e88ace +Revises: c5d1e2f3a4b5 +Create Date: 2026-04-13 09:13:06.733674 + +""" +from __future__ import annotations + +from typing import Sequence, Union + +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision: str = 'e04100e88ace' +down_revision: Union[str, None] = 'c5d1e2f3a4b5' +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.alter_column('users', 'avatar_url', + existing_type=sa.VARCHAR(length=2048), + type_=sa.Text(), + existing_nullable=True) + + +def downgrade() -> None: + op.alter_column('users', 'avatar_url', + existing_type=sa.Text(), + type_=sa.VARCHAR(length=2048), + existing_nullable=True) diff --git a/api/app/__init__.py b/api/app/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/app/agents/__init__.py b/api/app/agents/__init__.py new file mode 100644 index 0000000..a2dc4c6 --- /dev/null +++ b/api/app/agents/__init__.py @@ -0,0 +1,5 @@ +"""Expose tool modules used by deep orchestrator-worker graphs.""" + +from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent + +__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"] diff --git a/api/app/agents/client_agent.py b/api/app/agents/client_agent.py new file mode 100644 index 0000000..df1e945 --- /dev/null +++ b/api/app/agents/client_agent.py @@ -0,0 +1,52 @@ +"""Client agent — read-only tools for the clients table.""" + +from __future__ import annotations + +import json +from typing import Any + +from langchain_core.tools import tool + +from app.core.ws_context import execute_on_client + + +@tool +async def list_clients(search: str = "", limit: int = 20) -> str: + """List clients, optionally filtered by a name/email substring search. + + search: optional substring to match against client name or email. + limit: max rows to return (default 20). + """ + filters: dict[str, Any] = {"limit": limit} + if search: + filters["search"] = search + + result = await execute_on_client(action="select", table="clients", filters=filters) + rows = result.get("rows", []) + if not rows: + return "No clients found." + lines = [ + f"- {r.get('name', '?')} (id: {r.get('id')}, email: {r.get('email', '')}, " + f"company: {r.get('company', '')})" + for r in rows + ] + return f"Found {len(rows)} client(s):\n" + "\n".join(lines) + + +@tool +async def get_client(id: str) -> str: + """Get full details for one client by UUID. + + id: the client's UUID. + """ + if not id: + return "Client id is required." + + result = await execute_on_client(action="get", table="clients", data={"id": id}) + row = result.get("row") or result.get("rows", [None])[0] if result else None + if not row: + return f"Client '{id}' not found." + return f"Client details:\n{json.dumps(row, ensure_ascii=False, indent=2)}" + + +CLIENT_TOOLS: list[Any] = [list_clients, get_client] diff --git a/api/app/agents/filesystem_agent.py b/api/app/agents/filesystem_agent.py new file mode 100644 index 0000000..e7cf600 --- /dev/null +++ b/api/app/agents/filesystem_agent.py @@ -0,0 +1,194 @@ +"""Filesystem agent — tools for reading local directories and files on Electron. + +These tools delegate to the Electron client via ``execute_on_client()`` using +the same WS tool-call round-trip pattern as CRUD tools. The Electron app +handles actual disk I/O and responds with ``tool_result`` frames. +""" + +from __future__ import annotations + +import os +import re +from pathlib import Path +from typing import Any + +from langchain_core.tools import tool + +from app.core.ws_context import execute_on_client + +# Max characters returned by read_file_content in journey (exploration) tools. +# The journey only needs to understand file structure, not full content. +_JOURNEY_READ_MAX_CHARS: int = 4000 + + +def _resolve_path(path: str, base: str) -> str: + """Resolve *path* against *base* when *path* is relative. + + The LLM often passes ``"."`` meaning "the configured directory". + Without this, Electron resolves ``"."`` relative to its own CWD instead + of the user's chosen directory. + """ + if os.path.isabs(path): + return path + return str(Path(base) / path) + + +@tool +async def list_directory(path: str) -> str: + """List files and folders in a local directory on the user's device. + + Returns a formatted listing of entries with name, type (file/directory), + and full path. + """ + result = await execute_on_client( + action="list_directory", + data={"path": path}, + ) + entries: list[dict[str, Any]] = result.get("entries", []) + if not entries: + return f"Directory '{path}' is empty or does not exist." + lines: list[str] = [] + for entry in entries: + entry_type = entry.get("type", "unknown") + entry_name = entry.get("name", "") + entry_path = entry.get("path", "") + lines.append(f"- [{entry_type}] {entry_name} ({entry_path})") + return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines) + + +@tool +async def read_file_content(path: str) -> str: + """Read the text content of a local file on the user's device. + + Returns the file content as a string. Large files may be truncated + by the Electron client. + """ + result = await execute_on_client( + action="read_file_content", + data={"path": path}, + ) + content: str = result.get("content", "") + if not content: + return f"File '{path}' is empty or could not be read." + return content + + +@tool +async def get_file_metadata(path: str) -> str: + """Get metadata for a local file: size, creation date, modification date, extension. + + Returns a formatted summary of the file's metadata. + """ + result = await execute_on_client( + action="get_file_metadata", + data={"path": path}, + ) + size = result.get("size", "unknown") + created = result.get("createdAt", "unknown") + modified = result.get("modifiedAt", "unknown") + extension = result.get("extension", "unknown") + name = result.get("name", path) + return ( + f"File: {name}\n" + f" Extension: {extension}\n" + f" Size: {size} bytes\n" + f" Created: {created}\n" + f" Modified: {modified}" + ) + + +FILESYSTEM_TOOLS: list[Any] = [ + list_directory, + read_file_content, + get_file_metadata, +] + + +def make_directory_tools(base_directory: str) -> list[Any]: + """Return filesystem tools that resolve relative paths against *base_directory*. + + Use this instead of ``FILESYSTEM_TOOLS`` whenever you know the user's target + directory upfront (e.g., journey setup sessions). Relative paths like ``"."`` + from the LLM are resolved to the correct absolute path before being sent to + the Electron client, preventing it from falling back to its own CWD. + """ + + def _compact_for_journey(raw: str) -> str: + """Strip HTML noise and truncate for journey exploration. + + The journey LLM only needs to understand file structure (headers, + first paragraphs). Full CSS/style blocks are pure noise that eat + up context window budget. + """ + text = re.sub(r"]*>.*?", "", raw, flags=re.DOTALL | re.IGNORECASE) + text = re.sub(r"]*>.*?", "", text, flags=re.DOTALL | re.IGNORECASE) + text = re.sub(r"", "", text, flags=re.DOTALL) + if len(text) > _JOURNEY_READ_MAX_CHARS: + text = text[:_JOURNEY_READ_MAX_CHARS] + "\n[…truncated for exploration]" + return text + + @tool + async def list_directory(path: str) -> str: # noqa: F811 + """List files and folders in a local directory on the user's device. + + Returns a formatted listing of entries with name, type (file/directory), + and full path. + """ + resolved = _resolve_path(path, base_directory) + result = await execute_on_client( + action="list_directory", + data={"path": resolved}, + ) + entries: list[dict[str, Any]] = result.get("entries", []) + if not entries: + return f"Directory '{resolved}' is empty or does not exist." + lines: list[str] = [] + for entry in entries: + entry_type = entry.get("type", "unknown") + entry_name = entry.get("name", "") + entry_path = entry.get("path", "") + lines.append(f"- [{entry_type}] {entry_name} ({entry_path})") + return f"Directory listing for '{resolved}' ({len(entries)} entries):\n" + "\n".join(lines) + + @tool + async def read_file_content(path: str) -> str: # noqa: F811 + """Read the text content of a local file on the user's device. + + Returns the file content as a string. Large files may be truncated + by the Electron client. + """ + resolved = _resolve_path(path, base_directory) + result = await execute_on_client( + action="read_file_content", + data={"path": resolved}, + ) + content: str = result.get("content", "") + if not content: + return f"File '{resolved}' is empty or could not be read." + return _compact_for_journey(content) + + @tool + async def get_file_metadata(path: str) -> str: # noqa: F811 + """Get metadata for a local file: size, creation date, modification date, extension. + + Returns a formatted summary of the file's metadata. + """ + resolved = _resolve_path(path, base_directory) + result = await execute_on_client( + action="get_file_metadata", + data={"path": resolved}, + ) + size = result.get("size", "unknown") + created = result.get("createdAt", "unknown") + modified = result.get("modifiedAt", "unknown") + extension = result.get("extension", "unknown") + name = result.get("name", resolved) + return ( + f"File: {name}\n" + f" Extension: {extension}\n" + f" Size: {size} bytes\n" + f" Created: {created}\n" + f" Modified: {modified}" + ) + + return [list_directory, read_file_content, get_file_metadata] diff --git a/api/app/agents/folder_agent.py b/api/app/agents/folder_agent.py new file mode 100644 index 0000000..f6542d6 --- /dev/null +++ b/api/app/agents/folder_agent.py @@ -0,0 +1,168 @@ +"""Scoped file-read and search tools for the project folder feature.""" +from __future__ import annotations + +from langchain_core.tools import tool + +from app.core.folder_indexer import _extract_docx_text, _extract_pdf_text +from app.core.ws_context import execute_on_client + +# Cap returned slice size to keep tool output under control. +_MAX_RETURN_CHARS = 50_000 +_MAX_SEARCH_MATCHES = 20 + + +def _is_unsafe_path(rel: str) -> bool: + if not rel: + return True + norm = rel.replace("\\", "/") + if norm.startswith("/"): + return True + # Windows drive letter + if len(rel) >= 2 and rel[1] == ":": + return True + parts = norm.split("/") + return ".." in parts + + +async def _fetch_file(project_id: str, relative_path: str, offset: int, length: int) -> dict: + """Return the raw Electron tool_result dict for a file read.""" + return await execute_on_client( + action="read_project_folder_file", + data={ + "projectId": project_id, + "relativePath": relative_path, + "offset": offset, + "length": length, + }, + ) + + +def _decode(result: dict) -> tuple[str, str, int]: + """Decode a tool_result into (text, kind, total_size). For pdf/docx, + extracts text from base64. For images, returns a placeholder string. + For text, content is already a sliced utf-8 string. + """ + kind = result.get("kind", "text") + content = result.get("content", "") or "" + total = int(result.get("totalSize", 0) or 0) + if kind == "image": + return ("[Image file — cannot be navigated as text. See manifest summary.]", kind, total) + if kind == "pdf": + return (_extract_pdf_text(content), kind, total) + if kind == "docx": + return (_extract_docx_text(content), kind, total) + return (content, kind, total) + + +@tool +async def read_project_folder_file( + project_id: str, + relative_path: str, + offset: int = 0, + length: int = _MAX_RETURN_CHARS, +) -> str: + """Read a slice of a file inside the project's linked folder. + + Args: + project_id: project ID. + relative_path: path relative to the linked folder root. + offset: char offset to start reading from (0 = beginning). + length: max chars to return. Default 50000. Use smaller values to save tokens. + + Returns text content slice with a header showing position. Header tells you + when more content is available; call again with the suggested next offset. + + For PDF / DOCX files the backend extracts text first, then applies offset/length + on the extracted text. For images returns a placeholder; navigate with the + manifest summary instead. + """ + if _is_unsafe_path(relative_path): + return "Access denied" + + result = await _fetch_file(project_id, relative_path, offset, length) + text, kind, total_size = _decode(result) + + if not text and kind in ("missing", "error"): + return f"File not found or unreadable: {relative_path}" + + if kind in ("pdf", "docx"): + # Backend extracted full text — apply offset/length on chars. + sliced = text[offset:offset + length] + slice_end = min(offset + length, len(text)) + header = ( + f"[file={relative_path} kind={kind} offset={offset} end={slice_end} " + f"totalChars={len(text)}]" + ) + if slice_end < len(text): + header += f"\n[More content available — call again with offset={slice_end}.]" + return header + "\n" + sliced + + if kind == "text": + slice_end = offset + len(text) + header = ( + f"[file={relative_path} kind=text offset={offset} end={slice_end} " + f"totalBytes={total_size}]" + ) + if slice_end < total_size: + header += f"\n[More content available — call again with offset={slice_end}.]" + return header + "\n" + text + + # image or unknown + return text + + +@tool +async def search_project_folder_file( + project_id: str, + relative_path: str, + query: str, + context_lines: int = 3, +) -> str: + """Search a project folder file for a query string (case-insensitive substring). + + Args: + project_id: project ID. + relative_path: path relative to the linked folder root. + query: text to search for. + context_lines: number of lines of context around each match (default 3). + + Returns matching line ranges with surrounding context and 1-based line numbers. + Capped at 20 matches; if more exist the header shows the total. + + Works on text, code, markdown, PDF (extracted), and DOCX (extracted). + Images and binary files are not searchable. + """ + if _is_unsafe_path(relative_path): + return "Access denied" + if not query: + return "Empty query." + + # For text we still need full file; pass length=very large. + result = await _fetch_file(project_id, relative_path, offset=0, length=10_000_000) + text, kind, _ = _decode(result) + + if not text and kind in ("missing", "error"): + return f"File not found or unreadable: {relative_path}" + if kind == "image": + return "Cannot search inside images." + + lines = text.splitlines() + q = query.lower() + matches = [i for i, line in enumerate(lines) if q in line.lower()] + if not matches: + return f"No matches for '{query}' in {relative_path}." + + shown = matches[:_MAX_SEARCH_MATCHES] + snippets: list[str] = [] + for i in shown: + start = max(0, i - context_lines) + end = min(len(lines), i + context_lines + 1) + block = "\n".join(f"{n + 1:5d}: {lines[n]}" for n in range(start, end)) + snippets.append(block) + + header = f"[file={relative_path} matches={len(matches)} showing={len(shown)} query='{query}']" + body = "\n---\n".join(snippets) + return header + "\n" + body + + +FOLDER_TOOLS = [read_project_folder_file, search_project_folder_file] diff --git a/api/app/agents/note_agent.py b/api/app/agents/note_agent.py new file mode 100644 index 0000000..4cf75fb --- /dev/null +++ b/api/app/agents/note_agent.py @@ -0,0 +1,206 @@ +"""Note agent — Markdown note management (list, get, create, update, propose edit).""" + +from __future__ import annotations + +import asyncio +import re +from typing import Any + +from langchain_core.tools import tool + +from app.core.note_summarizer import generate_note_summary +from app.core.ws_context import execute_on_client + +_UUID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$" +) + + +def _is_uuid(value: str) -> bool: + return bool(_UUID_RE.match(value)) + + +def _fmt_summary(row: dict) -> str: + summary = (row.get("aiSummary") or row.get("ai_summary") or "").strip() + if summary: + return f" — {summary}" + snippet = (row.get("content") or "")[:120].replace("\n", " ").strip() + return f" — {snippet}" if snippet else "" + + +@tool +async def list_notes(project_id: str = "") -> str: + """List notes with AI summaries, optionally scoped to a project by project_id. + + Returns id, title, and ai_summary for each note so you can decide which + note to read in full with get_note before creating or updating. + """ + normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" + result = await execute_on_client( + action="select", + table="notes", + filters={"projectId": normalized_project_id or None}, + ) + rows = result.get("rows", []) + if not rows: + return "No notes found." + lines = [f" - [{r['id']}] {r['title']}{_fmt_summary(r)}" for r in rows] + return f"Found {len(rows)} note(s):\n" + "\n".join(lines) + + +@tool +async def get_note(note_id: str) -> str: + """Fetch a single note by its UUID to read its full Markdown content.""" + result = await execute_on_client(action="get", table="notes", data={"id": note_id}) + row = result.get("row") + if not row: + return f"Note {note_id} not found." + return f"Note '{row['title']}' (id: {row['id']}):\n\n{row['content']}" + + +@tool +async def create_note( + title: str, + content: str, + project_id: str = "", +) -> str: + """Create a new note. + title: note heading (required) + content: Markdown body text (required) + project_id: optional UUID linking this note to a project + """ + result = await execute_on_client( + action="insert", + table="notes", + data={ + "title": title, + "content": content, + "projectId": project_id or None, + }, + ) + row = result["row"] + note_id: str = row["id"] + # Generate summary asynchronously — fire-and-forget. + asyncio.create_task(_refresh_summary(note_id, title, content)) + return f"Note created: '{row['title']}' (id: {note_id})." + + +@tool +async def update_note( + note_id: str, + title: str = "", + content: str = "", +) -> str: + """Update an existing note directly (no approval required). + Use propose_note_edit instead when human review is needed. + note_id: UUID of the note (required) + If you need to preserve existing content, call get_note first. + """ + updates: dict[str, Any] = {} + if title: + updates["title"] = title + if content: + updates["content"] = content + result = await execute_on_client( + action="update", + table="notes", + data={"id": note_id, "updates": updates}, + ) + row = result["row"] + if content: + new_title = title or row.get("title", "") + asyncio.create_task(_refresh_summary(note_id, new_title, content)) + return f"Note updated: '{row['title']}' (id: {row['id']})." + + +@tool +async def propose_note_edit( + note_id: str, + edit_type: str, + proposed_content: str, + reasoning: str = "", + anchor_before: str = "", + anchor_text: str = "", + agent_id: str = "", + run_id: str = "", +) -> str: + """Propose an AI edit to an existing note, pending human approval. + + Use this instead of update_note when review_required is true. + The user will see the proposal highlighted before it is merged. + + note_id: UUID of the target note (required) + edit_type: 'append' | 'insert' | 'replace' + - append: adds proposed_content at the end of the note + - insert: inserts proposed_content immediately after anchor_before text + - replace: replaces the first occurrence of anchor_text with proposed_content + proposed_content: the new Markdown text to add or substitute (required) + reasoning: brief explanation shown to the user (recommended) + anchor_before: for 'insert' — the text snippet that precedes the insertion point + anchor_text: for 'replace' — the exact text to be replaced + agent_id: agent identifier (for traceability) + run_id: run identifier (for traceability) + """ + if edit_type not in ("append", "insert", "replace"): + return f"Invalid edit_type '{edit_type}'. Use 'append', 'insert', or 'replace'." + + result = await execute_on_client( + action="propose_note_edit", + data={ + "noteId": note_id, + "type": edit_type, + "proposedContent": proposed_content, + "reasoning": reasoning or None, + "anchorBefore": anchor_before or None, + "anchorText": anchor_text or None, + "agentId": agent_id or None, + "runId": run_id or None, + }, + ) + edit_id = result.get("id", "?") + return ( + f"Edit proposal created (id: {edit_id}) for note {note_id}. " + f"Status: pending user approval." + ) + + +@tool +async def delete_note(note_id: str) -> str: + """Delete a note permanently by its UUID.""" + await execute_on_client(action="delete", table="notes", data={"id": note_id}) + return f"Note {note_id} deleted." + + +async def _refresh_summary(note_id: str, title: str, content: str) -> None: + """Generate and persist the AI summary for a note. Fire-and-forget.""" + try: + summary = await generate_note_summary(title, content) + if summary: + await execute_on_client( + action="update", + table="notes", + data={ + "id": note_id, + "updates": { + "aiSummary": summary, + "aiSummaryUpdatedAt": int(__import__("time").time() * 1000), + }, + }, + ) + except Exception: + pass # fire-and-forget; errors logged by generate_note_summary + + +NOTE_TOOLS: list[Any] = [ + list_notes, + get_note, + create_note, + update_note, + propose_note_edit, + delete_note, +] + +NOTE_READ_TOOLS: list[Any] = [ + list_notes, + get_note, +] diff --git a/api/app/agents/project_agent.py b/api/app/agents/project_agent.py new file mode 100644 index 0000000..4689b31 --- /dev/null +++ b/api/app/agents/project_agent.py @@ -0,0 +1,133 @@ +"""Project agent — full lifecycle management (list, get, create, update, archive, delete).""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.tools import tool + +from app.core.ws_context import execute_on_client + + +@tool +async def list_projects( + client_id: str = "", + include_archived: int = 0, +) -> str: + """List projects, optionally filtered by client_id. + include_archived: 1 to include archived projects, 0 for active only (default). + """ + result = await execute_on_client( + action="select", + table="projects", + filters={ + "clientId": client_id or None, + "includeArchived": bool(include_archived), + }, + ) + rows = result.get("rows", []) + if not rows: + return "No projects found." + lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows] + return f"Found {len(rows)} project(s):\n" + "\n".join(lines) + + +@tool +async def list_all_projects() -> str: + """List every project regardless of client or status. + Use only when the user wants a complete cross-client overview. + """ + result = await execute_on_client(action="select", table="projects") + rows = result.get("rows", []) + if not rows: + return "No projects found." + lines = [f"- {r['name']} (status: {r['status']}, id: {r['id']})" for r in rows] + return f"All projects ({len(rows)}):\n" + "\n".join(lines) + + +@tool +async def get_project(project_id: str) -> str: + """Fetch a single project by its UUID.""" + result = await execute_on_client(action="get", table="projects", data={"id": project_id}) + row = result.get("row") + if not row: + return f"Project {project_id} not found." + return ( + f"Project: '{row['name']}' (id: {row['id']}, status: {row['status']}, " + f"clientId: {row.get('clientId', 'none')})" + ) + + +@tool +async def create_project( + name: str, + client_id: str = "", +) -> str: + """Create a new project. + name: human-readable project name (required) + client_id: optional UUID of the owning client + """ + result = await execute_on_client( + action="insert", + table="projects", + data={"name": name, "clientId": client_id or None}, + ) + row = result["row"] + return f"Project created: '{row['name']}' (id: {row['id']})" + + +@tool +async def update_project( + project_id: str, + name: str = "", + client_id: str = "", + status: str = "", + ai_summary: str = "", +) -> str: + """Update a project. Only pass fields that should change. + project_id: UUID of the project (required) + status: active | archived + ai_summary: AI-generated summary text (populate only when explicitly requested) + """ + updates: dict[str, Any] = {} + if name: + updates["name"] = name + if client_id: + updates["clientId"] = client_id + if status: + updates["status"] = status + if ai_summary: + updates["aiSummary"] = ai_summary + result = await execute_on_client( + action="update", + table="projects", + data={"id": project_id, "updates": updates}, + ) + row = result["row"] + return f"Project updated: '{row['name']}' (id: {row['id']}, status: {row['status']})" + + +@tool +async def delete_project(project_id: str) -> str: + """Permanently delete a project and orphan its tasks. + IMPORTANT: prefer update_project(status='archived') unless the user + has explicitly confirmed they want permanent deletion. + """ + await execute_on_client(action="delete", table="projects", data={"id": project_id}) + return f"Project {project_id} permanently deleted." + + +PROJECT_TOOLS: list[Any] = [ + list_projects, + list_all_projects, + get_project, + create_project, + update_project, + delete_project, +] + +PROJECT_READ_TOOLS: list[Any] = [ + list_projects, + list_all_projects, + get_project, +] diff --git a/api/app/agents/relations_agent.py b/api/app/agents/relations_agent.py new file mode 100644 index 0000000..5e98ab7 --- /dev/null +++ b/api/app/agents/relations_agent.py @@ -0,0 +1,63 @@ +"""Relations agent — read-only tool wrapping MemoryMiddleware.query_relations.""" + +from __future__ import annotations + +from typing import Any + +from langchain_core.tools import tool + +from app.core.memory_middleware import MemoryMiddleware +from app.db import async_session + +# Injected at tool-factory time by _brief_research_tools(); not a module-level global. +# Each tool closure captures the user_id bound at factory time. + + +def make_query_relations_tool(user_id: str, trace_id: str | None = None) -> Any: + """Return a query_relations tool bound to *user_id*.""" + + @tool + async def query_relations( + subject_label: str = "", + predicate: str = "", + object_label: str = "", + limit: int = 10, + ) -> str: + """Query the relational memory graph for entity relationships. + + Returns rows where subject ↔ predicate ↔ object match the given filters. + All parameters are optional — omit to retrieve all relations up to limit. + + subject_label: entity label on the left side (e.g. a client name, "Acme Corp"). + predicate: relationship type (e.g. "mentioned_in", "works_at", "related_to"). + object_label: entity label on the right side (e.g. a project name, "Website Redesign"). + limit: max rows to return (default 10). + """ + import logging + logger = logging.getLogger(__name__) + logger.info( + "relations_agent: query_relations trace=%s user=%s subject=%r predicate=%r object=%r", + trace_id or "-", user_id, subject_label, predicate, object_label, + ) + + async with async_session() as db: + memory = MemoryMiddleware(db) + rows = await memory.query_relations( + user_id=user_id, + subject=subject_label or None, + predicate=predicate or None, + object_=object_label or None, + limit=limit, + ) + + if not rows: + return "No relational memory entries found for the given filters." + + lines = [ + f"- {r.subject_label} —[{r.predicate}]→ {r.object_label}" + + (f" (confidence: {r.confidence:.2f})" if r.confidence is not None else "") + for r in rows + ] + return f"Found {len(rows)} relation(s):\n" + "\n".join(lines) + + return query_relations diff --git a/api/app/agents/task_agent.py b/api/app/agents/task_agent.py new file mode 100644 index 0000000..7761122 --- /dev/null +++ b/api/app/agents/task_agent.py @@ -0,0 +1,358 @@ +"""Task agent — full CRUD for tasks and task comments.""" + +from __future__ import annotations + +from datetime import datetime, timezone +import re +from typing import Any + +from langchain_core.tools import tool + +from app.core.ws_context import execute_on_client + +_UUID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$" +) + + +def _is_uuid(value: str) -> bool: + return bool(_UUID_RE.match(value)) + + +# ── Task tools ──────────────────────────────────────────────────────── + + +@tool +async def list_tasks( + project_id: str = "", + status: str = "", + priority: str = "", + assignee: str = "", + search: str = "", + order_by: str = "", + order_dir: str = "", + due_date_from: int = -1, + due_date_to: int = -1, + created_at_from: int = -1, + created_at_to: int = -1, + completed_at_from: int = -1, + completed_at_to: int = -1, + is_ai_suggested: int = -1, + limit: int = 50, + offset: int = 0, +) -> str: + """List tasks with optional filters. Returns up to `limit` results (default 50). + + project_id: UUID of the project to scope results to. + status: filter by status — todo | in_progress | done. + priority: filter by priority — high | medium | low. + assignee: substring to match against assignee names. OMIT unless the user explicitly + names a person or refers to themselves ("my tasks", "assigned to me", "mine"). + Do NOT default to the current user. + search: substring search across title and description. + order_by: sort field — dueDate | priority | createdAt | completedAt. + order_dir: asc (default) | desc. + due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit. + created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit. + completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit. + is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any. + limit: max rows to return (default 50). Use with offset to paginate. + offset: skip first N rows (default 0). + + Tip — combine *_from and *_to for a closed range; pass only one for open-ended. + Tip — prefer count_tasks for "how many" questions to avoid listing rows. + Tip — for natural-language windows ("today", "tomorrow", "this week", "last month", etc.) + take due_date_from / due_date_to verbatim from the DATE CONTEXT block in the system prompt; + do not compute boundaries from the current UTC instant. + """ + normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" + filters: dict[str, Any] = { + "projectId": normalized_project_id or None, + "status": status or None, + "priority": priority or None, + "search": search or None, + "orderBy": order_by or None, + "orderDir": order_dir or None, + "limit": limit, + "offset": offset, + } + if assignee: + filters["assignee"] = assignee + if due_date_from != -1: + filters["dueDateFrom"] = due_date_from + if due_date_to != -1: + filters["dueDateTo"] = due_date_to + if created_at_from != -1: + filters["createdAtFrom"] = created_at_from + if created_at_to != -1: + filters["createdAtTo"] = created_at_to + if completed_at_from != -1: + filters["completedAtFrom"] = completed_at_from + if completed_at_to != -1: + filters["completedAtTo"] = completed_at_to + if is_ai_suggested != -1: + filters["isAiSuggested"] = is_ai_suggested + + result = await execute_on_client(action="select", table="tasks", filters=filters) + rows = result.get("rows", []) + if not rows: + return "No tasks found matching the given filters." + lines = [ + f"- {r['title']} (status: {r['status']}, priority: {r['priority']}, " + f"dueDate: {r.get('dueDate')}, completedAt: {r.get('completedAt')}, " + f"projectId: {r.get('projectId')}, id: {r['id']})" + for r in rows + ] + return f"Found {len(rows)} task(s):\n" + "\n".join(lines) + + +@tool +async def count_tasks( + project_id: str = "", + status: str = "", + priority: str = "", + assignee: str = "", + search: str = "", + due_date_from: int = -1, + due_date_to: int = -1, + created_at_from: int = -1, + created_at_to: int = -1, + completed_at_from: int = -1, + completed_at_to: int = -1, + is_ai_suggested: int = -1, +) -> str: + """Count tasks matching the given filters without returning rows. + + Use this instead of list_tasks for "how many" questions — it is much cheaper. + Same filter parameters as list_tasks (no limit/offset/order_by needed). + assignee: OMIT unless the user explicitly names a person or refers to themselves + ("my tasks"). Do NOT default to the current user. + due_date_from / due_date_to: ms epoch range for dueDate. Use -1 to omit. + created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit. + completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit. + Tip — for natural-language windows take due_date_from / due_date_to from the DATE CONTEXT block; + do not compute boundaries from the current UTC instant. + """ + normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" + filters: dict[str, Any] = { + "projectId": normalized_project_id or None, + "status": status or None, + "priority": priority or None, + "search": search or None, + } + if assignee: + filters["assignee"] = assignee + if due_date_from != -1: + filters["dueDateFrom"] = due_date_from + if due_date_to != -1: + filters["dueDateTo"] = due_date_to + if created_at_from != -1: + filters["createdAtFrom"] = created_at_from + if created_at_to != -1: + filters["createdAtTo"] = created_at_to + if completed_at_from != -1: + filters["completedAtFrom"] = completed_at_from + if completed_at_to != -1: + filters["completedAtTo"] = completed_at_to + if is_ai_suggested != -1: + filters["isAiSuggested"] = is_ai_suggested + + result = await execute_on_client(action="count", table="tasks", filters=filters) + return f"Task count: {result.get('count', 0)}" + + +@tool +async def create_task( + title: str, + description: str = "", + status: str = "todo", + priority: str = "medium", + assignees: str = "[]", + due_date: int = 0, + project_id: str = "", + is_ai_suggested: int = 0, +) -> str: + """Create a new task. + title: task title (required) + description: optional details + status: todo | in_progress | done (default: todo) + priority: high | medium | low (default: medium) + assignees: JSON-encoded array of assignee names, e.g. '["Alice"]' + due_date: Unix timestamp in milliseconds; 0 means no due date + project_id: optional UUID of the parent project + is_ai_suggested: 1 if proactively suggested, 0 if user-requested + + completedAt is set automatically when status is 'done'. + """ + result = await execute_on_client( + action="insert", + table="tasks", + data={ + "title": title, + "description": description or None, + "status": status, + "priority": priority, + "assignee": assignees, + "dueDate": due_date or None, + "projectId": project_id or None, + "isAiSuggested": is_ai_suggested, + }, + ) + row = result["row"] + return ( + f"Task created: '{row['title']}' " + f"(id: {row['id']}, status: {row['status']}, priority: {row['priority']}, projectId: {row.get('projectId')})" + ) + + +@tool +async def update_task( + task_id: str, + title: str = "", + description: str = "", + status: str = "", + priority: str = "", + assignees: str = "", + due_date: int = -1, + project_id: str = "", +) -> str: + """Update fields on an existing task. Only pass fields you want to change. + task_id: the task's UUID (required) + due_date: -1 means unchanged; 0 clears the due date; any positive value sets it + + completedAt is managed automatically: + - setting status to 'done' records the current timestamp + - changing status away from 'done' clears completedAt + """ + updates: dict[str, Any] = {} + if title: + updates["title"] = title + if description: + updates["description"] = description + if status: + updates["status"] = status + if priority: + updates["priority"] = priority + if assignees: + updates["assignee"] = assignees + if due_date != -1: + updates["dueDate"] = due_date or None + if project_id: + updates["projectId"] = project_id + result = await execute_on_client( + action="update", + table="tasks", + data={"id": task_id, "updates": updates}, + ) + row = result["row"] + return f"Task updated: '{row['title']}' (id: {row['id']}, status: {row['status']}, projectId: {row.get('projectId')})" + + +@tool +async def delete_task(task_id: str) -> str: + """Delete a task permanently by its UUID.""" + await execute_on_client(action="delete", table="tasks", data={"id": task_id}) + return f"Task {task_id} deleted." + + +@tool +async def list_tasks_due_today(user_timezone: str = "UTC", include_done: bool = False) -> str: + """List all tasks whose due date falls on today's date. + + user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York'). + Always pass the user's timezone so 'today' is computed in their local time. + include_done: set True to also include already-completed tasks due today (default False). + """ + try: + from zoneinfo import ZoneInfo + tz = ZoneInfo(user_timezone or "UTC") + except Exception: + tz = timezone.utc + now_local = datetime.now(tz=tz) + start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz) + start_ms = int(start_dt.timestamp() * 1000) + end_ms = start_ms + 86_400_000 - 1 + filters: dict[str, Any] = {"dueDateFrom": start_ms, "dueDateTo": end_ms} + if not include_done: + filters["status"] = "todo" + result = await execute_on_client( + action="select", + table="tasks", + filters=filters, + ) + rows = result.get("rows", []) + if not rows: + return "No tasks are due today." + lines = [ + f"- {r['title']} (priority: {r['priority']}, status: {r['status']}, " + f"projectId: {r.get('projectId')}, id: {r['id']})" + for r in rows + ] + return f"Tasks due today ({len(rows)}):\n" + "\n".join(lines) + + +# ── Task comment tools ──────────────────────────────────────────────── + + +@tool +async def list_task_comments(task_id: str) -> str: + """List all comments on a task by its UUID.""" + result = await execute_on_client( + action="select", + table="taskComments", + filters={"taskId": task_id}, + ) + rows = result.get("rows", []) + if not rows: + return f"No comments found for task {task_id}." + lines = [f"- [{r['author']}]: {r['content']} (id: {r['id']})" for r in rows] + return f"Found {len(rows)} comment(s):\n" + "\n".join(lines) + + +@tool +async def add_task_comment(task_id: str, author: str, content: str) -> str: + """Add a comment to a task. + task_id: UUID of the task to comment on + author: name or ID of the comment author + content: comment text + """ + result = await execute_on_client( + action="insert", + table="taskComments", + data={"taskId": task_id, "author": author, "content": content}, + ) + row = result.get("row", {}) + row_author = row.get("author", author) + row_task_id = row.get("taskId") or row.get("task_id") or task_id + row_comment_id = row.get("id", "unknown") + return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})." + + +@tool +async def delete_task_comment(comment_id: str) -> str: + """Delete a task comment by its UUID.""" + await execute_on_client(action="delete", table="taskComments", data={"id": comment_id}) + return f"Comment {comment_id} deleted." + + +# ── Agent ───────────────────────────────────────────────────────────── + + +TASK_TOOLS: list[Any] = [ + list_tasks, + count_tasks, + create_task, + update_task, + delete_task, + list_tasks_due_today, + list_task_comments, + add_task_comment, + delete_task_comment, +] + +TASK_READ_TOOLS: list[Any] = [ + list_tasks, + count_tasks, + list_tasks_due_today, + list_task_comments, +] diff --git a/api/app/agents/timeline_agent.py b/api/app/agents/timeline_agent.py new file mode 100644 index 0000000..beeedb1 --- /dev/null +++ b/api/app/agents/timeline_agent.py @@ -0,0 +1,270 @@ +"""Timeline agent — project milestone management (list, create, update, delete).""" + +from __future__ import annotations + +import re +from datetime import datetime, timezone +from typing import Any + +from langchain_core.tools import tool + +from app.core.ws_context import execute_on_client + +_UUID_RE = re.compile( + r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$" +) + + +def _is_uuid(value: str) -> bool: + return bool(_UUID_RE.match(value)) + + +@tool +async def list_timelines( + project_id: str = "", + type: str = "", + is_completed: int = -1, + is_ai_suggested: int = -1, + order_by: str = "", + order_dir: str = "", + date_from: int = -1, + date_to: int = -1, + created_at_from: int = -1, + created_at_to: int = -1, + completed_at_from: int = -1, + completed_at_to: int = -1, + limit: int = 50, + offset: int = 0, +) -> str: + """List timeline events (milestones, checkpoints, activities) with optional filters. + + project_id: UUID to scope results to a specific project. + type: filter by event type — milestone | checkpoint | activity. + is_completed: 0 = incomplete only, 1 = completed only, -1 = any (default). + is_ai_suggested: 0 or 1 to filter by AI-suggested flag; -1 = any. + order_by: sort field — date (default) | createdAt | completedAt. + order_dir: asc (default) | desc. + date_from / date_to: ms epoch range for the event date. Use -1 to omit. + created_at_from / created_at_to: ms epoch range for createdAt. Use -1 to omit. + completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit. + limit: max rows to return (default 50). Use with offset to paginate. + offset: skip first N rows (default 0). + + Tip — combine *_from and *_to for a closed range; pass only one for open-ended. + Tip — prefer count_timelines for "how many" questions to avoid listing rows. + Tip — for natural-language windows ("today", "this week", "last month", etc.) + take date_from / date_to verbatim from the DATE CONTEXT block in the system prompt; + do not compute boundaries from the current UTC instant. + """ + normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" + filters: dict[str, Any] = { + "projectId": normalized_project_id or None, + "orderBy": order_by or None, + "orderDir": order_dir or None, + "limit": limit, + "offset": offset, + } + if type: + filters["type"] = type + if is_completed != -1: + filters["isCompleted"] = is_completed + if is_ai_suggested != -1: + filters["isAiSuggested"] = is_ai_suggested + if date_from != -1: + filters["dateFrom"] = date_from + if date_to != -1: + filters["dateTo"] = date_to + if created_at_from != -1: + filters["createdAtFrom"] = created_at_from + if created_at_to != -1: + filters["createdAtTo"] = created_at_to + if completed_at_from != -1: + filters["completedAtFrom"] = completed_at_from + if completed_at_to != -1: + filters["completedAtTo"] = completed_at_to + + result = await execute_on_client(action="select", table="timelines", filters=filters) + rows = result.get("rows", []) + if not rows: + return "No timeline events found." + lines = [ + f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, " + f"completed: {bool(r.get('isCompleted'))}, completedAt: {r.get('completedAt')}, " + f"projectId: {r.get('projectId')}, id: {r['id']})" + for r in rows + ] + return f"Found {len(rows)} timeline event(s):\n" + "\n".join(lines) + + +@tool +async def count_timelines( + project_id: str = "", + type: str = "", + is_completed: int = -1, + is_ai_suggested: int = -1, + date_from: int = -1, + date_to: int = -1, + created_at_from: int = -1, + created_at_to: int = -1, + completed_at_from: int = -1, + completed_at_to: int = -1, +) -> str: + """Count timeline events matching the given filters without returning rows. + + Use this instead of list_timelines for "how many" questions — it is much cheaper. + Same filter parameters as list_timelines (no limit/offset/order_by needed). + + date_from / date_to: ms epoch range for the event date. Use -1 to omit. + completed_at_from / completed_at_to: ms epoch range for completedAt. Use -1 to omit. + Tip — for natural-language windows take date_from / date_to from the DATE CONTEXT block; + do not compute boundaries from the current UTC instant. + """ + normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else "" + filters: dict[str, Any] = {"projectId": normalized_project_id or None} + if type: + filters["type"] = type + if is_completed != -1: + filters["isCompleted"] = is_completed + if is_ai_suggested != -1: + filters["isAiSuggested"] = is_ai_suggested + if date_from != -1: + filters["dateFrom"] = date_from + if date_to != -1: + filters["dateTo"] = date_to + if created_at_from != -1: + filters["createdAtFrom"] = created_at_from + if created_at_to != -1: + filters["createdAtTo"] = created_at_to + if completed_at_from != -1: + filters["completedAtFrom"] = completed_at_from + if completed_at_to != -1: + filters["completedAtTo"] = completed_at_to + + result = await execute_on_client(action="count", table="timelines", filters=filters) + return f"Timeline event count: {result.get('count', 0)}" + + +@tool +async def create_timeline( + project_id: str, + title: str, + date: int, + type: str = "milestone", + is_completed: int = 0, + is_ai_suggested: int = 0, +) -> str: + """Create a project timeline event. + project_id: REQUIRED UUID of the parent project + title: descriptive name for the event + date: Unix timestamp in milliseconds for the event date + type: milestone (default) | checkpoint | activity + is_completed: 1 if already completed, 0 if not (default 0) + is_ai_suggested: 1 if proactively suggested, 0 if user-requested + + completedAt is set automatically when is_completed is 1. + """ + result = await execute_on_client( + action="insert", + table="timelines", + data={ + "projectId": project_id, + "title": title, + "date": date, + "type": type, + "isCompleted": is_completed, + "isAiSuggested": is_ai_suggested, + }, + ) + row = result["row"] + return f"Timeline event created: '{row['title']}' (id: {row['id']}, date: {row['date']}, type: {row.get('type')})" + + +@tool +async def update_timeline( + timeline_id: str, + title: str = "", + date: int = -1, + is_completed: int = -1, +) -> str: + """Update a timeline event. Only pass fields that should change. + timeline_id: UUID of the event (required) + date: -1 means unchanged; any other value sets the new date (ms timestamp) + is_completed: 0 = mark incomplete, 1 = mark complete, -1 = unchanged + + completedAt is managed automatically: + - setting is_completed to 1 records the current timestamp + - setting is_completed to 0 clears completedAt + """ + updates: dict[str, Any] = {} + if title: + updates["title"] = title + if date != -1: + updates["date"] = date + if is_completed != -1: + updates["isCompleted"] = is_completed + result = await execute_on_client( + action="update", + table="timelines", + data={"id": timeline_id, "updates": updates}, + ) + row = result["row"] + return f"Timeline event updated: '{row['title']}' (id: {row['id']})" + + +@tool +async def delete_timeline(timeline_id: str) -> str: + """Delete a timeline event permanently by its UUID.""" + await execute_on_client(action="delete", table="timelines", data={"id": timeline_id}) + return f"Timeline event {timeline_id} deleted." + + +@tool +async def list_timelines_today(user_timezone: str = "UTC", include_completed: bool = True) -> str: + """List all timeline events whose date falls on today. + + user_timezone: IANA timezone name (e.g. 'Europe/Rome', 'America/New_York'). + Always pass the user's timezone so 'today' is computed in their local time. + include_completed: set False to exclude already-completed events (default True). + """ + try: + from zoneinfo import ZoneInfo + tz = ZoneInfo(user_timezone or "UTC") + except Exception: + tz = timezone.utc + now_local = datetime.now(tz=tz) + start_dt = datetime(now_local.year, now_local.month, now_local.day, tzinfo=tz) + start_ms = int(start_dt.timestamp() * 1000) + end_ms = start_ms + 86_400_000 - 1 + filters: dict[str, Any] = {"dateFrom": start_ms, "dateTo": end_ms} + if not include_completed: + filters["isCompleted"] = 0 + result = await execute_on_client( + action="select", + table="timelines", + filters=filters, + ) + rows = result.get("rows", []) + if not rows: + return "No timeline events today." + lines = [ + f"- {r['title']} (date: {r['date']}, type: {r.get('type')}, " + f"completed: {bool(r.get('isCompleted'))}, projectId: {r.get('projectId')}, id: {r['id']})" + for r in rows + ] + return f"Timeline events today ({len(rows)}):\n" + "\n".join(lines) + + +TIMELINE_TOOLS: list[Any] = [ + list_timelines, + count_timelines, + list_timelines_today, + create_timeline, + update_timeline, + delete_timeline, +] + +TIMELINE_READ_TOOLS: list[Any] = [ + list_timelines, + count_timelines, + list_timelines_today, +] diff --git a/api/app/api/__init__.py b/api/app/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/app/api/deps.py b/api/app/api/deps.py new file mode 100644 index 0000000..0339d0d --- /dev/null +++ b/api/app/api/deps.py @@ -0,0 +1,14 @@ +"""Shared FastAPI dependencies. + +``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth`` +(the canonical location per Step 9). This module re-exports them so that all +existing route imports (``from app.api.deps import get_current_user``) continue +to work without modification. + +Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL +instead of reading it from the JWT payload. +""" + +from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401 + +__all__ = ["get_current_user", "oauth2_scheme"] diff --git a/api/app/api/middleware/__init__.py b/api/app/api/middleware/__init__.py new file mode 100644 index 0000000..f67fc41 --- /dev/null +++ b/api/app/api/middleware/__init__.py @@ -0,0 +1,19 @@ +"""API middleware package. + +Exports the three middleware components introduced in Step 9: + - Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme`` + - Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter) + - Sanitizer: ``SanitizerMiddleware`` +""" + +from app.api.middleware.auth import get_current_user, oauth2_scheme +from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter +from app.api.middleware.sanitizer import SanitizerMiddleware + +__all__ = [ + "get_current_user", + "oauth2_scheme", + "TierRateLimitMiddleware", + "limiter", + "SanitizerMiddleware", +] diff --git a/api/app/api/middleware/auth.py b/api/app/api/middleware/auth.py new file mode 100644 index 0000000..3c92471 --- /dev/null +++ b/api/app/api/middleware/auth.py @@ -0,0 +1,103 @@ +"""Auth middleware — JWT validation dependency. + +``get_current_user`` is the FastAPI dependency used by all protected routes. +It decodes the Bearer JWT (identity + expiry), then fetches the current tier +from the ``subscriptions`` table so that tier changes take effect immediately +without requiring token re-issue. + +Exempt routes (no JWT required): + - POST /api/v1/auth/register + - POST /api/v1/auth/login + - POST /api/v1/billing/webhook +""" + +from __future__ import annotations + +from fastapi import Depends, HTTPException, status +from fastapi.security import OAuth2PasswordBearer +from jose import JWTError, jwt +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config.settings import settings +from app.db import get_session +from app.schemas import UserProfile + +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") + + +async def get_current_user( + token: str = Depends(oauth2_scheme), + db: AsyncSession = Depends(get_session), +) -> UserProfile: + """Validate a Bearer JWT and return the authenticated user. + + The JWT is used for identity and expiry only. The tier is fetched live + from the ``subscriptions`` table so that upgrades/downgrades take effect + immediately. Falls back to ``'free'`` when no subscription row exists. + + Raises HTTP 401 on any invalid or expired token. + """ + credentials_exc = HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not validate credentials", + headers={"WWW-Authenticate": "Bearer"}, + ) + try: + payload = jwt.decode( + token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] + ) + user_id: str | None = payload.get("sub") + email: str | None = payload.get("email") + if not user_id or not email: + raise credentials_exc + except JWTError: + raise credentials_exc + + # Live tier lookup — subscription row is the authoritative source. + # In dev, fall back to 'power' (unlimited) so quota limits don't + # block local development when no Stripe subscription exists. + from app.models import Subscription, User # noqa: PLC0415 + + result = await db.execute( + select(Subscription.tier).where(Subscription.user_id == user_id) + ) + default_tier = "power" if settings.ENV == "dev" else "free" + tier: str = result.scalar_one_or_none() or default_tier + + # Fetch name/surname/avatar_url/onboarding_completed_at/password_hash from user row. + user_result = await db.execute( + select( + User.name, User.surname, User.avatar_url, User.onboarding_completed_at, + User.password_hash, + ).where(User.id == user_id) + ) + user_row = user_result.one_or_none() + + # Convert onboarding_completed_at to epoch ms (int) or None. + onboarding_ms: int | None = None + if user_row and user_row.onboarding_completed_at is not None: + onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000) + + # Load decrypted core memory. + from app.core.memory_middleware import MemoryMiddleware # noqa: PLC0415 + + memory_dict: dict[str, str] = {} + try: + mw = MemoryMiddleware(db) + blocks = await mw.list_core_blocks(user_id) + memory_dict = {b["label"]: b["value"] for b in blocks} + except Exception: + pass # Non-critical — return empty memory on failure + + return UserProfile( + id=user_id, + email=email, + name=user_row.name if user_row else None, + surname=user_row.surname if user_row else None, + avatar_url=user_row.avatar_url if user_row else None, + has_password=bool(user_row.password_hash) if user_row else False, + tier=tier, + onboarding_completed_at=onboarding_ms, + memory=memory_dict, + ) # type: ignore[arg-type] diff --git a/api/app/api/middleware/rate_limit.py b/api/app/api/middleware/rate_limit.py new file mode 100644 index 0000000..4a2af76 --- /dev/null +++ b/api/app/api/middleware/rate_limit.py @@ -0,0 +1,129 @@ +"""Tier-aware rate limiting middleware. + +Uses a per-user sliding-window counter (in-process, no Redis required). +The ``slowapi`` Limiter is also exported for optional route-level decoration. + +Limits (requests per minute): + - free: 20 + - pro: 60 + - power: 120 + - team: 200 + +Exempt paths bypass the limiter entirely: + - POST /api/v1/auth/register + - POST /api/v1/auth/login + - POST /api/v1/billing/webhook + - GET /api/v1/health +""" + +from __future__ import annotations + +import json +import time +from collections import defaultdict + +from fastapi import Request, Response +from jose import JWTError, jwt +from slowapi import Limiter +from slowapi.util import get_remote_address +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp + +from app.config.settings import settings + +_TIER_LIMITS: dict[str, int] = { + "free": 20, + "pro": 60, + "power": 120, + "team": 200, +} + +_EXEMPT_PATHS: frozenset[str] = frozenset( + { + "/api/v1/auth/register", + "/api/v1/auth/login", + "/api/v1/billing/webhook", + "/api/v1/health", + } +) + + +def _get_user_id_from_jwt(request: Request) -> str: + """Key function for the slowapi Limiter: returns JWT sub or remote IP.""" + auth = request.headers.get("Authorization", "") + token = auth.removeprefix("Bearer ").strip() + if not token: + return get_remote_address(request) + try: + payload = jwt.decode( + token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] + ) + return payload.get("sub") or get_remote_address(request) + except JWTError: + return get_remote_address(request) + + +# Exported Limiter instance — available for optional route-level decoration. +limiter = Limiter(key_func=_get_user_id_from_jwt) + + +class TierRateLimitMiddleware(BaseHTTPMiddleware): + """Sliding-window rate limiter applied globally across all non-exempt routes. + + Each authenticated user gets their own 60-second window sized by tier. + Unauthenticated requests pass through (the auth dependency will reject them + with 401 before the route handler runs). + """ + + def __init__(self, app: ASGIApp) -> None: + super().__init__(app) + # user_id → list of request timestamps (float, seconds since epoch) + self._window: dict[str, list[float]] = defaultdict(list) + + async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override] + if request.url.path in _EXEMPT_PATHS: + return await call_next(request) + + # Extract JWT claims — if no valid token, pass through for auth dep to handle. + auth = request.headers.get("Authorization", "") + token = auth.removeprefix("Bearer ").strip() + if not token: + return await call_next(request) + + try: + payload = jwt.decode( + token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] + ) + user_id: str = payload.get("sub") or get_remote_address(request) + tier: str = payload.get("tier", "free") + except JWTError: + return await call_next(request) + + limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"]) + now = time.monotonic() + window_start = now - 60.0 + + # Slide the window: discard timestamps older than 60 seconds. + timestamps = [t for t in self._window[user_id] if t > window_start] + + if len(timestamps) >= limit: + retry_after = max(1, int(60 - (now - min(timestamps)))) + return Response( + content=json.dumps( + { + "detail": ( + f"Rate limit exceeded ({limit} req/min for {tier} tier). " + f"Retry in {retry_after}s." + ) + } + ), + status_code=429, + headers={ + "Retry-After": str(retry_after), + "Content-Type": "application/json", + }, + ) + + timestamps.append(now) + self._window[user_id] = timestamps + return await call_next(request) diff --git a/api/app/api/middleware/sanitizer.py b/api/app/api/middleware/sanitizer.py new file mode 100644 index 0000000..4dd3531 --- /dev/null +++ b/api/app/api/middleware/sanitizer.py @@ -0,0 +1,138 @@ +"""Response sanitizer middleware. + +Scans JSON responses from the /api/v1/chat endpoint and strips any fragments +that could reveal server-side prompt IP: + - System prompt openers ("You are a/an/the …") + - Agent routing metadata ("Available agents:", "intent classifier", …) + - LangChain tool schema fragments (``"type": "function"``) + - Internal reasoning markers (, , [INST], …) + - Exact-match known prompt fingerprints + +The middleware only activates for paths under /api/v1/chat. + +Any sanitisation event is logged as a WARNING with the request path and the +names of the fields that were modified. +""" + +from __future__ import annotations + +import json +import logging +import re + +from fastapi import Request, Response +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.types import ASGIApp + +logger = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Detection patterns — order matters: fingerprints checked first (exact), +# then compiled regexes. +# --------------------------------------------------------------------------- + +_FINGERPRINTS: tuple[str, ...] = ( + "You are an intent classifier", + "Respond with just the agent name", + "Summarize these agent results", + "Available agents:", + "route to:", +) + +_PATTERNS: tuple[re.Pattern[str], ...] = ( + re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL), + re.compile(r"Available agents\s*:", re.IGNORECASE), + re.compile(r"\bintent classifier\b", re.IGNORECASE), + re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema + re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE), + re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers + re.compile(r"route\s+to\s*:", re.IGNORECASE), + re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE), +) + + +def _sanitize_text(text: str) -> tuple[str, bool]: + """Scan *text* for prompt fragments and replace matches with ``[REDACTED]``. + + Returns ``(cleaned_text, was_changed)``. + """ + # Fingerprint check — if any exact phrase is present, redact the whole string. + for fp in _FINGERPRINTS: + if fp in text: + return "[REDACTED]", True + + changed = False + for pattern in _PATTERNS: + new_text, n = pattern.subn("[REDACTED]", text) + if n: + text = new_text + changed = True + + return text, changed + + +class SanitizerMiddleware(BaseHTTPMiddleware): + """Strip prompt IP from /api/v1/chat JSON responses.""" + + def __init__(self, app: ASGIApp) -> None: + super().__init__(app) + + async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override] + response: Response = await call_next(request) + + # Only process chat endpoint responses. + if not request.url.path.startswith("/api/v1/chat"): + return response + + # Read body — collect streaming chunks. + body_bytes = b"" + async for chunk in response.body_iterator: + body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode() + + # Skip non-JSON bodies (shouldn't happen on /chat, but be safe). + try: + body = json.loads(body_bytes.decode("utf-8")) + except (json.JSONDecodeError, UnicodeDecodeError): + return Response( + content=body_bytes, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + + if not isinstance(body, dict): + return Response( + content=body_bytes, + status_code=response.status_code, + headers=dict(response.headers), + media_type=response.media_type, + ) + + # Walk top-level string fields and sanitise. + sanitised_fields: list[str] = [] + for key, value in body.items(): + if isinstance(value, str): + cleaned, changed = _sanitize_text(value) + if changed: + body[key] = cleaned + sanitised_fields.append(key) + + if sanitised_fields: + logger.warning( + "Sanitizer redacted prompt fragments", + extra={ + "path": request.url.path, + "fields": sanitised_fields, + }, + ) + + new_body = json.dumps(body).encode("utf-8") + headers = dict(response.headers) + headers["content-length"] = str(len(new_body)) + + return Response( + content=new_body, + status_code=response.status_code, + headers=headers, + media_type="application/json", + ) diff --git a/api/app/api/routes/__init__.py b/api/app/api/routes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/app/api/routes/auth.py b/api/app/api/routes/auth.py new file mode 100644 index 0000000..73a8d67 --- /dev/null +++ b/api/app/api/routes/auth.py @@ -0,0 +1,795 @@ +"""Auth routes: register, login, refresh, me, OAuth social login, onboarding. + +Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens +tables). Passwords are hashed with bcrypt; refresh tokens are stored as +SHA-256 hashes so plaintext never reaches the DB. + +OAuth (Google): + GET /auth/oauth/{provider}/authorize — returns consent-screen URL + state + POST /auth/oauth/{provider}/callback — exchanges code, issues JWT tokens +""" + +from __future__ import annotations + +import hashlib +import json +import time +import urllib.parse +import uuid +from datetime import datetime, timedelta, timezone +from typing import Literal + +import bcrypt +from cryptography.fernet import Fernet +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import RedirectResponse +from jose import jwt +from pydantic import BaseModel, Field +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.auth.oauth_providers import GoogleOAuthProvider, generate_pkce_pair +from app.config.settings import settings +from app.core.llm import get_llm +from app.core.memory_middleware import MemoryMiddleware +from app.db import get_session +from app.models import OAuthAccount, RefreshToken, User +from app.schemas import AuthTokens, UserProfile + +router = APIRouter(prefix="/auth", tags=["auth"]) + + +# ── OAuth provider registry ─────────────────────────────────────────── + +def _get_google_provider() -> GoogleOAuthProvider: + if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET: + raise HTTPException( + status.HTTP_503_SERVICE_UNAVAILABLE, + "Google login is not configured on this server", + ) + return GoogleOAuthProvider( + client_id=settings.GOOGLE_AUTH_CLIENT_ID, + client_secret=settings.GOOGLE_AUTH_CLIENT_SECRET, + redirect_uri=settings.OAUTH_REDIRECT_URI, + ) + + +_PROVIDERS = {"google": _get_google_provider} + +# In-memory state store: state → (code_verifier, expires_at_epoch_s) +# Production note: replace with Redis for multi-process deployments. +_pending_states: dict[str, tuple[str, float]] = {} +_STATE_TTL_SECONDS = 600 # 10 minutes + + +# ── Internal helpers ───────────────────────────────────────────────── + + +def _hash_password(password: str) -> str: + return bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode() + + +def _verify_password(password: str, hashed: str) -> bool: + return bcrypt.checkpw(password.encode(), hashed.encode()) + + +def _hash_token(plain_token: str) -> str: + """SHA-256 of the plain refresh token string.""" + return hashlib.sha256(plain_token.encode()).hexdigest() + + +def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]: + """Return (signed JWT, expires_at_ms).""" + now = int(time.time()) + exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 + payload = { + "sub": user_id, + "email": email, + "tier": tier, + "exp": exp, + "iat": now, + } + token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) + return token, exp * 1000 # ms for client + + +# ── Request bodies ──────────────────────────────────────────────────── + + +class _RegisterRequest(BaseModel): + email: str + password: str + name: str | None = None + surname: str | None = None + + +class _LoginRequest(BaseModel): + email: str + password: str + + +class _RefreshRequest(BaseModel): + refresh_token: str + + +# ── Routes ──────────────────────────────────────────────────────────── + + +@router.post("/register", response_model=AuthTokens, status_code=status.HTTP_201_CREATED) +async def register( + body: _RegisterRequest, + db: AsyncSession = Depends(get_session), +) -> AuthTokens: + """Create a new account and return JWT tokens.""" + existing = await db.execute(select(User).where(User.email == body.email)) + if existing.scalar_one_or_none() is not None: + raise HTTPException(status.HTTP_409_CONFLICT, "Email already registered") + + user = User( + id=str(uuid.uuid4()), + email=body.email, + name=body.name, + surname=body.surname, + password_hash=_hash_password(body.password), + tier="free", + encryption_key=Fernet.generate_key().decode(), + ) + db.add(user) + await db.flush() # get user.id without committing + + plain_token = str(uuid.uuid4()) + expires_at = datetime.now(timezone.utc) + timedelta( + days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS + ) + rt = RefreshToken( + user_id=user.id, + token_hash=_hash_token(plain_token), + expires_at=expires_at, + ) + db.add(rt) + await db.commit() + + access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) + return AuthTokens( + access_token=access_token, + refresh_token=plain_token, + expires_at=expires_at_ms, + ) + + +@router.post("/login", response_model=AuthTokens) +async def login( + body: _LoginRequest, + db: AsyncSession = Depends(get_session), +) -> AuthTokens: + """Validate credentials and return JWT tokens.""" + result = await db.execute(select(User).where(User.email == body.email)) + user = result.scalar_one_or_none() + if user is None or not _verify_password(body.password, user.password_hash): + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials") + + plain_token = str(uuid.uuid4()) + expires_at = datetime.now(timezone.utc) + timedelta( + days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS + ) + rt = RefreshToken( + user_id=user.id, + token_hash=_hash_token(plain_token), + expires_at=expires_at, + ) + db.add(rt) + await db.commit() + + access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) + return AuthTokens( + access_token=access_token, + refresh_token=plain_token, + expires_at=expires_at_ms, + ) + + +@router.post("/refresh", response_model=AuthTokens) +async def refresh( + body: _RefreshRequest, + db: AsyncSession = Depends(get_session), +) -> AuthTokens: + """Rotate a refresh token and return a new token pair.""" + token_hash = _hash_token(body.refresh_token) + result = await db.execute( + select(RefreshToken).where(RefreshToken.token_hash == token_hash) + ) + rt = result.scalar_one_or_none() + + now = datetime.now(timezone.utc) + if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token") + + # Rotate: delete old token, issue new one. + await db.delete(rt) + + user_result = await db.execute(select(User).where(User.id == rt.user_id)) + user = user_result.scalar_one_or_none() + if user is None: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found") + + plain_token = str(uuid.uuid4()) + new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) + new_rt = RefreshToken( + user_id=user.id, + token_hash=_hash_token(plain_token), + expires_at=new_expires, + ) + db.add(new_rt) + await db.commit() + + access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) + return AuthTokens( + access_token=access_token, + refresh_token=plain_token, + expires_at=expires_at_ms, + ) + + +class _UpdateProfileRequest(BaseModel): + name: str | None = None + surname: str | None = None + + +@router.get("/me", response_model=UserProfile) +async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile: + """Return the profile for the authenticated user.""" + return current_user + + +@router.put("/me", response_model=UserProfile) +async def update_profile( + body: _UpdateProfileRequest, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> UserProfile: + """Update the authenticated user's name and surname.""" + result = await db.execute(select(User).where(User.id == current_user.id)) + user = result.scalar_one() + + if body.name is not None: + user.name = body.name + if body.surname is not None: + user.surname = body.surname + + await db.commit() + await db.refresh(user) + + return UserProfile( + id=user.id, + email=user.email, + name=user.name, + surname=user.surname, + avatar_url=user.avatar_url, + tier=current_user.tier, + ) + + +# ── OAuth helpers ───────────────────────────────────────────────────── + + +async def _issue_refresh_token(user: User, db: AsyncSession) -> tuple[str, AuthTokens]: + """Create a refresh token row and return (plain_token, AuthTokens).""" + plain_token = str(uuid.uuid4()) + expires_at = datetime.now(timezone.utc) + timedelta( + days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS + ) + rt = RefreshToken( + user_id=user.id, + token_hash=_hash_token(plain_token), + expires_at=expires_at, + ) + db.add(rt) + access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) + return plain_token, AuthTokens( + access_token=access_token, + refresh_token=plain_token, + expires_at=expires_at_ms, + ) + + +# ── OAuth request/response schemas ─────────────────────────────────── + + +class _OAuthAuthorizeResponse(BaseModel): + url: str + state: str + + +class _OAuthCallbackRequest(BaseModel): + code: str + state: str + + +# ── OAuth routes ────────────────────────────────────────────────────── + + +@router.get( + "/oauth/{provider}/web-callback", + summary="Web-facing OAuth redirect — bounces to the adiuvai:// deep link", + include_in_schema=False, +) +async def oauth_web_callback( + provider: Literal["google"], + code: str, + state: str, +) -> RedirectResponse: + """Google redirects here after user consent. + + This endpoint immediately redirects to the Electron deep-link URI so the + desktop app receives the authorization code. It is intentionally simple — + no state validation here (the Electron app + backend callback do that). + + Registered in Google Cloud Console as: + http://localhost:8000/api/v1/auth/oauth/google/web-callback (dev) + https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback (prod) + """ + params = urllib.parse.urlencode({"code": code, "state": state, "provider": provider}) + deep_link = f"adiuvai://oauth/callback?{params}" + return RedirectResponse(url=deep_link, status_code=302) + + +@router.get( + "/oauth/{provider}/authorize", + response_model=_OAuthAuthorizeResponse, + summary="Start OAuth flow — returns the provider consent-screen URL", +) +async def oauth_authorize( + provider: Literal["google"], +) -> _OAuthAuthorizeResponse: + """Generate a PKCE state + code_challenge and return the authorization URL. + + The client opens this URL in the system browser. After the user grants + consent, the provider redirects to the deep-link URI (adiuvai://oauth/callback) + with ``code`` and ``state`` query params. The client then calls + ``POST /auth/oauth/{provider}/callback`` with those values. + """ + provider_factory = _PROVIDERS.get(provider) + if provider_factory is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}") + + oauth_provider = provider_factory() + state = str(uuid.uuid4()) + code_verifier, code_challenge = generate_pkce_pair() + + # Purge expired states to prevent unbounded growth. + now = time.time() + expired = [s for s, (_, exp) in _pending_states.items() if exp < now] + for s in expired: + del _pending_states[s] + + _pending_states[state] = (code_verifier, now + _STATE_TTL_SECONDS) + + url = oauth_provider.get_authorization_url(state=state, code_challenge=code_challenge) + return _OAuthAuthorizeResponse(url=url, state=state) + + +@router.post( + "/oauth/{provider}/callback", + response_model=AuthTokens, + summary="Complete OAuth flow — exchange code and issue JWT tokens", +) +async def oauth_callback( + provider: Literal["google"], + body: _OAuthCallbackRequest, + db: AsyncSession = Depends(get_session), +) -> AuthTokens: + """Validate state, exchange the authorization code, and sign in (or register) the user. + + Resolution order: + 1. ``oauth_accounts`` row match → existing user, log in. + 2. Email match + ``email_verified=True`` → link OAuth account to existing user. + 3. No match → create new user (password_hash=None, avatar from provider). + """ + provider_factory = _PROVIDERS.get(provider) + if provider_factory is None: + raise HTTPException(status.HTTP_400_BAD_REQUEST, f"Unknown provider: {provider}") + + # Validate state (CSRF protection). + now = time.time() + entry = _pending_states.pop(body.state, None) + if entry is None or entry[1] < now: + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state") + + code_verifier, _ = entry + + oauth_provider = provider_factory() + + # Exchange code for tokens. + try: + token_data = await oauth_provider.exchange_code( + code=body.code, + code_verifier=code_verifier, + redirect_uri=settings.OAUTH_REDIRECT_URI, + ) + except Exception: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, "Failed to exchange authorization code" + ) + + access_token_google = token_data.get("access_token") + if not access_token_google: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "No access token in provider response") + + # Fetch user identity. + try: + userinfo = await oauth_provider.get_userinfo(access_token_google) + except Exception: + raise HTTPException(status.HTTP_400_BAD_REQUEST, "Failed to fetch user info from provider") + + # ── Resolution order ────────────────────────────────────────────── + + # 1. Existing OAuth link? + oauth_result = await db.execute( + select(OAuthAccount).where( + OAuthAccount.provider == provider, + OAuthAccount.provider_user_id == userinfo.provider_user_id, + ) + ) + oauth_account = oauth_result.scalar_one_or_none() + + if oauth_account is not None: + user_result = await db.execute(select(User).where(User.id == oauth_account.user_id)) + user = user_result.scalar_one() + # Backfill avatar if the user doesn't have one yet. + if user.avatar_url is None and userinfo.avatar_url: + user.avatar_url = userinfo.avatar_url + await db.commit() + plain_token, tokens = await _issue_refresh_token(user, db) + await db.commit() + return tokens + + # 2. Email match with a verified Google email → link accounts. + if userinfo.email_verified: + email_result = await db.execute(select(User).where(User.email == userinfo.email)) + existing_user = email_result.scalar_one_or_none() + + if existing_user is not None: + new_link = OAuthAccount( + user_id=existing_user.id, + provider=provider, + provider_user_id=userinfo.provider_user_id, + provider_email=userinfo.email, + ) + db.add(new_link) + if existing_user.avatar_url is None and userinfo.avatar_url: + existing_user.avatar_url = userinfo.avatar_url + plain_token, tokens = await _issue_refresh_token(existing_user, db) + await db.commit() + return tokens + + # Guard: if the email is already taken but we couldn't auto-link (e.g. + # email_verified=False), refuse with 409 instead of hitting a DB constraint. + if not userinfo.email_verified: + conflict = await db.execute(select(User).where(User.email == userinfo.email)) + if conflict.scalar_one_or_none() is not None: + raise HTTPException( + status.HTTP_409_CONFLICT, + "An account with this email already exists. " + "Please sign in with your password.", + ) + + # 3. New user — social-only account (no password). + new_user = User( + id=str(uuid.uuid4()), + email=userinfo.email, + name=userinfo.name, + password_hash=None, + avatar_url=userinfo.avatar_url, + tier="free", + encryption_key=Fernet.generate_key().decode(), + ) + db.add(new_user) + await db.flush() # populate new_user.id + + new_oauth = OAuthAccount( + user_id=new_user.id, + provider=provider, + provider_user_id=userinfo.provider_user_id, + provider_email=userinfo.email, + ) + db.add(new_oauth) + + plain_token, tokens = await _issue_refresh_token(new_user, db) + await db.commit() + return tokens + + +# ── Onboarding helpers ──────────────────────────────────────────────── + + +async def _build_profile(user_id: str, email: str, db: AsyncSession) -> UserProfile: + """Re-fetch and return a full UserProfile (reuses get_current_user logic).""" + + # We can't call the FastAPI dependency directly, but we can replicate + # the core logic inline. Instead, we just re-query the same way. + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription.tier).where(Subscription.user_id == user_id) + ) + default_tier = "power" if settings.ENV == "dev" else "free" + tier: str = result.scalar_one_or_none() or default_tier + + user_result = await db.execute( + select( + User.name, User.surname, User.avatar_url, User.onboarding_completed_at, + User.password_hash, + ).where(User.id == user_id) + ) + user_row = user_result.one_or_none() + + onboarding_ms: int | None = None + if user_row and user_row.onboarding_completed_at is not None: + onboarding_ms = int(user_row.onboarding_completed_at.timestamp() * 1000) + + memory_dict: dict[str, str] = {} + try: + mw = MemoryMiddleware(db) + blocks = await mw.list_core_blocks(user_id) + memory_dict = {b["label"]: b["value"] for b in blocks} + except Exception: + pass + + return UserProfile( + id=user_id, + email=email, + name=user_row.name if user_row else None, + surname=user_row.surname if user_row else None, + avatar_url=user_row.avatar_url if user_row else None, + has_password=bool(user_row.password_hash) if user_row else False, + tier=tier, + onboarding_completed_at=onboarding_ms, + memory=memory_dict, + ) + + +# ── Onboarding routes ──────────────────────────────────────────────── + + +class _UpdateMemoryRequest(BaseModel): + memory: dict[str, str] = Field(default_factory=dict) + mark_onboarded: bool = False + + +@router.put("/me/memory", response_model=UserProfile) +async def update_memory( + body: _UpdateMemoryRequest, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> UserProfile: + """Update core memory key/value pairs and optionally mark onboarding complete.""" + mw = MemoryMiddleware(db) + for key, value in body.memory.items(): + await mw.update_core(current_user.id, key, value) + if body.mark_onboarded: + result = await db.execute(select(User).where(User.id == current_user.id)) + user = result.scalar_one() + user.onboarding_completed_at = datetime.now(timezone.utc) + await db.commit() + return await _build_profile(current_user.id, current_user.email, db) + + +@router.post("/me/onboarding/reset") +async def reset_onboarding( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +): + """Reset onboarding so the wizard runs again on next login.""" + result = await db.execute(select(User).where(User.id == current_user.id)) + user = result.scalar_one() + user.onboarding_completed_at = None + await db.commit() + return {"status": "reset"} + + +class _NormalizeRequest(BaseModel): + inputs: dict[str, str] + + +class _NormalizeResponse(BaseModel): + normalized: dict[str, str] + + +@router.post("/onboarding/normalize", response_model=_NormalizeResponse) +async def normalize_onboarding( + body: _NormalizeRequest, + current_user: UserProfile = Depends(get_current_user), +) -> _NormalizeResponse: + """One-shot LLM normalization for free-text onboarding answers.""" + if not body.inputs: + return _NormalizeResponse(normalized={}) + try: + llm = get_llm(model="gpt-4o-mini", temperature=0) + prompt = ( + "You normalize user onboarding answers into clean, ≤3-word canonical labels.\n" + "Return a JSON object with the same keys and normalized values.\n" + "Examples: 'i build websites' → 'Web Developer', 'tech-ish stuff' → 'Technology'\n" + f"Input: {json.dumps(body.inputs)}" + ) + response = await llm.ainvoke( + [ + {"role": "system", "content": "You normalize user inputs. Return JSON only."}, + {"role": "user", "content": prompt}, + ], + ) + normalized = json.loads(response.content) + return _NormalizeResponse(normalized=normalized) + except Exception: + # LLM failure must never block onboarding — return inputs unchanged + return _NormalizeResponse(normalized=body.inputs) + + +# ── Password management ─────────────────────────────────────────────── + + +class _ChangePasswordRequest(BaseModel): + current_password: str = Field(min_length=1) + new_password: str = Field(min_length=8) + + +@router.put("/me/password", status_code=status.HTTP_200_OK) +async def change_password( + body: _ChangePasswordRequest, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, bool]: + """Change the authenticated user's password. + + Requires the current password for verification. + Returns 400 for social-only users (no password set). + """ + result = await db.execute(select(User).where(User.id == current_user.id)) + user = result.scalar_one() + + if user.password_hash is None: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + "This account uses social login and has no password to change", + ) + + if not _verify_password(body.current_password, user.password_hash): + raise HTTPException(status.HTTP_400_BAD_REQUEST, "Current password is incorrect") + + user.password_hash = _hash_password(body.new_password) + await db.commit() + return {"ok": True} + + +# ── OAuth account management ───────────────────────────────────────── + + +@router.get("/me/oauth-accounts", response_model=list[dict]) +async def list_oauth_accounts( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> list[dict]: + """List all OAuth providers linked to the authenticated user.""" + result = await db.execute( + select(OAuthAccount).where(OAuthAccount.user_id == current_user.id) + ) + accounts = result.scalars().all() + return [ + { + "provider": a.provider, + "provider_email": a.provider_email, + "created_at": int(a.created_at.timestamp() * 1000), + } + for a in accounts + ] + + +@router.delete("/me/oauth-accounts/{provider}", status_code=status.HTTP_200_OK) +async def unlink_oauth_account( + provider: str, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, bool]: + """Unlink an OAuth provider from the authenticated user. + + Refuses if the user has no password and this is their only login method. + """ + result = await db.execute(select(User).where(User.id == current_user.id)) + user = result.scalar_one() + + oauth_result = await db.execute( + select(OAuthAccount).where( + OAuthAccount.user_id == current_user.id, + OAuthAccount.provider == provider, + ) + ) + account = oauth_result.scalar_one_or_none() + if account is None: + raise HTTPException(status.HTTP_404_NOT_FOUND, f"No linked {provider} account found") + + # Safety: don't let users lock themselves out. + all_oauth = await db.execute( + select(OAuthAccount).where(OAuthAccount.user_id == current_user.id) + ) + oauth_count = len(all_oauth.scalars().all()) + + if user.password_hash is None and oauth_count <= 1: + raise HTTPException( + status.HTTP_400_BAD_REQUEST, + "Cannot unlink the only login method. Set a password first.", + ) + + await db.delete(account) + await db.commit() + return {"ok": True} + + +# ── Avatar update ───────────────────────────────────────────────────── + + +class _UpdateAvatarRequest(BaseModel): + avatar_url: str = Field(min_length=1) + + +@router.put("/me/avatar", response_model=UserProfile) +async def update_avatar( + body: _UpdateAvatarRequest, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> UserProfile: + """Update the authenticated user's avatar URL. + + Accepts {"avatar_url": "https://..."} — the client uploads the image + to its own storage and passes the resulting URL here. + """ + if not body.avatar_url.startswith(("https://", "http://", "data:image/")): + raise HTTPException(status.HTTP_400_BAD_REQUEST, "Invalid avatar URL") + + result = await db.execute(select(User).where(User.id == current_user.id)) + user = result.scalar_one() + user.avatar_url = body.avatar_url + await db.commit() + + return await _build_profile(current_user.id, current_user.email, db) + + +# ── Account deletion ───────────────────────────────────────────────── + + +@router.delete("/me", status_code=status.HTTP_200_OK) +async def delete_account( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, bool]: + """Permanently delete the authenticated user's account. + + Cascades: refresh tokens, OAuth accounts, subscription, and all memory + rows are deleted via SQLAlchemy relationship cascades. Stripe subscription + is cancelled if active. + """ + # Cancel Stripe subscription if present. + try: + from app.billing.stripe_service import stripe_service # noqa: PLC0415 + await stripe_service.cancel_subscription(current_user.id, db) + except HTTPException: + pass # No subscription — that's fine + + # Delete all memory rows (core, associative, episodic, proactive). + try: + from app.models import ( # noqa: PLC0415 + MemoryAssociative, MemoryCore, MemoryEpisodic, MemoryProactive, + ) + for model in (MemoryCore, MemoryAssociative, MemoryEpisodic, MemoryProactive): + await db.execute( + model.__table__.delete().where(model.user_id == current_user.id) + ) + except Exception: + pass # Non-critical — cascade on User will handle most + + # Delete the user row — cascades handle refresh_tokens, oauth_accounts, subscription. + result = await db.execute(select(User).where(User.id == current_user.id)) + user = result.scalar_one() + await db.delete(user) + await db.commit() + + return {"ok": True} diff --git a/api/app/api/routes/billing.py b/api/app/api/routes/billing.py new file mode 100644 index 0000000..fe21b38 --- /dev/null +++ b/api/app/api/routes/billing.py @@ -0,0 +1,132 @@ +"""Billing routes: Stripe checkout, webhook, subscription management. + +Business logic lives in ``app.billing.stripe_service.StripeService``. +The route layer handles HTTP concerns (request parsing, response shaping) +and delegates everything else to the service singleton. +""" + +from __future__ import annotations + +from typing import Any + +from fastapi import APIRouter, Depends, Header, HTTPException, Request, status +from pydantic import BaseModel +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.billing.stripe_service import stripe_service +from app.db import get_session +from app.schemas import BillingTier, UserProfile + +router = APIRouter(prefix="/billing", tags=["billing"]) + + +# ── Request bodies ───────────────────────────────────────────────────── + +class _CheckoutRequest(BaseModel): + tier: BillingTier + + +# ── Routes ───────────────────────────────────────────────────────────── + +@router.post("/checkout", response_model=dict) +async def create_checkout( + body: _CheckoutRequest, + current_user: UserProfile = Depends(get_current_user), +) -> dict[str, str]: + """Create a Stripe checkout session for a tier upgrade. + + Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured. + """ + url = stripe_service.create_checkout_session(current_user.id, body.tier) + return {"checkout_url": url} + + +@router.post("/webhook", response_model=dict) +async def stripe_webhook( + request: Request, + stripe_signature: str = Header(default="", alias="Stripe-Signature"), + db: AsyncSession = Depends(get_session), +) -> dict[str, bool]: + """Handle Stripe webhook events. + + No JWT auth — authenticated via Stripe signature verification instead. + Returns 200 immediately when Stripe is not configured (local dev). + """ + payload = await request.body() + await stripe_service.handle_webhook(payload, stripe_signature, db) + return {"ok": True} + + +@router.get("/subscription", response_model=dict) +async def get_subscription( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, Any]: + """Return the current subscription info for the authenticated user.""" + sub = await stripe_service.get_subscription(current_user.id, db) + if sub is None: + return { + "tier": current_user.tier, + "status": "free", + "stripe_subscription_id": None, + "current_period_end": None, + } + return sub + + +@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK) +async def cancel_subscription( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, bool]: + """Cancel the active subscription.""" + await stripe_service.cancel_subscription(current_user.id, db) + return {"ok": True} + + +@router.get("/invoices", response_model=list[dict]) +async def list_invoices( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> list[dict[str, Any]]: + """Return billing history (invoices) from Stripe. + + Returns an empty list when Stripe is not configured. + """ + invoices = await stripe_service.list_invoices(current_user.id, db) + return invoices + + +# ── Quota check ──────────────────────────────────────────────────────── + +from app.billing.quota import check_folder_quota, QuotaExceeded # noqa: E402 + + +class QuotaCheckRequest(BaseModel): + feature: str + estimated_files: int + + +@router.post("/quota/check") +async def quota_check( + payload: QuotaCheckRequest, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict: + """Pre-flight folder quota check. 402 if tier limits would be exceeded.""" + if payload.feature != "folder_index": + raise HTTPException(status_code=400, detail="Unknown feature") + try: + await check_folder_quota( + user_id=current_user.id, + tier=current_user.tier, + estimated_files=payload.estimated_files, + db=db, + ) + except QuotaExceeded as exc: + raise HTTPException( + status_code=402, + detail={"reason": exc.reason, "message": str(exc)}, + ) + return {"ok": True} diff --git a/api/app/api/routes/chat.py b/api/app/api/routes/chat.py new file mode 100644 index 0000000..3908b0f --- /dev/null +++ b/api/app/api/routes/chat.py @@ -0,0 +1,116 @@ +"""Chat routes: POST /chat (REST fallback) and POST /chat/embed (text → vector). + +WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device). +""" + +from __future__ import annotations + +import uuid +from typing import Literal + +from fastapi import APIRouter, Depends, HTTPException +from fastapi.responses import JSONResponse +from pydantic import BaseModel + +from app.api.deps import get_current_user +from app.core.brief_agent import run_home_brief, run_project_brief +from app.core.deep_agent import run_home +from app.core.llm import embed +from app.core.memory_middleware import MemoryMiddleware +from app.db import async_session +from app.schemas import ChatRequest, UserProfile + +router = APIRouter(prefix="/chat", tags=["chat"]) + + +# ── Embed helpers ───────────────────────────────────────────────────────── + + +class _EmbedRequest(BaseModel): + text: str + + +class _EmbedResponse(BaseModel): + vector: list[float] + + +# ── Endpoints ───────────────────────────────────────────────────────────── + + +@router.post("") +async def chat( + body: ChatRequest, + current_user: UserProfile = Depends(get_current_user), +) -> JSONResponse: + """REST fallback for home chat when websocket streaming is unavailable.""" + response = await run_home( + user_id=current_user.id, + message=body.message, + context=body.context.model_dump(), + ) + return JSONResponse(content={"response": response}) + + +class _BriefRequest(BaseModel): + mode: Literal["home", "project"] + project_id: str | None = None + + +class _BriefResponse(BaseModel): + response: str + + +@router.post("/brief", response_model=_BriefResponse) +async def brief( + body: _BriefRequest, + current_user: UserProfile = Depends(get_current_user), +) -> _BriefResponse: + """REST fallback for brief when the device WebSocket is not ready.""" + if body.mode == "project": + if not body.project_id: + raise HTTPException(status_code=422, detail="project_id required for project mode") + try: + uuid.UUID(body.project_id) + except ValueError: + raise HTTPException(status_code=422, detail="project_id must be a valid UUID") + + request_id = str(uuid.uuid4()) + async with async_session() as db: + memory = MemoryMiddleware(db) + memory_context = await memory.enrich_context( + current_user.id, + "", + trace_id=request_id, + session_id=request_id, + ) + + context: dict = { + "_debug": {"request_id": request_id, "user_id": current_user.id}, + **memory_context, + } + + chunks: list[str] = [] + if body.mode == "project": + stream = run_project_brief(current_user.id, body.project_id, context) # type: ignore[arg-type] + else: + stream = run_home_brief(current_user.id, context) + + async for event_type, data in stream: + if event_type == "token" and data: + chunks.append(str(data)) + + return _BriefResponse(response="".join(chunks)) + + +@router.post("/embed", response_model=_EmbedResponse) +async def embed_text( + body: _EmbedRequest, + current_user: UserProfile = Depends(get_current_user), +) -> _EmbedResponse: + """Generate a 1536-dim embedding vector for the given text. + + Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT). + Used by Electron (vectordb.ts) for local note search. + """ + vector = await embed(body.text) + return _EmbedResponse(vector=vector) diff --git a/api/app/api/routes/device_ws.py b/api/app/api/routes/device_ws.py new file mode 100644 index 0000000..5116b8e --- /dev/null +++ b/api/app/api/routes/device_ws.py @@ -0,0 +1,864 @@ +"""Device WebSocket endpoint. + +Persistent connection from Electron devices to the backend. + + WS /api/v1/ws/device?token= + +Auth: JWT passed as ``?token=`` query parameter (Bearer header is not +available during the WebSocket handshake). + +Protocol: + 1. Client connects → JWT validated → connection accepted. + 2. Client sends ``device_hello`` frame: ``{ type, device_id, scout_ids }``. + 3. Backend registers the connection in ``DeviceConnectionManager``. + 4. Session enters message dispatch loop + heartbeat. + +Incoming frame dispatch: + - ``tool_result`` → resolves a pending tool-call Future. + - ``journey_start`` → starts a guided setup journey session. + - ``journey_message`` → continues a journey conversation. + - ``pong`` → heartbeat acknowledgement (updates last-seen). + - unknown types → logged, ignored. + +Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s. + +On disconnect: + - Unregisters from DeviceConnectionManager. + - Marks all in-progress AgentRunLog rows for this user as ``error`` + with message "device disconnected". +""" + +from __future__ import annotations + +import asyncio +import json +import logging +from uuid import uuid4 + +from fastapi import APIRouter, WebSocket, WebSocketDisconnect +from jose import JWTError, jwt +from sqlalchemy import update + +from app.api.routes.scout_setup import handle_journey_message, handle_journey_start +from app.config.settings import settings +from app.scouts.engine import ScoutEngine +from app.core.scout_runner import trigger_pending_runs +from app.core.scout_session_buffer import session_buffer +from app.core.brief_agent import run_home_brief, run_project_brief +from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream +from app.core.output_formatter import extract_canvas_block +from app.core.device_manager import device_manager +from app.core.memory_middleware import MemoryMiddleware +from app.core.output_formatter import StreamFormatter +from app.core.ws_context import clear_client_executor, set_client_executor +from app.db import async_session +from app.models import ScoutRunLog +from app.schemas import WsFrameType, WsStreamEnd +from app.schemas.contextual import ContextualScope, render_scope_block + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/ws", tags=["device-ws"]) + +# ── v7 folder index session state ───────────────────────────────────── +# Keyed by sessionId; value: { user_id, project_id, processed, total, cancelled } +_index_sessions: dict[str, dict] = {} + +_HEARTBEAT_INTERVAL = 30 # seconds +_PONG_TIMEOUT = 10 # seconds — grace window after a ping + + +@router.websocket("/device") +async def device_ws(websocket: WebSocket) -> None: + """Persistent WebSocket endpoint for Electron device connections. + + Authentication is via ``?token=`` query parameter. + """ + # ── 1. Authenticate before accepting ───────────────────────────── + token = websocket.query_params.get("token", "") + try: + payload = jwt.decode( + token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] + ) + user_id: str | None = payload.get("sub") + if not user_id: + raise JWTError("missing sub") + except JWTError: + await websocket.close(code=1008) # Policy Violation + return + + await websocket.accept() + + # ── 2. Await device_hello frame ─────────────────────────────────── + try: + raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0) + except (asyncio.TimeoutError, WebSocketDisconnect): + await websocket.close(code=1008) + return + + try: + hello = json.loads(raw) + if hello.get("type") != WsFrameType.device_hello: + raise ValueError("expected device_hello as first frame") + device_id: str = hello["device_id"] + scout_ids: list[str] = hello.get("scout_ids", []) + except (KeyError, ValueError, json.JSONDecodeError) as exc: + logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc) + await websocket.close(code=1008) + return + + # ── 3. Register connection ──────────────────────────────────────── + device_manager.register(user_id, device_id, websocket) + logger.info( + "device_ws: connected user=%s device=%s scouts=%s", + user_id, + device_id, + scout_ids, + ) + + # Trigger any overdue agent runs now that the device is connected. + asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager)) + + # Drain any queued scout proposals and deliver to the client (non-blocking). + async def _deliver_pending_safe() -> None: + import uuid as _uuid # noqa: PLC0415 + try: + await ScoutEngine().deliver_pending(_uuid.UUID(user_id), websocket) + except Exception: + logger.exception("scout deliver_pending failed for user %s", user_id) + + asyncio.create_task(_deliver_pending_safe()) + + # ── 4. Concurrent message loop + heartbeat ──────────────────────── + try: + await asyncio.gather( + _message_loop(websocket, user_id), + _heartbeat_loop(websocket), + ) + except WebSocketDisconnect: + pass + except Exception as exc: + logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc) + finally: + device_manager.unregister(user_id) + logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id) + await _mark_runs_disconnected(user_id) + + +# ── Message dispatch loop ───────────────────────────────────────────── + +async def _message_loop(websocket: WebSocket, user_id: str) -> None: + """Receive frames from Electron and dispatch to the appropriate handler.""" + async for raw in websocket.iter_text(): + try: + frame: dict = json.loads(raw) + except json.JSONDecodeError: + logger.warning("device_ws: invalid JSON from user=%s", user_id) + continue + + frame_type = frame.get("type") + + if frame_type == WsFrameType.tool_result: + call_id = frame.get("id") + if call_id: + device_manager.resolve_pending_call(user_id, call_id, frame) + else: + logger.warning( + "device_ws: tool_result missing id from user=%s", user_id + ) + + elif frame_type == WsFrameType.home_request: + asyncio.create_task( + _handle_home_request(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.brief_request: + asyncio.create_task( + _handle_brief_request(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.task_brief_request: + asyncio.create_task( + _handle_task_brief_request(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.journey_start: + asyncio.create_task( + _handle_journey_start(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.journey_message: + asyncio.create_task( + _handle_journey_message(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.index_session_start: + asyncio.create_task( + _handle_index_session_start(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.index_file_batch: + asyncio.create_task( + _handle_index_file_batch(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.index_session_cancel: + await _handle_index_session_cancel(websocket, frame) + + elif frame_type == WsFrameType.contextual_request: + asyncio.create_task( + _handle_contextual_request(websocket, user_id, frame) + ) + + elif frame_type == WsFrameType.contextual_scope_update: + asyncio.create_task( + _handle_contextual_scope_update(websocket, user_id, frame) + ) + + elif frame_type == "scout_proposal_ack": + proposal_id = frame.get("proposal_id") + if proposal_id: + try: + await ScoutEngine().ack_proposal(proposal_id) + except Exception: + logger.exception("scout ack_proposal failed for %s", proposal_id) + + elif frame_type == "pong": + # Heartbeat ack — nothing to do, connection is alive. + pass + + else: + logger.debug( + "device_ws: unknown frame type %r from user=%s", frame_type, user_id + ) + + +# ── v3 Chat Handlers ────────────────────────────────────────────────── + +async def _make_ws_executor(websocket: WebSocket, user_id: str): + """Return a callback that sends tool_call frames and awaits tool_result.""" + async def _executor(payload: dict) -> dict: + payload["type"] = WsFrameType.tool_call + await websocket.send_text(json.dumps(payload)) + future = device_manager.create_pending_call(user_id, payload["id"]) + return await future + return _executor + + +async def _handle_home_request( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a home_request frame — streams HomeFormatter output back on the socket.""" + request_id = frame.get("request_id") or str(uuid4()) + message: str = frame.get("message", "") + session_id: str = frame.get("session_id") or str(uuid4()) + project_id: str | None = frame.get("project_id") or frame.get("projectId") or None + logger.info( + "device_ws: home_request_start user=%s req=%s session=%s project=%s msg=%s", + user_id, + request_id, + session_id, + project_id, + message[:200], + ) + + # ── Memory: enrich context before LLM call ──────────────────────── + async with async_session() as db: + memory = MemoryMiddleware(db) + memory_context = await memory.enrich_context( + user_id, + message, + trace_id=request_id, + session_id=session_id, + ) + + context: dict = { + "conversation_history": frame.get("conversation_history", []), + "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, + "format_prefs": frame.get("format_prefs"), + **memory_context, + } + + executor = await _make_ws_executor(websocket, user_id) + set_client_executor(executor) + response_chunks: list[str] = [] + try: + event_stream = run_home_stream(user_id, message, context, project_id=project_id) + formatter = StreamFormatter(request_id=request_id) + async for ws_frame in formatter.format(event_stream): + await websocket.send_text(ws_frame.model_dump_json()) + # Collect text chunks to build the full response for episode storage + if ws_frame.type == "stream_text": # type: ignore[union-attr] + response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] + except Exception as exc: + logger.error( + "device_ws: home_request failed user=%s req=%s: %s", + user_id, request_id, exc, + ) + finally: + clear_client_executor() + + # ── Memory: store episode after response ────────────────────────── + async with async_session() as db: + memory = MemoryMiddleware(db) + await memory.store_episode( + user_id, session_id, message, "".join(response_chunks), trace_id=request_id + ) + logger.info( + "device_ws: home_request_end user=%s req=%s session=%s response_chars=%d", + user_id, + request_id, + session_id, + len("".join(response_chunks)), + ) + + +# ── v8 Contextual Sidebar Handlers ─────────────────────────────────── + + +def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"): + """Return a session-scoped buffer proxy for the given user+session. + + Returns a _ContextualBufferProxy that exposes append_system_message(). + Defined at module level so tests can monkeypatch it. + The channel kwarg is accepted for forward-compatibility. + """ + from app.core.scout_session_buffer import ContextualBufferProxy # noqa: PLC0415 + return ContextualBufferProxy(session_buffer, user_id, session_id) + + +async def _handle_contextual_request( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a contextual_request frame — runs the contextual agent and streams frames.""" + request_id = frame.get("request_id") or str(uuid4()) + message: str = frame.get("message", "") + session_id: str = frame.get("session_id") or str(uuid4()) + scope_payload: dict = frame.get("scope", {}) + logger.info( + "device_ws: contextual_request_start user=%s req=%s session=%s msg=%s", + user_id, + request_id, + session_id, + message[:200], + ) + + scope = ContextualScope.model_validate(scope_payload) + + # Enrich context with memory before the LLM call. + async with async_session() as db: + memory = MemoryMiddleware(db) + memory_context = await memory.enrich_context( + user_id, + message, + trace_id=request_id, + session_id=session_id, + ) + + context: dict = { + "conversation_history": frame.get("conversation_history", []), + "format_prefs": frame.get("format_prefs"), + "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, + **memory_context, + } + + executor = await _make_ws_executor(websocket, user_id) + set_client_executor(executor) + response_chunks: list[str] = [] + try: + event_stream = run_contextual_stream( + user_id=user_id, + message=message, + context=context, + scope=scope, + ) + formatter = StreamFormatter(request_id=request_id) + async for ws_frame in formatter.format(event_stream): + await websocket.send_text(ws_frame.model_dump_json()) + if ws_frame.type == "stream_text": # type: ignore[union-attr] + response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] + except Exception as exc: + logger.error( + "device_ws: contextual_request failed user=%s req=%s: %s", + user_id, request_id, exc, + ) + finally: + clear_client_executor() + + # Store episode so the contextual agent can recall prior turns. + async with async_session() as db: + memory = MemoryMiddleware(db) + await memory.store_episode( + user_id, session_id, message, "".join(response_chunks), trace_id=request_id + ) + logger.info( + "device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d", + user_id, + request_id, + session_id, + len("".join(response_chunks)), + ) + + +async def _handle_contextual_scope_update( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a contextual_scope_update frame. + + Injects a synthetic system message into the session buffer so the next + agent turn knows the user navigated. No LLM call is made. + """ + session_id: str = frame.get("session_id") or str(uuid4()) + scope = ContextualScope.model_validate(frame.get("scope", {})) + block = render_scope_block(scope) + buf = get_session_buffer(user_id, session_id, channel="contextual") + buf.append_system_message( + f"User navigated to a new view. {block} Treat this as the new active context." + ) + await websocket.send_text(json.dumps({ + "type": WsFrameType.contextual_scope_ack, + "session_id": session_id, + })) + logger.info( + "device_ws: contextual_scope_update user=%s session=%s page=%s", + user_id, session_id, scope.page, + ) + + +async def _handle_brief_request( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a brief_request frame — streams plain-text brief back on the socket. + + No episode storage — briefs are not conversations. + """ + import uuid as _uuid + + request_id = frame.get("request_id") or str(uuid4()) + session_id = frame.get("session_id") or str(uuid4()) + mode: str = frame.get("mode", "home") + project_id: str | None = frame.get("project_id") + + logger.info( + "device_ws: brief_request_start user=%s req=%s mode=%s project_id=%s", + user_id, request_id, mode, project_id, + ) + + # Validate project_id for project mode before touching LLM. + if mode == "project": + try: + if not project_id: + raise ValueError("project_id required for project mode") + _uuid.UUID(project_id) + except (ValueError, AttributeError) as exc: + logger.warning( + "device_ws: brief_request invalid project_id user=%s req=%s: %s", + user_id, request_id, exc, + ) + await websocket.send_text( + WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json() + ) + return + + # Enrich context with memory (no user message — use empty string as probe). + async with async_session() as db: + memory = MemoryMiddleware(db) + memory_context = await memory.enrich_context( + user_id, + "", + trace_id=request_id, + session_id=session_id, + ) + + context: dict = { + "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, + "format_prefs": frame.get("format_prefs"), + **memory_context, + } + + executor = await _make_ws_executor(websocket, user_id) + set_client_executor(executor) + try: + if mode == "project": + event_stream = run_project_brief(user_id, project_id, context) # type: ignore[arg-type] + else: + event_stream = run_home_brief(user_id, context) + + formatter = StreamFormatter(request_id=request_id) + async for ws_frame in formatter.format(event_stream): + await websocket.send_text(ws_frame.model_dump_json()) + except Exception as exc: + logger.error( + "device_ws: brief_request failed user=%s req=%s: %s", + user_id, request_id, exc, + ) + await websocket.send_text( + WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json() + ) + finally: + clear_client_executor() + + logger.info( + "device_ws: brief_request_end user=%s req=%s mode=%s", + user_id, request_id, mode, + ) + + +# ── v6 Task Brief Handler ──────────────────────────────────────────── + + +async def _handle_task_brief_request( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a task_brief_request frame — Stage-1 executive assistant deep research. + + Streams the briefing markdown back to the client. + On stream_end, emits a ``canvas_draft`` mutation if the agent produced one. + """ + request_id = frame.get("request_id") or str(uuid4()) + session_id = frame.get("session_id") or str(uuid4()) + task_id: str = frame.get("task_id") or frame.get("taskId") or "" + project_id: str | None = frame.get("project_id") or frame.get("projectId") or None + + logger.info( + "device_ws: task_brief_request_start user=%s req=%s task=%s project=%s [cache_miss]", + user_id, request_id, task_id, project_id, + ) + + if not task_id: + await websocket.send_text( + WsStreamEnd(request_id=request_id, error="task_id is required").model_dump_json() + ) + return + + async with async_session() as db: + memory = MemoryMiddleware(db) + memory_context = await memory.enrich_context( + user_id, + f"task brief: {task_id}", + trace_id=request_id, + session_id=session_id, + ) + + context: dict = { + "_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id}, + "format_prefs": frame.get("format_prefs"), + **memory_context, + } + + executor = await _make_ws_executor(websocket, user_id) + set_client_executor(executor) + response_chunks: list[str] = [] + + try: + event_stream = run_task_brief_research_stream(user_id, task_id, context, project_id=project_id) + formatter = StreamFormatter(request_id=request_id) + async for ws_frame in formatter.format(event_stream): + if ws_frame.type == "stream_text": # type: ignore[union-attr] + response_chunks.append(ws_frame.chunk) # type: ignore[union-attr] + await websocket.send_text(ws_frame.model_dump_json()) + elif ws_frame.type == "stream_start": + await websocket.send_text(ws_frame.model_dump_json()) + # stream_end is emitted below with mutations — skip formatter's version + except Exception as exc: + logger.error( + "device_ws: task_brief_request failed user=%s req=%s task=%s: %s", + user_id, request_id, task_id, exc, + ) + await websocket.send_text( + WsStreamEnd(request_id=request_id, error=str(exc)).model_dump_json() + ) + return + finally: + clear_client_executor() + + # Extract canvas block then emit stream_end with optional mutations. + full_response = "".join(response_chunks) + _visible, canvas_content, canvas_kind = extract_canvas_block(full_response) + + mutations: list[dict] = [] + if canvas_content: + mutations.append({ + "type": "canvas_draft", + "content": canvas_content, + "kind": canvas_kind, + }) + + await websocket.send_text( + WsStreamEnd(request_id=request_id, mutations=mutations or None).model_dump_json() + ) + + logger.info( + "device_ws: task_brief_request_end user=%s req=%s task=%s response_chars=%d canvas=%s", + user_id, request_id, task_id, len(full_response), canvas_kind or "none", + ) + + +# ── v4 Journey Handlers ───────────────────────────────────────────── + + +async def _handle_journey_start( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a journey_start frame — explores directory and sends first question.""" + executor = await _make_ws_executor(websocket, user_id) + set_client_executor(executor) + try: + reply = await handle_journey_start(user_id, frame) + await websocket.send_text(json.dumps(reply)) + except Exception as exc: + logger.error( + "device_ws: journey_start failed user=%s: %s", user_id, exc + ) + await websocket.send_text(json.dumps({ + "type": "journey_reply", + "session_id": frame.get("session_id", ""), + "message": f"Failed to start journey: {exc}", + "done": True, + "prompt_template": None, + })) + finally: + clear_client_executor() + + +async def _handle_journey_message( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Handle a journey_message frame — continues the journey conversation.""" + executor = await _make_ws_executor(websocket, user_id) + set_client_executor(executor) + try: + reply = await handle_journey_message(user_id, frame) + await websocket.send_text(json.dumps(reply)) + except Exception as exc: + session_id = frame.get("session_id", "") + logger.error( + "device_ws: journey_message failed user=%s session=%s: %s", + user_id, session_id, exc, + ) + await websocket.send_text(json.dumps({ + "type": "journey_reply", + "session_id": session_id, + "message": f"Journey error: {exc}", + "done": True, + "prompt_template": None, + })) + finally: + clear_client_executor() + + +# ── v7 Folder Index Handlers ────────────────────────────────────────── + + +async def _handle_index_session_start( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Register a new folder index session. No response sent — client is declaring intent.""" + session_id: str = frame.get("sessionId") or frame.get("session_id") or "" + project_id: str | None = frame.get("projectId") or frame.get("project_id") + total: int = int(frame.get("totalFiles") or frame.get("total_files") or 0) + + if not session_id: + logger.warning("device_ws: index_session_start missing sessionId user=%s", user_id) + return + + _index_sessions[session_id] = { + "user_id": user_id, + "project_id": project_id, + "processed": 0, + "total": total, + "cancelled": False, + } + logger.info( + "device_ws: index_session_start user=%s session=%s project=%s total=%d", + user_id, session_id, project_id, total, + ) + + +async def _handle_index_session_cancel( + websocket: WebSocket, + frame: dict, +) -> None: + """Mark a session as cancelled and emit index_session_done(cancelled).""" + session_id: str = frame.get("sessionId") or frame.get("session_id") or "" + session = _index_sessions.get(session_id) + if session: + session["cancelled"] = True + + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_session_done, + "sessionId": session_id, + "status": "cancelled", + })) + _index_sessions.pop(session_id, None) + logger.info("device_ws: index_session_cancel session=%s", session_id) + + +async def _handle_index_file_batch( + websocket: WebSocket, + user_id: str, + frame: dict, +) -> None: + """Process a batch of files for an index session, streaming results back.""" + # Lazy imports to avoid heavy load at module startup. + from app.core.folder_indexer import ( # noqa: PLC0415 + summarize_image, + summarize_pdf, + summarize_docx, + summarize_text, + ) + from app.billing.tier_manager import tier_manager # noqa: PLC0415 + from app.billing.quota import add_token_usage # noqa: PLC0415 + + session_id: str = frame.get("sessionId") or frame.get("session_id") or "" + files: list[dict] = frame.get("files", []) + + session = _index_sessions.get(session_id) + if not session or session.get("cancelled"): + return + + async with async_session() as db: + tier = await tier_manager.get_tier(user_id, db) + raw_cap = tier_manager.get_feature_value(tier, "folder_monthly_tokens") + cap: int | None = None if raw_cap == -1 else raw_cap + + for file_info in files: + if session.get("cancelled"): + return + + # Electron's toSnakeCase converts payload keys, so accept both forms. + rel_path: str = file_info.get("relPath") or file_info.get("rel_path") or "" + kind: str = file_info.get("kind") or "text" + content: str = file_info.get("content") or "" + ext: str = file_info.get("ext") or "" + mime: str = file_info.get("mime") or "application/octet-stream" + name: str = rel_path.split("/")[-1] or rel_path + + try: + if kind == "image": + res = await summarize_image(image_b64=content, mime=mime) + elif kind == "pdf": + res = await summarize_pdf(pdf_b64=content, name=name) + elif kind == "docx": + res = await summarize_docx(docx_b64=content, name=name) + else: + res = await summarize_text(content=content, ext=ext, name=name) + except Exception as exc: + logger.warning( + "device_ws: index_file_batch summarize failed session=%s path=%s: %s", + session_id, rel_path, exc, + ) + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_file_result, + "sessionId": session_id, + "relPath": rel_path, + "summary": None, + "tokensUsed": 0, + "error": str(exc), + })) + session["processed"] += 1 + continue + + # Account for token usage and check cap. + usage = await add_token_usage( + user_id=user_id, + feature="folder_index", + tokens=res.tokens_used, + db=db, + cap=cap, + ) + + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_file_result, + "sessionId": session_id, + "relPath": rel_path, + "summary": res.summary, + "tokensUsed": res.tokens_used, + })) + session["processed"] += 1 + + if usage.exhausted: + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_session_done, + "sessionId": session_id, + "status": "quota_exceeded", + })) + _index_sessions.pop(session_id, None) + logger.info( + "device_ws: index_session quota_exceeded user=%s session=%s", + user_id, session_id, + ) + return + + # After processing the batch, emit progress. + processed = session["processed"] + total = session["total"] + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_session_progress, + "sessionId": session_id, + "processed": processed, + "total": total, + })) + + if processed >= total: + await websocket.send_text(json.dumps({ + "type": WsFrameType.index_session_done, + "sessionId": session_id, + "status": "completed", + })) + _index_sessions.pop(session_id, None) + logger.info( + "device_ws: index_session_done completed user=%s session=%s processed=%d", + user_id, session_id, processed, + ) + + +# ── Heartbeat ───────────────────────────────────────────────────────── + +async def _heartbeat_loop(websocket: WebSocket) -> None: + """Send a ping frame every 30 s to keep the connection alive.""" + while True: + await asyncio.sleep(_HEARTBEAT_INTERVAL) + await websocket.send_text(json.dumps({"type": "ping"})) + + +# ── Disconnect cleanup ──────────────────────────────────────────────── + +async def _mark_runs_disconnected(user_id: str) -> None: + """Mark all in-progress ScoutRunLog rows as 'error' for this user.""" + try: + async with async_session() as db: + await db.execute( + update(ScoutRunLog) + .where( + ScoutRunLog.user_id == user_id, + ScoutRunLog.status == "running", + ) + .values( + status="error", + errors=["device disconnected"], + ) + ) + await db.commit() + except Exception as exc: + logger.error( + "device_ws: failed to mark runs as disconnected for user=%s: %s", + user_id, + exc, + ) diff --git a/api/app/api/routes/memory.py b/api/app/api/routes/memory.py new file mode 100644 index 0000000..ffc5cfe --- /dev/null +++ b/api/app/api/routes/memory.py @@ -0,0 +1,225 @@ +"""Memory management routes — view/edit/delete user memory tiers. + +All routes require authentication. Data is always user-scoped. +""" + +from __future__ import annotations + +import logging +from datetime import datetime, timezone +from typing import Annotated + +from fastapi import APIRouter, Depends, Header, HTTPException, status +from pydantic import BaseModel, Field +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.core.memory_middleware import MemoryMiddleware +from app.db import get_session +from app.models import ( + ExtractionQueue, + MemoryAssociative, + MemoryCore, + MemoryEpisodic, + MemoryProactive, + MemoryRelation, +) +from app.schemas import UserProfile + +router = APIRouter(prefix="/memory", tags=["memory"]) + +logger = logging.getLogger(__name__) + +_ALLOWED_PREDICATES = { + "works_at", + "reports_to", + "stakeholder_of", + "last_contacted_on", + "owes_followup", + "manages", + "collaborates_with", + "owns", + "member_of", + "custom", +} + + +# ── Response schemas ───────────────────────────────────────────────────────── + +class RelationOut(BaseModel): + id: str + subject_label: str + subject_type: str + predicate: str + object_label: str + object_type: str + confidence: float + last_confirmed_at: int | None = None # epoch ms + + +class RelationPatch(BaseModel): + subject_label: str | None = None + object_label: str | None = None + predicate: str | None = None + confidence: float | None = Field(None, ge=0.0, le=1.0) + + +class CoreAddBody(BaseModel): + key: str = Field(..., min_length=1, max_length=255) + value: str = Field(..., min_length=1) + + +# ── Helpers ────────────────────────────────────────────────────────────────── + +def _relation_to_out(row: MemoryRelation) -> RelationOut: + last_ms: int | None = None + if row.last_confirmed_at is not None: + last_ms = int(row.last_confirmed_at.timestamp() * 1000) + return RelationOut( + id=row.id, + subject_label=row.subject_label, + subject_type=row.subject_type, + predicate=row.predicate, + object_label=row.object_label, + object_type=row.object_type, + confidence=row.confidence, + last_confirmed_at=last_ms, + ) + + +# ── Routes ─────────────────────────────────────────────────────────────────── + +@router.get("/core", response_model=dict[str, str]) +async def get_core_memory( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, str]: + """Return all core memory k/v pairs (plaintext) for the current user.""" + mw = MemoryMiddleware(db) + blocks = await mw.list_core_blocks(current_user.id) + return {b["label"]: b["value"] for b in blocks} + + +@router.delete("/core/{key}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_core_key( + key: str, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> None: + """Delete a single core memory key (GDPR Art. 17).""" + mw = MemoryMiddleware(db) + deleted = await mw.delete_core(current_user.id, key) + if not deleted: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Key not found") + + +@router.post("/core", status_code=status.HTTP_201_CREATED, response_model=dict[str, str]) +async def add_core_key( + body: CoreAddBody, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, str]: + """Add or overwrite a core memory key/value pair.""" + mw = MemoryMiddleware(db) + await mw.update_core(current_user.id, body.key, body.value) + return {body.key: body.value} + + +@router.get("/relational", response_model=list[RelationOut]) +async def get_relational_memory( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> list[RelationOut]: + """Return all relational memory rows for the current user.""" + mw = MemoryMiddleware(db) + rows = await mw.query_relations(current_user.id, limit=200) + return [_relation_to_out(r) for r in rows] + + +@router.patch("/relational/{relation_id}", response_model=RelationOut) +async def patch_relation( + relation_id: str, + body: RelationPatch, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> RelationOut: + """Edit a relation row's labels, predicate, or confidence.""" + if body.predicate is not None and body.predicate not in _ALLOWED_PREDICATES: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail=f"predicate must be one of: {sorted(_ALLOWED_PREDICATES)}", + ) + + result = await db.execute( + select(MemoryRelation).where( + MemoryRelation.id == relation_id, + MemoryRelation.user_id == current_user.id, + ) + ) + row = result.scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found") + + if body.subject_label is not None: + row.subject_label = body.subject_label + if body.object_label is not None: + row.object_label = body.object_label + if body.predicate is not None: + row.predicate = body.predicate + if body.confidence is not None: + row.confidence = body.confidence + row.last_confirmed_at = datetime.now(timezone.utc) + + await db.commit() + await db.refresh(row) + logger.info("memory: patch_relation user=%s relation=%s", current_user.id, relation_id) + return _relation_to_out(row) + + +@router.delete("/relational/{relation_id}", status_code=status.HTTP_204_NO_CONTENT) +async def delete_relation( + relation_id: str, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> None: + """Hard-delete a relation row (GDPR Art. 17).""" + result = await db.execute( + select(MemoryRelation).where( + MemoryRelation.id == relation_id, + MemoryRelation.user_id == current_user.id, + ) + ) + row = result.scalar_one_or_none() + if row is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Relation not found") + await db.delete(row) + await db.commit() + logger.info("memory: delete_relation user=%s relation=%s", current_user.id, relation_id) + + +@router.post("/forget-all", status_code=status.HTTP_204_NO_CONTENT) +async def forget_all( + x_confirm: Annotated[str | None, Header(alias="X-Confirm")] = None, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> None: + """Wipe all memory tiers for the current user (GDPR Art. 17). + + Requires ``X-Confirm: true`` header. Does NOT delete the user account. + """ + if x_confirm != "true": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Missing or invalid X-Confirm header. Send X-Confirm: true to confirm.", + ) + + uid = current_user.id + await db.execute(delete(MemoryCore).where(MemoryCore.user_id == uid)) + await db.execute(delete(MemoryAssociative).where(MemoryAssociative.user_id == uid)) + await db.execute(delete(MemoryEpisodic).where(MemoryEpisodic.user_id == uid)) + await db.execute(delete(MemoryProactive).where(MemoryProactive.user_id == uid)) + await db.execute(delete(MemoryRelation).where(MemoryRelation.user_id == uid)) + await db.execute(delete(ExtractionQueue).where(ExtractionQueue.user_id == uid)) + await db.commit() + logger.warning("memory: forget_all GDPR wipe user=%s", uid) diff --git a/api/app/api/routes/scout_setup.py b/api/app/api/routes/scout_setup.py new file mode 100644 index 0000000..36f8717 --- /dev/null +++ b/api/app/api/routes/scout_setup.py @@ -0,0 +1,513 @@ +"""Chatbot Journey — WS-based guided conversation to build an ScoutConfig. + +The journey is driven entirely through WebSocket frames (no REST endpoints). +The device WS handler dispatches ``journey_start`` and ``journey_message`` +frames to the functions exported here. + +Journey flow: + 1. FE sends ``journey_start`` frame with basic agent info (directory, + data_types, schedule). + 2. Server creates an in-memory session, sets up a WS executor so the + setup LLM can use file-system tools, does a first directory scrape, + and sends back a ``journey_reply`` with the first question. + 3. FE sends ``journey_message`` frames for each user reply. + 4. Server appends the user message, calls the LLM (which may read files + via tools), and sends back a ``journey_reply``. + 5. After 3-5 turns the LLM wraps up by emitting an ``ScoutConfig`` JSON + block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``. + 6. Server parses and validates the JSON with Pydantic, sends + ``journey_reply`` with ``done=True`` and the serialised config. + FE stores it locally. +""" + +from __future__ import annotations + +import json +import logging +import time +import uuid +from dataclasses import dataclass, field +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage + +from app.agents.filesystem_agent import make_directory_tools +from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context +from app.core.llm import get_agent_llm, model_for_agent +from app.schemas import ScoutConfig + +logger = logging.getLogger(__name__) + +# ── Session TTL ─────────────────────────────────────────────────────────── + +_SESSION_TTL_SECONDS: int = 1800 # 30 minutes + +# Sentinel strings used to delimit the LLM-produced ScoutConfig JSON. +_CONFIG_START = "AGENT_CONFIG_START" +_CONFIG_END = "AGENT_CONFIG_END" + +# Minimum turns before we consider nudging the LLM to wrap up. +_MIN_TURNS_BEFORE_NUDGE: int = 3 +# Hard cap to avoid infinite loops (safety net, not the primary stopping criterion). +_MAX_TURNS: int = 15 +# Max tool-calling steps per LLM invocation. +_MAX_TOOL_STEPS: int = 6 + +# ── In-memory session store ─────────────────────────────────────────────── + + +@dataclass +class JourneySession: + session_id: str + user_id: str + agent_type: str # "local" | "cloud" + directory: str + data_types: list[str] + history: list[dict[str, Any]] = field(default_factory=list) + system_prompt: str = "" + langfuse_prompt: Any = None + created_at: float = field(default_factory=time.monotonic) + + def is_expired(self) -> bool: + return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS + + +# session_id → session +_sessions: dict[str, JourneySession] = {} + + +def get_journey_session(session_id: str, user_id: str) -> JourneySession | None: + """Retrieve session; return None on missing, expired, or wrong owner.""" + s = _sessions.get(session_id) + if s is None or s.is_expired(): + _sessions.pop(session_id, None) + return None + if s.user_id != user_id: + return None + return s + + +# ── System prompt ───────────────────────────────────────────────────────── + +_JOURNEY_SYSTEM_PROMPT = """\ +You are a friendly assistant helping a freelancer configure a data-extraction agent. +Your job is to understand what files the user has in their directory and produce a +structured ScoutConfig JSON that the extraction agent will use as its instruction set. + +You have access to file-system tools to explore the user's directory: +- list_directory: see folder structure and file names +- read_file_content: peek at a file's content +- get_file_metadata: check file size, extension, dates + +The user's configured directory is: {directory} +Target data types: {data_types} + +## Your process + +### Step 1 — Explore the directory +Use list_directory and read_file_content to understand what types of files are present +(HTML emails, plain-text documents, CSVs, etc.). + +### Step 2 — Identify content types +For each distinct file type found, decide: +- A short id (e.g. "email_html", "plain_text", "csv") +- Which preprocessing handler to use: "email_html" for HTML emails, "generic" for everything else +- A human-readable label and optional detection_hint + +### Step 3 — Ask focused questions (one at a time) +Cover these topics based on what you discovered: +1. How to map content to entity types (task / note / timeline entry) +2. Field mapping rules (e.g. email Subject → task title, filename → note title) +3. Priority or status rules (e.g. "urgent" in subject → high priority) +4. Date extraction (e.g. "by Friday" → dueDate) +5. Exclusion rules (e.g. skip newsletters, skip files with no project match) + +### Step 4 — Produce the ScoutConfig JSON +Once you are ≥ 90% confident, output the final config between these exact markers +(each on its own line): + +{config_start} +{{ + "content_types": [ + {{ + "id": "email_html", + "label": "Email HTML", + "detection_hint": "HTML file with From/To/Subject headers", + "preprocessing": "email_html", + "extraction_prompt": "Detailed extraction instructions for this content type..." + }} + ], + "global_rules": [ + "If the file cannot be matched to any project, do not create any entity." + ], + "data_types": {data_types_json} +}} +{config_end} + +## Rules for the extraction_prompt field +- Describe when to create a task vs note vs timeline entry (be specific and concrete) +- Include field mapping rules based on what you found in the directory +- Include priority/status/date rules if applicable +- Do NOT include projectId logic — the runner handles project assignment automatically +- Do NOT mention isAiSuggested — the runner always sets it to 1 + +## Constraints +- Never ask about projects, projectId, or how to link records to projects +- Never include projectId or project creation logic in the generated config +- Keep asking questions until ≥ 90% confident, then output the JSON immediately + +{existing_section}\ +Begin by exploring the directory, then ask your first question.\ +""" + + +def _build_system_prompt( + directory: str, + data_types: list[str], + existing_config: str | None = None, +) -> tuple[str, Any]: + """Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``.""" + existing_section = ( + "\nThe user already has the following ScoutConfig — refine it based on their answers:\n" + f"```json\n{existing_config}\n```\n" + if existing_config + else "" + ) + template, prompt_obj = get_prompt_or_fallback( + "journey_system", _JOURNEY_SYSTEM_PROMPT + ) + compiled = compile_prompt( + template, + prompt_obj, + directory=directory, + data_types=", ".join(data_types), + data_types_json=json.dumps(data_types), + config_start=_CONFIG_START, + config_end=_CONFIG_END, + existing_section=existing_section, + ) + return compiled, prompt_obj + + +# ── ScoutConfig extraction ──────────────────────────────────────────────── + + +def _extract_agent_config(text: str) -> str | None: + """Return validated ScoutConfig JSON string from between markers, or None. + + Parses the JSON with Pydantic to ensure it conforms to the schema before + returning. Returns None if markers are absent or JSON is invalid. + """ + if _CONFIG_START not in text or _CONFIG_END not in text: + return None + start_idx = text.index(_CONFIG_START) + len(_CONFIG_START) + end_idx = text.index(_CONFIG_END) + raw = text[start_idx:end_idx].strip() + if not raw: + return None + try: + parsed = ScoutConfig.model_validate_json(raw) + return parsed.model_dump_json() + except Exception as exc: + logger.warning("agent_setup: failed to parse ScoutConfig JSON: %s", exc) + return None + + +# ── LLM call with tool support ─────────────────────────────────────────── + + +def _as_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + return "".join(parts) + return str(content) + + +async def _call_llm_with_tools( + system_prompt: str, + history: list[dict[str, Any]], + tools: list[Any], + *, + user_id: str = "", + session_id: str = "", + langfuse_prompt: Any = None, +) -> str: + """Build LangChain messages from history and invoke the LLM with tools. + + Handles tool-calling loops: if the LLM calls tools, execute them and + continue until a final text response is produced. + """ + lf = get_langfuse() + messages: list[Any] = [SystemMessage(content=system_prompt)] + for turn in history: + if turn["role"] == "user": + messages.append(HumanMessage(content=turn["content"])) + else: + messages.append(AIMessage(content=turn["content"])) + + llm = get_agent_llm("setup", temperature=0.4) + llm_with_tools = llm.bind_tools(tools) + tool_map = {tool_def.name: tool_def for tool_def in tools} + + _lf_ctx = langfuse_context(user_id=user_id or None, session_id=session_id or None) + _lf_ctx.__enter__() + + _span_ctx = ( + lf.start_as_current_observation( + as_type="span", + name="journey-setup", + input=history[-1]["content"] if history else "", + ) + if lf else None + ) + _span = _span_ctx.__enter__() if _span_ctx else None + + try: + for step in range(_MAX_TOOL_STEPS): + _gen_ctx = ( + lf.start_as_current_observation( + as_type="generation", + name="journey-setup-llm", + model=model_for_agent("setup"), + prompt=langfuse_prompt, + input=messages, + ) + if lf else None + ) + _gen = _gen_ctx.__enter__() if _gen_ctx else None + response: AIMessage = await llm_with_tools.ainvoke(messages) + if _gen_ctx: + _gen.update(output=_as_text(response.content), usage_details=extract_usage(response)) + _gen_ctx.__exit__(None, None, None) + + resp_text = _as_text(response.content) + + # Guard against empty responses (e.g. model returned finish_reason + # 'error' which LiteLLM maps to 'stop' with empty content). + if not response.tool_calls and not resp_text.strip(): + logger.warning( + "agent_setup: journey LLM returned empty response at step %d — retrying", + step, + ) + # Drop the empty AIMessage so we don't pollute history, and retry. + continue + + messages.append(response) + + if not response.tool_calls: + if _span: + _span.update(output=resp_text) + return resp_text + + for call in response.tool_calls: + call_name = str(call.get("name", "")) + call_args = call.get("args", {}) + logger.info( + "agent_setup: journey tool_call name=%s args=%s", + call_name, + json.dumps(call_args, ensure_ascii=True)[:500], + ) + + tool_fn = tool_map.get(call_name) + if tool_fn is None: + tool_output = f"Unknown tool: {call_name}" + else: + tool_output = await tool_fn.ainvoke(call_args) + + logger.info( + "agent_setup: journey tool_result name=%s output=%s", + call_name, + str(tool_output)[:800], + ) + messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) + + # Fallback: exceeded max steps. + final = await llm.ainvoke(messages) + final_text = _as_text(final.content) + if _span: + _span.update(output=final_text) + return final_text or ( + "Sorry, I had trouble processing the files. " + "Could you try again? If the issue persists, the files might be too large for me to analyse." + ) + finally: + if _span_ctx: + _span_ctx.__exit__(None, None, None) + _lf_ctx.__exit__(None, None, None) + if lf: + lf.flush() + + +# ── Journey handlers (called from device_ws.py) ────────────────────────── + + +async def handle_journey_start( + user_id: str, + frame: dict[str, Any], +) -> dict[str, Any]: + """Handle a ``journey_start`` WS frame. + + Creates a session, runs the setup LLM with directory exploration, + and returns the ``journey_reply`` payload. + """ + agent_type = frame.get("agent_type", "local") + directory = frame.get("directory", "") + data_types = frame.get("data_types", []) + existing_config = frame.get("existing_config") + + # Use the session_id provided by the FE so the reply matches the + # listener key; fall back to a generated one if absent. + session_id = frame.get("session_id") or str(uuid.uuid4()) + system_prompt, langfuse_prompt = _build_system_prompt(directory, data_types, existing_config) + + session = JourneySession( + session_id=session_id, + user_id=user_id, + agent_type=agent_type, + directory=directory, + data_types=data_types, + system_prompt=system_prompt, + langfuse_prompt=langfuse_prompt, + ) + + # Seed with an initial user message — some providers require at least one + # user/input message to be present. + seed_history: list[dict[str, Any]] = [ + {"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."}, + ] + ai_reply = await _call_llm_with_tools( + system_prompt=system_prompt, + history=seed_history, + tools=make_directory_tools(directory), + user_id=user_id, + session_id=session_id, + langfuse_prompt=langfuse_prompt, + ) + + session.history.extend(seed_history) + session.history.append({"role": "assistant", "content": ai_reply}) + _sessions[session_id] = session + + logger.info( + "agent_setup: journey session %s started for user %s (directory=%s)", + session_id, + user_id, + directory, + ) + + # Check if the LLM produced the config on the first turn (unlikely but possible). + agent_config = _extract_agent_config(ai_reply) + done = agent_config is not None + + display_message = ai_reply + if done: + display_message = ( + ai_reply[: ai_reply.index(_CONFIG_START)].strip() + or "Here is your agent configuration. You can save it or continue refining." + ) + _sessions.pop(session_id, None) + + return { + "type": "journey_reply", + "session_id": session_id, + "message": display_message, + "done": done, + "agent_config": agent_config, + } + + +async def handle_journey_message( + user_id: str, + frame: dict[str, Any], +) -> dict[str, Any]: + """Handle a ``journey_message`` WS frame. + + Appends the user message, calls the LLM, and returns the + ``journey_reply`` payload. + """ + session_id = frame.get("session_id", "") + message = frame.get("message", "") + + session = get_journey_session(session_id, user_id) + if session is None: + return { + "type": "journey_reply", + "session_id": session_id, + "message": "Journey session not found or expired. Please start a new setup.", + "done": True, + "agent_config": None, + } + + # Append user turn. + session.history.append({"role": "user", "content": message}) + + # Call the LLM with tools. + session_tools = make_directory_tools(session.directory) + ai_reply = await _call_llm_with_tools( + system_prompt=session.system_prompt, + history=session.history, + tools=session_tools, + user_id=session.user_id, + session_id=session_id, + langfuse_prompt=session.langfuse_prompt, + ) + + session.history.append({"role": "assistant", "content": ai_reply}) + + # Check if the LLM produced the final config. + agent_config = _extract_agent_config(ai_reply) + done = agent_config is not None + + # If the LLM didn't produce a config, nudge it once it hits the hard safety cap. + if not done: + turns = sum(1 for t in session.history if t["role"] == "user") + if turns >= _MAX_TURNS: + nudge_content = ( + "[System: You have enough information. Please generate the final " + f"ScoutConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]" + ) + session.history.append({"role": "user", "content": nudge_content}) + + nudge_reply = await _call_llm_with_tools( + system_prompt=session.system_prompt, + history=session.history, + tools=session_tools, + user_id=session.user_id, + session_id=session_id, + langfuse_prompt=session.langfuse_prompt, + ) + session.history.append({"role": "assistant", "content": nudge_reply}) + + agent_config = _extract_agent_config(nudge_reply) + if agent_config is not None: + done = True + ai_reply = nudge_reply + + display_message = ai_reply + if done: + display_message = ( + ai_reply[: ai_reply.index(_CONFIG_START)].strip() + if _CONFIG_START in ai_reply + else "Here is your agent configuration. You can save it or continue refining." + ) + _sessions.pop(session_id, None) + logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id) + + return { + "type": "journey_reply", + "session_id": session_id, + "message": display_message, + "done": done, + "agent_config": agent_config, + } diff --git a/api/app/api/routes/scout_webhooks.py b/api/app/api/routes/scout_webhooks.py new file mode 100644 index 0000000..cf89020 --- /dev/null +++ b/api/app/api/routes/scout_webhooks.py @@ -0,0 +1,120 @@ +"""Gmail Pub/Sub push receiver. + +Google Pub/Sub push subscriptions deliver Gmail watch notifications as POST +requests with a JSON envelope. The body payload contains a base64-encoded +JSON blob with ``emailAddress`` + ``historyId``. We resolve the user by +email, look up their cloud_scout_configs row for provider='gmail', and +hand off to ScoutEngine.trigger_scout. + +Authentication: Pub/Sub push includes an OIDC JWT in the Authorization +header. We verify it against Google's public keys with the audience +configured in our Pub/Sub subscription. + +Dev mode: when ``GMAIL_PUBSUB_AUDIENCE`` is empty, JWT verification is +skipped and a warning is logged. Production must set this env var. +""" + +from __future__ import annotations + +import base64 +import json +import logging +import uuid + +from fastapi import APIRouter, Header, HTTPException, Request, status +from sqlalchemy import select + +from app.config.settings import settings +from app.db import async_session +from app.models import CloudScoutConfig, User +from app.scouts.engine import ScoutEngine + +logger = logging.getLogger(__name__) +router = APIRouter(prefix="/scouts/webhooks", tags=["scout-webhooks"]) + + +def _verify_pubsub_jwt(token: str) -> bool: + """Verify the Google Pub/Sub OIDC JWT. + + Returns True when valid, False on any verification failure. + + Dev skip: if ``settings.GMAIL_PUBSUB_AUDIENCE`` is empty, logs a + warning and returns True so local development works without a real + Pub/Sub subscription. Production must configure the audience. + """ + if not token: + return False + + if not settings.GMAIL_PUBSUB_AUDIENCE: + logger.warning( + "GMAIL_PUBSUB_AUDIENCE not set — skipping Pub/Sub JWT verification (dev mode only)" + ) + return True + + try: + from google.auth.transport import requests as g_requests # noqa: PLC0415 + from google.oauth2 import id_token # noqa: PLC0415 + + id_token.verify_oauth2_token( + token, + g_requests.Request(), + audience=settings.GMAIL_PUBSUB_AUDIENCE, + ) + return True + except Exception: + logger.warning("pubsub jwt verification failed", exc_info=True) + return False + + +@router.post("/gmail", status_code=status.HTTP_204_NO_CONTENT) +async def gmail_pubsub( + request: Request, + authorization: str = Header(default=""), +) -> None: + """Receive a Gmail Pub/Sub push notification. + + Verifies the OIDC JWT, decodes the Pub/Sub envelope, resolves the user + by email, and triggers ScoutEngine.trigger_scout for each enabled Gmail + scout belonging to that user. + + Returns 204 No Content on success (including benign no-ops like unknown + email or empty message data). Returns 401 on JWT verification failure. + """ + token = authorization.removeprefix("Bearer ").strip() + if not _verify_pubsub_jwt(token): + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid Pub/Sub JWT") + + body = await request.json() + msg = body.get("message") or {} + raw = msg.get("data") + if not raw: + return # ack without action — empty message data + + try: + decoded = json.loads(base64.b64decode(raw).decode()) + except Exception: + logger.warning("pubsub payload decode failed") + return + + email = decoded.get("emailAddress") + if not email: + return + + async with async_session() as session: + user_q = await session.execute(select(User).where(User.email == email)) + user = user_q.scalar_one_or_none() + if user is None: + logger.info("pubsub: no user for %s — ignoring", email) + return + scouts_q = await session.execute( + select(CloudScoutConfig).where( + CloudScoutConfig.user_id == user.id, + CloudScoutConfig.provider == "gmail", + CloudScoutConfig.enabled == True, # noqa: E712 + ) + ) + scouts = scouts_q.scalars().all() + + engine = ScoutEngine() + for scout in scouts: + await engine.trigger_scout(uuid.UUID(str(scout.id))) diff --git a/api/app/api/routes/scouts.py b/api/app/api/routes/scouts.py new file mode 100644 index 0000000..9c07932 --- /dev/null +++ b/api/app/api/routes/scouts.py @@ -0,0 +1,807 @@ +"""Scout routes. + +Backend responsibilities are intentionally minimal: + GET /scouts/catalog — static catalog for UI display + POST /scouts/can-create — billing eligibility check + POST /scouts/trigger — trigger a local scout run + +Scout configuration is owned by the Electron app and is not persisted +in backend scout-config tables. + +Gmail OAuth setup (scout-specific consent): + GET /scouts/oauth/gmail/authorize — returns consent-screen URL + GET /scouts/oauth/gmail/web-callback — bounces to deep link (excluded from schema) + POST /scouts/oauth/gmail/callback — exchanges code, stores encrypted token +""" + +from __future__ import annotations + +import asyncio +import logging +import secrets +import time +import urllib.parse +import uuid +from datetime import datetime, timezone + +from fastapi import APIRouter, Depends, HTTPException, status +from fastapi.responses import RedirectResponse +from sqlalchemy import delete as sa_delete, func, select +from sqlalchemy.ext.asyncio import AsyncSession + +from pydantic import BaseModel + +from app.api.deps import get_current_user +from app.auth.oauth_providers import generate_pkce_pair +from app.billing.tier_manager import FEATURES +from app.config.settings import settings +from app.core.scout_runner import is_agent_running, run_local_agent +from app.core.device_manager import device_manager +from app.core.note_summarizer import generate_note_summary +from app.db import get_session +from app.integrations import decrypt_token, encrypt_token +from app.models import CloudScoutConfig, ScoutRunLog, LocalScoutConfig +from app.scouts.connectors.registry import get_connector +from app.schemas import ( + CloudScoutCreateRequest, + CloudScoutResponse, + CloudScoutUpdateRequest, + ScoutCatalogItem, + ScoutCreationCheckRequest, + ScoutCreationCheckResponse, + ScoutRunLogResponse, + ScoutTriggerRequest, + UserProfile, +) + +logger = logging.getLogger(__name__) + +router = APIRouter(prefix="/scouts", tags=["scouts"]) + + +# ── Datetime helpers ────────────────────────────────────────────────── + +def _dt_ms(dt: datetime) -> int: + return int(dt.timestamp() * 1000) + + +def _dt_ms_opt(dt: datetime | None) -> int | None: + return int(dt.timestamp() * 1000) if dt else None + + +def _to_data_types(values: list[str]) -> list[str]: + normalize = { + "task": "tasks", "tasks": "tasks", + "note": "notes", "notes": "notes", + "timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines", + "project": "projects", "projects": "projects", + } + seen: set[str] = set() + result: list[str] = [] + for v in values: + mapped = normalize.get(v) + if mapped and mapped not in seen: + seen.add(mapped) + result.append(mapped) + return result + + +def _to_run_log_response(log: ScoutRunLog) -> ScoutRunLogResponse: + return ScoutRunLogResponse( + id=log.id, + agent_id=log.scout_id, + agent_type=log.scout_type, # type: ignore[arg-type] + status=log.status, # type: ignore[arg-type] + items_processed=log.items_processed, + items_created=log.items_created, + errors=log.errors or [], + started_at=_dt_ms(log.started_at), + completed_at=_dt_ms_opt(log.completed_at), + ) + + +def _enforce_agent_limit(tier: str, current_count: int) -> int: + limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"] + if limit != -1 and current_count >= limit: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.", + ) + return limit + + +async def _enforce_run_frequency( + tier: str, + user_id: str, + db: AsyncSession, +) -> None: + """Raise HTTP 402 if the user has exceeded their daily batch run limit.""" + limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"] + if limit == -1: + return # unlimited + + today_start = datetime.now(timezone.utc).replace( + hour=0, minute=0, second=0, microsecond=0 + ) + result = await db.execute( + select(func.count(ScoutRunLog.id)).where( + ScoutRunLog.user_id == user_id, + ScoutRunLog.started_at >= today_start, + ) + ) + runs_today: int = result.scalar_one() + + if runs_today >= limit: + raise HTTPException( + status_code=status.HTTP_402_PAYMENT_REQUIRED, + detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.", + ) + + +# ── Catalog ─────────────────────────────────────────────────────────── + +@router.get("/catalog", response_model=list[ScoutCatalogItem]) +async def get_agent_catalog( + current_user: UserProfile = Depends(get_current_user), +) -> list[ScoutCatalogItem]: + """Return the static list of available agent types and their descriptions.""" + return [ + ScoutCatalogItem( + type="local_directory", + name="Local Directory Monitor", + description="Watches local directories, extracts data from files using AI", + ), + ScoutCatalogItem( + type="gmail", + name="Gmail Connector", + description="Scans Gmail inbox, extracts tasks/notes from emails", + ), + ScoutCatalogItem( + type="teams", + name="Microsoft Teams Connector", + description="Monitors Teams messages, extracts action items", + ), + ScoutCatalogItem( + type="outlook", + name="Outlook Connector", + description="Scans Outlook inbox, extracts tasks/notes", + ), + ] + + +@router.post("/can-create", response_model=ScoutCreationCheckResponse) +async def can_create_agent( + body: ScoutCreationCheckRequest, + current_user: UserProfile = Depends(get_current_user), +) -> ScoutCreationCheckResponse: + """Check if the user can create one more agent based on billing tier. + + Since configuration is client-owned, the Electron app sends its current + active agent count and the backend applies tier limits. + """ + limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"] + allowed = limit == -1 or body.active_agents < limit + return ScoutCreationCheckResponse( + allowed=allowed, + tier=current_user.tier, + active_agents=body.active_agents, + limit=limit, + ) + + +@router.post("/trigger", response_model=ScoutRunLogResponse, status_code=status.HTTP_202_ACCEPTED) +async def trigger_agent_run( + body: ScoutTriggerRequest, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> ScoutRunLogResponse: + """Trigger a local agent run using client-provided configuration.""" + _enforce_agent_limit(current_user.tier, body.active_agents) + await _enforce_run_frequency(current_user.tier, current_user.id, db) + + last_run_dt = ( + datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc) + if body.last_run_at + else None + ) + config = LocalScoutConfig( + id=str(uuid.uuid4()), + user_id=current_user.id, + device_id=body.device_id, + name="Local Directory Monitor", + directory_paths=[body.directory], + data_types=_to_data_types(body.what_to_extract), + prompt_template=body.custom_agent_prompt or "", + scout_config=body.agent_config, + file_extensions=[], + schedule_cron=body.batch_interval, + enabled=True, + last_run_at=last_run_dt, + ) + + # Use the FE's stable agent_id if provided, fall back to the ephemeral config id. + stable_agent_id = body.agent_id or config.id + + if is_agent_running(stable_agent_id): + raise HTTPException( + status_code=status.HTTP_409_CONFLICT, + detail="Agent is already running. Only one run per agent is allowed at a time.", + ) + + run_log = ScoutRunLog( + scout_id=stable_agent_id, + scout_type="local", + user_id=current_user.id, + status="running", + ) + db.add(run_log) + await db.commit() + await db.refresh(run_log) + + run_context = { + "type": "agent_batch", + "run_id": run_log.id, + "agent_id": stable_agent_id, + } + + asyncio.create_task( + run_local_agent(current_user.id, config, run_log, device_manager, run_context) + ) + + return _to_run_log_response(run_log) + + +# ── Note summary endpoint ────────────────────────────────────────────────────── + + +class NoteSummarizeRequest(BaseModel): + title: str + content: str + + +class NoteSummarizeResponse(BaseModel): + summary: str + + +@router.post("/notes/summarize", response_model=NoteSummarizeResponse) +async def summarize_note( + body: NoteSummarizeRequest, + current_user: UserProfile = Depends(get_current_user), +) -> NoteSummarizeResponse: + """Generate an AI summary for a note. Used by the Electron backfill on startup.""" + summary = await generate_note_summary(body.title, body.content) + return NoteSummarizeResponse(summary=summary) + + +# ── Cloud scout CRUD ────────────────────────────────────────────────────────── + +_DEFAULT_CLOUD_SCHEDULE = "0 */6 * * *" + + +def _to_cloud_response(scout: CloudScoutConfig) -> dict: + return { + "id": scout.id, + "user_id": scout.user_id, + "provider": scout.provider, + "name": scout.name, + "data_types": scout.data_types or [], + "prompt_template": scout.prompt_template or "", + "schedule_cron": scout.schedule_cron, + "filter_config": scout.filter_config, + "auto_trash_spam": scout.auto_trash_spam, + "enabled": scout.enabled, + "last_run_at": _dt_ms_opt(scout.last_run_at), + "gmail_address": scout.gmail_address, + "oauth_connected": scout.oauth_token_encrypted is not None, + "created_at": _dt_ms(scout.created_at), + "updated_at": _dt_ms(scout.updated_at), + } + + +@router.get("/cloud", response_model=list[CloudScoutResponse]) +async def list_cloud_scouts( + db: AsyncSession = Depends(get_session), + current_user: UserProfile = Depends(get_current_user), +): + rows = (await db.execute( + select(CloudScoutConfig).where(CloudScoutConfig.user_id == current_user.id) + )).scalars().all() + return [_to_cloud_response(s) for s in rows] + + +@router.post("/cloud", response_model=CloudScoutResponse, status_code=status.HTTP_201_CREATED) +async def create_cloud_scout( + body: CloudScoutCreateRequest, + db: AsyncSession = Depends(get_session), + current_user: UserProfile = Depends(get_current_user), +): + scout = CloudScoutConfig( + id=str(uuid.uuid4()), + user_id=current_user.id, + provider=body.provider, + name=body.name, + data_types=body.data_types, + prompt_template=body.prompt_template, + filter_config=body.filter_config, + schedule_cron=body.schedule_cron or _DEFAULT_CLOUD_SCHEDULE, + auto_trash_spam=body.auto_trash_spam, + enabled=True, + ) + db.add(scout) + await db.commit() + await db.refresh(scout) + return _to_cloud_response(scout) + + +@router.put("/cloud/{scout_id}", response_model=CloudScoutResponse) +async def update_cloud_scout( + scout_id: str, + body: CloudScoutUpdateRequest, + db: AsyncSession = Depends(get_session), + current_user: UserProfile = Depends(get_current_user), +): + scout = await db.get(CloudScoutConfig, scout_id) + if scout is None or scout.user_id != current_user.id: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found") + if body.name is not None: + scout.name = body.name + if body.data_types is not None: + scout.data_types = body.data_types + if body.prompt_template is not None: + scout.prompt_template = body.prompt_template + if body.schedule_cron is not None: + scout.schedule_cron = body.schedule_cron + if body.filter_config is not None: + scout.filter_config = body.filter_config + if body.auto_trash_spam is not None: + scout.auto_trash_spam = body.auto_trash_spam + if body.enabled is not None: + scout.enabled = body.enabled + await db.commit() + await db.refresh(scout) + return _to_cloud_response(scout) + + +@router.delete("/cloud/{scout_id}") +async def delete_cloud_scout( + scout_id: str, + db: AsyncSession = Depends(get_session), + current_user: UserProfile = Depends(get_current_user), +): + scout = await db.get(CloudScoutConfig, scout_id) + if scout is None or scout.user_id != current_user.id: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found") + # Core deletes bypass the polymorphic ScoutRunLog relationship whose + # varchar scout_id vs uuid id join is not directly comparable in Postgres. + # scout_run_logs.scout_id is a plain string (matches the str scout_id); + # scout_triage_queue rows cascade automatically via their FK ondelete. + await db.execute(sa_delete(ScoutRunLog).where(ScoutRunLog.scout_id == scout_id)) + await db.execute(sa_delete(CloudScoutConfig).where(CloudScoutConfig.id == scout_id)) + await db.commit() + return {"ok": True} + + +@router.get("/cloud/{scout_id}/gmail-labels") +async def list_gmail_labels( + scout_id: str, + db: AsyncSession = Depends(get_session), + current_user: UserProfile = Depends(get_current_user), +): + scout = await db.get(CloudScoutConfig, scout_id) + if scout is None or scout.user_id != current_user.id: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found") + try: + connector = get_connector("gmail") + except KeyError: + return [] + return await connector.list_labels(scout) + + +@router.post("/cloud/{scout_id}/gmail-disconnect", response_model=CloudScoutResponse) +async def disconnect_gmail( + scout_id: str, + db: AsyncSession = Depends(get_session), + current_user: UserProfile = Depends(get_current_user), +): + scout = await db.get(CloudScoutConfig, scout_id) + if scout is None or scout.user_id != current_user.id: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found") + try: + connector = get_connector("gmail") + await connector.stop_watch(scout) + except KeyError: + pass + scout.oauth_token_encrypted = None + scout.gmail_history_id = None + scout.gmail_watch_expires_at = None + scout.gmail_address = None + scout.enabled = False + await db.commit() + await db.refresh(scout) + return _to_cloud_response(scout) + + +# ── Gmail OAuth setup (scout-specific) ─────────────────────────────────────── + +# Scopes required for Gmail scout connectivity. +_GMAIL_SCOUT_SCOPES = [ + "openid", + "email", + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/gmail.modify", +] + +# Google OAuth endpoints. +_GOOGLE_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" +_GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token" + +# In-memory pending OAuth states for scout Gmail consent. +# +# state → { +# "code_verifier": str, +# "user_id": str, +# "expires_at": float (epoch seconds), +# "mode": "reconnect" | "create", +# "scout_id": str | None, # set for reconnect mode +# "draft": {name, prompt_template, auto_trash_spam} | None, # set for create mode +# "token_encrypted": str | None, # populated after a successful create-mode callback +# "gmail_address": str | None, +# } +# +# Zero-trust: in create mode the encrypted Gmail token lives ONLY here, in +# process memory, for at most _SCOUT_OAUTH_TTL_SECONDS. It is persisted to the +# DB only when the user finalizes the scout (POST /scouts/cloud/finalize). +# An abandoned/errored flow leaves no scout row and no stored token. +# +# Production note: this in-memory store is single-process only — replace with +# Redis (keyed by state, TTL'd) for multi-worker deployments. +_pending_scout_oauth_states: dict[str, dict] = {} +_SCOUT_OAUTH_TTL_SECONDS = 900 # 15 minutes + + +def _purge_expired_oauth_states() -> None: + now = time.time() + expired = [s for s, e in _pending_scout_oauth_states.items() if e.get("expires_at", 0) < now] + for s in expired: + del _pending_scout_oauth_states[s] + + +def _scout_gmail_redirect_uri() -> str: + """Derive the scout Gmail web-callback URI from the configured base OAUTH_REDIRECT_URI. + + ``OAUTH_REDIRECT_URI`` is the full path used for login OAuth + (e.g. http://localhost:8000/api/v1/auth/oauth/google/web-callback). + We strip the path to get the scheme+host base, then append the scout path. + """ + parsed = urllib.parse.urlparse(settings.OAUTH_REDIRECT_URI) + base = f"{parsed.scheme}://{parsed.netloc}" + return f"{base}/api/v1/scouts/oauth/gmail/web-callback" + + +class _ScoutGmailAuthorizeResponse(BaseModel): + authorize_url: str + + +class _ScoutGmailCallbackBody(BaseModel): + code: str + state: str + + +class _ScoutGmailAuthorizeDraftBody(BaseModel): + name: str + prompt_template: str = "" + auto_trash_spam: bool = False + + +class _ScoutGmailFinalizeBody(BaseModel): + session: str + filter_config: dict | None = None + + +def _build_gmail_authorize_url(state: str, code_challenge: str) -> str: + """Build the Google consent URL for the scout Gmail flow (shared by both modes).""" + redirect_uri = _scout_gmail_redirect_uri() + params = { + "client_id": settings.GOOGLE_AUTH_CLIENT_ID, + "redirect_uri": redirect_uri, + "response_type": "code", + "scope": " ".join(_GMAIL_SCOUT_SCOPES), + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "access_type": "offline", + "prompt": "consent", + } + return f"{_GOOGLE_AUTH_URL}?{urllib.parse.urlencode(params)}" + + +@router.get("/oauth/gmail/authorize", response_model=_ScoutGmailAuthorizeResponse) +async def scout_gmail_oauth_authorize( + scout_id: str, + current_user: UserProfile = Depends(get_current_user), +) -> _ScoutGmailAuthorizeResponse: + """Start the Gmail OAuth flow for a specific cloud scout. + + Returns the Google consent-screen URL. The client opens this URL in the + system browser; after consent Google redirects to web-callback which bounces + to the ``adiuvai://scout/oauth/gmail/callback`` deep link. + """ + if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET: + raise HTTPException( + status.HTTP_503_SERVICE_UNAVAILABLE, + "Google OAuth is not configured on this server", + ) + + code_verifier, code_challenge = generate_pkce_pair() + state = secrets.token_urlsafe(32) + + _purge_expired_oauth_states() + + _pending_scout_oauth_states[state] = { + "code_verifier": code_verifier, + "user_id": current_user.id, + "expires_at": time.time() + _SCOUT_OAUTH_TTL_SECONDS, + "mode": "reconnect", + "scout_id": scout_id, + "draft": None, + "token_encrypted": None, + "gmail_address": None, + } + + return _ScoutGmailAuthorizeResponse( + authorize_url=_build_gmail_authorize_url(state, code_challenge) + ) + + +@router.post("/oauth/gmail/authorize-draft", response_model=_ScoutGmailAuthorizeResponse) +async def scout_gmail_oauth_authorize_draft( + body: _ScoutGmailAuthorizeDraftBody, + current_user: UserProfile = Depends(get_current_user), +) -> _ScoutGmailAuthorizeResponse: + """Start the Gmail OAuth flow in *creation* mode — no scout row exists yet. + + The draft scout fields are held in the pending OAuth session; the scout is + only created once the user finalizes (POST /scouts/cloud/finalize). + """ + if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET: + raise HTTPException( + status.HTTP_503_SERVICE_UNAVAILABLE, + "Google OAuth is not configured on this server", + ) + + code_verifier, code_challenge = generate_pkce_pair() + state = secrets.token_urlsafe(32) + + _purge_expired_oauth_states() + + _pending_scout_oauth_states[state] = { + "code_verifier": code_verifier, + "user_id": current_user.id, + "expires_at": time.time() + _SCOUT_OAUTH_TTL_SECONDS, + "mode": "create", + "scout_id": None, + "draft": { + "name": body.name, + "prompt_template": body.prompt_template, + "auto_trash_spam": body.auto_trash_spam, + }, + "token_encrypted": None, + "gmail_address": None, + } + + return _ScoutGmailAuthorizeResponse( + authorize_url=_build_gmail_authorize_url(state, code_challenge) + ) + + +@router.get("/oauth/gmail/web-callback", include_in_schema=False) +async def scout_gmail_oauth_web_callback(code: str, state: str) -> RedirectResponse: + """Google redirects here after Gmail consent. + + Immediately bounces to the Electron deep link so the desktop app + receives the authorization code. + """ + params = urllib.parse.urlencode({"code": code, "state": state}) + deep_link = f"adiuvai://scout/oauth/gmail/callback?{params}" + return RedirectResponse(url=deep_link, status_code=302) + + +@router.post("/oauth/gmail/callback") +async def scout_gmail_oauth_callback( + body: _ScoutGmailCallbackBody, + db: AsyncSession = Depends(get_session), + current_user: UserProfile = Depends(get_current_user), +) -> dict: + """Exchange the Gmail authorization code and store the encrypted token on the scout. + + Called by the Electron app after it receives the deep-link callback with + the ``code`` and ``state`` params. + """ + entry = _pending_scout_oauth_states.pop(body.state, None) + if ( + entry is None + or entry["expires_at"] < time.time() + or entry["user_id"] != current_user.id + ): + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state") + + code_verifier = entry["code_verifier"] + mode = entry["mode"] + scout_id = entry.get("scout_id") + + redirect_uri = _scout_gmail_redirect_uri() + + import httpx + async with httpx.AsyncClient() as client: + response = await client.post( + _GOOGLE_TOKEN_URL, + data={ + "client_id": settings.GOOGLE_AUTH_CLIENT_ID, + "client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET, + "code": body.code, + "code_verifier": code_verifier, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + try: + response.raise_for_status() + except httpx.HTTPStatusError as exc: + logger.error("Gmail token exchange failed: %s", exc.response.text) + raise HTTPException(status.HTTP_502_BAD_GATEWAY, "Failed to exchange Gmail authorization code") + + token_data = response.json() + + creds_dict: dict = { + "token": token_data["access_token"], + "refresh_token": token_data.get("refresh_token"), + "token_uri": _GOOGLE_TOKEN_URL, + "client_id": settings.GOOGLE_AUTH_CLIENT_ID, + "client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET, + "scopes": [ + "https://www.googleapis.com/auth/gmail.readonly", + "https://www.googleapis.com/auth/gmail.modify", + ], + } + encrypted = encrypt_token(creds_dict) + + # Fetch the connected Gmail address for display. + gmail_address: str | None = None + try: + from googleapiclient.discovery import build + from google.oauth2.credentials import Credentials + + def _fetch_email() -> str | None: + creds = Credentials( + token=creds_dict["token"], + refresh_token=creds_dict.get("refresh_token"), + token_uri=creds_dict["token_uri"], + client_id=creds_dict["client_id"], + client_secret=creds_dict["client_secret"], + scopes=creds_dict["scopes"], + ) + service = build("gmail", "v1", credentials=creds, cache_discovery=False) + profile = service.users().getProfile(userId="me").execute() + return profile.get("emailAddress") + + gmail_address = await asyncio.to_thread(_fetch_email) + except Exception: + logger.exception("failed to fetch gmail address (mode=%s)", mode) + + if mode == "create": + # Do NOT create a scout yet. Hold the encrypted token + address in the + # transient in-memory session; the scout is created at finalize. + entry["token_encrypted"] = encrypted + entry["gmail_address"] = gmail_address + entry["expires_at"] = time.time() + _SCOUT_OAUTH_TTL_SECONDS + _pending_scout_oauth_states[body.state] = entry + return {"ok": True, "session_id": body.state, "gmail_address": gmail_address} + + # mode == "reconnect": update the existing scout in place. + scout = await db.get(CloudScoutConfig, scout_id) + if scout is None or scout.user_id != current_user.id: + raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found") + scout.oauth_token_encrypted = encrypted + scout.gmail_address = gmail_address + + await db.commit() + + # Attempt to set up Gmail push watch so we start receiving Pub/Sub notifications. + try: + connector = get_connector("gmail") + await connector.setup_watch(scout) + await db.commit() + except KeyError: + logger.warning("gmail connector not registered — skipping setup_watch for scout %s", scout_id) + except Exception: + logger.exception("setup_watch failed for scout %s", scout_id) + + return {"ok": True, "session_id": None, "gmail_address": gmail_address} + + +@router.get("/oauth/gmail/session-labels") +async def scout_gmail_session_labels( + session: str, + current_user: UserProfile = Depends(get_current_user), +) -> list[dict]: + """List Gmail labels for a pending create-mode OAuth session (no scout row yet). + + Builds a Gmail service from the session's transient decrypted token. + Returns [] on any error. + """ + entry = _pending_scout_oauth_states.get(session) + if ( + entry is None + or entry["expires_at"] < time.time() + or entry["user_id"] != current_user.id + or entry.get("token_encrypted") is None + ): + raise HTTPException(status.HTTP_404_NOT_FOUND, "Session not found or expired") + + try: + from app.scouts.connectors.gmail import _gmail_service_from_token + + creds = decrypt_token(entry["token_encrypted"]) + + def _sync() -> list[dict]: + service = _gmail_service_from_token(creds) + resp = service.users().labels().list(userId="me").execute() + return [{"id": lbl["id"], "name": lbl["name"]} for lbl in resp.get("labels", [])] + + return await asyncio.to_thread(_sync) + except Exception: + logger.exception("session-labels failed for session %s", session) + return [] + + +@router.post("/cloud/finalize", response_model=CloudScoutResponse, status_code=status.HTTP_201_CREATED) +async def finalize_cloud_scout( + body: _ScoutGmailFinalizeBody, + db: AsyncSession = Depends(get_session), + current_user: UserProfile = Depends(get_current_user), +): + """Create the cloud scout from a completed create-mode OAuth session. + + This is the only path that persists the Gmail token for a newly-created + scout. Abandoned flows never reach here, so they leave no orphan rows. + """ + entry = _pending_scout_oauth_states.pop(body.session, None) + if ( + entry is None + or entry["expires_at"] < time.time() + or entry["user_id"] != current_user.id + or entry.get("mode") != "create" + or entry.get("token_encrypted") is None + ): + raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth session") + + draft = entry["draft"] or {} + scout = CloudScoutConfig( + id=str(uuid.uuid4()), + user_id=current_user.id, + provider="gmail", + name=draft.get("name", ""), + data_types=[], + prompt_template=draft.get("prompt_template", ""), + filter_config=body.filter_config, + schedule_cron=_DEFAULT_CLOUD_SCHEDULE, + auto_trash_spam=draft.get("auto_trash_spam", False), + enabled=True, + oauth_token_encrypted=entry["token_encrypted"], + gmail_address=entry.get("gmail_address"), + ) + db.add(scout) + await db.commit() + await db.refresh(scout) + + # Best-effort Gmail push watch — failure must not block scout creation. + try: + connector = get_connector("gmail") + await connector.setup_watch(scout) + await db.commit() + except KeyError: + logger.warning("gmail connector not registered — skipping setup_watch for scout %s", scout.id) + except Exception: + logger.exception("setup_watch failed for scout %s", scout.id) + + return _to_cloud_response(scout) diff --git a/api/app/auth/__init__.py b/api/app/auth/__init__.py new file mode 100644 index 0000000..b45e86e --- /dev/null +++ b/api/app/auth/__init__.py @@ -0,0 +1 @@ +"OAuth provider abstractions and utilities." diff --git a/api/app/auth/oauth_providers.py b/api/app/auth/oauth_providers.py new file mode 100644 index 0000000..3363528 --- /dev/null +++ b/api/app/auth/oauth_providers.py @@ -0,0 +1,135 @@ +"""OAuth 2.0 + PKCE provider abstractions. + +Each provider implements a three-step flow designed for a desktop (public) client: + + 1. get_authorization_url(state, code_challenge) → str + Build the provider's consent-screen URL. State and code_challenge are + generated server-side; the client opens this URL in the system browser. + + 2. exchange_code(code, code_verifier, redirect_uri) → dict + Exchange the short-lived authorization code for an access token. + The code_verifier proves ownership of the PKCE challenge. + + 3. get_userinfo(access_token) → OAuthUserInfo + Fetch the canonical user identity from the provider. + +Currently supported providers: + - GoogleOAuthProvider (scope: openid email profile) + +Adding a new provider: + - Implement the three methods above. + - Register in _PROVIDERS inside routes/auth.py. +""" + +from __future__ import annotations + +import base64 +import hashlib +import os +import urllib.parse +from dataclasses import dataclass + +import httpx + + +# ── Data transfer objects ───────────────────────────────────────────── + + +@dataclass +class OAuthUserInfo: + """Normalized user identity returned by any provider.""" + + provider_user_id: str + email: str + email_verified: bool + avatar_url: str | None + name: str | None + + +# ── PKCE helpers ────────────────────────────────────────────────────── + + +def generate_pkce_pair() -> tuple[str, str]: + """Generate a (code_verifier, code_challenge) pair for PKCE S256. + + The code_verifier is a random 32-byte URL-safe base64 string. + The code_challenge is SHA-256(code_verifier) base64url-encoded (no padding). + """ + code_verifier = base64.urlsafe_b64encode(os.urandom(32)).rstrip(b"=").decode() + digest = hashlib.sha256(code_verifier.encode()).digest() + code_challenge = base64.urlsafe_b64encode(digest).rstrip(b"=").decode() + return code_verifier, code_challenge + + +# ── Google provider ─────────────────────────────────────────────────── + + +class GoogleOAuthProvider: + """Google OAuth 2.0 provider (openid email profile scope). + + Uses Google's standard authorization endpoint with PKCE S256. + Does NOT use google-auth-oauthlib to keep the flow generic and async. + """ + + name = "google" + + _AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth" + _TOKEN_URL = "https://oauth2.googleapis.com/token" + _USERINFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo" + + def __init__(self, client_id: str, client_secret: str, redirect_uri: str) -> None: + self.client_id = client_id + self.client_secret = client_secret + self.redirect_uri = redirect_uri + + def get_authorization_url(self, state: str, code_challenge: str) -> str: + """Build the Google consent-screen URL.""" + params = { + "client_id": self.client_id, + "redirect_uri": self.redirect_uri, + "response_type": "code", + "scope": "openid email profile", + "state": state, + "code_challenge": code_challenge, + "code_challenge_method": "S256", + "access_type": "offline", + "prompt": "select_account", + } + return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}" + + async def exchange_code( + self, code: str, code_verifier: str, redirect_uri: str + ) -> dict: + """Exchange authorization code for an access token.""" + async with httpx.AsyncClient() as client: + response = await client.post( + self._TOKEN_URL, + data={ + "client_id": self.client_id, + "client_secret": self.client_secret, + "code": code, + "code_verifier": code_verifier, + "grant_type": "authorization_code", + "redirect_uri": redirect_uri, + }, + ) + response.raise_for_status() + return response.json() + + async def get_userinfo(self, access_token: str) -> OAuthUserInfo: + """Fetch the authenticated user's identity from Google.""" + async with httpx.AsyncClient() as client: + response = await client.get( + self._USERINFO_URL, + headers={"Authorization": f"Bearer {access_token}"}, + ) + response.raise_for_status() + data = response.json() + + return OAuthUserInfo( + provider_user_id=data["sub"], + email=data["email"], + email_verified=data.get("email_verified", False), + avatar_url=data.get("picture"), + name=data.get("name"), + ) diff --git a/api/app/billing/__init__.py b/api/app/billing/__init__.py new file mode 100644 index 0000000..ef83f83 --- /dev/null +++ b/api/app/billing/__init__.py @@ -0,0 +1,4 @@ +from app.billing.stripe_service import stripe_service +from app.billing.tier_manager import tier_manager + +__all__ = ["stripe_service", "tier_manager"] diff --git a/api/app/billing/quota.py b/api/app/billing/quota.py new file mode 100644 index 0000000..f22767c --- /dev/null +++ b/api/app/billing/quota.py @@ -0,0 +1,139 @@ +"""Quota checks and atomic token-usage accounting for folder integration.""" +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime, timezone + +from sqlalchemy import select, update +from sqlalchemy.dialects.postgresql import insert as pg_insert +from sqlalchemy.ext.asyncio import AsyncSession + +from app.billing.tier_manager import TierManager +from app.models import MonthlyTokenUsage +from app.schemas import BillingTier + + +class QuotaExceeded(Exception): + """Raised when a folder operation cannot proceed under the user's tier.""" + + def __init__(self, reason: str, message: str) -> None: + super().__init__(message) + self.reason = reason # "max_files" | "monthly_tokens" + + +@dataclass +class TokenUsageResult: + tokens_used: int + exhausted: bool + + +def _current_year_month() -> str: + return datetime.now(timezone.utc).strftime("%Y-%m") + + +_tier_manager = TierManager() + + +async def check_folder_quota( + *, + user_id: str, + tier: BillingTier, + estimated_files: int, + db: AsyncSession, +) -> None: + """Raise QuotaExceeded if folder_max_files or folder_monthly_tokens + would be violated. -1 in either feature means unlimited.""" + max_files = _tier_manager.get_feature_value(tier, "folder_max_files") + if max_files != -1 and estimated_files > max_files: + raise QuotaExceeded( + "max_files", + f"Folder has {estimated_files} files; tier '{tier}' allows max {max_files}.", + ) + + cap = _tier_manager.get_feature_value(tier, "folder_monthly_tokens") + if cap == -1: + return + ym = _current_year_month() + row = ( + await db.execute( + select(MonthlyTokenUsage).where( + MonthlyTokenUsage.user_id == user_id, + MonthlyTokenUsage.year_month == ym, + MonthlyTokenUsage.feature == "folder_index", + ) + ) + ).scalar_one_or_none() + used = row.tokens_used if row else 0 + if used >= cap: + raise QuotaExceeded( + "monthly_tokens", + f"Monthly token budget exhausted ({used}/{cap}); resets next month.", + ) + + +async def add_token_usage( + *, + user_id: str, + feature: str, + tokens: int, + db: AsyncSession, + cap: int | None = None, +) -> TokenUsageResult: + """Atomically add `tokens` to MonthlyTokenUsage row for (user, current month, feature). + + Uses PostgreSQL ``INSERT … ON CONFLICT DO UPDATE`` when available; falls + back to a read-then-write on other engines (e.g. aiosqlite in tests). + Returns post-update total and whether cap is exhausted. + """ + ym = _current_year_month() + + # Detect dialect to choose between native upsert and portable fallback. + dialect_name: str = db.bind.dialect.name if db.bind is not None else "" # type: ignore[union-attr] + + if dialect_name == "postgresql": + # Native atomic upsert — production path. + stmt = ( + pg_insert(MonthlyTokenUsage) + .values( + user_id=user_id, + year_month=ym, + feature=feature, + tokens_used=tokens, + ) + .on_conflict_do_update( + index_elements=["user_id", "year_month", "feature"], + set_={"tokens_used": MonthlyTokenUsage.tokens_used + tokens}, + ) + .returning(MonthlyTokenUsage.tokens_used) + ) + used: int = (await db.execute(stmt)).scalar_one() + await db.commit() + else: + # Portable fallback — used in tests (SQLite) and any non-PG engine. + row = ( + await db.execute( + select(MonthlyTokenUsage).where( + MonthlyTokenUsage.user_id == user_id, + MonthlyTokenUsage.year_month == ym, + MonthlyTokenUsage.feature == feature, + ) + ) + ).scalar_one_or_none() + + if row is None: + row = MonthlyTokenUsage( + user_id=user_id, + year_month=ym, + feature=feature, + tokens_used=tokens, + ) + db.add(row) + else: + row.tokens_used += tokens + + await db.commit() + await db.refresh(row) + used = row.tokens_used + + exhausted = cap is not None and cap != -1 and used >= cap + return TokenUsageResult(tokens_used=used, exhausted=exhausted) diff --git a/api/app/billing/stripe_service.py b/api/app/billing/stripe_service.py new file mode 100644 index 0000000..19ccc08 --- /dev/null +++ b/api/app/billing/stripe_service.py @@ -0,0 +1,295 @@ +"""Stripe service: checkout sessions, webhook handling, subscription management. + +Subscription records are persisted in the PostgreSQL ``subscriptions`` table. +All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not +configured, enabling local development without live credentials. +""" + +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +import stripe as stripe_lib +from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.config.settings import settings + +# Stripe price IDs per tier — replace with real IDs in production .env +TIER_PRICE_IDS: dict[str, str] = { + "pro": "price_pro_monthly", + "power": "price_power_monthly", + "team": "price_team_monthly", +} + + +class StripeService: + """Wraps all Stripe interactions and owns subscription persistence.""" + + # ── Internal helpers ──────────────────────────────────────────────── + + def _configured(self) -> bool: + return bool(settings.STRIPE_SECRET_KEY) + + def _client(self) -> Any: + stripe_lib.api_key = settings.STRIPE_SECRET_KEY + return stripe_lib + + # ── Public API ────────────────────────────────────────────────────── + + def create_checkout_session( + self, + user_id: str, + tier: str, + success_url: str = "https://app.adiuvai.app/billing/success?session_id={CHECKOUT_SESSION_ID}", + cancel_url: str = "https://app.adiuvai.app/billing/cancel", + ) -> str: + """Create a Stripe checkout session and return the URL. + + Returns a stub URL when Stripe is not configured. + Raises ``HTTP 400`` for the free tier or an unknown tier. + """ + if tier == "free": + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Cannot create a checkout session for the free tier", + ) + + price_id = TIER_PRICE_IDS.get(tier) + if not price_id: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail=f"Unknown tier: {tier}", + ) + + if not self._configured(): + return "https://stripe.com/stub-checkout" + + s = self._client() + session = s.checkout.Session.create( + payment_method_types=["card"], + mode="subscription", + line_items=[{"price": price_id, "quantity": 1}], + success_url=success_url, + cancel_url=cancel_url, + metadata={"user_id": user_id, "tier": tier}, + ) + return session.url + + async def handle_webhook( + self, + payload: bytes, + sig_header: str, + db: AsyncSession, + ) -> None: + """Process a Stripe webhook event. + + Verifies the signature, then dispatches on event type. + Raises ``HTTP 400`` on signature mismatch. + No-ops when Stripe is not configured. + """ + if not self._configured(): + return + + try: + s = self._client() + event = s.Webhook.construct_event( + payload, sig_header, settings.STRIPE_WEBHOOK_SECRET + ) + except stripe_lib.error.SignatureVerificationError: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="Invalid Stripe signature", + ) + + event_type: str = event["type"] + data: dict[str, Any] = event["data"]["object"] + + if event_type == "checkout.session.completed": + user_id = data.get("metadata", {}).get("user_id") + tier = data.get("metadata", {}).get("tier", "free") + sub_id = data.get("subscription") + period_end_ts = data.get("current_period_end") + period_end = ( + datetime.fromtimestamp(period_end_ts, tz=timezone.utc) + if period_end_ts + else None + ) + if user_id: + await self._upsert_subscription( + db, user_id, sub_id, tier, "active", period_end + ) + + elif event_type == "customer.subscription.updated": + sub_id = data.get("id") + new_status = data.get("status", "active") + period_end_ts = data.get("current_period_end") + period_end = ( + datetime.fromtimestamp(period_end_ts, tz=timezone.utc) + if period_end_ts + else None + ) + if sub_id: + await self._update_subscription_by_stripe_id( + db, sub_id, status=new_status, current_period_end=period_end + ) + + elif event_type == "customer.subscription.deleted": + sub_id = data.get("id") + if sub_id: + await self._update_subscription_by_stripe_id( + db, sub_id, tier="free", status="canceled" + ) + + elif event_type == "invoice.payment_failed": + sub_id = data.get("subscription") + if sub_id: + await self._update_subscription_by_stripe_id( + db, sub_id, status="past_due" + ) + + await db.commit() + + async def get_subscription( + self, user_id: str, db: AsyncSession + ) -> dict[str, Any] | None: + """Return the subscription record for ``user_id``, or ``None`` if absent.""" + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription).where(Subscription.user_id == user_id) + ) + sub = result.scalar_one_or_none() + if sub is None: + return None + return { + "tier": sub.tier, + "stripe_subscription_id": sub.stripe_subscription_id, + "status": sub.status, + "current_period_end": ( + int(sub.current_period_end.timestamp() * 1000) + if sub.current_period_end + else None + ), + } + + async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None: + """Cancel the user's Stripe subscription and downgrade them to free. + + Raises ``HTTP 404`` when no active subscription exists. + """ + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription).where(Subscription.user_id == user_id) + ) + sub = result.scalar_one_or_none() + if sub is None or not sub.stripe_subscription_id: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="No active subscription found", + ) + + if self._configured(): + s = self._client() + s.Subscription.cancel(sub.stripe_subscription_id) + + sub.tier = "free" + sub.status = "canceled" + await db.commit() + + async def list_invoices( + self, user_id: str, db: AsyncSession, limit: int = 24 + ) -> list[dict[str, Any]]: + """Return recent invoices for the user from Stripe. + + Returns an empty list when Stripe is not configured or the user has + no ``stripe_customer_id``. + """ + if not self._configured(): + return [] + + from app.models import User # noqa: PLC0415 + + result = await db.execute( + select(User.stripe_customer_id).where(User.id == user_id) + ) + customer_id = result.scalar_one_or_none() + if not customer_id: + return [] + + try: + s = self._client() + invoices = s.Invoice.list(customer=customer_id, limit=limit) + return [ + { + "id": inv.id, + "amount_due": inv.amount_due, + "amount_paid": inv.amount_paid, + "currency": inv.currency, + "status": inv.status, + "created": inv.created * 1000, # epoch ms + "invoice_url": inv.hosted_invoice_url, + "invoice_pdf": inv.invoice_pdf, + } + for inv in invoices.auto_paging_iter() + ] + except Exception: + return [] + + # ── Private DB helpers ─────────────────────────────────────────────── + + async def _upsert_subscription( + self, + db: AsyncSession, + user_id: str, + stripe_subscription_id: str | None, + tier: str, + sub_status: str, + current_period_end: datetime | None, + ) -> None: + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription).where(Subscription.user_id == user_id) + ) + sub = result.scalar_one_or_none() + if sub is None: + sub = Subscription(user_id=user_id) + db.add(sub) + sub.stripe_subscription_id = stripe_subscription_id + sub.tier = tier + sub.status = sub_status + sub.current_period_end = current_period_end + + async def _update_subscription_by_stripe_id( + self, + db: AsyncSession, + stripe_subscription_id: str, + *, + tier: str | None = None, + status: str | None = None, + current_period_end: datetime | None = None, + ) -> None: + from app.models import Subscription # noqa: PLC0415 + + result = await db.execute( + select(Subscription).where( + Subscription.stripe_subscription_id == stripe_subscription_id + ) + ) + sub = result.scalar_one_or_none() + if sub is None: + return + if tier is not None: + sub.tier = tier + if status is not None: + sub.status = status + if current_period_end is not None: + sub.current_period_end = current_period_end + + +# Module-level singleton shared across the app. +stripe_service = StripeService() diff --git a/api/app/billing/tier_manager.py b/api/app/billing/tier_manager.py new file mode 100644 index 0000000..c09ce8d --- /dev/null +++ b/api/app/billing/tier_manager.py @@ -0,0 +1,149 @@ +"""Tier manager: feature matrix and quota enforcement. + +``TierManager`` is the single source of truth for what each billing tier +allows. ``get_tier`` queries the ``subscriptions`` table for the live tier. +Quota-enforcement helpers take ``tier`` directly — the caller already has it +from ``current_user.tier`` (provided by ``get_current_user``). +""" + +from __future__ import annotations + +from typing import Any + +from fastapi import HTTPException, status +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.schemas import BillingTier + +# Feature matrix per tier. -1 means unlimited; 0 means disabled. +FEATURES: dict[str, dict[str, Any]] = { + "free": { + "agents": 3, + "batch_active": 2, + "batch_runs_per_day": 5, + "providers": 1, + "batch_builder": False, + "sso": False, + "real_embeddings": False, # keyword fallback only + "realtime_extraction": False, # batch queue (Phase 2) + "relational_memory": False, # relational tier (Phase 3) — Pro+ + "proactive_mining": False, # Power+ only (Phase 5) + "folder_max_files": 200, + "folder_monthly_tokens": 100_000, + }, + "pro": { + "agents": -1, # unlimited + "batch_active": 10, + "batch_runs_per_day": 50, + "providers": -1, + "batch_builder": False, + "sso": False, + "real_embeddings": True, # pgvector cosine search + "realtime_extraction": True, # fire-and-forget asyncio.create_task + "relational_memory": True, # person/project predicates + "proactive_mining": False, # Power+ only (Phase 5) + "folder_max_files": 5000, + "folder_monthly_tokens": 2_000_000, + }, + "power": { + "agents": -1, + "batch_active": -1, # unlimited + "batch_runs_per_day": -1, # unlimited + "providers": -1, + "batch_builder": True, + "sso": False, + "real_embeddings": True, + "realtime_extraction": True, + "relational_memory": True, # all predicates incl. custom + "proactive_mining": True, # scheduled pattern mining (Phase 5) + "folder_max_files": -1, # unlimited + "folder_monthly_tokens": -1, # unlimited + }, + "team": { + "agents": -1, + "batch_active": -1, + "batch_runs_per_day": -1, # unlimited + "providers": -1, + "batch_builder": True, + "sso": True, + "real_embeddings": True, + "realtime_extraction": True, + "relational_memory": True, # all predicates incl. custom + "proactive_mining": True, # scheduled pattern mining (Phase 5) + "folder_max_files": -1, # unlimited + "folder_monthly_tokens": -1, # unlimited + }, +} + +# Requests-per-minute limit per tier. +RATE_LIMITS: dict[str, int] = { + "free": 20, + "pro": 60, + "power": 120, + "team": 200, +} + + +class TierManager: + """Centralises tier feature-gating, rate-limit lookups, and quota checks.""" + + # ── Tier lookup ───────────────────────────────────────────────────── + + async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier: + """Return the current billing tier for ``user_id`` from the DB. + + Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod + when no subscription row exists. + """ + from app.models import Subscription # noqa: PLC0415 + from app.config.settings import settings # noqa: PLC0415 + + result = await db.execute( + select(Subscription.tier).where(Subscription.user_id == user_id) + ) + tier: str | None = result.scalar_one_or_none() + if tier is None or tier not in FEATURES: + return "power" if settings.ENV == "dev" else "free" + return tier # type: ignore[return-value] + + # ── Feature access ─────────────────────────────────────────────────── + + def check_feature(self, tier: BillingTier, feature: str) -> bool: + """Return ``True`` if ``tier`` has ``feature`` enabled. + + For numeric features, any value > 0 or -1 (unlimited) counts as enabled. + """ + value = FEATURES.get(tier, FEATURES["free"]).get(feature) + if value is None: + return False + if isinstance(value, bool): + return value + return value != 0 + + def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None: + """Raise ``HTTP 403`` if ``tier`` does not have ``feature``.""" + if not self.check_feature(tier, feature): + detail = ( + f"Feature '{feature}' requires {tier_name} tier or above." + if tier_name + else f"Feature '{feature}' is not available on your current tier." + ) + raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) + + def get_feature_value(self, tier: BillingTier, feature: str) -> int: + """Return integer feature value for tier. -1 means unlimited.""" + value = FEATURES.get(tier, FEATURES["free"]).get(feature) + if not isinstance(value, int): + return 0 + return value + + # ── Rate limiting ──────────────────────────────────────────────────── + + def get_rate_limit(self, tier: BillingTier) -> int: + """Return the requests-per-minute limit for ``tier``.""" + return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) + + +# Module-level singleton shared across the app. +tier_manager = TierManager() diff --git a/api/app/config/__init__.py b/api/app/config/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/app/config/settings.py b/api/app/config/settings.py new file mode 100644 index 0000000..f3ede2c --- /dev/null +++ b/api/app/config/settings.py @@ -0,0 +1,95 @@ +from typing import Literal +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class Settings(BaseSettings): + DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuvai" + JWT_SECRET: str = "change-me-in-production" + JWT_ALGORITHM: str = "HS256" + JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30 + JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30 + + STRIPE_SECRET_KEY: str = "" + STRIPE_WEBHOOK_SECRET: str = "" + + OPENAI_API_KEY: str = "" + ANTHROPIC_API_KEY: str = "" + GOOGLE_API_KEY: str = "" + CEREBRAS_API_KEY: str = "" + GROQ_API_KEY: str = "" + DEEPSEEK_API_KEY: str = "" + + LLM_MODEL: str = "gpt-4o" + LLM_EMBED_MODEL: str = "text-embedding-3-small" + + # Per-agent model overrides. Leave empty to fall back to LLM_MODEL. + LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use) + LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream) + LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner) + LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner) + LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs) + LLM_MODEL_TASK_BRIEF_AGENT: str = "" # task-brief-agent (per-task deep research) + LLM_MODEL_SETUP_AGENT: str = "" # agent-setup journey + LLM_MODEL_MEMORY_EXTRACTOR: str = "" # memory-extractor (Phase 2 extract/decide) + LLM_MODEL_MEMORY_MINER: str = "" # memory-miner (Phase 5 proactive mining) + LLM_MODEL_MEMORY_AUDITOR: str = "" # memory-auditor (Phase 7 weekly audit) + + # GitHub Copilot OAuth token storage directory. + # Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot). + # In Docker, set this to a path backed by a named volume so tokens survive restarts. + GITHUB_COPILOT_TOKEN_DIR: str = "" + + # OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows. + GMAIL_CLIENT_ID: str = "" + GMAIL_CLIENT_SECRET: str = "" + MS_CLIENT_ID: str = "" + MS_CLIENT_SECRET: str = "" + # MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts). + MS_TENANT_ID: str = "common" + + # Google Login OAuth credentials — scope: openid email profile. + # Separate from GMAIL_CLIENT_ID/SECRET (which uses gmail.readonly scope). + GOOGLE_AUTH_CLIENT_ID: str = "" + GOOGLE_AUTH_CLIENT_SECRET: str = "" + # The redirect URI registered in Google Cloud Console. + # Google redirects here after consent; this backend route then bounces to + # the adiuvai:// deep link so the Electron app receives the code. + # Dev: http://localhost:8000/api/v1/auth/oauth/google/web-callback + # Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback + OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback" + + # Gmail Pub/Sub topic for push notifications. + # Full resource name, e.g. "projects/my-project/topics/gmail-push". + # Leave empty in dev — setup_watch will skip registration gracefully. + GMAIL_PUBSUB_TOPIC: str = "" + # OIDC token audience for Pub/Sub push subscription JWT verification. + # Set to the service account email or audience string configured in the + # Pub/Sub push subscription. Leave empty in dev to skip verification + # (a warning is logged — never silent in production). + GMAIL_PUBSUB_AUDIENCE: str = "" + + # Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth + # tokens stored in cloud_agent_configs.oauth_token_encrypted. + # Generate with: from cryptography.fernet import Fernet; Fernet.generate_key() + OAUTH_ENCRYPTION_KEY: str = "" + + CORS_ORIGINS: list[str] = [ + "app://.", + "http://localhost:3000", + "http://localhost:5173", + "http://localhost:4173", # Vite preview (web SPA) + "https://app.adiuvai.com", # Production web portal + ] + + LANGFUSE_SECRET_KEY: str = "" + LANGFUSE_PUBLIC_KEY: str = "" + LANGFUSE_BASE_URL: str = "https://cloud.langfuse.com" + + SCHEDULER_ENABLED: bool = True + + ENV: Literal["dev", "prod"] = "dev" + + model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8", extra="ignore") + + +settings = Settings() diff --git a/api/app/core/__init__.py b/api/app/core/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/api/app/core/brief_agent.py b/api/app/core/brief_agent.py new file mode 100644 index 0000000..954f890 --- /dev/null +++ b/api/app/core/brief_agent.py @@ -0,0 +1,228 @@ +"""Brief agent — produces plain-text home and project status briefs. + +Read-only tool subset only. Never calls _normalize_tagged_list_lines — +the brief prompt forbids XML tags, so skipping post-processing is intentional. +""" + +from __future__ import annotations + +from collections.abc import AsyncGenerator +from datetime import date +from typing import Any + +from app.agents.note_agent import NOTE_READ_TOOLS +from app.agents.project_agent import PROJECT_READ_TOOLS +from app.agents.task_agent import TASK_READ_TOOLS +from app.agents.timeline_agent import TIMELINE_READ_TOOLS +from app.core.deep_agent import ( + _language_instruction, + _proactive_hints_injection, + _read_only_memory_tools, + _relational_memory_injection, + _run_single_agent_stream, + _trace_id_from_context, + build_brief_multi_project_manifest, +) +from app.core.langfuse_client import compile_prompt, get_prompt_or_fallback + +_LANGUAGE_NAMES: dict[str, str] = { + "en": "English", "it": "Italian", "es": "Spanish", + "fr": "French", "de": "German", + "english": "English", "italian": "Italian", "italiano": "Italian", + "spanish": "Spanish", "español": "Spanish", + "french": "French", "français": "French", + "german": "German", "deutsch": "German", +} + +_HOME_BRIEF_FALLBACK = """\ +You are the user's personal assistant producing a short daily brief. + +ROLE +Act like a calm, attentive secretary writing a stand-up note for your boss. +Warm and human, never breezy. Never cheerful filler, never emojis, never +"here is your brief" meta-text. The user is opening the app mid-workday and +is probably stressed — your job is to lower cognitive load, not add noise. + +TOOLS — always call before writing +Pull fresh data every run. Do not invent counts or titles. Use at minimum: +- list_tasks_due_today — tasks the user owes today +- list_timelines_today — events starting or ending today +- list_all_projects — projects currently in progress or at risk +- memory_list_blocks / memory_get — personal context about people, clients, + payment habits, working preferences +If a tool returns nothing, simply omit that topic. Never report zeros. + +WHAT TO INCLUDE +1. Tasks due today (title + priority; group the 1-2 most important). +2. Timeline events starting or ending today (and anything that starts/ends + tomorrow if the user has a very light day). +3. Active projects that need a nudge — stalled, blocked, or awaiting input. +4. Memory-aware colour where it sharpens the brief. Examples: + - "Client Rossi tends to pay late — the Acme invoice is 6 days out." + - "You usually dislike meetings before 10:00 — the call at 09:30 is unusual." + Only add a memory line when it changes what the user does. Do not pad. + +WHAT TO OMIT +- Zero-counts ("no overdue items", "0 meetings today"). +- Statistics ("2 active projects, 3 completed tasks"). +- Headers, titles, greetings, sign-offs, dates, emojis, slang. +- Meta-phrases ("here is", "let me know if", "hope this helps"). +- XML/HTML tags of any kind. Plain prose only. + +LIGHT-DAY CLAUSE +If tasks + events + active-project-nudges together produce fewer than two +sentences of content, also list 1-2 projects in status on_hold or waiting +and ask a single, specific question about them — e.g. "Is the Bianchi +redesign still paused, or ready to pick back up?" One question max, grounded +in a real project name. + +VOICE +- Calm. Concise. Human. Short sentences. +- Use **bold** sparingly for task titles, project names, and people's names. +- No bullet lists. Flow as 2-4 sentences of prose. + +LENGTH +2-4 sentences total. Hard cap 4. If the day is truly empty, one sentence. + +Respond in the user's language ({language}). Today is {today}.\ +""" + +_PROJECT_BRIEF_FALLBACK = """\ +You are the project assistant producing a short status brief for ONE project. + +ROLE +A senior project manager summarising state-of-play for the owner. Factual, +sharp, forward-looking. Never reassuring filler, never emojis. + +SCOPE +Work only with project_id = {project_id}. Do not mention or pull data from +other projects. Use tools to fetch fresh data: +- get_project — current status, dates, description +- list_tasks(project_id) — open work, split by status +- list_timelines(project_id) — milestones hit, upcoming, overdue +- list_notes(project_id) — any recent decisions or blockers +- memory_get — relevant context about the client, collaborators, constraints + +STRUCTURE — follow exactly, one short paragraph per section, no headers +1. **State.** One sentence: current phase, health (on track / at risk / blocked), + and why. Cite the concrete signal (overdue milestone, stalled tasks, recent + blocker note). +2. **What's moving.** What was completed or progressed recently. Name specific + tasks or milestones. +3. **Next steps.** The 1-3 most important things the user should do next, in + priority order. Be concrete — task name, who owns it, when due if known. + If waiting on someone else, name them and what the ask is. +4. **Risks / memory-flagged items.** One line max. Only include when there is + a real risk or a relevant memory (e.g. late-paying client, tight deadline, + scope change). Omit the section entirely if nothing to say. + +WHAT TO OMIT +- Zero-counts ("no overdue tasks"). +- Generic advice ("keep up the good work"). +- Greetings, headers, bullet lists, emojis, sign-offs, meta-phrases. +- XML/HTML tags or bracketed id lists. Plain prose only. + +VOICE +- Direct. Factual. No fluff. +- Use **bold** sparingly for task titles, milestone names, and the owner's name. +- Short sentences. Prefer verbs over nouns ("Client review is blocking release" + not "There is a blocker which is the client review"). + +LENGTH +4-8 sentences total across the 3-4 sections. Hard cap 8. + +Respond in the user's language ({language}). Today is {today}.\ +""" + + +def _resolve_language(context: dict[str, Any]) -> str: + core = context.get("core_memory") or {} + raw = (core.get("language") or "en").strip().lower() + return _LANGUAGE_NAMES.get(raw, raw.title()) or "English" + + +def _build_read_tools(user_id: str, trace_id: str | None) -> list[Any]: + return [ + *TASK_READ_TOOLS, + *PROJECT_READ_TOOLS, + *TIMELINE_READ_TOOLS, + *NOTE_READ_TOOLS, + *_read_only_memory_tools(user_id, trace_id), + ] + + +async def run_home_brief( + user_id: str, + context: dict[str, Any], +) -> AsyncGenerator[tuple[str, Any], None]: + """Stream a plain-text daily home brief. + + Yields (event_type, data) tuples identical to _run_single_agent_stream. + Do NOT post-process output through _normalize_tagged_list_lines. + """ + from app.agents.folder_agent import FOLDER_TOOLS + + trace_id = _trace_id_from_context(context) + today = date.today().isoformat() + language = _resolve_language(context) + + raw_template, langfuse_prompt = get_prompt_or_fallback("home_brief", _HOME_BRIEF_FALLBACK) + system_prompt = compile_prompt(raw_template, langfuse_prompt, language=language, today=today) + system_prompt += _relational_memory_injection(context) + system_prompt += _proactive_hints_injection(context) + system_prompt += _language_instruction(context) + if today not in system_prompt: + system_prompt += f"\nToday is {today}." + + brief_manifest = await build_brief_multi_project_manifest() + system_prompt = system_prompt + ("\n\n" + brief_manifest if brief_manifest else "") + + tools = [*_build_read_tools(user_id, trace_id), *FOLDER_TOOLS] + async for event in _run_single_agent_stream( + user_id=user_id, + system_prompt=system_prompt, + message="Generate the daily brief.", + context=context, + langfuse_prompt=langfuse_prompt, + agent_name="brief-agent", + tools=tools, + ): + yield event + + +async def run_project_brief( + user_id: str, + project_id: str, + context: dict[str, Any], +) -> AsyncGenerator[tuple[str, Any], None]: + """Stream a plain-text project status brief for project_id. + + Yields (event_type, data) tuples identical to _run_single_agent_stream. + Do NOT post-process output through _normalize_tagged_list_lines. + """ + trace_id = _trace_id_from_context(context) + today = date.today().isoformat() + language = _resolve_language(context) + + raw_template, langfuse_prompt = get_prompt_or_fallback("project_brief", _PROJECT_BRIEF_FALLBACK) + system_prompt = compile_prompt( + raw_template, langfuse_prompt, + language=language, today=today, project_id=project_id, + ) + system_prompt += _relational_memory_injection(context) + system_prompt += _proactive_hints_injection(context) + system_prompt += _language_instruction(context) + if today not in system_prompt: + system_prompt += f"\nToday is {today}." + + tools = _build_read_tools(user_id, trace_id) + async for event in _run_single_agent_stream( + user_id=user_id, + system_prompt=system_prompt, + message=f"Generate the project status brief for project {project_id}.", + context=context, + langfuse_prompt=langfuse_prompt, + agent_name="brief-agent", + tools=tools, + ): + yield event diff --git a/api/app/core/deep_agent.py b/api/app/core/deep_agent.py new file mode 100644 index 0000000..1a91c6b --- /dev/null +++ b/api/app/core/deep_agent.py @@ -0,0 +1,1329 @@ +"""Single-agent runners for home and contextual chat contexts.""" + +from __future__ import annotations + +import json +import logging +import re +from datetime import date +from collections.abc import AsyncGenerator +from typing import Any + +from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage +from langchain_core.tools import tool + +from app.agents.client_agent import CLIENT_TOOLS +from app.agents.note_agent import NOTE_TOOLS +from app.agents.project_agent import PROJECT_TOOLS +from app.agents.relations_agent import make_query_relations_tool +from app.agents.task_agent import TASK_TOOLS +from app.agents.timeline_agent import TIMELINE_TOOLS +from app.core.scout_session_buffer import session_buffer +from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context +from app.core.llm import get_agent_llm, model_for_agent +from app.core.memory_middleware import MemoryMiddleware +from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector +from app.db import async_session + +logger = logging.getLogger(__name__) + +MAX_HISTORY_TURNS = 20 + +# Mapping of core-memory language values to natural-language names for prompts. +_LANGUAGE_NAMES: dict[str, str] = { + "en": "English", "it": "Italian", "es": "Spanish", + "fr": "French", "de": "German", + "english": "English", "italian": "Italian", "italiano": "Italian", + "spanish": "Spanish", "español": "Spanish", + "french": "French", "français": "French", + "german": "German", "deutsch": "German", +} + + +def _language_instruction(context: dict[str, Any]) -> str: + """Return a system-prompt suffix that tells the LLM to respond in the user's language. + + Returns an empty string when the language is English or unknown — saves tokens. + """ + core = context.get("core_memory") or {} + raw = (core.get("language") or "").strip().lower() + if not raw: + return "" + lang = _LANGUAGE_NAMES.get(raw, raw.title()) # best-effort capitalisation + if lang.lower() == "english": + return "" + return ( + f"\n\nIMPORTANT: Always respond in {lang}. " + f"All your output text must be written in {lang}." + ) + +MANIFEST_TOKEN_BUDGET = 3000 # rough budget for block + + +def format_folder_manifest(manifest: dict | None) -> str: + """Format a folder manifest into the block. + + Truncates by mtime DESC if estimated tokens exceed MANIFEST_TOKEN_BUDGET. + Returns empty string if manifest is None or has no files. + """ + if not manifest or not manifest.get("files"): + return "" + files = list(manifest["files"]) + files.sort(key=lambda f: f.get("mtimeMs", 0), reverse=True) + + header = ( + f"\npath: {manifest.get('folderPath', '?')} " + f"({len(files)} files, scanned {manifest.get('lastScannedAt', '?')})\nfiles:\n" + ) + footer_template = "… {} more files omitted, use read_project_folder_file to access by path\n" + + char_budget = MANIFEST_TOKEN_BUDGET * 4 # ~4 chars/token + body = "" + included = 0 + for f in files: + line = f"- /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}\n" + if len(header) + len(body) + len(line) + len(footer_template.format(0)) > char_budget: + break + body += line + included += 1 + omitted = len(files) - included + if omitted > 0: + return header + body + footer_template.format(omitted) + return header + body + "" + + +async def _fetch_project_manifest(project_id: str) -> dict | None: + """Fetch manifest from Electron via execute_on_client. Returns None if unlinked or error.""" + from app.core.ws_context import execute_on_client + try: + result = await execute_on_client( + action="read_project_folder_manifest", + data={"projectId": project_id}, + ) + if not result or not result.get("folderPath"): + return None + return result + except Exception: + return None + + +async def build_brief_multi_project_manifest() -> str: + """Build a compact multi-project manifest for the daily brief agent. + + Calls execute_on_client('list_projects_with_folder_manifests') and keeps + the top 5 most-recently-modified files per project. + """ + try: + result = await execute_on_client( + action="list_projects_with_folder_manifests", + data={}, + ) + except Exception: + return "" + projects = (result or {}).get("projects") or [] + if not projects: + return "" + blocks: list[str] = [""] + any_entry = False + for p in projects: + all_files = p.get("files", []) or [] + files = sorted(all_files, key=lambda f: f.get("mtimeMs", 0), reverse=True)[:5] + blocks.append(f"project: {p.get('projectName','?')} [{p.get('projectId','?')}]") + blocks.append(f" path: {p.get('folderPath','?')} (scanned {p.get('lastScannedAt','?')})") + if not all_files: + blocks.append(" (no indexed files yet — folder is linked but empty or unscanned)") + else: + for f in files: + blocks.append(f" - /{f['relPath']} [{f.get('kind','text')}] {f.get('summary','')}") + if len(all_files) > 5: + blocks.append(f" … {len(all_files) - 5} more files (use read_project_folder_file by relPath)") + any_entry = True + if not any_entry: + return "" + blocks.append("") + return "\n".join(blocks) + + +def _datetime_context_injection(context: dict[str, Any]) -> str: + """Build a comprehensive DATE CONTEXT block with pre-computed ms-epoch boundaries for common ranges.""" + fp = context.get("format_prefs") + if not fp or not isinstance(fp, dict): + return "" + try: + from zoneinfo import ZoneInfo + from datetime import datetime as _dt, timezone as _utc, timedelta as _td + + tz_name: str = str(fp.get("timezone") or "UTC") + now_iso: str = str(fp.get("now_iso") or "") + date_fmt: str = str(fp.get("date_format") or "dd/MM/yyyy") + time_fmt: str = str(fp.get("time_format") or "24h") + + tz = ZoneInfo(tz_name) + if now_iso: + now_utc = _dt.fromisoformat(now_iso.replace("Z", "+00:00")) + else: + now_utc = _dt.now(_utc.utc) + + now_ms = int(now_utc.timestamp() * 1000) + now_local = now_utc.astimezone(tz) + now_local_str = now_local.strftime("%Y-%m-%d %H:%M") + weekday_str = now_local.strftime("%A") + y, m, d = now_local.year, now_local.month, now_local.day + + def _day(year: int, month: int, day: int) -> tuple[int, int]: + s = _dt(year, month, day, tzinfo=tz) + e = s + _td(days=1) + return int(s.timestamp() * 1000), int(e.timestamp() * 1000) - 1 + + def _between(start: "_dt", end_excl: "_dt") -> tuple[int, int]: + return int(start.timestamp() * 1000), int(end_excl.timestamp() * 1000) - 1 + + today_s, today_e = _day(y, m, d) + yd = now_local - _td(days=1) + yesterday_s, yesterday_e = _day(yd.year, yd.month, yd.day) + tm = now_local + _td(days=1) + tomorrow_s, tomorrow_e = _day(tm.year, tm.month, tm.day) + + # ISO week (Mon–Sun) + monday = _dt(y, m, d, tzinfo=tz) - _td(days=now_local.weekday()) + last_monday = monday - _td(weeks=1) + next_monday = monday + _td(weeks=1) + this_week_s, this_week_e = _between(monday, next_monday) + last_week_s, last_week_e = _between(last_monday, monday) + next_week_s, next_week_e = _between(next_monday, next_monday + _td(weeks=1)) + + # Calendar months + this_m_start = _dt(y, m, 1, tzinfo=tz) + next_m_start = _dt(y + (m // 12), m % 12 + 1, 1, tzinfo=tz) + last_m_start = _dt(y - (1 if m == 1 else 0), 12 if m == 1 else m - 1, 1, tzinfo=tz) + next2_m = next_m_start.month % 12 + 1 + next2_y = next_m_start.year + (1 if next_m_start.month == 12 else 0) + next2_m_start = _dt(next2_y, next2_m, 1, tzinfo=tz) + this_month_s, this_month_e = _between(this_m_start, next_m_start) + last_month_s, last_month_e = _between(last_m_start, this_m_start) + next_month_s, next_month_e = _between(next_m_start, next2_m_start) + + # Calendar years + this_yr_s, this_yr_e = _between(_dt(y, 1, 1, tzinfo=tz), _dt(y + 1, 1, 1, tzinfo=tz)) + last_yr_s, last_yr_e = _between(_dt(y - 1, 1, 1, tzinfo=tz), _dt(y, 1, 1, tzinfo=tz)) + + sunday = monday + _td(days=6) + last_sunday = last_monday + _td(days=6) + next_sunday = next_monday + _td(days=6) + + return ( + f"\n\nDATE CONTEXT (timezone: {tz_name}, dateFormat: {date_fmt}, timeFormat: {time_fmt})\n" + f"now_local: {now_local_str} ({weekday_str})\n" + f"now_ms: {now_ms}\n\n" + f"today [{today_s}, {today_e}] {y:04d}-{m:02d}-{d:02d}\n" + f"tomorrow [{tomorrow_s}, {tomorrow_e}] {tm.strftime('%Y-%m-%d')}\n" + f"yesterday [{yesterday_s}, {yesterday_e}] {yd.strftime('%Y-%m-%d')}\n" + f"this_week [{this_week_s}, {this_week_e}] {monday.strftime('%Y-%m-%d')} → {sunday.strftime('%Y-%m-%d')} (Mon–Sun)\n" + f"last_week [{last_week_s}, {last_week_e}] {last_monday.strftime('%Y-%m-%d')} → {last_sunday.strftime('%Y-%m-%d')}\n" + f"next_week [{next_week_s}, {next_week_e}] {next_monday.strftime('%Y-%m-%d')} → {next_sunday.strftime('%Y-%m-%d')}\n" + f"this_month [{this_month_s}, {this_month_e}] {y:04d}-{m:02d}\n" + f"last_month [{last_month_s}, {last_month_e}] {last_m_start.strftime('%Y-%m')}\n" + f"next_month [{next_month_s}, {next_month_e}] {next_m_start.strftime('%Y-%m')}\n" + f"this_year [{this_yr_s}, {this_yr_e}] {y:04d}\n" + f"last_year [{last_yr_s}, {last_yr_e}] {y - 1:04d}\n\n" + f"When calling list_tasks_due_today or list_timelines_today, always pass user_timezone=\"{tz_name}\".\n" + f"When presenting dates, format using dateFormat={date_fmt} and timeFormat={time_fmt}." + ) + except Exception: + return "" + + +def _proactive_hints_injection(context: dict[str, Any]) -> str: + """Return a system-prompt paragraph listing proactive behavioral hints. + + Returns empty string when no hints or confidence below threshold. + Capped at 600 chars. + """ + hints: list[str] = context.get("proactive_hints") or [] + if not hints: + return "" + body = "\n".join(f"- {h}" for h in hints) + section = f"\n\nI noticed (behavioral patterns):\n{body}" + if len(section) > 600: + section = section[:597] + "..." + return section + + +def _relational_memory_injection(context: dict[str, Any]) -> str: + """Return a system-prompt paragraph listing known people/projects from relational memory. + + Returns empty string when no relational rows or tier is Free. + Capped at 800 chars to control token spend. + """ + relations: list[str] = context.get("relational_memory") or [] + if not relations: + return "" + body = "\n".join(f"- {r}" for r in relations) + section = f"\n\nKnown people & projects:\n{body}" + if len(section) > 800: + section = section[:797] + "..." + return section + + +_IDENTITY_KEYS = ("user_name", "job_role", "industry", "primary_use_case", "tone_preference") + + +def _user_identity_injection(context: dict[str, Any]) -> str: + """Return a compact user-profile block from core memory onboarding fields. + + Returns empty string when no onboarding keys are present. + """ + core = context.get("core_memory") or {} + parts: list[str] = [] + for key in _IDENTITY_KEYS: + val = (core.get(key) or "").strip() + if val: + parts.append(f"- {key}: {val}") + if not parts: + return "" + return "\n\nUser profile:\n" + "\n".join(parts) + + +def _request_context_block(context: dict[str, Any]) -> str: + """Return a small block with per-request scope and resolved project context.""" + parts: list[str] = [] + scope = context.get("scope") + if scope and isinstance(scope, dict): + parts.append(f"scope: {json.dumps(scope, ensure_ascii=True)}") + resolved = context.get("resolved_project_id") + if resolved and isinstance(resolved, str): + parts.append(f"resolved_project_id: {resolved}") + return "\n".join(parts) + + +_HOME_SYSTEM_PROMPT = """\ +You are adiuvAI's home executive assistant.{user_identity} +You are not a chatbot — you are a proactive partner who runs ahead of the user, anticipates what they need next, and closes every reply with a concrete next step or a clarifying question. + +# How you work +- Use tools before answering anything factual. Never guess counts, dates, or status. +- Prefer parallel tool calls when the questions are independent (e.g. counts per status). Chain calls when one result feeds the next. +- After delivering the answer, propose the next useful action: a follow-up task to draft, a deadline at risk, a project to triage, a person to remind. Use what you know about the user (job role, industry, primary use case) to make the suggestion relevant. +- Match the user's tone preference. Default to warm-but-direct; stay concise. +- When the user asks to remember, forget, or update something, use memory tools. + +# Filter discipline +- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine"). +- The user's own name in the User profile block is for context only — it is NOT a default filter. +- When in doubt, omit `assignee` and return the global result. + +# Output format +Return markdown. Reference entities with these tags exactly — one id per tag, each tag on its own line, no prefix/suffix text on the same line: + id id id id + +When the answer contains a list of entities (any of the tags above), structure the reply as three blocks separated by blank lines: + 1. One short intro line stating what is coming (count + scope, e.g. "Ecco i tuoi 18 task ad alta priorità:"). Match the user's language. + 2. All entity tags, one per line, consecutive, no prose interleaved. Do NOT put titles, dates, priorities, or any descriptive text on the same line as a tag or between tags. + 3. One short closing recap (1–2 sentences) that points out a pattern, risk, or insight noticed in the list, and ends with a concrete next step or clarifying question. + +For single-entity answers skip blocks 1 and 3 if they would be redundant; just emit the tag. + +For analytical answers (status overviews, breakdowns by category/priority/project, comparisons, trends, "resoconto", "panoramica") consider returning a chart block when it communicates the answer faster than prose. The decision is yours — skip charts for trivial single-number answers. Schema: + {{"chartType":"pie|bar|line|area|radar|radial","title":"...","data":[{{"name":"...","value":N}},...], "config":{{"value":{{"label":"...","color":"var(--chart-1)"}} }} }} +- pie for share-of-total breakdowns; bar for category comparisons; line/area for time series; radar for multi-dimension. +- data rows must include a "name" field; numeric series keys must match config keys. +- Use var(--chart-1) through var(--chart-5) for colors, cycling 1-5 in series order. Do NOT wrap in hsl() or oklch() — these are complete CSS values already. + +For upcoming-timeline questions ("prossimi eventi"), include only future items in the current month unless the user asks otherwise. + +# Date filtering +{date_context} + +When filtering tasks/timelines/notes by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself. +For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms. +For "today" / "tomorrow" queries, prefer list_tasks_due_today / list_timelines_today with user_timezone from DATE CONTEXT. + +# Language +{language_instruction} + +# Known people & projects +{relational_memory} + +# Behavioral hints +{proactive_hints} + +# Request context +{request_context}\ +""" + +_CONTEXTUAL_SYSTEM_PROMPT = """You are adiuvAI's contextual assistant. The user is working inside the app and has opened a side chat anchored to a specific view ("current view"). Help them act on that view: recap, plan, create entities, answer questions. + +Rules: +1. Base context (current view summary) is provided every turn. Treat it as ground truth for ids and names; never invent them. +2. ALL reads go through `get_page_details`. The legacy tools `list_projects`, `get_project`, `list_tasks`, `get_task`, `list_notes`, `get_note` are NOT available in this channel — do not attempt to call them. To find an entity by name, call `get_page_details({entityType: 'projects_all' | 'tasks_all' | 'timeline_all'})` to list, then `get_page_details({entityType: '', entityId})` for the full snapshot. +3. When the user requests an action that creates or updates an entity: + - If the current view is a project and no project is specified, use the current project automatically. + - If the current view is the global Tasks / Projects / Timeline list and no project is specified, ASK before attaching to any project. Don't silently create orphan entities. +4. The current view can change mid-conversation (user navigates). When you see a system message "User navigated to ...", treat the new view as the active context. Prior turns remain visible but the active scope shifts. +5. Notes: you can read note bodies via `get_page_details({entityType:'note'})`. You CANNOT edit, summarize-to-replace, or append. Tell the user "note editing is coming in a later release" if asked. +6. Be concise. Default to 1-3 short paragraphs. Bullet lists fine. Don't restate the user's request. +7. Never expose ids in prose. Use names. Ids only travel through tool calls. + +# Date context +{date_context} + +# Language +{language_instruction} +""" + +_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT = """\ +You are an executive assistant preparing a briefing dossier for your principal before they act on a specific task. +Your job: gather all relevant context, synthesize it into a tight actionable dossier, and — if the task requires writing (email, message, document) — produce a ready-to-use draft.{user_identity} + +# Research workflow +Follow these steps in order, using tools: +1. Read the task fully (title, description, due date, priority, status, project, comments). +2. Fetch the parent project (`get_project`) to understand scope, aiSummary, and any linked client. +3. If the project has a clientId: call `get_client(id)` to retrieve full client details. +4. Call `query_relations` (subject_label=client_name or task subject) to find cross-project connections — e.g. the same client appearing in multiple projects. +5. Search associative memory (`search_associative`) and archival memory (`archival_memory_search`) using the task title + client name as query phrases to surface relevant past interactions. +6. Read core memory blocks for tone preference, language, and user style: `memory_get("tone_preference")`, `memory_get("language")`. +7. Determine task kind: is this a writing task (email reply, message, follow-up, proposal)? If yes, draft a ready-to-send piece. + +# Output structure +Write the briefing in the user's language. Use this exact structure: + +**What needs to be done** +(1–2 sentences, concrete and specific — what action the user must take) + +**Context you should know** +(bullet points covering: client background, related projects, prior interactions, tone/style notes, any relevant deadlines or dependencies) + +**Suggested first step** +(one specific, immediately actionable instruction) + +If this is a writing task, append a canvas block at the very end: + +...ready-to-use draft here... + + +Do NOT include the canvas block for non-writing tasks. +Do NOT repeat verbatim task fields the user already sees in the UI. +Be concrete — no vague advice. Every bullet should be a fact that changes what the user does. + +# Date context +{date_context} + +# Language +{language_instruction} + +# Known people & projects +{relational_memory} + +# Request context +{request_context}\ +""" + +_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT = """\ +You are an executive assistant continuing a conversation with your principal. +You have already prepared and delivered a research briefing for the active task. The user has read it.{user_identity} + +Your briefing: +--- +{briefing_context} +--- + +Continue from here. Do NOT repeat the briefing. Refer to it when relevant. +Help the user execute: edit drafts, refine wording, look up additional details, plan next steps. +Stay terse — your principal is a busy executive. + +# Date context +{date_context} + +# Language +{language_instruction} + +# Known people & projects +{relational_memory} + +# Request context +{request_context}\ +""" + +def _as_text(content: Any) -> str: + if content is None: + return "" + if isinstance(content, str): + return content + if isinstance(content, list): + parts: list[str] = [] + for item in content: + if isinstance(item, str): + parts.append(item) + elif isinstance(item, dict): + text = item.get("text") + if isinstance(text, str): + parts.append(text) + return "".join(parts) + return str(content) + + +def _candidate_tokens(message: str) -> list[str]: + tokens = re.findall(r"[a-zA-Z0-9_-]+", message.lower()) + return [token for token in tokens if len(token) >= 3] + + +async def _resolve_project_id_from_message(message: str) -> str | None: + """Resolve likely project UUID from user message using client project list.""" + try: + result = await execute_on_client(action="select", table="projects") + except Exception as exc: + logger.warning("deep_agent: project resolve select failed: %s", exc) + return None + + rows = result.get("rows", []) + if not isinstance(rows, list) or not rows: + return None + + tokens = _candidate_tokens(message) + scored: list[tuple[int, dict[str, Any]]] = [] + for row in rows: + if not isinstance(row, dict): + continue + name = str(row.get("name", "")).lower() + score = sum(1 for token in tokens if token in name) + if score > 0: + scored.append((score, row)) + + if not scored: + return None + + scored.sort(key=lambda item: item[0], reverse=True) + top_score = scored[0][0] + top_rows = [row for score, row in scored if score == top_score] + if len(top_rows) != 1: + return None + + project_id = top_rows[0].get("id") + return project_id if isinstance(project_id, str) else None + + +def _needs_project_resolution(message: str) -> bool: + lowered = message.lower() + return any(keyword in lowered for keyword in ["project", "progetto", "progetti", "whitelist"]) + + +async def _prepare_context(message: str, context: dict[str, Any]) -> dict[str, Any]: + prepared = dict(context) + if _needs_project_resolution(message): + resolved_project_id = await _resolve_project_id_from_message(message) + if resolved_project_id: + prepared["resolved_project_id"] = resolved_project_id + logger.info("deep_agent: resolved_project_id=%s", resolved_project_id) + return prepared + + +def _all_tools() -> list[Any]: + return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS] + + +# ── Contextual sidebar tools ────────────────────────────────────────── + + +@tool +async def get_page_details( + entity_type: str = "", + entity_id: str = "", +) -> str: + """Fetch full details for the entity currently in view. + + entity_type: one of 'project' | 'task' | 'note' | 'timeline_event' | + 'tasks_all' | 'projects_all' | 'timeline_all'. + entity_id: UUID of the entity for singular entity views. Omit for list views. + + The Electron drizzle-executor fulfils this op against local SQLite and + returns the row(s) as a JSON tool result. + """ + result = await execute_on_client( + action="get_page_details", + table=entity_type or "unknown", + data={"entityId": entity_id or None}, + ) + if not result: + return "No details found." + return str(result) + + +def _contextual_tools(user_id: str, trace_id: str | None) -> list[Any]: + """Return the tool palette for the contextual sidebar agent. + + Read ops go through get_page_details only — legacy list_*/get_* tools + return shallow snapshots and cause the agent to under-answer (see + smoke trace 0b46841484ba7d024ed9f8d5ac8b1df0). Writes are limited + to entity creation + task update; note edits are next-sprint. + """ + from app.agents.note_agent import create_note # noqa: PLC0415 + from app.agents.task_agent import create_task, update_task # noqa: PLC0415 + from app.agents.timeline_agent import create_timeline # noqa: PLC0415 + + return [ + get_page_details, + create_task, + update_task, + create_note, + create_timeline, + *_memory_tools(user_id, trace_id), + ] + + +def _trace_id_from_context(context: dict[str, Any]) -> str | None: + debug = context.get("_debug") + if isinstance(debug, dict): + request_id = debug.get("request_id") + if isinstance(request_id, str) and request_id: + return request_id + return None + + +def _session_id_from_context(context: dict[str, Any]) -> str | None: + debug = context.get("_debug") + if isinstance(debug, dict): + session_id = debug.get("session_id") + if isinstance(session_id, str) and session_id: + return session_id + return None + + +def _build_system_prompt(name: str, fallback: str, context: dict[str, Any]) -> tuple[str, Any]: + """Fetch Langfuse template and compile all per-request slots into one system prompt.""" + template, prompt_obj = get_prompt_or_fallback(name, fallback) + text = compile_prompt( + template, prompt_obj, + date_context=_datetime_context_injection(context).strip(), + language_instruction=_language_instruction(context).strip(), + user_identity=_user_identity_injection(context).strip(), + relational_memory=_relational_memory_injection(context).strip(), + proactive_hints=_proactive_hints_injection(context).strip(), + request_context=_request_context_block(context), + ) + return text, prompt_obj + + +_TAG_LINE_RE = re.compile(r"<(task|timeline)>\[[^\]]+\]") +_TIMELINE_DMY_RE = re.compile(r"(?P\d{2})/(?P\d{2})/(?P\d{4})") + + +def _is_upcoming_timeline_query(message: str) -> bool: + lowered = message.lower() + has_upcoming = "prossim" in lowered or "upcoming" in lowered or "next" in lowered + has_timeline_topic = any( + token in lowered + for token in ("event", "evento", "eventi", "timeline", "milestone", "scaden") + ) + return has_upcoming and has_timeline_topic + + +def _timeline_date_in_current_month_or_future(dmy: str) -> bool: + match = _TIMELINE_DMY_RE.search(dmy) + if not match: + return True + try: + parsed = date( + int(match.group("y")), + int(match.group("m")), + int(match.group("d")), + ) + except ValueError: + return True + + today = date.today() + return parsed >= today and parsed.year == today.year and parsed.month == today.month + + +def _normalize_tagged_list_lines(text: str, message: str) -> str: + if not text: + return text + + upcoming_timeline_only = _is_upcoming_timeline_query(message) + output_lines: list[str] = [] + + for line in text.splitlines(): + matches = list(_TAG_LINE_RE.finditer(line)) + if not matches: + output_lines.append(line) + continue + + had_non_tag_text = _TAG_LINE_RE.sub("", line).strip(" -\t0123456789.*:)") + if not had_non_tag_text and len(matches) == 1: + tag_text = matches[0].group(0) + if ( + upcoming_timeline_only + and "" in tag_text + and not _timeline_date_in_current_month_or_future(line) + ): + continue + output_lines.append(tag_text) + continue + + for match in matches: + tag_text = match.group(0) + if ( + upcoming_timeline_only + and "" in tag_text + and not _timeline_date_in_current_month_or_future(line) + ): + continue + output_lines.append(tag_text) + + return "\n".join(output_lines) + + +def _normalize_memory_label(path_or_label: str) -> str: + value = path_or_label.strip() + if value.startswith("/memories/"): + value = value[len("/memories/"):] + value = value.strip("/") + return value + + +def _memory_tools(user_id: str, trace_id: str | None) -> list[Any]: + @tool + async def memory_list_blocks() -> str: + """List all core memory blocks currently stored for the user.""" + logger.info("deep_agent: memory_list_blocks trace=%s user=%s", trace_id or "-", user_id) + async with async_session() as db: + memory = MemoryMiddleware(db) + blocks = await memory.list_core_blocks(user_id) + if not blocks: + return "No memory blocks found." + lines = [f"- {b['label']}: {b['value']}" for b in blocks] + return "Memory blocks:\n" + "\n".join(lines) + + @tool + async def memory_get(path_or_label: str) -> str: + """Get one memory block by label or /memories/