Compare commits
8 Commits
feature/mi
...
706bf88883
| Author | SHA1 | Date | |
|---|---|---|---|
| 706bf88883 | |||
| 4ff0b27084 | |||
| 61d2a18234 | |||
| b3687719b6 | |||
| f80bdfa8f7 | |||
| 617a17db40 | |||
| 92716cb89a | |||
| cfc9d7a942 |
31
.env.example
31
.env.example
@@ -4,19 +4,9 @@ ENV=dev
|
|||||||
# ── Database ──────────────────────────────────────────────────────────────────
|
# ── Database ──────────────────────────────────────────────────────────────────
|
||||||
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
|
||||||
|
|
||||||
# ── Redis ─────────────────────────────────────────────────────────────────────
|
# ── Auth ──────────────────────────────────────────────────────────────────────
|
||||||
REDIS_URL=redis://localhost:6379/0
|
JWT_SECRET=replace-with-a-long-random-secret
|
||||||
|
JWT_ALGORITHM=HS256
|
||||||
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
|
|
||||||
# Generate keypair:
|
|
||||||
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
|
||||||
# openssl rsa -in private.pem -pubout -out public.pem
|
|
||||||
# Paste PEM content with literal \n for newlines.
|
|
||||||
#
|
|
||||||
# Private key — ONLY used by the Auth Service (JWT signing).
|
|
||||||
JWT_PRIVATE_KEY=
|
|
||||||
# Public key — used by all services / Traefik ForwardAuth (JWT verification).
|
|
||||||
JWT_PUBLIC_KEY=
|
|
||||||
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
|
||||||
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
|
||||||
|
|
||||||
@@ -27,6 +17,7 @@ OPENAI_API_KEY=
|
|||||||
ANTHROPIC_API_KEY=
|
ANTHROPIC_API_KEY=
|
||||||
GOOGLE_API_KEY=
|
GOOGLE_API_KEY=
|
||||||
LLM_MODEL=gpt-4o
|
LLM_MODEL=gpt-4o
|
||||||
|
LLM_ROUTER_MODEL=gpt-4o-mini
|
||||||
|
|
||||||
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
# ── Stripe (leave empty to stub billing) ──────────────────────────────────────
|
||||||
STRIPE_SECRET_KEY=
|
STRIPE_SECRET_KEY=
|
||||||
@@ -51,17 +42,3 @@ QDRANT_API_KEY=
|
|||||||
# ── CORS ──────────────────────────────────────────────────────────────────────
|
# ── CORS ──────────────────────────────────────────────────────────────────────
|
||||||
# Comma-separated list parsed by Settings (override default if needed)
|
# Comma-separated list parsed by Settings (override default if needed)
|
||||||
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
# CORS_ORIGINS=["app://.","http://localhost:3000"]
|
||||||
|
|
||||||
# ── Langfuse (observability) ─────────────────────────────────────────────────
|
|
||||||
LANGFUSE_SECRET_KEY=sk-lf-...
|
|
||||||
LANGFUSE_PUBLIC_KEY=pk-lf-...
|
|
||||||
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL
|
|
||||||
|
|
||||||
# ── Cloudflare (Traefik ACME DNS-01 challenge) ───────────────────────────────
|
|
||||||
CF_DNS_API_TOKEN=
|
|
||||||
ACME_EMAIL=
|
|
||||||
|
|
||||||
# ── PostgreSQL (used by docker-compose) ──────────────────────────────────────
|
|
||||||
POSTGRES_USER=postgres
|
|
||||||
POSTGRES_PASSWORD=postgres
|
|
||||||
POSTGRES_DB=adiuva
|
|
||||||
6
.gitignore
vendored
6
.gitignore
vendored
@@ -13,9 +13,6 @@ env/
|
|||||||
# Environment variables
|
# Environment variables
|
||||||
.env
|
.env
|
||||||
|
|
||||||
# Cryptographic keys
|
|
||||||
*.pem
|
|
||||||
|
|
||||||
# IDE
|
# IDE
|
||||||
.vscode/
|
.vscode/
|
||||||
.idea/
|
.idea/
|
||||||
@@ -35,6 +32,3 @@ Thumbs.db
|
|||||||
# Claude Code
|
# Claude Code
|
||||||
.claude/
|
.claude/
|
||||||
logs/
|
logs/
|
||||||
|
|
||||||
# Eval private test data
|
|
||||||
services/batch-agent/eval/fixtures/private_data/
|
|
||||||
|
|||||||
@@ -3,34 +3,37 @@ FROM python:3.12-slim AS builder
|
|||||||
|
|
||||||
WORKDIR /build
|
WORKDIR /build
|
||||||
|
|
||||||
COPY services/chat/requirements.txt ./requirements.txt
|
COPY requirements.txt .
|
||||||
RUN pip install --upgrade pip && \
|
RUN pip install --upgrade pip && \
|
||||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
||||||
|
|
||||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
# ── runtime ──────────────────────────────────────────────────────────────────
|
||||||
FROM python:3.12-slim AS runtime
|
FROM python:3.12-slim AS runtime
|
||||||
|
|
||||||
|
# Non-root user
|
||||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
||||||
|
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
|
# Copy installed packages from builder
|
||||||
COPY --from=builder /install /usr/local
|
COPY --from=builder /install /usr/local
|
||||||
|
|
||||||
# Shared module
|
# Copy application source
|
||||||
COPY shared/ shared/
|
COPY app/ app/
|
||||||
|
|
||||||
# Service source
|
# Copy Alembic migration files
|
||||||
COPY services/chat/app/ app/
|
COPY alembic/ alembic/
|
||||||
|
COPY alembic.ini .
|
||||||
|
|
||||||
|
# Ensure appuser owns the working directory
|
||||||
RUN chown -R appuser:appgroup /app
|
RUN chown -R appuser:appgroup /app
|
||||||
|
|
||||||
USER appuser
|
USER appuser
|
||||||
|
|
||||||
EXPOSE 8000
|
EXPOSE 8000
|
||||||
|
|
||||||
# Chat service is CPU-bound (LLM calls) — use multiple workers
|
|
||||||
CMD ["gunicorn", "app.main:app", \
|
CMD ["gunicorn", "app.main:app", \
|
||||||
"-k", "uvicorn.workers.UvicornWorker", \
|
"-k", "uvicorn.workers.UvicornWorker", \
|
||||||
"--bind", "0.0.0.0:8000", \
|
"--bind", "0.0.0.0:8000", \
|
||||||
"--workers", "2", \
|
"--workers", "4", \
|
||||||
"--timeout", "120"]
|
"--timeout", "120"]
|
||||||
@@ -739,7 +739,7 @@ adiuva-api/
|
|||||||
│ │
|
│ │
|
||||||
│ ├── core/ # Orchestration engine
|
│ ├── core/ # Orchestration engine
|
||||||
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
|
||||||
│ │ ├── llm.py # LiteLLM factory (get_llm)
|
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm)
|
||||||
│ │ ├── orchestrator.py # Intent classification & routing
|
│ │ ├── orchestrator.py # Intent classification & routing
|
||||||
│ │ └── execution_plan.py # Plan builder, templates, cache
|
│ │ └── execution_plan.py # Plan builder, templates, cache
|
||||||
│ │
|
│ │
|
||||||
|
|||||||
@@ -1,92 +0,0 @@
|
|||||||
"""Deprecate backend agent config tables.
|
|
||||||
|
|
||||||
The Electron client is now the source of truth for agent configuration
|
|
||||||
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
|
||||||
billing checks and trigger/run logs only.
|
|
||||||
|
|
||||||
Revision ID: 9a1f2d0b6c7e
|
|
||||||
Revises: 818478c251dc
|
|
||||||
Create Date: 2026-03-16
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Sequence, Union
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
|
||||||
from alembic import op
|
|
||||||
from sqlalchemy.dialects import postgresql
|
|
||||||
|
|
||||||
revision: str = "9a1f2d0b6c7e"
|
|
||||||
down_revision: Union[str, None] = "818478c251dc"
|
|
||||||
branch_labels: Union[str, Sequence[str], None] = None
|
|
||||||
depends_on: Union[str, Sequence[str], None] = None
|
|
||||||
|
|
||||||
|
|
||||||
def upgrade() -> None:
|
|
||||||
bind = op.get_bind()
|
|
||||||
inspector = sa.inspect(bind)
|
|
||||||
existing = set(inspector.get_table_names())
|
|
||||||
|
|
||||||
if "cloud_agent_configs" in existing:
|
|
||||||
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
|
||||||
op.drop_table("cloud_agent_configs")
|
|
||||||
|
|
||||||
if "local_agent_configs" in existing:
|
|
||||||
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
|
||||||
op.drop_table("local_agent_configs")
|
|
||||||
|
|
||||||
|
|
||||||
def downgrade() -> None:
|
|
||||||
op.create_table(
|
|
||||||
"local_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("device_id", sa.String(255), nullable=False),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
|
||||||
|
|
||||||
op.execute(
|
|
||||||
"""
|
|
||||||
DO $$ BEGIN
|
|
||||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
|
||||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
|
||||||
END $$;
|
|
||||||
"""
|
|
||||||
)
|
|
||||||
|
|
||||||
op.create_table(
|
|
||||||
"cloud_agent_configs",
|
|
||||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
|
||||||
sa.Column(
|
|
||||||
"provider",
|
|
||||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
|
||||||
nullable=False,
|
|
||||||
),
|
|
||||||
sa.Column("name", sa.String(255), nullable=False),
|
|
||||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
|
||||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
|
||||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
|
||||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
|
||||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
|
||||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
|
||||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
|
||||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
|
||||||
sa.PrimaryKeyConstraint("id"),
|
|
||||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
|
||||||
)
|
|
||||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
|
||||||
5
app/agents/__init__.py
Normal file
5
app/agents/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
"""Agent tool modules — imported by deep_agent.py to build sub-agent graphs."""
|
||||||
|
|
||||||
|
from app.agents import timeline_agent, note_agent, project_agent, task_agent
|
||||||
|
|
||||||
|
__all__ = ["timeline_agent", "note_agent", "project_agent", "task_agent"]
|
||||||
@@ -1,50 +1,22 @@
|
|||||||
"""Note agent — Markdown note management (list, get, create, update, delete).
|
"""Note agent — tool definitions for Markdown note CRUD."""
|
||||||
|
|
||||||
Shared tool definitions used by both Chat and Batch Agent services.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from shared.llm import embed
|
from app.core.llm import embed
|
||||||
from shared.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
NOTE_SYSTEM_PROMPT = (
|
|
||||||
"You are a note-taking assistant. You help users create, retrieve, update,\n"
|
|
||||||
"and delete Markdown notes in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - content is always Markdown; preserve formatting when updating\n"
|
|
||||||
" - project_id is optional; link a note to a project when mentioned\n"
|
|
||||||
" - When updating, call get_note first if you need to read existing content\n"
|
|
||||||
" before appending or replacing sections\n"
|
|
||||||
" - list_notes without project_id returns all notes; scope with project_id\n"
|
|
||||||
" when the user is working within a specific project\n"
|
|
||||||
" - project_id must be a UUID; if you only know a project name, do not pass it as project_id\n"
|
|
||||||
" - Do not fabricate note content — reflect what the user provides or what\n"
|
|
||||||
" is already in the note (retrieved via get_note)."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_notes(project_id: str = "") -> str:
|
async def list_notes(project_id: str = "") -> str:
|
||||||
"""List notes, optionally scoped to a project by project_id."""
|
"""List notes, optionally scoped to a project by project_id."""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="notes",
|
table="notes",
|
||||||
filters={"projectId": normalized_project_id or None},
|
filters={"projectId": project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -133,10 +105,4 @@ async def delete_note(note_id: str) -> str:
|
|||||||
return f"Note {note_id} deleted."
|
return f"Note {note_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
NOTE_TOOLS: list[Any] = [
|
|
||||||
list_notes,
|
|
||||||
get_note,
|
|
||||||
create_note,
|
|
||||||
update_note,
|
|
||||||
delete_note,
|
|
||||||
]
|
|
||||||
@@ -1,7 +1,4 @@
|
|||||||
"""Project agent — full lifecycle management (list, get, create, update, archive, delete).
|
"""Project agent — tool definitions for project lifecycle CRUD."""
|
||||||
|
|
||||||
Shared tool definitions used by both Chat and Batch Agent services.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
@@ -9,23 +6,7 @@ from typing import Any
|
|||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from shared.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
PROJECT_SYSTEM_PROMPT = (
|
|
||||||
"You are a project management assistant. You help users create, find,\n"
|
|
||||||
"update, and archive projects in their workspace.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: active, archived\n"
|
|
||||||
" - client_id is optional; link to a client only when explicitly mentioned\n"
|
|
||||||
" - ai_summary is populated only when the user asks for a project summary;\n"
|
|
||||||
" derive it from context data — do not fabricate content\n"
|
|
||||||
" - Use list_projects for scoped queries; list_all_projects only when the\n"
|
|
||||||
" user wants a complete cross-client view including archived projects\n"
|
|
||||||
" - get_project requires a project UUID; resolve the ID first by calling\n"
|
|
||||||
" list_projects if you only have a project name\n"
|
|
||||||
" - Prefer archiving (update_project status=archived) over deletion;\n"
|
|
||||||
" only call delete_project when the user explicitly confirms deletion."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -136,11 +117,4 @@ async def delete_project(project_id: str) -> str:
|
|||||||
return f"Project {project_id} permanently deleted."
|
return f"Project {project_id} permanently deleted."
|
||||||
|
|
||||||
|
|
||||||
PROJECT_TOOLS: list[Any] = [
|
|
||||||
list_projects,
|
|
||||||
list_all_projects,
|
|
||||||
get_project,
|
|
||||||
create_project,
|
|
||||||
update_project,
|
|
||||||
delete_project,
|
|
||||||
]
|
|
||||||
@@ -1,41 +1,13 @@
|
|||||||
"""Task agent — full CRUD for tasks and task comments.
|
"""Task agent — tool definitions for task and task comment CRUD."""
|
||||||
|
|
||||||
Shared tool definitions used by both Chat and Batch Agent services.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from shared.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
TASK_SYSTEM_PROMPT = (
|
|
||||||
"You are a task management assistant for a project workspace.\n"
|
|
||||||
"You create, update, list, and track tasks and their comments.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - status must be one of: todo, in_progress, done\n"
|
|
||||||
" - priority must be one of: high, medium, low\n"
|
|
||||||
" - due_date is a Unix timestamp in milliseconds; convert human dates\n"
|
|
||||||
" - assignees is a JSON-encoded array of strings (e.g. '[\"Alice\",\"Bob\"]')\n"
|
|
||||||
" - project_id is optional; link to a project when the user mentions one\n"
|
|
||||||
" - is_ai_suggested: 1 only when proactively proposing a task the user\n"
|
|
||||||
" did not explicitly request; 0 otherwise\n"
|
|
||||||
" - Use list_tasks_due_today for 'what's due today' queries\n"
|
|
||||||
" - For update_task, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Always confirm the action in plain, user-friendly language."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Task tools ────────────────────────────────────────────────────────
|
# ── Task tools ────────────────────────────────────────────────────────
|
||||||
@@ -50,12 +22,11 @@ async def list_tasks(
|
|||||||
) -> str:
|
) -> str:
|
||||||
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
"""List tasks, optionally filtered by project_id, status (todo|in_progress|done),
|
||||||
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
a search string, or an order_by field name (dueDate|priority|createdAt)."""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
filters={
|
filters={
|
||||||
"projectId": normalized_project_id or None,
|
"projectId": project_id or None,
|
||||||
"status": status or None,
|
"status": status or None,
|
||||||
"search": search or None,
|
"search": search or None,
|
||||||
"orderBy": order_by or None,
|
"orderBy": order_by or None,
|
||||||
@@ -81,6 +52,7 @@ async def create_task(
|
|||||||
due_date: int = 0,
|
due_date: int = 0,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a new task.
|
"""Create a new task.
|
||||||
title: task title (required)
|
title: task title (required)
|
||||||
@@ -91,6 +63,7 @@ async def create_task(
|
|||||||
due_date: Unix timestamp in milliseconds; 0 means no due date
|
due_date: Unix timestamp in milliseconds; 0 means no due date
|
||||||
project_id: optional UUID of the parent project
|
project_id: optional UUID of the parent project
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms; 1 when confirmed
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -104,6 +77,7 @@ async def create_task(
|
|||||||
"dueDate": due_date or None,
|
"dueDate": due_date or None,
|
||||||
"projectId": project_id or None,
|
"projectId": project_id or None,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -123,10 +97,12 @@ async def update_task(
|
|||||||
assignees: str = "",
|
assignees: str = "",
|
||||||
due_date: int = -1,
|
due_date: int = -1,
|
||||||
project_id: str = "",
|
project_id: str = "",
|
||||||
|
is_approved: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update fields on an existing task. Only pass fields you want to change.
|
"""Update fields on an existing task. Only pass fields you want to change.
|
||||||
task_id: the task's UUID (required)
|
task_id: the task's UUID (required)
|
||||||
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
due_date: -1 means unchanged; 0 clears the due date; any positive value sets it
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the value
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
@@ -143,6 +119,8 @@ async def update_task(
|
|||||||
updates["dueDate"] = due_date or None
|
updates["dueDate"] = due_date or None
|
||||||
if project_id:
|
if project_id:
|
||||||
updates["projectId"] = project_id
|
updates["projectId"] = project_id
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="tasks",
|
table="tasks",
|
||||||
@@ -210,11 +188,8 @@ async def add_task_comment(task_id: str, author: str, content: str) -> str:
|
|||||||
table="taskComments",
|
table="taskComments",
|
||||||
data={"taskId": task_id, "author": author, "content": content},
|
data={"taskId": task_id, "author": author, "content": content},
|
||||||
)
|
)
|
||||||
row = result.get("row", {})
|
row = result["row"]
|
||||||
row_author = row.get("author", author)
|
return f"Comment added by {row['author']} on task {row['taskId']} (comment id: {row['id']})."
|
||||||
row_task_id = row.get("taskId") or row.get("task_id") or task_id
|
|
||||||
row_comment_id = row.get("id", "unknown")
|
|
||||||
return f"Comment added by {row_author} on task {row_task_id} (comment id: {row_comment_id})."
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
@@ -224,16 +199,4 @@ async def delete_task_comment(comment_id: str) -> str:
|
|||||||
return f"Comment {comment_id} deleted."
|
return f"Comment {comment_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
# ── Exports ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
TASK_TOOLS: list[Any] = [
|
|
||||||
list_tasks,
|
|
||||||
create_task,
|
|
||||||
update_task,
|
|
||||||
delete_task,
|
|
||||||
list_tasks_due_today,
|
|
||||||
list_task_comments,
|
|
||||||
add_task_comment,
|
|
||||||
delete_task_comment,
|
|
||||||
]
|
|
||||||
@@ -1,47 +1,21 @@
|
|||||||
"""Timeline agent — project milestone management (list, create, update, delete).
|
"""Timeline agent — tool definitions for project milestone CRUD."""
|
||||||
|
|
||||||
Shared tool definitions used by both Chat and Batch Agent services.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import re
|
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
from langchain_core.tools import tool
|
||||||
|
|
||||||
from shared.ws_context import execute_on_client
|
from app.core.ws_context import execute_on_client
|
||||||
|
|
||||||
_UUID_RE = re.compile(
|
|
||||||
r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[1-5][0-9a-fA-F]{3}-[89abAB][0-9a-fA-F]{3}-[0-9a-fA-F]{12}$"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _is_uuid(value: str) -> bool:
|
|
||||||
return bool(_UUID_RE.match(value))
|
|
||||||
|
|
||||||
TIMELINE_SYSTEM_PROMPT = (
|
|
||||||
"You are a project timeline assistant. Timelines are milestone dates that\n"
|
|
||||||
"track progress on a project — they are not calendar events.\n\n"
|
|
||||||
"Rules:\n"
|
|
||||||
" - project_id is REQUIRED for every create; confirm with the user if unknown\n"
|
|
||||||
" - For listing, project_id must be a UUID; never pass plain names as project_id\n"
|
|
||||||
" - date is a Unix timestamp in milliseconds; convert human-readable dates\n"
|
|
||||||
" - is_ai_suggested: 1 when proactively proposing a timeline, 0 otherwise\n"
|
|
||||||
" - For update_timeline, use -1 for integer fields you do not want to change\n"
|
|
||||||
" - Listing without a project_id returns all timelines across projects\n"
|
|
||||||
" - Always echo the title and formatted date in your confirmation."
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
@tool
|
||||||
async def list_timelines(project_id: str = "") -> str:
|
async def list_timelines(project_id: str = "") -> str:
|
||||||
"""List timelines. Provide project_id to scope to a specific project."""
|
"""List timelines. Provide project_id to scope to a specific project."""
|
||||||
normalized_project_id = project_id if (project_id and _is_uuid(project_id)) else ""
|
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="select",
|
action="select",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
filters={"projectId": normalized_project_id or None},
|
filters={"projectId": project_id or None},
|
||||||
)
|
)
|
||||||
rows = result.get("rows", [])
|
rows = result.get("rows", [])
|
||||||
if not rows:
|
if not rows:
|
||||||
@@ -56,12 +30,14 @@ async def create_timeline(
|
|||||||
title: str,
|
title: str,
|
||||||
date: int,
|
date: int,
|
||||||
is_ai_suggested: int = 0,
|
is_ai_suggested: int = 0,
|
||||||
|
is_approved: int = 0,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a project timeline (milestone).
|
"""Create a project timeline (milestone).
|
||||||
project_id: REQUIRED UUID of the parent project
|
project_id: REQUIRED UUID of the parent project
|
||||||
title: descriptive name for the milestone
|
title: descriptive name for the milestone
|
||||||
date: Unix timestamp in milliseconds
|
date: Unix timestamp in milliseconds
|
||||||
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
is_ai_suggested: 1 if proactively suggested, 0 if user-requested
|
||||||
|
is_approved: 0 until the user confirms
|
||||||
"""
|
"""
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="insert",
|
action="insert",
|
||||||
@@ -71,6 +47,7 @@ async def create_timeline(
|
|||||||
"title": title,
|
"title": title,
|
||||||
"date": date,
|
"date": date,
|
||||||
"isAiSuggested": is_ai_suggested,
|
"isAiSuggested": is_ai_suggested,
|
||||||
|
"isApproved": is_approved,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
row = result["row"]
|
row = result["row"]
|
||||||
@@ -82,16 +59,20 @@ async def update_timeline(
|
|||||||
timeline_id: str,
|
timeline_id: str,
|
||||||
title: str = "",
|
title: str = "",
|
||||||
date: int = -1,
|
date: int = -1,
|
||||||
|
is_approved: int = -1,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Update a timeline. Only pass fields that should change.
|
"""Update a timeline. Only pass fields that should change.
|
||||||
timeline_id: UUID of the timeline (required)
|
timeline_id: UUID of the timeline (required)
|
||||||
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
date: -1 means unchanged; any other value sets the new date (ms timestamp)
|
||||||
|
is_approved: -1 means unchanged; 0 or 1 sets the approval state
|
||||||
"""
|
"""
|
||||||
updates: dict[str, Any] = {}
|
updates: dict[str, Any] = {}
|
||||||
if title:
|
if title:
|
||||||
updates["title"] = title
|
updates["title"] = title
|
||||||
if date != -1:
|
if date != -1:
|
||||||
updates["date"] = date
|
updates["date"] = date
|
||||||
|
if is_approved != -1:
|
||||||
|
updates["isApproved"] = is_approved
|
||||||
result = await execute_on_client(
|
result = await execute_on_client(
|
||||||
action="update",
|
action="update",
|
||||||
table="timelines",
|
table="timelines",
|
||||||
@@ -108,9 +89,4 @@ async def delete_timeline(timeline_id: str) -> str:
|
|||||||
return f"Timeline {timeline_id} deleted."
|
return f"Timeline {timeline_id} deleted."
|
||||||
|
|
||||||
|
|
||||||
TIMELINE_TOOLS: list[Any] = [
|
|
||||||
list_timelines,
|
|
||||||
create_timeline,
|
|
||||||
update_timeline,
|
|
||||||
delete_timeline,
|
|
||||||
]
|
|
||||||
14
app/api/deps.py
Normal file
14
app/api/deps.py
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
"""Shared FastAPI dependencies.
|
||||||
|
|
||||||
|
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
|
||||||
|
(the canonical location per Step 9). This module re-exports them so that all
|
||||||
|
existing route imports (``from app.api.deps import get_current_user``) continue
|
||||||
|
to work without modification.
|
||||||
|
|
||||||
|
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
|
||||||
|
instead of reading it from the JWT payload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
|
||||||
|
|
||||||
|
__all__ = ["get_current_user", "oauth2_scheme"]
|
||||||
19
app/api/middleware/__init__.py
Normal file
19
app/api/middleware/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
"""API middleware package.
|
||||||
|
|
||||||
|
Exports the three middleware components introduced in Step 9:
|
||||||
|
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
|
||||||
|
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
|
||||||
|
- Sanitizer: ``SanitizerMiddleware``
|
||||||
|
"""
|
||||||
|
|
||||||
|
from app.api.middleware.auth import get_current_user, oauth2_scheme
|
||||||
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
|
||||||
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
"get_current_user",
|
||||||
|
"oauth2_scheme",
|
||||||
|
"TierRateLimitMiddleware",
|
||||||
|
"limiter",
|
||||||
|
"SanitizerMiddleware",
|
||||||
|
]
|
||||||
@@ -1,7 +1,14 @@
|
|||||||
"""Auth dependencies — JWT validation for the Auth Service.
|
"""Auth middleware — JWT validation dependency.
|
||||||
|
|
||||||
This is the canonical get_current_user used by protected endpoints
|
``get_current_user`` is the FastAPI dependency used by all protected routes.
|
||||||
within the Auth Service itself (/me, /me PUT).
|
It decodes the Bearer JWT (identity + expiry), then fetches the current tier
|
||||||
|
from the ``subscriptions`` table so that tier changes take effect immediately
|
||||||
|
without requiring token re-issue.
|
||||||
|
|
||||||
|
Exempt routes (no JWT required):
|
||||||
|
- POST /api/v1/auth/register
|
||||||
|
- POST /api/v1/auth/login
|
||||||
|
- POST /api/v1/billing/webhook
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -12,12 +19,9 @@ from jose import JWTError, jwt
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from shared.config import settings
|
from app.config.settings import settings
|
||||||
from shared.db import get_session
|
from app.db import get_session
|
||||||
from shared.models import Subscription, User
|
from app.schemas import UserProfile
|
||||||
from shared.schemas import UserProfile
|
|
||||||
|
|
||||||
from app.config import auth_settings
|
|
||||||
|
|
||||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
|
||||||
|
|
||||||
@@ -28,8 +32,11 @@ async def get_current_user(
|
|||||||
) -> UserProfile:
|
) -> UserProfile:
|
||||||
"""Validate a Bearer JWT and return the authenticated user.
|
"""Validate a Bearer JWT and return the authenticated user.
|
||||||
|
|
||||||
The JWT is used for identity and expiry. Tier is fetched live from the
|
The JWT is used for identity and expiry only. The tier is fetched live
|
||||||
subscriptions table so upgrades/downgrades take effect immediately.
|
from the ``subscriptions`` table so that upgrades/downgrades take effect
|
||||||
|
immediately. Falls back to ``'free'`` when no subscription row exists.
|
||||||
|
|
||||||
|
Raises HTTP 401 on any invalid or expired token.
|
||||||
"""
|
"""
|
||||||
credentials_exc = HTTPException(
|
credentials_exc = HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
@@ -38,7 +45,7 @@ async def get_current_user(
|
|||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
payload = jwt.decode(
|
payload = jwt.decode(
|
||||||
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
)
|
)
|
||||||
user_id: str | None = payload.get("sub")
|
user_id: str | None = payload.get("sub")
|
||||||
email: str | None = payload.get("email")
|
email: str | None = payload.get("email")
|
||||||
@@ -47,14 +54,15 @@ async def get_current_user(
|
|||||||
except JWTError:
|
except JWTError:
|
||||||
raise credentials_exc
|
raise credentials_exc
|
||||||
|
|
||||||
# Live tier lookup
|
# Live tier lookup — subscription row is the authoritative source.
|
||||||
|
from app.models import Subscription, User # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
tier: str = result.scalar_one_or_none() or "free"
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
# Fetch name/surname
|
# Fetch name/surname from user row.
|
||||||
user_result = await db.execute(
|
user_result = await db.execute(
|
||||||
select(User.name, User.surname).where(User.id == user_id)
|
select(User.name, User.surname).where(User.id == user_id)
|
||||||
)
|
)
|
||||||
129
app/api/middleware/rate_limit.py
Normal file
129
app/api/middleware/rate_limit.py
Normal file
@@ -0,0 +1,129 @@
|
|||||||
|
"""Tier-aware rate limiting middleware.
|
||||||
|
|
||||||
|
Uses a per-user sliding-window counter (in-process, no Redis required).
|
||||||
|
The ``slowapi`` Limiter is also exported for optional route-level decoration.
|
||||||
|
|
||||||
|
Limits (requests per minute):
|
||||||
|
- free: 20
|
||||||
|
- pro: 60
|
||||||
|
- power: 120
|
||||||
|
- team: 200
|
||||||
|
|
||||||
|
Exempt paths bypass the limiter entirely:
|
||||||
|
- POST /api/v1/auth/register
|
||||||
|
- POST /api/v1/auth/login
|
||||||
|
- POST /api/v1/billing/webhook
|
||||||
|
- GET /api/v1/health
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import time
|
||||||
|
from collections import defaultdict
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from slowapi import Limiter
|
||||||
|
from slowapi.util import get_remote_address
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
_TIER_LIMITS: dict[str, int] = {
|
||||||
|
"free": 20,
|
||||||
|
"pro": 60,
|
||||||
|
"power": 120,
|
||||||
|
"team": 200,
|
||||||
|
}
|
||||||
|
|
||||||
|
_EXEMPT_PATHS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"/api/v1/auth/register",
|
||||||
|
"/api/v1/auth/login",
|
||||||
|
"/api/v1/billing/webhook",
|
||||||
|
"/api/v1/health",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_user_id_from_jwt(request: Request) -> str:
|
||||||
|
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
token = auth.removeprefix("Bearer ").strip()
|
||||||
|
if not token:
|
||||||
|
return get_remote_address(request)
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
return payload.get("sub") or get_remote_address(request)
|
||||||
|
except JWTError:
|
||||||
|
return get_remote_address(request)
|
||||||
|
|
||||||
|
|
||||||
|
# Exported Limiter instance — available for optional route-level decoration.
|
||||||
|
limiter = Limiter(key_func=_get_user_id_from_jwt)
|
||||||
|
|
||||||
|
|
||||||
|
class TierRateLimitMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Sliding-window rate limiter applied globally across all non-exempt routes.
|
||||||
|
|
||||||
|
Each authenticated user gets their own 60-second window sized by tier.
|
||||||
|
Unauthenticated requests pass through (the auth dependency will reject them
|
||||||
|
with 401 before the route handler runs).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
# user_id → list of request timestamps (float, seconds since epoch)
|
||||||
|
self._window: dict[str, list[float]] = defaultdict(list)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||||
|
if request.url.path in _EXEMPT_PATHS:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
|
||||||
|
auth = request.headers.get("Authorization", "")
|
||||||
|
token = auth.removeprefix("Bearer ").strip()
|
||||||
|
if not token:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
user_id: str = payload.get("sub") or get_remote_address(request)
|
||||||
|
tier: str = payload.get("tier", "free")
|
||||||
|
except JWTError:
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
|
||||||
|
now = time.monotonic()
|
||||||
|
window_start = now - 60.0
|
||||||
|
|
||||||
|
# Slide the window: discard timestamps older than 60 seconds.
|
||||||
|
timestamps = [t for t in self._window[user_id] if t > window_start]
|
||||||
|
|
||||||
|
if len(timestamps) >= limit:
|
||||||
|
retry_after = max(1, int(60 - (now - min(timestamps))))
|
||||||
|
return Response(
|
||||||
|
content=json.dumps(
|
||||||
|
{
|
||||||
|
"detail": (
|
||||||
|
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
|
||||||
|
f"Retry in {retry_after}s."
|
||||||
|
)
|
||||||
|
}
|
||||||
|
),
|
||||||
|
status_code=429,
|
||||||
|
headers={
|
||||||
|
"Retry-After": str(retry_after),
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
timestamps.append(now)
|
||||||
|
self._window[user_id] = timestamps
|
||||||
|
return await call_next(request)
|
||||||
139
app/api/middleware/sanitizer.py
Normal file
139
app/api/middleware/sanitizer.py
Normal file
@@ -0,0 +1,139 @@
|
|||||||
|
"""Response sanitizer middleware.
|
||||||
|
|
||||||
|
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
|
||||||
|
that could reveal server-side prompt IP:
|
||||||
|
- System prompt openers ("You are a/an/the …")
|
||||||
|
- Agent routing metadata ("Available agents:", "intent classifier", …)
|
||||||
|
- LangChain tool schema fragments (``"type": "function"``)
|
||||||
|
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
|
||||||
|
- Exact-match known prompt fingerprints
|
||||||
|
|
||||||
|
Binary responses (storage blobs, backup data) are never touched — the
|
||||||
|
middleware only activates for paths under /api/v1/chat.
|
||||||
|
|
||||||
|
Any sanitisation event is logged as a WARNING with the request path and the
|
||||||
|
names of the fields that were modified.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import re
|
||||||
|
|
||||||
|
from fastapi import Request, Response
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.types import ASGIApp
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Detection patterns — order matters: fingerprints checked first (exact),
|
||||||
|
# then compiled regexes.
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_FINGERPRINTS: tuple[str, ...] = (
|
||||||
|
"You are an intent classifier",
|
||||||
|
"Respond with just the agent name",
|
||||||
|
"Summarize these agent results",
|
||||||
|
"Available agents:",
|
||||||
|
"route to:",
|
||||||
|
)
|
||||||
|
|
||||||
|
_PATTERNS: tuple[re.Pattern[str], ...] = (
|
||||||
|
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
|
||||||
|
re.compile(r"Available agents\s*:", re.IGNORECASE),
|
||||||
|
re.compile(r"\bintent classifier\b", re.IGNORECASE),
|
||||||
|
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
|
||||||
|
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
|
||||||
|
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
|
||||||
|
re.compile(r"route\s+to\s*:", re.IGNORECASE),
|
||||||
|
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_text(text: str) -> tuple[str, bool]:
|
||||||
|
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
|
||||||
|
|
||||||
|
Returns ``(cleaned_text, was_changed)``.
|
||||||
|
"""
|
||||||
|
# Fingerprint check — if any exact phrase is present, redact the whole string.
|
||||||
|
for fp in _FINGERPRINTS:
|
||||||
|
if fp in text:
|
||||||
|
return "[REDACTED]", True
|
||||||
|
|
||||||
|
changed = False
|
||||||
|
for pattern in _PATTERNS:
|
||||||
|
new_text, n = pattern.subn("[REDACTED]", text)
|
||||||
|
if n:
|
||||||
|
text = new_text
|
||||||
|
changed = True
|
||||||
|
|
||||||
|
return text, changed
|
||||||
|
|
||||||
|
|
||||||
|
class SanitizerMiddleware(BaseHTTPMiddleware):
|
||||||
|
"""Strip prompt IP from /api/v1/chat JSON responses."""
|
||||||
|
|
||||||
|
def __init__(self, app: ASGIApp) -> None:
|
||||||
|
super().__init__(app)
|
||||||
|
|
||||||
|
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
|
||||||
|
response: Response = await call_next(request)
|
||||||
|
|
||||||
|
# Only process chat endpoint responses.
|
||||||
|
if not request.url.path.startswith("/api/v1/chat"):
|
||||||
|
return response
|
||||||
|
|
||||||
|
# Read body — collect streaming chunks.
|
||||||
|
body_bytes = b""
|
||||||
|
async for chunk in response.body_iterator:
|
||||||
|
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
|
||||||
|
|
||||||
|
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
|
||||||
|
try:
|
||||||
|
body = json.loads(body_bytes.decode("utf-8"))
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||||
|
return Response(
|
||||||
|
content=body_bytes,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
media_type=response.media_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
if not isinstance(body, dict):
|
||||||
|
return Response(
|
||||||
|
content=body_bytes,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=dict(response.headers),
|
||||||
|
media_type=response.media_type,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Walk top-level string fields and sanitise.
|
||||||
|
sanitised_fields: list[str] = []
|
||||||
|
for key, value in body.items():
|
||||||
|
if isinstance(value, str):
|
||||||
|
cleaned, changed = _sanitize_text(value)
|
||||||
|
if changed:
|
||||||
|
body[key] = cleaned
|
||||||
|
sanitised_fields.append(key)
|
||||||
|
|
||||||
|
if sanitised_fields:
|
||||||
|
logger.warning(
|
||||||
|
"Sanitizer redacted prompt fragments",
|
||||||
|
extra={
|
||||||
|
"path": request.url.path,
|
||||||
|
"fields": sanitised_fields,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
new_body = json.dumps(body).encode("utf-8")
|
||||||
|
headers = dict(response.headers)
|
||||||
|
headers["content-length"] = str(len(new_body))
|
||||||
|
|
||||||
|
return Response(
|
||||||
|
content=new_body,
|
||||||
|
status_code=response.status_code,
|
||||||
|
headers=headers,
|
||||||
|
media_type="application/json",
|
||||||
|
)
|
||||||
317
app/api/routes/agent_setup.py
Normal file
317
app/api/routes/agent_setup.py
Normal file
@@ -0,0 +1,317 @@
|
|||||||
|
"""Chatbot Journey endpoints — guided conversation to build an agent prompt_template.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
POST /agents/journey/start — start a new journey session
|
||||||
|
POST /agents/journey/message — continue the conversation
|
||||||
|
|
||||||
|
Sessions are stored in-memory with a 30-minute TTL. Stale entries are
|
||||||
|
cleaned up lazily on access. Upgrade to Redis for multi-instance deployments.
|
||||||
|
|
||||||
|
Journey flow:
|
||||||
|
1. Client sends ``{ agent_type, agent_id? }`` to ``/start``.
|
||||||
|
2. Server creates a session, calls the LLM with a contextual system prompt,
|
||||||
|
and returns the first question.
|
||||||
|
3. Client sends follow-up messages to ``/message``.
|
||||||
|
4. After 3-5 turns the LLM wraps up by emitting a ``prompt_template`` block
|
||||||
|
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||||
|
5. Server parses the block, sets ``done=True``, and returns the template.
|
||||||
|
|
||||||
|
The ``prompt_template`` from the final response is meant to be stored in
|
||||||
|
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
||||||
|
by the Electron client (via the agent CRUD endpoints).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, status
|
||||||
|
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import CloudAgentConfig, LocalAgentConfig
|
||||||
|
from app.schemas import (
|
||||||
|
JourneyMessageRequest,
|
||||||
|
JourneyResponse,
|
||||||
|
JourneyStartRequest,
|
||||||
|
UserProfile,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents/journey", tags=["agents"])
|
||||||
|
|
||||||
|
# ── Session TTL ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
||||||
|
|
||||||
|
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
||||||
|
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
||||||
|
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
||||||
|
|
||||||
|
# Maximum number of conversation turns before the LLM is nudged to wrap up.
|
||||||
|
_MAX_TURNS: int = 5
|
||||||
|
|
||||||
|
# ── In-memory session store ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class _JourneySession:
|
||||||
|
session_id: str
|
||||||
|
user_id: str
|
||||||
|
agent_type: str # "local" | "cloud"
|
||||||
|
history: list[dict[str, Any]] = field(default_factory=list)
|
||||||
|
created_at: float = field(default_factory=time.monotonic)
|
||||||
|
|
||||||
|
def is_expired(self) -> bool:
|
||||||
|
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
||||||
|
|
||||||
|
|
||||||
|
# session_id → session
|
||||||
|
_sessions: dict[str, _JourneySession] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_session(session_id: str, user_id: str) -> _JourneySession:
|
||||||
|
"""Retrieve session; raise 404 on missing, expired, or wrong owner."""
|
||||||
|
s = _sessions.get(session_id)
|
||||||
|
if s is None or s.is_expired():
|
||||||
|
_sessions.pop(session_id, None)
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
|
if s.user_id != user_id:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Journey session not found or expired")
|
||||||
|
return s
|
||||||
|
|
||||||
|
|
||||||
|
# ── System prompt builder ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_LOCAL_PREAMBLE = """\
|
||||||
|
What kind of files are in the directories you want to monitor? \
|
||||||
|
(for example: emails saved as .eml, documents in .pdf or .txt, markdown notes, etc.)"""
|
||||||
|
|
||||||
|
_CLOUD_PREAMBLE = """\
|
||||||
|
What kind of emails or messages should I look for? \
|
||||||
|
(for example: client communications, invoices, meeting notes, project updates, etc.)"""
|
||||||
|
|
||||||
|
_SYSTEM_PROMPT_TEMPLATE = """\
|
||||||
|
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
||||||
|
Your job is to understand exactly what data the user wants to extract from their {source_description} \
|
||||||
|
and produce a detailed prompt_template that a separate AI will use as its instruction set.
|
||||||
|
|
||||||
|
Ask concise, focused questions one at a time. Cover these topics (not necessarily in this order):
|
||||||
|
1. The type and format of the source content.
|
||||||
|
2. Which data types to extract: tasks, notes, timelines, and/or projects.
|
||||||
|
3. How fields should be mapped (e.g. email subject → task title).
|
||||||
|
4. Priority or status rules (e.g. "urgent" keyword → high priority).
|
||||||
|
5. Any special handling, date extraction, or exclusions.
|
||||||
|
|
||||||
|
After 3-5 questions (when you have enough information), output the final prompt_template between \
|
||||||
|
these exact markers on their own lines:
|
||||||
|
|
||||||
|
{template_start}
|
||||||
|
<the complete extraction prompt here>
|
||||||
|
{template_end}
|
||||||
|
|
||||||
|
The prompt_template must be a self-contained instruction for an AI that receives a document/email/message \
|
||||||
|
and must return a JSON array of records in this shape:
|
||||||
|
[{{ "table": "<tasks|notes|timelines|projects>", "data": {{ <field: value> }} }}, ...]
|
||||||
|
|
||||||
|
Rules for the generated template:
|
||||||
|
- Be explicit about field names (camelCase: title, status, priority, dueDate, projectId, content, etc.).
|
||||||
|
- Include concrete examples of mappings.
|
||||||
|
- Mention that Electron adds id/createdAt/updatedAt automatically.
|
||||||
|
- Set isAiSuggested: true and isApproved: false on every record.
|
||||||
|
{existing_section}\
|
||||||
|
Do not ask more than {max_turns} questions total. Start with your first question now.\
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _build_system_prompt(agent_type: str, existing_template: str | None) -> str:
|
||||||
|
source_description = (
|
||||||
|
"files in local directories" if agent_type == "local" else "emails and messages from cloud providers"
|
||||||
|
)
|
||||||
|
existing_section = (
|
||||||
|
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
||||||
|
f"---\n{existing_template}\n---\n"
|
||||||
|
if existing_template
|
||||||
|
else ""
|
||||||
|
)
|
||||||
|
return _SYSTEM_PROMPT_TEMPLATE.format(
|
||||||
|
source_description=source_description,
|
||||||
|
template_start=_TEMPLATE_START,
|
||||||
|
template_end=_TEMPLATE_END,
|
||||||
|
existing_section=existing_section,
|
||||||
|
max_turns=_MAX_TURNS,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _first_question(agent_type: str) -> str:
|
||||||
|
return _LOCAL_PREAMBLE if agent_type == "local" else _CLOUD_PREAMBLE
|
||||||
|
|
||||||
|
|
||||||
|
# ── Template extraction ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _extract_template(text: str) -> str | None:
|
||||||
|
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
||||||
|
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
||||||
|
return None
|
||||||
|
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
||||||
|
end_idx = text.index(_TEMPLATE_END)
|
||||||
|
return text[start_idx:end_idx].strip() or None
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM call ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _call_llm(system_prompt: str, history: list[dict[str, Any]]) -> str:
|
||||||
|
"""Build LangChain messages from history and invoke the LLM."""
|
||||||
|
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
||||||
|
for turn in history:
|
||||||
|
if turn["role"] == "user":
|
||||||
|
messages.append(HumanMessage(content=turn["content"]))
|
||||||
|
else:
|
||||||
|
messages.append(AIMessage(content=turn["content"]))
|
||||||
|
|
||||||
|
llm = get_llm(model=None, temperature=0.4)
|
||||||
|
response = await llm.ainvoke(messages)
|
||||||
|
return response.content # type: ignore[return-value]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Existing-config loader ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _load_existing_template(
|
||||||
|
agent_id: str,
|
||||||
|
user_id: str,
|
||||||
|
db: AsyncSession,
|
||||||
|
) -> str | None:
|
||||||
|
"""Return the prompt_template of an existing agent config, or None."""
|
||||||
|
# Try local first, then cloud.
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local = local_result.scalar_one_or_none()
|
||||||
|
if local is not None:
|
||||||
|
return local.prompt_template
|
||||||
|
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud = cloud_result.scalar_one_or_none()
|
||||||
|
return cloud.prompt_template if cloud is not None else None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/start", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
|
async def start_journey(
|
||||||
|
body: JourneyStartRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Start a new Chatbot Journey session.
|
||||||
|
|
||||||
|
If ``agent_id`` is provided the session is pre-seeded with the existing
|
||||||
|
agent's ``prompt_template`` so the user can refine it.
|
||||||
|
"""
|
||||||
|
# Load existing template (may be None).
|
||||||
|
existing_template: str | None = None
|
||||||
|
if body.agent_id:
|
||||||
|
existing_template = await _load_existing_template(body.agent_id, current_user.id, db)
|
||||||
|
# If agent_id was given but not found, proceed without seeding (don't 404 —
|
||||||
|
# the user may be starting a fresh journey for a not-yet-persisted config).
|
||||||
|
|
||||||
|
system_prompt = _build_system_prompt(body.agent_type, existing_template)
|
||||||
|
first_question = _first_question(body.agent_type)
|
||||||
|
|
||||||
|
session_id = str(uuid.uuid4())
|
||||||
|
session = _JourneySession(
|
||||||
|
session_id=session_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
agent_type=body.agent_type,
|
||||||
|
# Seed history with the AI's first question so it stays consistent.
|
||||||
|
history=[{"role": "assistant", "content": first_question}],
|
||||||
|
)
|
||||||
|
# Store the system prompt inside the session for reuse in /message.
|
||||||
|
session.__dict__["_system_prompt"] = system_prompt # type: ignore[index]
|
||||||
|
_sessions[session_id] = session
|
||||||
|
|
||||||
|
logger.info("Journey session %s started for user %s (agent_type=%s)", session_id, current_user.id, body.agent_type)
|
||||||
|
return JourneyResponse(session_id=session_id, message=first_question, done=False)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/message", response_model=JourneyResponse, status_code=status.HTTP_200_OK)
|
||||||
|
async def send_journey_message(
|
||||||
|
body: JourneyMessageRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> JourneyResponse:
|
||||||
|
"""Send a message in an existing Chatbot Journey session.
|
||||||
|
|
||||||
|
The server appends the user's message to the conversation history,
|
||||||
|
calls the LLM, and appends the AI reply. When the LLM wraps up with a
|
||||||
|
``prompt_template`` block the response includes ``done=True`` and the
|
||||||
|
extracted template.
|
||||||
|
"""
|
||||||
|
session = _get_session(body.session_id, current_user.id)
|
||||||
|
system_prompt: str = session.__dict__.get("_system_prompt", _build_system_prompt(session.agent_type, None)) # type: ignore[assignment]
|
||||||
|
|
||||||
|
# Append user turn to history.
|
||||||
|
session.history.append({"role": "user", "content": body.message})
|
||||||
|
|
||||||
|
# Call the LLM with the full conversation so far.
|
||||||
|
ai_reply = await _call_llm(system_prompt, session.history)
|
||||||
|
|
||||||
|
# Append AI turn.
|
||||||
|
session.history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
|
||||||
|
# Check if the LLM produced the final template.
|
||||||
|
prompt_template = _extract_template(ai_reply)
|
||||||
|
done = prompt_template is not None
|
||||||
|
|
||||||
|
# Strip the sentinel markers from the message shown to the user.
|
||||||
|
display_message = ai_reply
|
||||||
|
if done:
|
||||||
|
display_message = (
|
||||||
|
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
||||||
|
or "Here is your agent configuration. You can save it or continue refining."
|
||||||
|
)
|
||||||
|
|
||||||
|
if done:
|
||||||
|
logger.info("Journey session %s completed for user %s", body.session_id, current_user.id)
|
||||||
|
# Clean up the session immediately on completion.
|
||||||
|
_sessions.pop(body.session_id, None)
|
||||||
|
else:
|
||||||
|
# Nudge the LLM to wrap up after max turns.
|
||||||
|
turns = sum(1 for t in session.history if t["role"] == "user")
|
||||||
|
if turns >= _MAX_TURNS:
|
||||||
|
# Add a system-level nudge as a hidden user message.
|
||||||
|
session.history.append({
|
||||||
|
"role": "user",
|
||||||
|
"content": (
|
||||||
|
"[System: You have enough information. Please generate the final "
|
||||||
|
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
||||||
|
),
|
||||||
|
})
|
||||||
|
|
||||||
|
return JourneyResponse(
|
||||||
|
session_id=body.session_id,
|
||||||
|
message=display_message,
|
||||||
|
done=done,
|
||||||
|
prompt_template=prompt_template,
|
||||||
|
)
|
||||||
452
app/api/routes/agents.py
Normal file
452
app/api/routes/agents.py
Normal file
@@ -0,0 +1,452 @@
|
|||||||
|
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
||||||
|
|
||||||
|
Endpoints:
|
||||||
|
GET /agents/catalog — hardcoded agent type catalog
|
||||||
|
GET /agents/local — list user's local agent configs
|
||||||
|
POST /agents/local — create local agent (tier-gated)
|
||||||
|
PUT /agents/local/{agent_id} — partial update (ownership check)
|
||||||
|
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
||||||
|
GET /agents/cloud — list user's cloud agent configs
|
||||||
|
POST /agents/cloud — create cloud agent (tier-gated)
|
||||||
|
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
||||||
|
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
||||||
|
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
||||||
|
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, or_, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import FEATURES
|
||||||
|
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
||||||
|
from app.core.device_manager import device_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
from app.schemas import (
|
||||||
|
AgentCatalogItem,
|
||||||
|
AgentRunLogResponse,
|
||||||
|
CloudAgentConfigCreate,
|
||||||
|
CloudAgentConfigResponse,
|
||||||
|
CloudAgentConfigUpdate,
|
||||||
|
LocalAgentConfigCreate,
|
||||||
|
LocalAgentConfigResponse,
|
||||||
|
LocalAgentConfigUpdate,
|
||||||
|
UserProfile,
|
||||||
|
)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Datetime helpers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _dt_ms(dt: datetime) -> int:
|
||||||
|
return int(dt.timestamp() * 1000)
|
||||||
|
|
||||||
|
|
||||||
|
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||||
|
return int(dt.timestamp() * 1000) if dt else None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Model → schema converters ─────────────────────────────────────────
|
||||||
|
|
||||||
|
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
||||||
|
return LocalAgentConfigResponse(
|
||||||
|
id=a.id,
|
||||||
|
name=a.name,
|
||||||
|
device_id=a.device_id,
|
||||||
|
directory_paths=a.directory_paths,
|
||||||
|
data_types=a.data_types,
|
||||||
|
prompt_template=a.prompt_template,
|
||||||
|
file_extensions=a.file_extensions,
|
||||||
|
schedule_cron=a.schedule_cron,
|
||||||
|
enabled=a.enabled,
|
||||||
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
||||||
|
return CloudAgentConfigResponse(
|
||||||
|
id=a.id,
|
||||||
|
provider=a.provider, # type: ignore[arg-type]
|
||||||
|
name=a.name,
|
||||||
|
data_types=a.data_types,
|
||||||
|
prompt_template=a.prompt_template,
|
||||||
|
schedule_cron=a.schedule_cron,
|
||||||
|
filter_config=a.filter_config,
|
||||||
|
enabled=a.enabled,
|
||||||
|
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||||
|
created_at=_dt_ms(a.created_at),
|
||||||
|
updated_at=_dt_ms(a.updated_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||||
|
return AgentRunLogResponse(
|
||||||
|
id=log.id,
|
||||||
|
agent_id=log.agent_id,
|
||||||
|
agent_type=log.agent_type, # type: ignore[arg-type]
|
||||||
|
status=log.status, # type: ignore[arg-type]
|
||||||
|
items_processed=log.items_processed,
|
||||||
|
items_created=log.items_created,
|
||||||
|
errors=log.errors or [],
|
||||||
|
started_at=_dt_ms(log.started_at),
|
||||||
|
completed_at=_dt_ms_opt(log.completed_at),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Ownership-checked lookups ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_local_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> LocalAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_cloud_agent_for_user(
|
||||||
|
agent_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> CloudAgentConfig:
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier limit helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return combined enabled local + cloud agent count for the user."""
|
||||||
|
local_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(LocalAgentConfig.id)).where(
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
LocalAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
cloud_count = (
|
||||||
|
await db.execute(
|
||||||
|
select(func.count(CloudAgentConfig.id)).where(
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
CloudAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
).scalar_one()
|
||||||
|
return local_count + cloud_count
|
||||||
|
|
||||||
|
|
||||||
|
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
||||||
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||||
|
if limit != -1 and current_count >= limit:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
||||||
|
|
||||||
|
class _RunsPage(BaseModel):
|
||||||
|
total: int
|
||||||
|
page: int
|
||||||
|
limit: int
|
||||||
|
items: list[AgentRunLogResponse]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Catalog ───────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/catalog", response_model=list[AgentCatalogItem])
|
||||||
|
async def get_agent_catalog(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> list[AgentCatalogItem]:
|
||||||
|
"""Return the static list of available agent types and their descriptions."""
|
||||||
|
return [
|
||||||
|
AgentCatalogItem(
|
||||||
|
type="local_directory",
|
||||||
|
name="Local Directory Monitor",
|
||||||
|
description="Watches local directories, extracts data from files using AI",
|
||||||
|
),
|
||||||
|
AgentCatalogItem(
|
||||||
|
type="gmail",
|
||||||
|
name="Gmail Connector",
|
||||||
|
description="Scans Gmail inbox, extracts tasks/notes from emails",
|
||||||
|
),
|
||||||
|
AgentCatalogItem(
|
||||||
|
type="teams",
|
||||||
|
name="Microsoft Teams Connector",
|
||||||
|
description="Monitors Teams messages, extracts action items",
|
||||||
|
),
|
||||||
|
AgentCatalogItem(
|
||||||
|
type="outlook",
|
||||||
|
name="Outlook Connector",
|
||||||
|
description="Scans Outlook inbox, extracts tasks/notes",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent CRUD ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
||||||
|
async def list_local_agents(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[LocalAgentConfigResponse]:
|
||||||
|
"""List all local directory agent configs owned by the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
return [_to_local_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_local_agent(
|
||||||
|
body: LocalAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Create a new local directory agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = LocalAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
name=body.name,
|
||||||
|
device_id=body.device_id,
|
||||||
|
directory_paths=body.directory_paths,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
file_extensions=body.file_extensions,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
||||||
|
async def update_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: LocalAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> LocalAgentConfigResponse:
|
||||||
|
"""Partially update a local agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_local_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/local/{agent_id}", response_model=dict)
|
||||||
|
async def delete_local_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
||||||
|
async def list_cloud_agents(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[CloudAgentConfigResponse]:
|
||||||
|
"""List all cloud connector agent configs owned by the authenticated user."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
||||||
|
)
|
||||||
|
return [_to_cloud_response(a) for a in result.scalars().all()]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_cloud_agent(
|
||||||
|
body: CloudAgentConfigCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Create a new cloud connector agent config.
|
||||||
|
|
||||||
|
The combined count of enabled local and cloud agents for the user is
|
||||||
|
checked against the ``batch_active`` limit for their billing tier.
|
||||||
|
"""
|
||||||
|
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||||
|
agent = CloudAgentConfig(
|
||||||
|
user_id=current_user.id,
|
||||||
|
provider=body.provider,
|
||||||
|
name=body.name,
|
||||||
|
data_types=body.data_types,
|
||||||
|
prompt_template=body.prompt_template,
|
||||||
|
oauth_token_encrypted=body.oauth_token_encrypted,
|
||||||
|
schedule_cron=body.schedule_cron,
|
||||||
|
filter_config=body.filter_config,
|
||||||
|
)
|
||||||
|
db.add(agent)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
||||||
|
async def update_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
body: CloudAgentConfigUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> CloudAgentConfigResponse:
|
||||||
|
"""Partially update a cloud agent config. Only provided fields are changed."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
for field, value in body.model_dump(exclude_unset=True).items():
|
||||||
|
setattr(agent, field, value)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(agent)
|
||||||
|
return _to_cloud_response(agent)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/cloud/{agent_id}", response_model=dict)
|
||||||
|
async def delete_cloud_agent(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
||||||
|
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||||
|
await db.delete(agent)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Run logs ──────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("/runs", response_model=_RunsPage)
|
||||||
|
async def list_run_logs(
|
||||||
|
agent_id: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=20, ge=1, le=100),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _RunsPage:
|
||||||
|
"""Return paginated run logs for the authenticated user.
|
||||||
|
|
||||||
|
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
||||||
|
"""
|
||||||
|
base_filter = [AgentRunLog.user_id == current_user.id]
|
||||||
|
if agent_id:
|
||||||
|
base_filter.append(AgentRunLog.agent_id == agent_id)
|
||||||
|
|
||||||
|
total = (
|
||||||
|
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
||||||
|
).scalar_one()
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(AgentRunLog)
|
||||||
|
.where(*base_filter)
|
||||||
|
.order_by(AgentRunLog.started_at.desc())
|
||||||
|
.offset((page - 1) * limit)
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
||||||
|
|
||||||
|
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Manual trigger stub ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||||
|
async def trigger_agent_run(
|
||||||
|
agent_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> AgentRunLogResponse:
|
||||||
|
"""Manually trigger an agent run.
|
||||||
|
|
||||||
|
Looks up the agent config (local or cloud) by ID with ownership check,
|
||||||
|
creates a run log entry with ``status="running"``, and returns it.
|
||||||
|
|
||||||
|
Actual dispatch to the agent runner is wired in Step 3.4 once
|
||||||
|
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
||||||
|
"""
|
||||||
|
# Determine agent type by trying local first, then cloud.
|
||||||
|
# Keep the full config object so we can pass it to the agent runner.
|
||||||
|
local_config: LocalAgentConfig | None = None
|
||||||
|
cloud_config: CloudAgentConfig | None = None
|
||||||
|
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.id == agent_id,
|
||||||
|
LocalAgentConfig.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local_config = local_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if local_config is not None:
|
||||||
|
agent_type = "local"
|
||||||
|
else:
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.id == agent_id,
|
||||||
|
CloudAgentConfig.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud_config = cloud_result.scalar_one_or_none()
|
||||||
|
if cloud_config is not None:
|
||||||
|
agent_type = "cloud"
|
||||||
|
else:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
|
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
user_id=current_user.id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
|
||||||
|
# Dispatch the run as a background task — returns 202 immediately.
|
||||||
|
if agent_type == "local" and local_config is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
||||||
|
)
|
||||||
|
elif agent_type == "cloud" and cloud_config is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
||||||
|
)
|
||||||
|
|
||||||
|
return _to_run_log_response(run_log)
|
||||||
@@ -1,6 +1,8 @@
|
|||||||
"""Auth routes: register, login, refresh, me.
|
"""Auth routes: register, login, refresh, me.
|
||||||
|
|
||||||
Extracted from app/api/routes/auth.py — uses shared.* imports instead of app.*.
|
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens
|
||||||
|
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
|
||||||
|
SHA-256 hashes so plaintext never reaches the DB.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -18,13 +20,11 @@ from pydantic import BaseModel
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from shared.config import settings
|
from app.api.deps import get_current_user
|
||||||
from shared.db import get_session
|
from app.config.settings import settings
|
||||||
from shared.models import RefreshToken, Subscription, User
|
from app.db import get_session
|
||||||
from shared.schemas import AuthTokens, UserProfile
|
from app.models import RefreshToken, User
|
||||||
|
from app.schemas import AuthTokens, UserProfile
|
||||||
from app.config import auth_settings
|
|
||||||
from app.deps import get_current_user
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/auth", tags=["auth"])
|
router = APIRouter(prefix="/auth", tags=["auth"])
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ def _hash_token(plain_token: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
||||||
"""Return (RS256-signed JWT, expires_at_ms)."""
|
"""Return (signed JWT, expires_at_ms)."""
|
||||||
now = int(time.time())
|
now = int(time.time())
|
||||||
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
|
||||||
payload = {
|
payload = {
|
||||||
@@ -56,19 +56,10 @@ def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
|
|||||||
"exp": exp,
|
"exp": exp,
|
||||||
"iat": now,
|
"iat": now,
|
||||||
}
|
}
|
||||||
token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
|
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM)
|
||||||
return token, exp * 1000 # ms for client
|
return token, exp * 1000 # ms for client
|
||||||
|
|
||||||
|
|
||||||
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
|
|
||||||
"""Fetch authoritative tier from subscriptions table."""
|
|
||||||
result = await db.execute(
|
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
||||||
)
|
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
|
||||||
return result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ────────────────────────────────────────────────────
|
# ── Request bodies ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -88,11 +79,6 @@ class _RefreshRequest(BaseModel):
|
|||||||
refresh_token: str
|
refresh_token: str
|
||||||
|
|
||||||
|
|
||||||
class _UpdateProfileRequest(BaseModel):
|
|
||||||
name: str | None = None
|
|
||||||
surname: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Routes ────────────────────────────────────────────────────────────
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@@ -116,7 +102,7 @@ async def register(
|
|||||||
encryption_key=Fernet.generate_key().decode(),
|
encryption_key=Fernet.generate_key().decode(),
|
||||||
)
|
)
|
||||||
db.add(user)
|
db.add(user)
|
||||||
await db.flush()
|
await db.flush() # get user.id without committing
|
||||||
|
|
||||||
plain_token = str(uuid.uuid4())
|
plain_token = str(uuid.uuid4())
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
@@ -149,9 +135,6 @@ async def login(
|
|||||||
if user is None or not _verify_password(body.password, user.password_hash):
|
if user is None or not _verify_password(body.password, user.password_hash):
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
|
||||||
|
|
||||||
# Fetch live tier for the JWT claim
|
|
||||||
tier = await _get_live_tier(db, user.id)
|
|
||||||
|
|
||||||
plain_token = str(uuid.uuid4())
|
plain_token = str(uuid.uuid4())
|
||||||
expires_at = datetime.now(timezone.utc) + timedelta(
|
expires_at = datetime.now(timezone.utc) + timedelta(
|
||||||
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
|
||||||
@@ -164,7 +147,7 @@ async def login(
|
|||||||
db.add(rt)
|
db.add(rt)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
return AuthTokens(
|
return AuthTokens(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
refresh_token=plain_token,
|
refresh_token=plain_token,
|
||||||
@@ -188,6 +171,7 @@ async def refresh(
|
|||||||
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
|
||||||
|
|
||||||
|
# Rotate: delete old token, issue new one.
|
||||||
await db.delete(rt)
|
await db.delete(rt)
|
||||||
|
|
||||||
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
user_result = await db.execute(select(User).where(User.id == rt.user_id))
|
||||||
@@ -195,9 +179,6 @@ async def refresh(
|
|||||||
if user is None:
|
if user is None:
|
||||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
|
||||||
|
|
||||||
# Fetch live tier for the new JWT
|
|
||||||
tier = await _get_live_tier(db, user.id)
|
|
||||||
|
|
||||||
plain_token = str(uuid.uuid4())
|
plain_token = str(uuid.uuid4())
|
||||||
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
|
||||||
new_rt = RefreshToken(
|
new_rt = RefreshToken(
|
||||||
@@ -208,7 +189,7 @@ async def refresh(
|
|||||||
db.add(new_rt)
|
db.add(new_rt)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
|
|
||||||
access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
|
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier)
|
||||||
return AuthTokens(
|
return AuthTokens(
|
||||||
access_token=access_token,
|
access_token=access_token,
|
||||||
refresh_token=plain_token,
|
refresh_token=plain_token,
|
||||||
@@ -216,6 +197,11 @@ async def refresh(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _UpdateProfileRequest(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
surname: str | None = None
|
||||||
|
|
||||||
|
|
||||||
@router.get("/me", response_model=UserProfile)
|
@router.get("/me", response_model=UserProfile)
|
||||||
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
|
||||||
"""Return the profile for the authenticated user."""
|
"""Return the profile for the authenticated user."""
|
||||||
171
app/api/routes/backup.py
Normal file
171
app/api/routes/backup.py
Normal file
@@ -0,0 +1,171 @@
|
|||||||
|
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
|
||||||
|
PostgreSQL ``backup_metadata`` table.
|
||||||
|
|
||||||
|
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
|
||||||
|
treating "history" as a ``{backup_id}`` path parameter.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from email.utils import parsedate_to_datetime
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import BackupMetadata as BackupMetadataModel
|
||||||
|
from app.schemas import BackupMetadata, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/backup", tags=["backup"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
|
||||||
|
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return total backup bytes stored by *user_id*."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
|
||||||
|
BackupMetadataModel.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_backup_quota(
|
||||||
|
user: UserProfile, size_bytes: int, db: AsyncSession
|
||||||
|
) -> None:
|
||||||
|
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
|
||||||
|
current = await _current_backup_bytes(user.id, db)
|
||||||
|
tier_manager.enforce_backup_quota(
|
||||||
|
user.tier, current_bytes=current, additional_bytes=size_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("")
|
||||||
|
async def upload_backup(
|
||||||
|
request: Request,
|
||||||
|
x_backup_version: int = Header(..., alias="X-Backup-Version"),
|
||||||
|
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
|
||||||
|
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Upload an E2E-encrypted backup blob.
|
||||||
|
|
||||||
|
Metadata is passed via custom headers; the raw body is the encrypted blob.
|
||||||
|
"""
|
||||||
|
blob = await request.body()
|
||||||
|
reject_if_tampered(blob, x_backup_checksum)
|
||||||
|
await _check_backup_quota(current_user, len(blob), db)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
row = BackupMetadataModel(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=current_user.id,
|
||||||
|
s3_key=s3_key,
|
||||||
|
version=x_backup_version,
|
||||||
|
timestamp=x_backup_timestamp,
|
||||||
|
checksum=x_backup_checksum,
|
||||||
|
size_bytes=len(blob),
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/history", response_model=list[BackupMetadata])
|
||||||
|
async def backup_history(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[BackupMetadata]:
|
||||||
|
"""Return backup metadata records for the authenticated user (no blob bytes)."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
return [
|
||||||
|
BackupMetadata(
|
||||||
|
version=r.version,
|
||||||
|
timestamp=r.timestamp,
|
||||||
|
checksum=r.checksum,
|
||||||
|
chunk_count=1,
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("")
|
||||||
|
async def download_backup(
|
||||||
|
request: Request,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> Response:
|
||||||
|
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel)
|
||||||
|
.where(BackupMetadataModel.user_id == current_user.id)
|
||||||
|
.order_by(BackupMetadataModel.timestamp.desc())
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
latest = result.scalar_one_or_none()
|
||||||
|
if latest is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
|
||||||
|
|
||||||
|
ims_header = request.headers.get("If-Modified-Since")
|
||||||
|
if ims_header:
|
||||||
|
try:
|
||||||
|
ims_dt = parsedate_to_datetime(ims_header)
|
||||||
|
ims_ms = int(ims_dt.timestamp() * 1000)
|
||||||
|
if latest.timestamp <= ims_ms:
|
||||||
|
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
|
||||||
|
except Exception:
|
||||||
|
pass # malformed header — ignore and serve the blob
|
||||||
|
|
||||||
|
blob = await _blob_store.download(current_user.id, latest.s3_key)
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={
|
||||||
|
"X-Backup-Version": str(latest.version),
|
||||||
|
"X-Backup-Timestamp": str(latest.timestamp),
|
||||||
|
"X-Checksum": latest.checksum,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{backup_id}", response_model=dict)
|
||||||
|
async def delete_backup(
|
||||||
|
backup_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a specific backup by ID."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(BackupMetadataModel).where(
|
||||||
|
BackupMetadataModel.id == backup_id,
|
||||||
|
BackupMetadataModel.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
target = result.scalar_one_or_none()
|
||||||
|
if target is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
|
||||||
|
|
||||||
|
await _blob_store.delete(current_user.id, target.s3_key)
|
||||||
|
await db.delete(target)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
85
app/api/routes/billing.py
Normal file
85
app/api/routes/billing.py
Normal file
@@ -0,0 +1,85 @@
|
|||||||
|
"""Billing routes: Stripe checkout, webhook, subscription management.
|
||||||
|
|
||||||
|
Business logic lives in ``app.billing.stripe_service.StripeService``.
|
||||||
|
The route layer handles HTTP concerns (request parsing, response shaping)
|
||||||
|
and delegates everything else to the service singleton.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, Header, Request, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.stripe_service import stripe_service
|
||||||
|
from app.db import get_session
|
||||||
|
from app.schemas import BillingTier, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/billing", tags=["billing"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Request bodies ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CheckoutRequest(BaseModel):
|
||||||
|
tier: BillingTier
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/checkout", response_model=dict)
|
||||||
|
async def create_checkout(
|
||||||
|
body: _CheckoutRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, str]:
|
||||||
|
"""Create a Stripe checkout session for a tier upgrade.
|
||||||
|
|
||||||
|
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
|
||||||
|
"""
|
||||||
|
url = stripe_service.create_checkout_session(current_user.id, body.tier)
|
||||||
|
return {"checkout_url": url}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/webhook", response_model=dict)
|
||||||
|
async def stripe_webhook(
|
||||||
|
request: Request,
|
||||||
|
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Handle Stripe webhook events.
|
||||||
|
|
||||||
|
No JWT auth — authenticated via Stripe signature verification instead.
|
||||||
|
Returns 200 immediately when Stripe is not configured (local dev).
|
||||||
|
"""
|
||||||
|
payload = await request.body()
|
||||||
|
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/subscription", response_model=dict)
|
||||||
|
async def get_subscription(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return the current subscription info for the authenticated user."""
|
||||||
|
sub = await stripe_service.get_subscription(current_user.id, db)
|
||||||
|
if sub is None:
|
||||||
|
return {
|
||||||
|
"tier": current_user.tier,
|
||||||
|
"status": "free",
|
||||||
|
"stripe_subscription_id": None,
|
||||||
|
"current_period_end": None,
|
||||||
|
}
|
||||||
|
return sub
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
|
||||||
|
async def cancel_subscription(
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Cancel the active subscription."""
|
||||||
|
await stripe_service.cancel_subscription(current_user.id, db)
|
||||||
|
return {"ok": True}
|
||||||
42
app/api/routes/chat.py
Normal file
42
app/api/routes/chat.py
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
"""Chat routes: POST /chat (REST fallback).
|
||||||
|
|
||||||
|
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from fastapi.responses import JSONResponse
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.deep_agent import run_home
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.db import async_session
|
||||||
|
from app.schemas import ChatRequest, ChatResponse, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/chat", tags=["chat"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("")
|
||||||
|
async def chat(
|
||||||
|
body: ChatRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> JSONResponse:
|
||||||
|
"""Route a chat message through the Home deep agent (non-streaming)."""
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(current_user.id, body.message)
|
||||||
|
|
||||||
|
context = {
|
||||||
|
**body.context.model_dump(),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
response_text = await run_home(
|
||||||
|
user_id=current_user.id,
|
||||||
|
message=body.message,
|
||||||
|
context=context,
|
||||||
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
|
result = ChatResponse(response=response_text)
|
||||||
|
return JSONResponse(content=result.model_dump())
|
||||||
351
app/api/routes/device_ws.py
Normal file
351
app/api/routes/device_ws.py
Normal file
@@ -0,0 +1,351 @@
|
|||||||
|
"""Device WebSocket endpoint.
|
||||||
|
|
||||||
|
Persistent connection from Electron devices to the backend.
|
||||||
|
|
||||||
|
WS /api/v1/ws/device?token=<jwt>
|
||||||
|
|
||||||
|
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
|
||||||
|
available during the WebSocket handshake).
|
||||||
|
|
||||||
|
Protocol:
|
||||||
|
1. Client connects → JWT validated → connection accepted.
|
||||||
|
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
|
||||||
|
3. Backend registers the connection in ``DeviceConnectionManager``.
|
||||||
|
4. Session enters message dispatch loop + heartbeat.
|
||||||
|
|
||||||
|
Incoming frame dispatch:
|
||||||
|
- ``tool_result`` → resolves a pending tool-call Future.
|
||||||
|
- ``agent_data`` → enqueued in the per-run agent data queue.
|
||||||
|
- ``agent_complete`` → sends None sentinel to close the queue stream.
|
||||||
|
- ``pong`` → heartbeat acknowledgement (updates last-seen).
|
||||||
|
- unknown types → logged, ignored.
|
||||||
|
|
||||||
|
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
|
||||||
|
|
||||||
|
On disconnect:
|
||||||
|
- Unregisters from DeviceConnectionManager.
|
||||||
|
- Marks all in-progress AgentRunLog rows for this user as ``error``
|
||||||
|
with message "device disconnected".
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
|
||||||
|
from jose import JWTError, jwt
|
||||||
|
from sqlalchemy import update
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
|
from app.core.device_manager import device_manager
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
from app.core.deep_agent import run_home_stream, run_floating_stream
|
||||||
|
from app.core.output_formatter import HomeFormatter, FloatingFormatter
|
||||||
|
from app.core.ws_context import clear_client_executor, set_client_executor
|
||||||
|
from app.db import async_session
|
||||||
|
from app.models import AgentRunLog
|
||||||
|
from app.schemas import WsFrameType
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/ws", tags=["device-ws"])
|
||||||
|
|
||||||
|
_HEARTBEAT_INTERVAL = 30 # seconds
|
||||||
|
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
|
||||||
|
|
||||||
|
|
||||||
|
@router.websocket("/device")
|
||||||
|
async def device_ws(websocket: WebSocket) -> None:
|
||||||
|
"""Persistent WebSocket endpoint for Electron device connections.
|
||||||
|
|
||||||
|
Authentication is via ``?token=<jwt>`` query parameter.
|
||||||
|
"""
|
||||||
|
# ── 1. Authenticate before accepting ─────────────────────────────
|
||||||
|
token = websocket.query_params.get("token", "")
|
||||||
|
try:
|
||||||
|
payload = jwt.decode(
|
||||||
|
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
|
||||||
|
)
|
||||||
|
user_id: str | None = payload.get("sub")
|
||||||
|
if not user_id:
|
||||||
|
raise JWTError("missing sub")
|
||||||
|
except JWTError:
|
||||||
|
await websocket.close(code=1008) # Policy Violation
|
||||||
|
return
|
||||||
|
|
||||||
|
await websocket.accept()
|
||||||
|
|
||||||
|
# ── 2. Await device_hello frame ───────────────────────────────────
|
||||||
|
try:
|
||||||
|
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
|
||||||
|
except (asyncio.TimeoutError, WebSocketDisconnect):
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
hello = json.loads(raw)
|
||||||
|
if hello.get("type") != WsFrameType.device_hello:
|
||||||
|
raise ValueError("expected device_hello as first frame")
|
||||||
|
device_id: str = hello["device_id"]
|
||||||
|
agent_ids: list[str] = hello.get("agent_ids", [])
|
||||||
|
except (KeyError, ValueError, json.JSONDecodeError) as exc:
|
||||||
|
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
|
||||||
|
await websocket.close(code=1008)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Register connection ────────────────────────────────────────
|
||||||
|
device_manager.register(user_id, device_id, websocket)
|
||||||
|
logger.info(
|
||||||
|
"device_ws: connected user=%s device=%s agents=%s",
|
||||||
|
user_id,
|
||||||
|
device_id,
|
||||||
|
agent_ids,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Trigger any overdue agent runs now that the device is connected.
|
||||||
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||||
|
|
||||||
|
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||||
|
try:
|
||||||
|
await asyncio.gather(
|
||||||
|
_message_loop(websocket, user_id),
|
||||||
|
_heartbeat_loop(websocket),
|
||||||
|
)
|
||||||
|
except WebSocketDisconnect:
|
||||||
|
pass
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
|
||||||
|
finally:
|
||||||
|
device_manager.unregister(user_id)
|
||||||
|
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
|
||||||
|
await _mark_runs_disconnected(user_id)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Message dispatch loop ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||||
|
"""Receive frames from Electron and dispatch to the appropriate handler."""
|
||||||
|
async for raw in websocket.iter_text():
|
||||||
|
try:
|
||||||
|
frame: dict = json.loads(raw)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
logger.warning("device_ws: invalid JSON from user=%s", user_id)
|
||||||
|
continue
|
||||||
|
|
||||||
|
frame_type = frame.get("type")
|
||||||
|
|
||||||
|
if frame_type == WsFrameType.tool_result:
|
||||||
|
call_id = frame.get("id")
|
||||||
|
if call_id:
|
||||||
|
device_manager.resolve_pending_call(user_id, call_id, frame)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: tool_result missing id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_data:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
await queue.put(frame)
|
||||||
|
except RuntimeError:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data for unknown run user=%s run=%s",
|
||||||
|
user_id,
|
||||||
|
run_id,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_data missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.agent_complete:
|
||||||
|
run_id = frame.get("run_id")
|
||||||
|
if run_id:
|
||||||
|
try:
|
||||||
|
queue = device_manager.get_agent_data_queue(user_id, run_id)
|
||||||
|
# Sentinel: signals the agent data stream is finished.
|
||||||
|
await queue.put(None)
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"device_ws: agent_complete missing run_id from user=%s", user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.home_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_home_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == WsFrameType.floating_request:
|
||||||
|
asyncio.create_task(
|
||||||
|
_handle_floating_request(websocket, user_id, frame)
|
||||||
|
)
|
||||||
|
|
||||||
|
elif frame_type == "pong":
|
||||||
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
|
pass
|
||||||
|
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── v3 Chat Handlers ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _make_ws_executor(websocket: WebSocket, user_id: str):
|
||||||
|
"""Return a callback that sends tool_call frames and awaits tool_result."""
|
||||||
|
async def _executor(payload: dict) -> dict:
|
||||||
|
payload["type"] = WsFrameType.tool_call
|
||||||
|
call_id = payload["id"]
|
||||||
|
logger.info("ws_executor: sending tool_call id=%s action=%s", call_id, payload.get("action"))
|
||||||
|
await websocket.send_text(json.dumps(payload))
|
||||||
|
future = device_manager.create_pending_call(user_id, call_id)
|
||||||
|
result = await future
|
||||||
|
logger.info("ws_executor: tool_result id=%s result_type=%s result_keys=%s",
|
||||||
|
call_id, type(result).__name__,
|
||||||
|
list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
|
if result is None:
|
||||||
|
logger.error("ws_executor: future resolved to None for call_id=%s user=%s", call_id, user_id)
|
||||||
|
return result
|
||||||
|
return _executor
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_home_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
|
||||||
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
|
|
||||||
|
context: dict = {
|
||||||
|
"conversation_history": frame.get("conversation_history", []),
|
||||||
|
**memory_context,
|
||||||
|
}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
try:
|
||||||
|
event_stream = run_home_stream(
|
||||||
|
user_id, message, context, db_session_factory=async_session
|
||||||
|
)
|
||||||
|
formatter = HomeFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: home_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Memory: store episode after response ──────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def _handle_floating_request(
|
||||||
|
websocket: WebSocket,
|
||||||
|
user_id: str,
|
||||||
|
frame: dict,
|
||||||
|
) -> None:
|
||||||
|
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
|
||||||
|
request_id = frame.get("request_id") or str(uuid4())
|
||||||
|
message: str = frame.get("message", "")
|
||||||
|
session_id: str = frame.get("session_id") or str(uuid4())
|
||||||
|
scope: dict = frame.get("scope", {})
|
||||||
|
|
||||||
|
# ── Memory: enrich context before LLM call ────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
memory_context = await memory.enrich_context(user_id, message)
|
||||||
|
|
||||||
|
context: dict = {"scope": scope, **memory_context}
|
||||||
|
|
||||||
|
executor = await _make_ws_executor(websocket, user_id)
|
||||||
|
set_client_executor(executor)
|
||||||
|
response_chunks: list[str] = []
|
||||||
|
try:
|
||||||
|
event_stream = run_floating_stream(
|
||||||
|
user_id, message, context, scope=scope,
|
||||||
|
db_session_factory=async_session,
|
||||||
|
)
|
||||||
|
formatter = FloatingFormatter(request_id=request_id)
|
||||||
|
async for ws_frame in formatter.format(event_stream):
|
||||||
|
await websocket.send_text(ws_frame.model_dump_json())
|
||||||
|
if ws_frame.type == "stream_text": # type: ignore[union-attr]
|
||||||
|
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: floating_request failed user=%s req=%s: %s",
|
||||||
|
user_id, request_id, exc,
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_client_executor()
|
||||||
|
|
||||||
|
# ── Memory: store episode after response ──────────────────────────
|
||||||
|
async with async_session() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.store_episode(
|
||||||
|
user_id, session_id, message, "".join(response_chunks)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Heartbeat ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _heartbeat_loop(websocket: WebSocket) -> None:
|
||||||
|
"""Send a ping frame every 30 s to keep the connection alive."""
|
||||||
|
while True:
|
||||||
|
await asyncio.sleep(_HEARTBEAT_INTERVAL)
|
||||||
|
await websocket.send_text(json.dumps({"type": "ping"}))
|
||||||
|
|
||||||
|
|
||||||
|
# ── Disconnect cleanup ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _mark_runs_disconnected(user_id: str) -> None:
|
||||||
|
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
await db.execute(
|
||||||
|
update(AgentRunLog)
|
||||||
|
.where(
|
||||||
|
AgentRunLog.user_id == user_id,
|
||||||
|
AgentRunLog.status == "running",
|
||||||
|
)
|
||||||
|
.values(
|
||||||
|
status="error",
|
||||||
|
errors=["device disconnected"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"device_ws: failed to mark runs as disconnected for user=%s: %s",
|
||||||
|
user_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
148
app/api/routes/plugins.py
Normal file
148
app/api/routes/plugins.py
Normal file
@@ -0,0 +1,148 @@
|
|||||||
|
"""Plugins routes: browse and install plugins from the marketplace.
|
||||||
|
|
||||||
|
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
|
||||||
|
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.db import get_session
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
from app.models import PluginInstallation, PluginReview as PluginReviewModel
|
||||||
|
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/plugins", tags=["plugins"])
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tier gate ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _require_plugin_tier(user: UserProfile) -> None:
|
||||||
|
"""Raise HTTP 403 for users below Power tier."""
|
||||||
|
if user.tier not in ("power", "team"):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
|
detail="Plugin marketplace requires Power tier or above",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local detail schema ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _PluginDetail(BaseModel):
|
||||||
|
plugin: PluginManifest
|
||||||
|
install_count: int
|
||||||
|
ratings: list[Any]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.get("", response_model=PluginListResponse)
|
||||||
|
async def list_plugins(
|
||||||
|
category: str | None = Query(default=None),
|
||||||
|
q: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Browse the plugin marketplace. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/{plugin_id}", response_model=_PluginDetail)
|
||||||
|
async def get_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _PluginDetail:
|
||||||
|
"""Get full plugin details including install count. Requires Power tier or above."""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Fetch review ratings for this plugin
|
||||||
|
review_result = await db.execute(
|
||||||
|
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
|
||||||
|
)
|
||||||
|
reviews = review_result.scalars().all()
|
||||||
|
ratings = [
|
||||||
|
{
|
||||||
|
"reviewer_id": r.reviewer_id,
|
||||||
|
"decision": r.decision,
|
||||||
|
"notes": r.notes,
|
||||||
|
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
|
||||||
|
}
|
||||||
|
for r in reviews
|
||||||
|
]
|
||||||
|
|
||||||
|
return _PluginDetail(
|
||||||
|
plugin=entry["manifest"],
|
||||||
|
install_count=entry["install_count"],
|
||||||
|
ratings=ratings,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def install_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
|
||||||
|
|
||||||
|
Requires Power tier or above.
|
||||||
|
"""
|
||||||
|
_require_plugin_tier(current_user)
|
||||||
|
entry = await registry.get_plugin(db, plugin_id)
|
||||||
|
if entry is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
|
||||||
|
|
||||||
|
# Record the installation in plugin_installations
|
||||||
|
installation = PluginInstallation(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
)
|
||||||
|
db.add(installation)
|
||||||
|
await db.flush()
|
||||||
|
|
||||||
|
await revenue_share.record_install(
|
||||||
|
db,
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
amount_cents=entry["manifest"].price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
|
||||||
|
return {"ok": True, "download_url": download_url}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/{plugin_id}/install", response_model=dict)
|
||||||
|
async def uninstall_plugin(
|
||||||
|
plugin_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Unregister a plugin installation."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(PluginInstallation).where(
|
||||||
|
PluginInstallation.plugin_id == plugin_id,
|
||||||
|
PluginInstallation.user_id == current_user.id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
installation = result.scalar_one_or_none()
|
||||||
|
if installation is not None:
|
||||||
|
await db.delete(installation)
|
||||||
|
await db.commit()
|
||||||
|
await registry.record_uninstall(db, plugin_id)
|
||||||
|
return {"ok": True}
|
||||||
195
app/api/routes/storage.py
Normal file
195
app/api/routes/storage.py
Normal file
@@ -0,0 +1,195 @@
|
|||||||
|
"""Storage routes: CRUD for E2E-encrypted cloud records.
|
||||||
|
|
||||||
|
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
|
||||||
|
PostgreSQL ``storage_records`` table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from sqlalchemy import func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.models import StorageRecord
|
||||||
|
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
|
||||||
|
from app.storage.blob_store import BlobStore
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["storage"])
|
||||||
|
|
||||||
|
_blob_store = BlobStore()
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local response schemas ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
class _CreateResponse(BaseModel):
|
||||||
|
id: str
|
||||||
|
created_at: int
|
||||||
|
|
||||||
|
|
||||||
|
class _RecordMeta(BaseModel):
|
||||||
|
id: str
|
||||||
|
table: str
|
||||||
|
checksum: str
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Helpers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
|
||||||
|
"""Return total bytes stored by *user_id*."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
|
||||||
|
StorageRecord.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return int(result.scalar_one())
|
||||||
|
|
||||||
|
|
||||||
|
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
|
||||||
|
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
|
||||||
|
current = await _current_usage_bytes(user.id, db)
|
||||||
|
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_record_for_user(
|
||||||
|
record_id: str, user_id: str, db: AsyncSession
|
||||||
|
) -> StorageRecord:
|
||||||
|
"""Look up a record and verify ownership. Returns 404 on mismatch
|
||||||
|
to prevent user enumeration attacks."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(StorageRecord).where(
|
||||||
|
StorageRecord.id == record_id, StorageRecord.user_id == user_id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
record = result.scalar_one_or_none()
|
||||||
|
if record is None:
|
||||||
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
|
||||||
|
return record
|
||||||
|
|
||||||
|
|
||||||
|
# ── Routes ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
|
||||||
|
async def create_record(
|
||||||
|
body: StorageRecordCreate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> _CreateResponse:
|
||||||
|
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
await _check_quota(current_user, len(body.blob), db)
|
||||||
|
|
||||||
|
record_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, body.table, record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
record = StorageRecord(
|
||||||
|
id=record_id,
|
||||||
|
user_id=current_user.id,
|
||||||
|
table_name=body.table,
|
||||||
|
s3_key=s3_key,
|
||||||
|
checksum=body.checksum,
|
||||||
|
size_bytes=len(body.blob),
|
||||||
|
)
|
||||||
|
db.add(record)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(record)
|
||||||
|
|
||||||
|
created_at_ms = int(record.created_at.timestamp() * 1000)
|
||||||
|
return _CreateResponse(id=record_id, created_at=created_at_ms)
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records", response_model=list[_RecordMeta])
|
||||||
|
async def list_records(
|
||||||
|
table: str | None = Query(default=None),
|
||||||
|
page: int = Query(default=1, ge=1),
|
||||||
|
limit: int = Query(default=50, ge=1, le=200),
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> list[_RecordMeta]:
|
||||||
|
"""List record metadata for the authenticated user. Blob bytes are never returned."""
|
||||||
|
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
|
||||||
|
if table is not None:
|
||||||
|
query = query.where(StorageRecord.table_name == table)
|
||||||
|
query = query.offset((page - 1) * limit).limit(limit)
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
|
||||||
|
return [
|
||||||
|
_RecordMeta(
|
||||||
|
id=r.id,
|
||||||
|
table=r.table_name,
|
||||||
|
checksum=r.checksum,
|
||||||
|
created_at=int(r.created_at.timestamp() * 1000),
|
||||||
|
updated_at=int(r.updated_at.timestamp() * 1000),
|
||||||
|
)
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/records/{record_id}")
|
||||||
|
async def download_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> Response:
|
||||||
|
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
blob = await _blob_store.download(current_user.id, record.s3_key)
|
||||||
|
return Response(
|
||||||
|
content=blob,
|
||||||
|
media_type="application/octet-stream",
|
||||||
|
headers={"X-Checksum": record.checksum},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@router.put("/records/{record_id}", response_model=dict)
|
||||||
|
async def update_record(
|
||||||
|
record_id: str,
|
||||||
|
body: StorageRecordUpdate,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Replace the blob for an existing record. Verifies checksum before storing."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
reject_if_tampered(body.blob, body.checksum)
|
||||||
|
|
||||||
|
delta = len(body.blob) - record.size_bytes
|
||||||
|
if delta > 0:
|
||||||
|
await _check_quota(current_user, delta, db)
|
||||||
|
|
||||||
|
s3_key = await _blob_store.upload(
|
||||||
|
current_user.id, record.table_name, record_id, body.blob, body.checksum
|
||||||
|
)
|
||||||
|
|
||||||
|
record.s3_key = s3_key
|
||||||
|
record.checksum = body.checksum
|
||||||
|
record.size_bytes = len(body.blob)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/records/{record_id}", response_model=dict)
|
||||||
|
async def delete_record(
|
||||||
|
record_id: str,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
db: AsyncSession = Depends(get_session),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete a record and its S3 blob."""
|
||||||
|
record = await _get_record_for_user(record_id, current_user.id, db)
|
||||||
|
await _blob_store.delete(current_user.id, record.s3_key)
|
||||||
|
await db.delete(record)
|
||||||
|
await db.commit()
|
||||||
|
return {"ok": True}
|
||||||
79
app/api/routes/vectors.py
Normal file
79
app/api/routes/vectors.py
Normal file
@@ -0,0 +1,79 @@
|
|||||||
|
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Depends
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from app.api.deps import get_current_user
|
||||||
|
from app.core.llm import embed
|
||||||
|
from app.schemas import (
|
||||||
|
UserProfile,
|
||||||
|
VectorSearchRequest,
|
||||||
|
VectorSearchResponse,
|
||||||
|
VectorUpsertRequest,
|
||||||
|
)
|
||||||
|
from app.storage.encryption import reject_if_tampered
|
||||||
|
from app.storage.vector_store import VectorStore
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/storage", tags=["vectors"])
|
||||||
|
|
||||||
|
_vector_store = VectorStore()
|
||||||
|
|
||||||
|
|
||||||
|
class _VectorDeleteRequest(BaseModel):
|
||||||
|
ids: list[str]
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedRequest(BaseModel):
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbedResponse(BaseModel):
|
||||||
|
vector: list[float]
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/upsert", response_model=dict)
|
||||||
|
async def upsert_vectors(
|
||||||
|
body: VectorUpsertRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
|
||||||
|
for item in body.vectors:
|
||||||
|
reject_if_tampered(item.blob, item.checksum)
|
||||||
|
await _vector_store.upsert(current_user.id, body.vectors)
|
||||||
|
return {"upserted": len(body.vectors)}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/search", response_model=VectorSearchResponse)
|
||||||
|
async def search_vectors(
|
||||||
|
body: VectorSearchRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> VectorSearchResponse:
|
||||||
|
"""Search the user-scoped vector namespace with an encrypted query blob."""
|
||||||
|
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
|
||||||
|
return VectorSearchResponse(results=results)
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/vectors", response_model=dict)
|
||||||
|
async def delete_vectors(
|
||||||
|
body: _VectorDeleteRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> dict[str, bool]:
|
||||||
|
"""Delete vectors by ID, scoped to the authenticated user."""
|
||||||
|
await _vector_store.delete(current_user.id, body.ids)
|
||||||
|
return {"ok": True}
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/vectors/embed", response_model=_EmbedResponse)
|
||||||
|
async def embed_text(
|
||||||
|
body: _EmbedRequest,
|
||||||
|
current_user: UserProfile = Depends(get_current_user),
|
||||||
|
) -> _EmbedResponse:
|
||||||
|
"""Generate a 1536-dim embedding vector for the given text.
|
||||||
|
|
||||||
|
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
|
||||||
|
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
|
||||||
|
"""
|
||||||
|
vector = await embed(body.text)
|
||||||
|
return _EmbedResponse(vector=vector)
|
||||||
4
app/billing/__init__.py
Normal file
4
app/billing/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
from app.billing.stripe_service import stripe_service
|
||||||
|
from app.billing.tier_manager import tier_manager
|
||||||
|
|
||||||
|
__all__ = ["stripe_service", "tier_manager"]
|
||||||
@@ -1,7 +1,7 @@
|
|||||||
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
"""Stripe service: checkout sessions, webhook handling, subscription management.
|
||||||
|
|
||||||
Adapted for the Billing microservice — uses shared.models and shared.db.
|
Subscription records are persisted in the PostgreSQL ``subscriptions`` table.
|
||||||
All Stripe calls are gracefully stubbed when STRIPE_SECRET_KEY is not
|
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not
|
||||||
configured, enabling local development without live credentials.
|
configured, enabling local development without live credentials.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -15,8 +15,7 @@ from fastapi import HTTPException, status
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from shared.config import settings
|
from app.config.settings import settings
|
||||||
from shared.models import Subscription
|
|
||||||
|
|
||||||
# Stripe price IDs per tier — replace with real IDs in production .env
|
# Stripe price IDs per tier — replace with real IDs in production .env
|
||||||
TIER_PRICE_IDS: dict[str, str] = {
|
TIER_PRICE_IDS: dict[str, str] = {
|
||||||
@@ -47,7 +46,11 @@ class StripeService:
|
|||||||
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
|
||||||
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
cancel_url: str = "https://app.adiuva.app/billing/cancel",
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Create a Stripe checkout session and return the URL."""
|
"""Create a Stripe checkout session and return the URL.
|
||||||
|
|
||||||
|
Returns a stub URL when Stripe is not configured.
|
||||||
|
Raises ``HTTP 400`` for the free tier or an unknown tier.
|
||||||
|
"""
|
||||||
if tier == "free":
|
if tier == "free":
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_400_BAD_REQUEST,
|
status_code=status.HTTP_400_BAD_REQUEST,
|
||||||
@@ -84,6 +87,8 @@ class StripeService:
|
|||||||
"""Process a Stripe webhook event.
|
"""Process a Stripe webhook event.
|
||||||
|
|
||||||
Verifies the signature, then dispatches on event type.
|
Verifies the signature, then dispatches on event type.
|
||||||
|
Raises ``HTTP 400`` on signature mismatch.
|
||||||
|
No-ops when Stripe is not configured.
|
||||||
"""
|
"""
|
||||||
if not self._configured():
|
if not self._configured():
|
||||||
return
|
return
|
||||||
@@ -150,7 +155,9 @@ class StripeService:
|
|||||||
async def get_subscription(
|
async def get_subscription(
|
||||||
self, user_id: str, db: AsyncSession
|
self, user_id: str, db: AsyncSession
|
||||||
) -> dict[str, Any] | None:
|
) -> dict[str, Any] | None:
|
||||||
"""Return the subscription record for user_id, or None."""
|
"""Return the subscription record for ``user_id``, or ``None`` if absent."""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription).where(Subscription.user_id == user_id)
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
@@ -169,7 +176,12 @@ class StripeService:
|
|||||||
}
|
}
|
||||||
|
|
||||||
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
|
||||||
"""Cancel the user's Stripe subscription and downgrade to free."""
|
"""Cancel the user's Stripe subscription and downgrade them to free.
|
||||||
|
|
||||||
|
Raises ``HTTP 404`` when no active subscription exists.
|
||||||
|
"""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription).where(Subscription.user_id == user_id)
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
@@ -199,6 +211,8 @@ class StripeService:
|
|||||||
sub_status: str,
|
sub_status: str,
|
||||||
current_period_end: datetime | None,
|
current_period_end: datetime | None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription).where(Subscription.user_id == user_id)
|
select(Subscription).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
@@ -220,6 +234,8 @@ class StripeService:
|
|||||||
status: str | None = None,
|
status: str | None = None,
|
||||||
current_period_end: datetime | None = None,
|
current_period_end: datetime | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription).where(
|
select(Subscription).where(
|
||||||
Subscription.stripe_subscription_id == stripe_subscription_id
|
Subscription.stripe_subscription_id == stripe_subscription_id
|
||||||
@@ -236,5 +252,5 @@ class StripeService:
|
|||||||
sub.current_period_end = current_period_end
|
sub.current_period_end = current_period_end
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
# Module-level singleton shared across the app.
|
||||||
stripe_service = StripeService()
|
stripe_service = StripeService()
|
||||||
@@ -1,8 +1,9 @@
|
|||||||
"""Tier manager: feature matrix and quota enforcement.
|
"""Tier manager: feature matrix and quota enforcement.
|
||||||
|
|
||||||
Single source of truth for what each billing tier allows.
|
``TierManager`` is the single source of truth for what each billing tier
|
||||||
Other services can query the /tier/{user_id} endpoint or rely on the
|
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier.
|
||||||
X-User-Tier header injected by Traefik.
|
Quota-enforcement helpers take ``tier`` directly — the caller already has it
|
||||||
|
from ``current_user.tier`` (provided by ``get_current_user``).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -13,16 +14,13 @@ from fastapi import HTTPException, status
|
|||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from shared.config import settings
|
from app.schemas import BillingTier
|
||||||
from shared.models import Subscription
|
|
||||||
from shared.schemas import BillingTier
|
|
||||||
|
|
||||||
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
# Feature matrix per tier. -1 means unlimited; 0 means disabled.
|
||||||
FEATURES: dict[str, dict[str, Any]] = {
|
FEATURES: dict[str, dict[str, Any]] = {
|
||||||
"free": {
|
"free": {
|
||||||
"agents": 3,
|
"agents": 3,
|
||||||
"batch_active": 2,
|
"batch_active": 2,
|
||||||
"batch_runs_per_day": 5,
|
|
||||||
"cloud_storage_gb": 0,
|
"cloud_storage_gb": 0,
|
||||||
"backup_gb": 0,
|
"backup_gb": 0,
|
||||||
"providers": 1,
|
"providers": 1,
|
||||||
@@ -31,9 +29,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"sso": False,
|
"sso": False,
|
||||||
},
|
},
|
||||||
"pro": {
|
"pro": {
|
||||||
"agents": -1,
|
"agents": -1, # unlimited
|
||||||
"batch_active": 10,
|
"batch_active": 10,
|
||||||
"batch_runs_per_day": 50,
|
|
||||||
"cloud_storage_gb": 5,
|
"cloud_storage_gb": 5,
|
||||||
"backup_gb": 5,
|
"backup_gb": 5,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -43,8 +40,7 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
},
|
},
|
||||||
"power": {
|
"power": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1, # unlimited
|
||||||
"batch_runs_per_day": -1,
|
|
||||||
"cloud_storage_gb": 25,
|
"cloud_storage_gb": 25,
|
||||||
"backup_gb": 25,
|
"backup_gb": 25,
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
@@ -55,9 +51,8 @@ FEATURES: dict[str, dict[str, Any]] = {
|
|||||||
"team": {
|
"team": {
|
||||||
"agents": -1,
|
"agents": -1,
|
||||||
"batch_active": -1,
|
"batch_active": -1,
|
||||||
"batch_runs_per_day": -1,
|
"cloud_storage_gb": -1, # unlimited
|
||||||
"cloud_storage_gb": -1,
|
"backup_gb": -1, # unlimited
|
||||||
"backup_gb": -1,
|
|
||||||
"providers": -1,
|
"providers": -1,
|
||||||
"batch_builder": True,
|
"batch_builder": True,
|
||||||
"plugin_marketplace": True,
|
"plugin_marketplace": True,
|
||||||
@@ -77,22 +72,30 @@ RATE_LIMITS: dict[str, int] = {
|
|||||||
class TierManager:
|
class TierManager:
|
||||||
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
"""Centralises tier feature-gating, rate-limit lookups, and quota checks."""
|
||||||
|
|
||||||
|
# ── Tier lookup ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
|
||||||
"""Return the current billing tier for user_id from the DB."""
|
"""Return the current billing tier for ``user_id`` from the DB.
|
||||||
|
|
||||||
|
Falls back to ``'free'`` when no subscription row exists.
|
||||||
|
"""
|
||||||
|
from app.models import Subscription # noqa: PLC0415
|
||||||
|
|
||||||
result = await db.execute(
|
result = await db.execute(
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
select(Subscription.tier).where(Subscription.user_id == user_id)
|
||||||
)
|
)
|
||||||
tier: str | None = result.scalar_one_or_none()
|
tier: str | None = result.scalar_one_or_none()
|
||||||
if tier is None or tier not in FEATURES:
|
if tier is None or tier not in FEATURES:
|
||||||
return "power" if settings.ENV == "dev" else "free"
|
return "free"
|
||||||
return tier # type: ignore[return-value]
|
return tier # type: ignore[return-value]
|
||||||
|
|
||||||
def get_features(self, tier: BillingTier) -> dict[str, Any]:
|
# ── Feature access ───────────────────────────────────────────────────
|
||||||
"""Return the full feature dict for a tier."""
|
|
||||||
return FEATURES.get(tier, FEATURES["free"])
|
|
||||||
|
|
||||||
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
def check_feature(self, tier: BillingTier, feature: str) -> bool:
|
||||||
"""Return True if tier has feature enabled."""
|
"""Return ``True`` if ``tier`` has ``feature`` enabled.
|
||||||
|
|
||||||
|
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
|
||||||
|
"""
|
||||||
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
value = FEATURES.get(tier, FEATURES["free"]).get(feature)
|
||||||
if value is None:
|
if value is None:
|
||||||
return False
|
return False
|
||||||
@@ -101,7 +104,7 @@ class TierManager:
|
|||||||
return value != 0
|
return value != 0
|
||||||
|
|
||||||
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
|
||||||
"""Raise HTTP 403 if tier does not have feature."""
|
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``."""
|
||||||
if not self.check_feature(tier, feature):
|
if not self.check_feature(tier, feature):
|
||||||
detail = (
|
detail = (
|
||||||
f"Feature '{feature}' requires {tier_name} tier or above."
|
f"Feature '{feature}' requires {tier_name} tier or above."
|
||||||
@@ -110,17 +113,25 @@ class TierManager:
|
|||||||
)
|
)
|
||||||
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||||
|
|
||||||
|
# ── Rate limiting ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def get_rate_limit(self, tier: BillingTier) -> int:
|
def get_rate_limit(self, tier: BillingTier) -> int:
|
||||||
"""Return the requests-per-minute limit for tier."""
|
"""Return the requests-per-minute limit for ``tier``."""
|
||||||
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
|
||||||
|
|
||||||
|
# ── Storage quota ────────────────────────────────────────────────────
|
||||||
|
|
||||||
def enforce_quota(
|
def enforce_quota(
|
||||||
self,
|
self,
|
||||||
tier: BillingTier,
|
tier: BillingTier,
|
||||||
current_bytes: int = 0,
|
current_bytes: int = 0,
|
||||||
additional_bytes: int = 0,
|
additional_bytes: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Raise HTTP 402 if the user would exceed their cloud storage quota."""
|
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota.
|
||||||
|
|
||||||
|
``tier`` is the caller's current tier (from ``current_user.tier``).
|
||||||
|
``current_bytes`` is the total bytes already stored (queried by caller).
|
||||||
|
"""
|
||||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
if limit_gb == 0:
|
if limit_gb == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -128,7 +139,7 @@ class TierManager:
|
|||||||
detail=f"Cloud storage is not available on the '{tier}' tier",
|
detail=f"Cloud storage is not available on the '{tier}' tier",
|
||||||
)
|
)
|
||||||
if limit_gb == -1:
|
if limit_gb == -1:
|
||||||
return
|
return # unlimited
|
||||||
limit_bytes = limit_gb * 1024 ** 3
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
if current_bytes + additional_bytes > limit_bytes:
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -142,7 +153,7 @@ class TierManager:
|
|||||||
current_bytes: int = 0,
|
current_bytes: int = 0,
|
||||||
additional_bytes: int = 0,
|
additional_bytes: int = 0,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Raise HTTP 402 if the user would exceed their backup quota."""
|
"""Raise ``HTTP 402`` if the user would exceed their backup quota."""
|
||||||
limit_gb: int = FEATURES[tier]["backup_gb"]
|
limit_gb: int = FEATURES[tier]["backup_gb"]
|
||||||
if limit_gb == 0:
|
if limit_gb == 0:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -150,7 +161,7 @@ class TierManager:
|
|||||||
detail=f"Backup is not available on the '{tier}' tier",
|
detail=f"Backup is not available on the '{tier}' tier",
|
||||||
)
|
)
|
||||||
if limit_gb == -1:
|
if limit_gb == -1:
|
||||||
return
|
return # unlimited
|
||||||
limit_bytes = limit_gb * 1024 ** 3
|
limit_bytes = limit_gb * 1024 ** 3
|
||||||
if current_bytes + additional_bytes > limit_bytes:
|
if current_bytes + additional_bytes > limit_bytes:
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
@@ -164,7 +175,7 @@ class TierManager:
|
|||||||
current_bytes: int = 0,
|
current_bytes: int = 0,
|
||||||
additional_bytes: int = 0,
|
additional_bytes: int = 0,
|
||||||
) -> bool:
|
) -> bool:
|
||||||
"""Return True if the user can store additional_bytes more data."""
|
"""Return ``True`` if the user can store ``additional_bytes`` more data."""
|
||||||
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
|
||||||
if limit_gb == 0:
|
if limit_gb == 0:
|
||||||
return False
|
return False
|
||||||
@@ -174,5 +185,5 @@ class TierManager:
|
|||||||
return current_bytes + additional_bytes <= limit_bytes
|
return current_bytes + additional_bytes <= limit_bytes
|
||||||
|
|
||||||
|
|
||||||
# Module-level singleton
|
# Module-level singleton shared across the app.
|
||||||
tier_manager = TierManager()
|
tier_manager = TierManager()
|
||||||
60
app/config/settings.py
Normal file
60
app/config/settings.py
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
from typing import Literal
|
||||||
|
from pydantic_settings import BaseSettings, SettingsConfigDict
|
||||||
|
|
||||||
|
|
||||||
|
class Settings(BaseSettings):
|
||||||
|
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
|
||||||
|
JWT_SECRET: str = "change-me-in-production"
|
||||||
|
JWT_ALGORITHM: str = "HS256"
|
||||||
|
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
|
||||||
|
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
|
||||||
|
|
||||||
|
STRIPE_SECRET_KEY: str = ""
|
||||||
|
STRIPE_WEBHOOK_SECRET: str = ""
|
||||||
|
|
||||||
|
S3_BUCKET: str = ""
|
||||||
|
S3_REGION: str = "us-east-1"
|
||||||
|
S3_ENDPOINT_URL: str = ""
|
||||||
|
AWS_ACCESS_KEY_ID: str = ""
|
||||||
|
AWS_SECRET_ACCESS_KEY: str = ""
|
||||||
|
|
||||||
|
PINECONE_API_KEY: str = ""
|
||||||
|
PINECONE_INDEX: str = "adiuva"
|
||||||
|
QDRANT_URL: str = ""
|
||||||
|
QDRANT_API_KEY: str = ""
|
||||||
|
|
||||||
|
OPENAI_API_KEY: str = ""
|
||||||
|
ANTHROPIC_API_KEY: str = ""
|
||||||
|
GOOGLE_API_KEY: str = ""
|
||||||
|
CEREBRAS_API_KEY: str = ""
|
||||||
|
|
||||||
|
LLM_MODEL: str = "gpt-4o"
|
||||||
|
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
|
||||||
|
LLM_EMBED_MODEL: str = "text-embedding-3-small"
|
||||||
|
|
||||||
|
# GitHub Copilot OAuth token storage directory.
|
||||||
|
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
|
||||||
|
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
|
||||||
|
GITHUB_COPILOT_TOKEN_DIR: str = ""
|
||||||
|
|
||||||
|
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
|
||||||
|
GMAIL_CLIENT_ID: str = ""
|
||||||
|
GMAIL_CLIENT_SECRET: str = ""
|
||||||
|
MS_CLIENT_ID: str = ""
|
||||||
|
MS_CLIENT_SECRET: str = ""
|
||||||
|
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
|
||||||
|
MS_TENANT_ID: str = "common"
|
||||||
|
|
||||||
|
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||||
|
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||||
|
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||||
|
OAUTH_ENCRYPTION_KEY: str = ""
|
||||||
|
|
||||||
|
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
|
||||||
|
|
||||||
|
ENV: Literal["dev", "prod"] = "dev"
|
||||||
|
|
||||||
|
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
||||||
|
|
||||||
|
|
||||||
|
settings = Settings()
|
||||||
718
app/core/agent_runner.py
Normal file
718
app/core/agent_runner.py
Normal file
@@ -0,0 +1,718 @@
|
|||||||
|
"""Agent run manager.
|
||||||
|
|
||||||
|
Drives two agent types:
|
||||||
|
|
||||||
|
* **Local directory agent** — sends an ``agent_run`` frame to the connected
|
||||||
|
Electron device, waits for the device to stream back file contents via
|
||||||
|
``agent_data`` frames, then calls the LLM to extract structured items from
|
||||||
|
each file and pushes inserts to Electron via tool-call round-trips.
|
||||||
|
|
||||||
|
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
||||||
|
Teams, Outlook) and pushes extracted items to Electron. **This path is
|
||||||
|
a stub** — provider integrations are implemented in Step 3.6.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
Background tasks are spawned with ``asyncio.create_task()``::
|
||||||
|
|
||||||
|
asyncio.create_task(run_local_agent(user_id, config, run_log, device_manager))
|
||||||
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||||
|
|
||||||
|
The ``trigger_pending_runs`` function is called by the device WS endpoint
|
||||||
|
when Electron sends ``device_hello``, so any overdue runs fire immediately
|
||||||
|
when the device reconnects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from croniter import croniter
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.db import async_session
|
||||||
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Max seconds to wait for Electron to finish streaming file data.
|
||||||
|
_FILE_READ_TIMEOUT: int = 120
|
||||||
|
# Max seconds to wait for Electron to acknowledge a single tool-call insert.
|
||||||
|
_INSERT_TIMEOUT: int = 30
|
||||||
|
|
||||||
|
# ── Allowed tables & extraction schema hints ───────────────────────────────
|
||||||
|
|
||||||
|
_ALLOWED_TABLES: frozenset[str] = frozenset(
|
||||||
|
{"tasks", "notes", "timelines", "projects", "taskComments"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Field descriptions fed to the extraction LLM as concise schema references.
|
||||||
|
_TABLE_SCHEMAS: dict[str, str] = {
|
||||||
|
"tasks": (
|
||||||
|
"title (str, required), description (str), "
|
||||||
|
"status (todo|in_progress|done, default todo), "
|
||||||
|
"priority (high|medium|low, default medium), "
|
||||||
|
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
|
||||||
|
),
|
||||||
|
"notes": "title (str, required), content (str, markdown), projectId (str)",
|
||||||
|
"timelines": (
|
||||||
|
"title (str, required), projectId (str, required), date (ms timestamp int)"
|
||||||
|
),
|
||||||
|
"projects": "name (str, required), clientId (str)",
|
||||||
|
"taskComments": "taskId (str, required), author (str), content (str, required)",
|
||||||
|
}
|
||||||
|
|
||||||
|
_EXTRACTION_SYSTEM_PROMPT = """\
|
||||||
|
You are a data extraction assistant for a freelance project management tool.
|
||||||
|
Given a document, extract structured records matching the user's instructions.
|
||||||
|
|
||||||
|
Output a JSON array (no markdown fences, no explanation) of objects shaped:
|
||||||
|
[{{"table": "<table_name>", "data": {{...fields}}}}, ...]
|
||||||
|
|
||||||
|
Allowed table names and their fields:
|
||||||
|
{table_schemas}
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Only extract tables listed in the "data_types" instructions.
|
||||||
|
- Use camelCase field names exactly as shown above.
|
||||||
|
- Omit optional fields you cannot determine; do not invent data.
|
||||||
|
- Never include id, createdAt, updatedAt, isAiSuggested, or isApproved.
|
||||||
|
- If nothing relevant is found, return an empty JSON array: []
|
||||||
|
- Return ONLY the JSON array.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cron helper ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
|
||||||
|
"""Return ``True`` if the next scheduled run time has already passed.
|
||||||
|
|
||||||
|
Always validates the cron expression first — an invalid expression returns
|
||||||
|
``False`` (fail-safe: never trigger an unparseable schedule).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if last_run_at is None:
|
||||||
|
# Validate the expression before deciding this is overdue.
|
||||||
|
croniter(schedule_cron, now)
|
||||||
|
return True
|
||||||
|
ts = last_run_at
|
||||||
|
if ts.tzinfo is None:
|
||||||
|
ts = ts.replace(tzinfo=timezone.utc)
|
||||||
|
cron = croniter(schedule_cron, ts)
|
||||||
|
next_run: datetime = cron.get_next(datetime)
|
||||||
|
return now >= next_run
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc)
|
||||||
|
return False # Fail-safe: don't trigger if expression is invalid.
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM extraction ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_items_from_content(
|
||||||
|
prompt_template: str,
|
||||||
|
file_content: str,
|
||||||
|
data_types: list[str],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Call the LLM to extract structured records from *file_content*.
|
||||||
|
|
||||||
|
Returns a validated list of ``{table: str, data: dict}`` objects.
|
||||||
|
Items referencing tables not in *data_types* are discarded.
|
||||||
|
"""
|
||||||
|
allowed = [t for t in data_types if t in _ALLOWED_TABLES]
|
||||||
|
if not allowed:
|
||||||
|
return []
|
||||||
|
|
||||||
|
schema_text = "\n".join(
|
||||||
|
f" {table}: {_TABLE_SCHEMAS.get(table, '(unknown)')}" for table in allowed
|
||||||
|
)
|
||||||
|
system_prompt = _EXTRACTION_SYSTEM_PROMPT.format(table_schemas=schema_text)
|
||||||
|
user_prompt = (
|
||||||
|
f"User instructions: {prompt_template}\n\n"
|
||||||
|
f"Extract these record types: {', '.join(allowed)}\n\n"
|
||||||
|
f"Document:\n{file_content[:8000]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = get_llm()
|
||||||
|
raw = ""
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
|
||||||
|
)
|
||||||
|
raw = str(response.content).strip()
|
||||||
|
items: list[dict] = json.loads(raw)
|
||||||
|
if not isinstance(items, list):
|
||||||
|
raise ValueError("LLM response is not a JSON array")
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"agent_runner: LLM extraction returned invalid JSON: %s — snippet: %.200r",
|
||||||
|
exc,
|
||||||
|
raw,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
# Other exceptions (LLM API errors, network errors) propagate to the
|
||||||
|
# caller (run_local_agent) which records them per-file in the run log.
|
||||||
|
|
||||||
|
validated: list[dict[str, Any]] = []
|
||||||
|
for item in items:
|
||||||
|
table = item.get("table")
|
||||||
|
data = item.get("data")
|
||||||
|
if not isinstance(table, str) or table not in allowed:
|
||||||
|
continue
|
||||||
|
if not isinstance(data, dict) or not data:
|
||||||
|
continue
|
||||||
|
# Strip any server-generated or forbidden fields.
|
||||||
|
for _field in ("id", "createdAt", "updatedAt", "isAiSuggested", "isApproved"):
|
||||||
|
data.pop(_field, None)
|
||||||
|
validated.append({"table": table, "data": data})
|
||||||
|
return validated
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool-call insert helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_insert_to_client(
|
||||||
|
user_id: str,
|
||||||
|
table: str,
|
||||||
|
data: dict[str, Any],
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send an ``insert`` tool_call frame to Electron and await the tool_result.
|
||||||
|
|
||||||
|
All inserts include ``isAiSuggested=1, isApproved=0`` so the user can
|
||||||
|
review AI-produced records before they are treated as confirmed.
|
||||||
|
|
||||||
|
Raises ``asyncio.TimeoutError`` if Electron does not respond within
|
||||||
|
``_INSERT_TIMEOUT`` seconds. Raises ``RuntimeError`` if the device
|
||||||
|
disconnects before the frame can be sent.
|
||||||
|
"""
|
||||||
|
call_id = str(uuid.uuid4())
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": call_id,
|
||||||
|
"action": "insert",
|
||||||
|
"table": table,
|
||||||
|
"data": {**data, "isAiSuggested": 1, "isApproved": 0},
|
||||||
|
}
|
||||||
|
fut = device_mgr.create_pending_call(user_id, call_id)
|
||||||
|
await device_mgr.send_frame(user_id, payload)
|
||||||
|
return await asyncio.wait_for(fut, timeout=_INSERT_TIMEOUT)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent runner ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_local_agent(
|
||||||
|
user_id: str,
|
||||||
|
config: LocalAgentConfig,
|
||||||
|
run_log: AgentRunLog,
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a local directory agent run end-to-end.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
|
||||||
|
1. Verify the device identified by ``config.device_id`` is currently online.
|
||||||
|
2. Pre-create the agent_data queue so no incoming frames are lost.
|
||||||
|
3. Send ``agent_run`` frame to Electron (paths, extensions, prompt, data_types).
|
||||||
|
4. Consume ``agent_data`` frames until the ``None`` sentinel from
|
||||||
|
``agent_complete``.
|
||||||
|
5. For each received file call the LLM to extract ``{table, data}`` items.
|
||||||
|
6. Push each item to Electron as an ``insert`` tool-call; include
|
||||||
|
``isAiSuggested=1, isApproved=0`` so users can review AI suggestions.
|
||||||
|
7. Persist the run outcome (status, counts, errors) and update
|
||||||
|
``config.last_run_at``.
|
||||||
|
"""
|
||||||
|
run_id = run_log.id
|
||||||
|
|
||||||
|
# ── 1. Device online check ─────────────────────────────────────────
|
||||||
|
if not device_mgr.is_online(user_id, config.device_id):
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: skip run=%s — device %r offline for user=%s",
|
||||||
|
run_id,
|
||||||
|
config.device_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Device {config.device_id!r} is not connected"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 2. Pre-create agent_data queue ────────────────────────────────
|
||||||
|
try:
|
||||||
|
device_mgr.get_agent_data_queue(user_id, run_id)
|
||||||
|
except RuntimeError:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=["Device disconnected before agent run could start"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Send agent_run frame ────────────────────────────────────────
|
||||||
|
frame: dict[str, Any] = {
|
||||||
|
"type": "agent_run",
|
||||||
|
"run_id": run_id,
|
||||||
|
"agent_id": config.id,
|
||||||
|
"config": {
|
||||||
|
"paths": config.directory_paths,
|
||||||
|
"file_extensions": config.file_extensions,
|
||||||
|
"prompt_template": config.prompt_template,
|
||||||
|
"data_types": config.data_types,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await device_mgr.send_frame(user_id, frame)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to send agent_run frame: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: sent agent_run run=%s agent=%s user=%s",
|
||||||
|
run_id,
|
||||||
|
config.id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 4. Consume agent_data frames ──────────────────────────────────
|
||||||
|
files: list[dict[str, Any]] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
queue = device_mgr.get_agent_data_queue(user_id, run_id)
|
||||||
|
deadline = asyncio.get_event_loop().time() + _FILE_READ_TIMEOUT
|
||||||
|
while True:
|
||||||
|
remaining = deadline - asyncio.get_event_loop().time()
|
||||||
|
if remaining <= 0:
|
||||||
|
errors.append("Timed out waiting for file data from device")
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
frame_data = await asyncio.wait_for(queue.get(), timeout=remaining)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append("Timed out waiting for file data from device")
|
||||||
|
break
|
||||||
|
if frame_data is None:
|
||||||
|
# Sentinel from agent_complete — stream is done.
|
||||||
|
break
|
||||||
|
files.extend(frame_data.get("files", []))
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Queue error reading agent data: {exc}")
|
||||||
|
|
||||||
|
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
for file_info in files:
|
||||||
|
file_path: str = file_info.get("path", "<unknown>")
|
||||||
|
content: str = file_info.get("content", "")
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
try:
|
||||||
|
extracted = await _extract_items_from_content(
|
||||||
|
config.prompt_template, content, config.data_types
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM extraction error for {file_path!r}: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for item in extracted:
|
||||||
|
try:
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
user_id, item["table"], item["data"], device_mgr
|
||||||
|
)
|
||||||
|
if result.get("error"):
|
||||||
|
errors.append(
|
||||||
|
f"Insert failed ({item['table']}, {file_path!r}): {result['error']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
items_created += 1
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append(
|
||||||
|
f"Timed out awaiting insert ack ({item['table']}, {file_path!r})"
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Insert error ({item['table']}, {file_path!r}): {exc}")
|
||||||
|
|
||||||
|
# ── 7. Finalise ────────────────────────────────────────────────────
|
||||||
|
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
||||||
|
|
||||||
|
if errors and items_created == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=items_created,
|
||||||
|
errors=errors,
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="local",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
|
run_id,
|
||||||
|
final_status,
|
||||||
|
items_processed,
|
||||||
|
items_created,
|
||||||
|
len(errors),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Default lookback window when an agent has never run before.
|
||||||
|
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
||||||
|
|
||||||
|
|
||||||
|
async def run_cloud_agent(
|
||||||
|
user_id: str,
|
||||||
|
config: CloudAgentConfig,
|
||||||
|
run_log: AgentRunLog,
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a cloud connector agent run end-to-end.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
|
||||||
|
1. Verify the user's device is online — results are pushed to Electron
|
||||||
|
via WS tool-call frames. If no device is connected, abort.
|
||||||
|
2. Decrypt the stored OAuth token from ``config.oauth_token_encrypted``.
|
||||||
|
3. Instantiate the provider client (Gmail or MS Graph).
|
||||||
|
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
|
||||||
|
the first run) applying ``config.filter_config`` filters.
|
||||||
|
5. For each message/email call ``_extract_items_from_content`` with
|
||||||
|
``config.prompt_template`` to get structured ``{table, data}`` items.
|
||||||
|
6. Push each item to Electron as an ``insert`` tool-call.
|
||||||
|
7. If the provider refreshed its access token, re-encrypt and write it
|
||||||
|
back to ``config.oauth_token_encrypted``.
|
||||||
|
8. Persist the run outcome via ``_finalize_run``.
|
||||||
|
"""
|
||||||
|
run_id = run_log.id
|
||||||
|
|
||||||
|
# ── 1. Device online check ─────────────────────────────────────────
|
||||||
|
if not device_mgr.is_online(user_id):
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: skip cloud run=%s — no device online for user=%s",
|
||||||
|
run_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=["No connected device — cloud agent results cannot be delivered"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 2. Decrypt OAuth token ─────────────────────────────────────────
|
||||||
|
from app.integrations import decrypt_token, encrypt_token, get_provider
|
||||||
|
|
||||||
|
if not config.oauth_token_encrypted:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
||||||
|
except ValueError as exc:
|
||||||
|
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Instantiate provider client ────────────────────────────────
|
||||||
|
try:
|
||||||
|
provider = get_provider(config.provider, credentials_info)
|
||||||
|
except ValueError as exc:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[str(exc)],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 4. Fetch messages ─────────────────────────────────────────────
|
||||||
|
since: datetime | None = config.last_run_at
|
||||||
|
if since is None:
|
||||||
|
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
||||||
|
if since.tzinfo is None:
|
||||||
|
since = since.replace(tzinfo=timezone.utc)
|
||||||
|
|
||||||
|
errors: list[str] = []
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
if config.provider == "gmail":
|
||||||
|
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "outlook":
|
||||||
|
raw_messages = await provider.fetch_emails( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
elif config.provider == "teams":
|
||||||
|
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
|
||||||
|
filter_config=config.filter_config,
|
||||||
|
since=since,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raw_messages = []
|
||||||
|
except RuntimeError as exc:
|
||||||
|
logger.error(
|
||||||
|
"agent_runner: provider fetch failed for cloud agent %s: %s",
|
||||||
|
config.id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Provider fetch failed: {exc}"],
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
|
||||||
|
config.id,
|
||||||
|
len(raw_messages),
|
||||||
|
config.provider,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
||||||
|
for msg in raw_messages:
|
||||||
|
content_text = msg.as_text
|
||||||
|
if not content_text:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
try:
|
||||||
|
extracted = await _extract_items_from_content(
|
||||||
|
config.prompt_template, content_text, config.data_types
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM extraction error for message {msg.id!r}: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for item in extracted:
|
||||||
|
try:
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
user_id, item["table"], item["data"], device_mgr
|
||||||
|
)
|
||||||
|
if result.get("error"):
|
||||||
|
errors.append(
|
||||||
|
f"Insert failed ({item['table']}, msg={msg.id!r}): {result['error']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
items_created += 1
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append(
|
||||||
|
f"Timed out awaiting insert ack ({item['table']}, msg={msg.id!r})"
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Insert error ({item['table']}, msg={msg.id!r}): {exc}")
|
||||||
|
|
||||||
|
# ── 7. Persist refreshed token (if any) ───────────────────────────
|
||||||
|
refreshed = getattr(provider, "refreshed_credentials", None)
|
||||||
|
if refreshed:
|
||||||
|
try:
|
||||||
|
new_encrypted = encrypt_token(refreshed)
|
||||||
|
async with async_session() as db:
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
||||||
|
)
|
||||||
|
cfg_row = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg_row:
|
||||||
|
cfg_row.oauth_token_encrypted = new_encrypted
|
||||||
|
await db.commit()
|
||||||
|
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: failed to persist refreshed token for agent %s: %s", config.id, exc)
|
||||||
|
|
||||||
|
# ── 8. Finalise ────────────────────────────────────────────────────
|
||||||
|
if errors and items_created == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=items_created,
|
||||||
|
errors=errors,
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="cloud",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
|
run_id,
|
||||||
|
final_status,
|
||||||
|
items_processed,
|
||||||
|
items_created,
|
||||||
|
len(errors),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pending-run trigger ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def trigger_pending_runs(
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Dispatch any overdue agent runs after an Electron device connects.
|
||||||
|
|
||||||
|
Called as a background task from the device WS endpoint on ``device_hello``.
|
||||||
|
|
||||||
|
Scheduling rules:
|
||||||
|
|
||||||
|
* **Local agents**: only triggered when ``config.device_id == device_id``.
|
||||||
|
* **Cloud agents**: triggered on any connected device (no device binding).
|
||||||
|
* Runs execute **sequentially** to avoid flooding the WS connection.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
LocalAgentConfig.enabled == True, # noqa: E712
|
||||||
|
LocalAgentConfig.device_id == device_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
|
||||||
|
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
CloudAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
|
||||||
|
|
||||||
|
# Build ordered list of overdue (type, config) pairs.
|
||||||
|
pending: list[tuple[str, Any]] = []
|
||||||
|
for cfg in local_configs:
|
||||||
|
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||||
|
pending.append(("local", cfg))
|
||||||
|
for cfg in cloud_configs:
|
||||||
|
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||||
|
pending.append(("cloud", cfg))
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent_type, cfg in pending:
|
||||||
|
# Create a fresh run log for this scheduled dispatch.
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=cfg.id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
|
||||||
|
if agent_type == "local":
|
||||||
|
await run_local_agent(user_id, cfg, run_log, device_mgr)
|
||||||
|
else:
|
||||||
|
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _finalize_run(
|
||||||
|
run_log: AgentRunLog,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
items_processed: int = 0,
|
||||||
|
items_created: int = 0,
|
||||||
|
errors: list[str] | None = None,
|
||||||
|
update_config_last_run: bool = False,
|
||||||
|
config_id: str | None = None,
|
||||||
|
config_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Persist the run outcome and optionally update ``LocalAgentConfig.last_run_at``.
|
||||||
|
|
||||||
|
Uses a fresh DB session so this is safe to call from background tasks
|
||||||
|
after the original request session has closed.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
managed = await db.merge(run_log)
|
||||||
|
managed.status = status
|
||||||
|
managed.items_processed = items_processed
|
||||||
|
managed.items_created = items_created
|
||||||
|
managed.errors = errors or []
|
||||||
|
managed.completed_at = now
|
||||||
|
|
||||||
|
if update_config_last_run and config_id:
|
||||||
|
if config_type == "local":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
elif config_type == "cloud":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc
|
||||||
|
)
|
||||||
458
app/core/deep_agent.py
Normal file
458
app/core/deep_agent.py
Normal file
@@ -0,0 +1,458 @@
|
|||||||
|
"""Deep Agent — LangGraph hierarchical supervisors for home and floating modes.
|
||||||
|
|
||||||
|
Two supervisor graphs (both ``create_react_agent``):
|
||||||
|
* **HomeSupervisor** — gathers data from multiple domains, presents
|
||||||
|
structured overview with tool-result blocks.
|
||||||
|
* **FloatingSupervisor** — focused, scoped assistant for a single entity/domain.
|
||||||
|
|
||||||
|
Each supervisor delegates to four sub-agent tools, each a compiled
|
||||||
|
``create_react_agent`` wrapping the domain CRUD tools (task, project, note,
|
||||||
|
timeline). The sub-agents talk to Electron via ``execute_on_client``.
|
||||||
|
|
||||||
|
Streaming uses ``astream(stream_mode=["messages", "updates"])`` so that
|
||||||
|
callers can sniff:
|
||||||
|
* ``("messages", (token, metadata))`` — text tokens for streaming
|
||||||
|
* ``("updates", ...)`` — tool call results for mutations
|
||||||
|
|
||||||
|
An ``update_core_memory`` tool is available to both supervisors for
|
||||||
|
persisting user preferences mid-conversation (MemGPT-style).
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from typing import Any, AsyncGenerator
|
||||||
|
|
||||||
|
from langchain_core.messages import AIMessageChunk, HumanMessage
|
||||||
|
from langchain_core.tools import tool
|
||||||
|
from langgraph.prebuilt import create_react_agent
|
||||||
|
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.core.ws_context import (
|
||||||
|
clear_tool_result_collector,
|
||||||
|
set_tool_result_collector,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Sub-agent tool imports ────────────────────────────────────────────
|
||||||
|
|
||||||
|
from app.agents.task_agent import ( # noqa: E402
|
||||||
|
add_task_comment,
|
||||||
|
create_task,
|
||||||
|
delete_task,
|
||||||
|
delete_task_comment,
|
||||||
|
list_task_comments,
|
||||||
|
list_tasks,
|
||||||
|
list_tasks_due_today,
|
||||||
|
update_task,
|
||||||
|
)
|
||||||
|
from app.agents.note_agent import ( # noqa: E402
|
||||||
|
create_note,
|
||||||
|
delete_note,
|
||||||
|
get_note,
|
||||||
|
list_notes,
|
||||||
|
update_note,
|
||||||
|
)
|
||||||
|
from app.agents.project_agent import ( # noqa: E402
|
||||||
|
create_project,
|
||||||
|
delete_project,
|
||||||
|
get_project,
|
||||||
|
list_all_projects,
|
||||||
|
list_projects,
|
||||||
|
update_project,
|
||||||
|
)
|
||||||
|
from app.agents.timeline_agent import ( # noqa: E402
|
||||||
|
create_timeline,
|
||||||
|
delete_timeline,
|
||||||
|
list_timelines,
|
||||||
|
update_timeline,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Sub-agent definitions ─────────────────────────────────────────────
|
||||||
|
|
||||||
|
_TASK_TOOLS = [
|
||||||
|
list_tasks,
|
||||||
|
create_task,
|
||||||
|
update_task,
|
||||||
|
delete_task,
|
||||||
|
list_tasks_due_today,
|
||||||
|
list_task_comments,
|
||||||
|
add_task_comment,
|
||||||
|
delete_task_comment,
|
||||||
|
]
|
||||||
|
|
||||||
|
_NOTE_TOOLS = [list_notes, get_note, create_note, update_note, delete_note]
|
||||||
|
|
||||||
|
_PROJECT_TOOLS = [
|
||||||
|
list_projects,
|
||||||
|
list_all_projects,
|
||||||
|
get_project,
|
||||||
|
create_project,
|
||||||
|
update_project,
|
||||||
|
delete_project,
|
||||||
|
]
|
||||||
|
|
||||||
|
_TIMELINE_TOOLS = [list_timelines, create_timeline, update_timeline, delete_timeline]
|
||||||
|
|
||||||
|
|
||||||
|
def _build_subagent_tool(
|
||||||
|
name: str,
|
||||||
|
description: str,
|
||||||
|
system_prompt: str,
|
||||||
|
tools: list,
|
||||||
|
):
|
||||||
|
"""Build a compiled sub-agent graph and wrap it as a LangChain tool."""
|
||||||
|
subgraph = create_react_agent(
|
||||||
|
model=get_llm(),
|
||||||
|
tools=tools,
|
||||||
|
prompt=system_prompt,
|
||||||
|
name=name,
|
||||||
|
)
|
||||||
|
|
||||||
|
@tool(name, description=description)
|
||||||
|
async def _run(query: str) -> str:
|
||||||
|
result = await subgraph.ainvoke(
|
||||||
|
{"messages": [HumanMessage(content=query)]}
|
||||||
|
)
|
||||||
|
messages = result["messages"]
|
||||||
|
# Return the last AI message content
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None):
|
||||||
|
return str(msg.content)
|
||||||
|
return "No response from sub-agent."
|
||||||
|
|
||||||
|
return _run
|
||||||
|
|
||||||
|
|
||||||
|
def _make_subagent_tools() -> list:
|
||||||
|
"""Create the four sub-agent tools for the supervisor."""
|
||||||
|
return [
|
||||||
|
_build_subagent_tool(
|
||||||
|
name="task_agent",
|
||||||
|
description=(
|
||||||
|
"Manages tasks and comments: list, create, update, delete, "
|
||||||
|
"due-today, comments. Delegate task-related queries here."
|
||||||
|
),
|
||||||
|
system_prompt=(
|
||||||
|
"You are a task management assistant. You create, update, list, "
|
||||||
|
"and track tasks and their comments.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: todo, in_progress, done\n"
|
||||||
|
" - priority must be one of: high, medium, low\n"
|
||||||
|
" - due_date is a Unix timestamp in milliseconds\n"
|
||||||
|
" - assignees is a JSON-encoded array of strings\n"
|
||||||
|
" - is_approved defaults to 0; set to 1 only when the user confirms\n"
|
||||||
|
" - For update_task, use -1 for integer fields you do not want to change\n"
|
||||||
|
" - Always confirm the action in plain, user-friendly language."
|
||||||
|
),
|
||||||
|
tools=_TASK_TOOLS,
|
||||||
|
),
|
||||||
|
_build_subagent_tool(
|
||||||
|
name="note_agent",
|
||||||
|
description=(
|
||||||
|
"Manages notes: list, get, create, update, delete. "
|
||||||
|
"Delegate note-related queries here."
|
||||||
|
),
|
||||||
|
system_prompt=(
|
||||||
|
"You are a note-taking assistant. You help users create, retrieve, "
|
||||||
|
"update, and delete Markdown notes in their workspace.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - content is always Markdown; preserve formatting when updating\n"
|
||||||
|
" - When updating, call get_note first if you need to read existing "
|
||||||
|
"content before appending or replacing sections\n"
|
||||||
|
" - Do not fabricate note content."
|
||||||
|
),
|
||||||
|
tools=_NOTE_TOOLS,
|
||||||
|
),
|
||||||
|
_build_subagent_tool(
|
||||||
|
name="project_agent",
|
||||||
|
description=(
|
||||||
|
"Manages projects: list, get, create, update, archive, delete. "
|
||||||
|
"Delegate project-related queries here."
|
||||||
|
),
|
||||||
|
system_prompt=(
|
||||||
|
"You are a project management assistant. You help users create, "
|
||||||
|
"find, update, and archive projects.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - status must be one of: active, archived\n"
|
||||||
|
" - Prefer archiving over deletion\n"
|
||||||
|
" - ai_summary is populated only when the user asks for a summary."
|
||||||
|
),
|
||||||
|
tools=_PROJECT_TOOLS,
|
||||||
|
),
|
||||||
|
_build_subagent_tool(
|
||||||
|
name="timeline_agent",
|
||||||
|
description=(
|
||||||
|
"Manages project timelines (milestones): list, create, update, "
|
||||||
|
"delete. Delegate timeline/milestone queries here."
|
||||||
|
),
|
||||||
|
system_prompt=(
|
||||||
|
"You are a project timeline assistant. Timelines are milestone "
|
||||||
|
"dates that track progress on a project.\n\n"
|
||||||
|
"Rules:\n"
|
||||||
|
" - project_id is REQUIRED for every create\n"
|
||||||
|
" - date is a Unix timestamp in milliseconds\n"
|
||||||
|
" - For update_timeline, use -1 for integer fields you do not "
|
||||||
|
"want to change."
|
||||||
|
),
|
||||||
|
tools=_TIMELINE_TOOLS,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# ── Update core memory tool ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def _make_update_core_memory_tool(user_id: str, db_session_factory):
|
||||||
|
"""Create a tool that persists a key/value preference in core memory."""
|
||||||
|
|
||||||
|
@tool
|
||||||
|
async def update_core_memory(key: str, value: str) -> str:
|
||||||
|
"""Save a user preference or fact to long-term core memory.
|
||||||
|
key: short label for the memory (e.g. 'preferred_language', 'timezone')
|
||||||
|
value: the value to remember
|
||||||
|
Use this when the user states a preference or fact worth remembering.
|
||||||
|
"""
|
||||||
|
from app.core.memory_middleware import MemoryMiddleware
|
||||||
|
|
||||||
|
async with db_session_factory() as db:
|
||||||
|
memory = MemoryMiddleware(db)
|
||||||
|
await memory.update_core(user_id, key, value)
|
||||||
|
return f"Remembered: {key} = {value}"
|
||||||
|
|
||||||
|
return update_core_memory
|
||||||
|
|
||||||
|
|
||||||
|
# ── System prompts ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
_HOME_SYSTEM = (
|
||||||
|
"You are Adiuva, a smart workspace assistant on the Home dashboard.\n"
|
||||||
|
"Your job is to help the user by gathering data from their workspace and "
|
||||||
|
"presenting a comprehensive overview.\n\n"
|
||||||
|
"You have sub-agent tools (task_agent, note_agent, project_agent, "
|
||||||
|
"timeline_agent) that can query and mutate workspace data. Delegate to "
|
||||||
|
"the appropriate sub-agent(s) based on the user's request. You can call "
|
||||||
|
"multiple sub-agents if needed.\n\n"
|
||||||
|
"You also have an update_core_memory tool — use it when the user states "
|
||||||
|
"a preference or important fact worth remembering long-term.\n\n"
|
||||||
|
"## Entity References\n"
|
||||||
|
"When your response mentions specific workspace entities, embed them "
|
||||||
|
"inline using entity tags so the UI can render interactive components.\n"
|
||||||
|
"Format: <type>[comma-separated UUIDs]</type>\n"
|
||||||
|
"Supported types: task, project, note, timeline\n\n"
|
||||||
|
"Example response:\n"
|
||||||
|
" Here is your project:\n"
|
||||||
|
" <project>[abc-123-def]</project>\n"
|
||||||
|
" It has these pending tasks:\n"
|
||||||
|
" <task>[def-456,ghi-789]</task>\n\n"
|
||||||
|
"IMPORTANT: Only include IDs of entities that are directly relevant to "
|
||||||
|
"the user's question. Do NOT dump all entity IDs returned by a tool — "
|
||||||
|
"filter to only the ones the user asked about or that matter for the answer.\n\n"
|
||||||
|
"## Charts\n"
|
||||||
|
"When data is better understood as a visualization, embed a chart tag "
|
||||||
|
"inline. The frontend renders it using shadcn/ui Recharts components.\n"
|
||||||
|
"Format: <chart>{{JSON}}</chart>\n\n"
|
||||||
|
"JSON shape:\n"
|
||||||
|
' {{"chartType":"<type>","title":"...","data":[...],"config":{{...}}}}\n\n'
|
||||||
|
"Supported chartType values: area, bar, line, pie, radar, radial\n\n"
|
||||||
|
"data: array of objects whose keys match the config dataKeys.\n"
|
||||||
|
"config: {{ dataKey: {{ label, color }} }} — follows shadcn ChartConfig.\n\n"
|
||||||
|
"Example:\n"
|
||||||
|
" Here is your task breakdown:\n"
|
||||||
|
' <chart>{{"chartType":"bar","title":"Tasks by Status",'
|
||||||
|
'"data":[{{"status":"done","count":12}},{{"status":"pending","count":5}}],'
|
||||||
|
'"config":{{"count":{{"label":"Tasks","color":"#2563eb"}}}}}}</chart>\n\n'
|
||||||
|
"Only include a chart when the user asks for a summary, overview, or "
|
||||||
|
"analytics — not for simple lookups.\n\n"
|
||||||
|
"Memory context:\n{memory_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
_FLOATING_SYSTEM = (
|
||||||
|
"You are Adiuva, a focused workspace assistant in the floating panel.\n"
|
||||||
|
"The user is currently working in the '{scope_type}' section"
|
||||||
|
"{scope_detail}.\n\n"
|
||||||
|
"You have sub-agent tools (task_agent, note_agent, project_agent, "
|
||||||
|
"timeline_agent) that can query and mutate workspace data. Focus your "
|
||||||
|
"help on the user's current scope, but you can use other sub-agents "
|
||||||
|
"if the request requires it.\n\n"
|
||||||
|
"You also have an update_core_memory tool — use it when the user states "
|
||||||
|
"a preference or important fact worth remembering long-term.\n\n"
|
||||||
|
"Provide direct, conversational responses.\n\n"
|
||||||
|
"Memory context:\n{memory_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _format_memory_context(memory: dict[str, Any]) -> str:
|
||||||
|
"""Format the memory dict into a readable string for the system prompt."""
|
||||||
|
if not memory:
|
||||||
|
return "(no memory available)"
|
||||||
|
parts = []
|
||||||
|
if memory.get("core_memory"):
|
||||||
|
parts.append("Preferences: " + json.dumps(memory["core_memory"]))
|
||||||
|
if memory.get("associative_memory"):
|
||||||
|
parts.append("Related memories: " + "; ".join(memory["associative_memory"][:3]))
|
||||||
|
if memory.get("episodic_memory"):
|
||||||
|
parts.append("Recent sessions: " + "; ".join(memory["episodic_memory"][:3]))
|
||||||
|
if memory.get("proactive_hints"):
|
||||||
|
parts.append("Patterns: " + "; ".join(memory["proactive_hints"][:3]))
|
||||||
|
return "\n".join(parts) if parts else "(no memory available)"
|
||||||
|
|
||||||
|
|
||||||
|
# ── Graph builders ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def build_home_graph(
|
||||||
|
user_id: str,
|
||||||
|
memory_context: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
):
|
||||||
|
"""Build the Home supervisor graph."""
|
||||||
|
subagent_tools = _make_subagent_tools()
|
||||||
|
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
|
||||||
|
all_tools = subagent_tools + [memory_tool]
|
||||||
|
|
||||||
|
prompt = _HOME_SYSTEM.format(
|
||||||
|
memory_context=_format_memory_context(memory_context),
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_react_agent(
|
||||||
|
model=get_llm(),
|
||||||
|
tools=all_tools,
|
||||||
|
prompt=prompt,
|
||||||
|
name="home_supervisor",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def build_floating_graph(
|
||||||
|
user_id: str,
|
||||||
|
memory_context: dict[str, Any],
|
||||||
|
scope: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
):
|
||||||
|
"""Build the Floating supervisor graph."""
|
||||||
|
subagent_tools = _make_subagent_tools()
|
||||||
|
memory_tool = _make_update_core_memory_tool(user_id, db_session_factory)
|
||||||
|
all_tools = subagent_tools + [memory_tool]
|
||||||
|
|
||||||
|
scope_type = scope.get("type", "general")
|
||||||
|
scope_id = scope.get("id")
|
||||||
|
scope_detail = f" (id: {scope_id})" if scope_id else ""
|
||||||
|
|
||||||
|
prompt = _FLOATING_SYSTEM.format(
|
||||||
|
scope_type=scope_type,
|
||||||
|
scope_detail=scope_detail,
|
||||||
|
memory_context=_format_memory_context(memory_context),
|
||||||
|
)
|
||||||
|
|
||||||
|
return create_react_agent(
|
||||||
|
model=get_llm(),
|
||||||
|
tools=all_tools,
|
||||||
|
prompt=prompt,
|
||||||
|
name="floating_supervisor",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stream event type ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Events yielded by run_*_stream:
|
||||||
|
# ("token", str) — text token for streaming
|
||||||
|
# ("tool_start", dict) — {"name": "task_agent", "args": {...}}
|
||||||
|
# ("tool_end", dict) — {"name": "task_agent", "result": "..."}
|
||||||
|
|
||||||
|
|
||||||
|
# ── Stream runners ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _run_graph_stream(
|
||||||
|
graph,
|
||||||
|
message: str,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Run a supervisor graph with streaming, yielding event tuples.
|
||||||
|
|
||||||
|
Uses ``stream_mode=["messages", "updates"]`` to get both token-level
|
||||||
|
streaming and update events for tool calls.
|
||||||
|
"""
|
||||||
|
inputs = {"messages": [HumanMessage(content=message)]}
|
||||||
|
|
||||||
|
collector: list[dict] = []
|
||||||
|
set_tool_result_collector(collector)
|
||||||
|
try:
|
||||||
|
async for stream_mode, chunk in graph.astream(
|
||||||
|
inputs,
|
||||||
|
stream_mode=["messages", "updates"],
|
||||||
|
):
|
||||||
|
if stream_mode == "messages":
|
||||||
|
msg, metadata = chunk
|
||||||
|
# Only yield tokens from the supervisor's final response
|
||||||
|
# (not from sub-agent internal LLM calls)
|
||||||
|
if (
|
||||||
|
isinstance(msg, AIMessageChunk)
|
||||||
|
and msg.content
|
||||||
|
and not msg.tool_calls
|
||||||
|
and isinstance(metadata, dict)
|
||||||
|
and metadata.get("langgraph_node") == "agent"
|
||||||
|
):
|
||||||
|
yield ("token", str(msg.content))
|
||||||
|
|
||||||
|
elif stream_mode == "updates":
|
||||||
|
# Updates is a dict of {node_name: state_update}
|
||||||
|
if not isinstance(chunk, dict):
|
||||||
|
continue
|
||||||
|
for node_name, state_update in chunk.items():
|
||||||
|
if node_name != "tools":
|
||||||
|
continue
|
||||||
|
# Tool node executed — extract tool call results
|
||||||
|
tool_messages = state_update.get("messages", [])
|
||||||
|
for tool_msg in tool_messages:
|
||||||
|
if hasattr(tool_msg, "name") and hasattr(tool_msg, "content"):
|
||||||
|
yield (
|
||||||
|
"tool_end",
|
||||||
|
{"name": tool_msg.name, "result": str(tool_msg.content)},
|
||||||
|
)
|
||||||
|
finally:
|
||||||
|
clear_tool_result_collector()
|
||||||
|
|
||||||
|
# Yield the collected mutations so callers can attach them to stream_end
|
||||||
|
yield ("mutations", collector)
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Run the Home supervisor and yield streaming events."""
|
||||||
|
graph = build_home_graph(user_id, context, db_session_factory)
|
||||||
|
async for event in _run_graph_stream(graph, message):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def run_floating_stream(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
scope: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
) -> AsyncGenerator[tuple[str, Any], None]:
|
||||||
|
"""Run the Floating supervisor and yield streaming events."""
|
||||||
|
graph = build_floating_graph(user_id, context, scope, db_session_factory)
|
||||||
|
async for event in _run_graph_stream(graph, message):
|
||||||
|
yield event
|
||||||
|
|
||||||
|
|
||||||
|
async def run_home(
|
||||||
|
user_id: str,
|
||||||
|
message: str,
|
||||||
|
context: dict[str, Any],
|
||||||
|
db_session_factory,
|
||||||
|
) -> str:
|
||||||
|
"""Run the Home supervisor (non-streaming) and return full response text."""
|
||||||
|
graph = build_home_graph(user_id, context, db_session_factory)
|
||||||
|
result = await graph.ainvoke(
|
||||||
|
{"messages": [HumanMessage(content=message)]}
|
||||||
|
)
|
||||||
|
messages = result["messages"]
|
||||||
|
for msg in reversed(messages):
|
||||||
|
if hasattr(msg, "content") and msg.content and not getattr(msg, "tool_calls", None):
|
||||||
|
return str(msg.content)
|
||||||
|
return ""
|
||||||
183
app/core/device_manager.py
Normal file
183
app/core/device_manager.py
Normal file
@@ -0,0 +1,183 @@
|
|||||||
|
"""Device connection manager.
|
||||||
|
|
||||||
|
Maintains in-memory state for all active Electron → backend WebSocket
|
||||||
|
connections. One connection per user (latest replaces previous).
|
||||||
|
|
||||||
|
The manager participates in two interaction patterns:
|
||||||
|
|
||||||
|
1. **Tool-call round-trip** (bidirectional CRUD):
|
||||||
|
- Backend sends ``tool_call`` frame → Electron executes CRUD → returns
|
||||||
|
``tool_result`` frame.
|
||||||
|
- ``create_pending_call`` registers a Future keyed by ``call_id``.
|
||||||
|
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
|
||||||
|
receive the result dict from Electron.
|
||||||
|
|
||||||
|
2. **Agent-data streaming** (local directory agent runs):
|
||||||
|
- Backend sends ``agent_run`` frame → Electron reads files and sends
|
||||||
|
back a stream of ``agent_data`` frames followed by ``agent_complete``.
|
||||||
|
- ``get_agent_data_queue`` returns (or creates) an asyncio.Queue for
|
||||||
|
a specific ``run_id`` so the agent runner can iterate frames.
|
||||||
|
|
||||||
|
The ``device_manager`` module-level singleton is imported by both the
|
||||||
|
device WS route and the agent runner.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
|
||||||
|
from fastapi import WebSocket
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DeviceConnection:
|
||||||
|
"""State for a single connected Electron device."""
|
||||||
|
|
||||||
|
ws: WebSocket
|
||||||
|
device_id: str
|
||||||
|
# Futures indexed by tool_call id — resolved when tool_result arrives.
|
||||||
|
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
||||||
|
# Per-run queues for agent_data / agent_complete frames.
|
||||||
|
agent_data_queues: dict[str, asyncio.Queue[dict | None]] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class DeviceConnectionManager:
|
||||||
|
"""Singleton registry of active Electron WebSocket connections.
|
||||||
|
|
||||||
|
Thread/task safety note: asyncio is single-threaded by design. All
|
||||||
|
mutations happen inside await-points on the main event loop, so no
|
||||||
|
locking is required for the in-memory dicts.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._connections: dict[str, DeviceConnection] = {}
|
||||||
|
|
||||||
|
# ── Registration ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
|
||||||
|
"""Store the active connection for *user_id*, replacing any previous one."""
|
||||||
|
if user_id in self._connections:
|
||||||
|
old = self._connections[user_id]
|
||||||
|
logger.info(
|
||||||
|
"device_manager: replacing existing connection for user=%s device=%s",
|
||||||
|
user_id,
|
||||||
|
old.device_id,
|
||||||
|
)
|
||||||
|
# Cancel any futures that were waiting on the old connection.
|
||||||
|
for fut in old.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
|
||||||
|
logger.info(
|
||||||
|
"device_manager: registered user=%s device=%s", user_id, device_id
|
||||||
|
)
|
||||||
|
|
||||||
|
def unregister(self, user_id: str) -> None:
|
||||||
|
"""Remove the connection for *user_id* and cancel any pending futures."""
|
||||||
|
conn = self._connections.pop(user_id, None)
|
||||||
|
if conn is None:
|
||||||
|
return
|
||||||
|
for fut in conn.pending_calls.values():
|
||||||
|
if not fut.done():
|
||||||
|
fut.cancel()
|
||||||
|
logger.info("device_manager: unregistered user=%s", user_id)
|
||||||
|
|
||||||
|
# ── Presence queries ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_ws(self, user_id: str) -> WebSocket | None:
|
||||||
|
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
return conn.ws if conn else None
|
||||||
|
|
||||||
|
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
|
||||||
|
"""Return ``True`` if the user has an active connection.
|
||||||
|
|
||||||
|
If *device_id* is provided also checks that it matches the connected device.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
return False
|
||||||
|
if device_id is not None:
|
||||||
|
return conn.device_id == device_id
|
||||||
|
return True
|
||||||
|
|
||||||
|
# ── Frame sending ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def send_frame(self, user_id: str, frame: dict) -> None:
|
||||||
|
"""Send *frame* as a JSON text message to the device.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if the user is not connected.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"send_frame: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
await conn.ws.send_text(json.dumps(frame))
|
||||||
|
|
||||||
|
# ── Tool-call round-trip ──────────────────────────────────────────
|
||||||
|
|
||||||
|
def create_pending_call(
|
||||||
|
self, user_id: str, call_id: str
|
||||||
|
) -> asyncio.Future[dict]:
|
||||||
|
"""Register a Future that will be resolved when the tool_result arrives.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if the user is not connected.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"create_pending_call: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
fut: asyncio.Future[dict] = loop.create_future()
|
||||||
|
conn.pending_calls[call_id] = fut
|
||||||
|
return fut
|
||||||
|
|
||||||
|
def resolve_pending_call(
|
||||||
|
self, user_id: str, call_id: str, result: dict
|
||||||
|
) -> None:
|
||||||
|
"""Fulfil the Future registered under *call_id* with the Electron result.
|
||||||
|
|
||||||
|
No-ops if the call_id is unknown (already timed out or cancelled).
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
return
|
||||||
|
fut = conn.pending_calls.pop(call_id, None)
|
||||||
|
if fut is not None and not fut.done():
|
||||||
|
fut.set_result(result)
|
||||||
|
|
||||||
|
# ── Agent-data queue ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def get_agent_data_queue(
|
||||||
|
self, user_id: str, run_id: str
|
||||||
|
) -> asyncio.Queue[dict | None]:
|
||||||
|
"""Return (creating if absent) the queue for *run_id* agent frames.
|
||||||
|
|
||||||
|
The agent runner reads from this queue. The device WS handler writes
|
||||||
|
to it. ``None`` is the sentinel that signals the stream is finished.
|
||||||
|
"""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"get_agent_data_queue: user {user_id!r} is not connected"
|
||||||
|
)
|
||||||
|
if run_id not in conn.agent_data_queues:
|
||||||
|
conn.agent_data_queues[run_id] = asyncio.Queue()
|
||||||
|
return conn.agent_data_queues[run_id]
|
||||||
|
|
||||||
|
def cleanup_agent_data_queue(self, user_id: str, run_id: str) -> None:
|
||||||
|
"""Remove the queue for *run_id* once a run has completed."""
|
||||||
|
conn = self._connections.get(user_id)
|
||||||
|
if conn:
|
||||||
|
conn.agent_data_queues.pop(run_id, None)
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton — import this everywhere.
|
||||||
|
device_manager = DeviceConnectionManager()
|
||||||
116
app/core/llm.py
Normal file
116
app/core/llm.py
Normal file
@@ -0,0 +1,116 @@
|
|||||||
|
"""LLM factory — centralised model instantiation via LiteLLM.
|
||||||
|
|
||||||
|
Every agent and the deep-agent supervisors call ``get_llm()`` or ``get_router_llm()``
|
||||||
|
instead of directly constructing a provider-specific class. The model string
|
||||||
|
follows the `LiteLLM model naming convention
|
||||||
|
<https://docs.litellm.ai/docs/providers>`_:
|
||||||
|
|
||||||
|
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
|
||||||
|
* Anthropic: ``anthropic/claude-3.5-sonnet``
|
||||||
|
* Google: ``gemini/gemini-pro``
|
||||||
|
* Ollama: ``ollama/llama3``
|
||||||
|
* Bedrock: ``bedrock/anthropic.claude-v2``
|
||||||
|
|
||||||
|
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
|
||||||
|
— no code changes required.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import litellm
|
||||||
|
|
||||||
|
from langchain_openai import ChatOpenAI
|
||||||
|
from langchain_litellm import ChatLiteLLM
|
||||||
|
from litellm import get_supported_openai_params # noqa: F401 – validates install
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
|
||||||
|
# Drop them silently instead of raising UnsupportedParamsError.
|
||||||
|
litellm.drop_params = True
|
||||||
|
|
||||||
|
|
||||||
|
def _api_key_for_model(model: str) -> str | None:
|
||||||
|
"""Return the most appropriate API key for the given LiteLLM model string."""
|
||||||
|
if model.startswith("anthropic/"):
|
||||||
|
return settings.ANTHROPIC_API_KEY or None
|
||||||
|
if model.startswith("gemini/") or model.startswith("google/"):
|
||||||
|
return settings.GOOGLE_API_KEY or None
|
||||||
|
if model.startswith("cerebras/"):
|
||||||
|
return settings.CEREBRAS_API_KEY or None
|
||||||
|
if model.startswith("github_copilot/"):
|
||||||
|
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
|
||||||
|
# No API key is required; returning None lets LiteLLM handle auth.
|
||||||
|
return None
|
||||||
|
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
|
||||||
|
return settings.OPENAI_API_KEY or None
|
||||||
|
|
||||||
|
|
||||||
|
def get_llm(
|
||||||
|
*,
|
||||||
|
model: str | None = None,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
"""Return a LangChain chat model backed by LiteLLM.
|
||||||
|
|
||||||
|
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
|
||||||
|
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
|
||||||
|
``openai`` client transparently when the model string contains a provider
|
||||||
|
prefix (``anthropic/…``, ``gemini/…``, etc.).
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model:
|
||||||
|
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
|
||||||
|
temperature:
|
||||||
|
Sampling temperature. ``0`` = deterministic.
|
||||||
|
"""
|
||||||
|
model = model or settings.LLM_MODEL
|
||||||
|
|
||||||
|
# Point LiteLLM to the custom token directory when configured.
|
||||||
|
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
||||||
|
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
||||||
|
|
||||||
|
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
|
||||||
|
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
|
||||||
|
if "/" in model:
|
||||||
|
return ChatLiteLLM(model=model, temperature=temperature)
|
||||||
|
|
||||||
|
return ChatOpenAI(
|
||||||
|
model=model,
|
||||||
|
temperature=temperature,
|
||||||
|
api_key=_api_key_for_model(model),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def get_router_llm(
|
||||||
|
*,
|
||||||
|
temperature: float = 0,
|
||||||
|
) -> ChatOpenAI | ChatLiteLLM:
|
||||||
|
"""Return the lighter model used for intent classification / routing."""
|
||||||
|
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
|
||||||
|
|
||||||
|
|
||||||
|
async def embed(text: str) -> list[float]:
|
||||||
|
"""Return an embedding vector for *text*.
|
||||||
|
|
||||||
|
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
|
||||||
|
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
|
||||||
|
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
|
||||||
|
model names to preserve existing behaviour.
|
||||||
|
"""
|
||||||
|
model = settings.LLM_EMBED_MODEL
|
||||||
|
|
||||||
|
if model.startswith("github_copilot/") or "/" in model:
|
||||||
|
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
|
||||||
|
# so the provider's auth mechanism is applied correctly.
|
||||||
|
response = await litellm.aembedding(model=model, input=[text])
|
||||||
|
return response.data[0]["embedding"]
|
||||||
|
|
||||||
|
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
|
||||||
|
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
||||||
|
response = await client.embeddings.create(model=model, input=text)
|
||||||
|
return response.data[0].embedding
|
||||||
231
app/core/memory_middleware.py
Normal file
231
app/core/memory_middleware.py
Normal file
@@ -0,0 +1,231 @@
|
|||||||
|
"""Memory Middleware — enrich requests with memory context and store interactions.
|
||||||
|
|
||||||
|
Four-tier memory model (MemGPT-style):
|
||||||
|
core — persistent key/value user preferences, always injected
|
||||||
|
associative — semantic similarity search via pgvector (top-k)
|
||||||
|
episodic — recent session summaries (last N)
|
||||||
|
proactive — behavioral patterns above confidence threshold
|
||||||
|
|
||||||
|
All memory content is encrypted at rest using the per-user Fernet key
|
||||||
|
stored in User.encryption_key. Decryption happens in-memory only.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
memory = MemoryMiddleware(db_session)
|
||||||
|
context = await memory.enrich_context(user_id, message)
|
||||||
|
# ... run agent ...
|
||||||
|
await memory.store_episode(user_id, session_id, message, response)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import (
|
||||||
|
MemoryAssociative,
|
||||||
|
MemoryCore,
|
||||||
|
MemoryEpisodic,
|
||||||
|
MemoryProactive,
|
||||||
|
User,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Tuning constants
|
||||||
|
_ASSOCIATIVE_TOP_K = 5
|
||||||
|
_EPISODIC_RECENT_N = 10
|
||||||
|
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
|
||||||
|
|
||||||
|
|
||||||
|
class MemoryMiddleware:
|
||||||
|
"""Enrich agent context with memory and persist interactions after."""
|
||||||
|
|
||||||
|
def __init__(self, db: AsyncSession) -> None:
|
||||||
|
self._db = db
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def enrich_context(self, user_id: str, message: str) -> dict[str, Any]:
|
||||||
|
"""Build memory context dict to inject into the agent before LLM call.
|
||||||
|
|
||||||
|
Returns a dict with keys:
|
||||||
|
core_memory — {key: plaintext_value, ...}
|
||||||
|
associative_memory — [plaintext_content, ...] (top-k by keyword match)
|
||||||
|
episodic_memory — [plaintext_summary, ...] (most recent N)
|
||||||
|
proactive_hints — [plaintext_pattern, ...] (above threshold)
|
||||||
|
"""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
core = await self._load_core(user_id, fernet)
|
||||||
|
associative = await self._load_associative(user_id, message, fernet)
|
||||||
|
episodic = await self._load_episodic(user_id, fernet)
|
||||||
|
proactive = await self._load_proactive(user_id, fernet)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"core_memory": core,
|
||||||
|
"associative_memory": associative,
|
||||||
|
"episodic_memory": episodic,
|
||||||
|
"proactive_hints": proactive,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def store_episode(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
session_id: str,
|
||||||
|
message: str,
|
||||||
|
response: str,
|
||||||
|
) -> None:
|
||||||
|
"""Summarise and store a completed interaction in episodic memory.
|
||||||
|
|
||||||
|
The summary is a simple heuristic concatenation (no LLM call) to keep
|
||||||
|
latency low. Full LLM summarisation can be added in a later step.
|
||||||
|
"""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
summary = f"User: {message[:200]}\nAssistant: {response[:200]}"
|
||||||
|
encrypted = _encrypt(fernet, summary)
|
||||||
|
|
||||||
|
row = MemoryEpisodic(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
summary_encrypted=encrypted,
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
self._db.add(row)
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
async def update_core(self, user_id: str, key: str, value: str) -> None:
|
||||||
|
"""Upsert a core memory key/value for a user."""
|
||||||
|
fernet = await self._get_fernet(user_id)
|
||||||
|
if fernet is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
encrypted = _encrypt(fernet, value)
|
||||||
|
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(
|
||||||
|
MemoryCore.user_id == user_id,
|
||||||
|
MemoryCore.key == key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
existing = result.scalar_one_or_none()
|
||||||
|
if existing is not None:
|
||||||
|
existing.value_encrypted = encrypted
|
||||||
|
else:
|
||||||
|
self._db.add(MemoryCore(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
key=key,
|
||||||
|
value_encrypted=encrypted,
|
||||||
|
))
|
||||||
|
try:
|
||||||
|
await self._db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
|
||||||
|
await self._db.rollback()
|
||||||
|
|
||||||
|
# ── Private helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _get_fernet(self, user_id: str) -> Fernet | None:
|
||||||
|
"""Load the user's Fernet key from DB. Returns None if missing."""
|
||||||
|
result = await self._db.execute(select(User).where(User.id == user_id))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
|
if user is None or not user.encryption_key:
|
||||||
|
logger.warning("memory: no encryption_key for user=%s", user_id)
|
||||||
|
return None
|
||||||
|
return Fernet(user.encryption_key.encode())
|
||||||
|
|
||||||
|
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryCore).where(MemoryCore.user_id == user_id)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: dict[str, str] = {}
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.value_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out[row.key] = plaintext
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_associative(
|
||||||
|
self, user_id: str, message: str, fernet: Fernet
|
||||||
|
) -> list[str]:
|
||||||
|
"""Load top-k associative memories.
|
||||||
|
|
||||||
|
Production: uses pgvector cosine similarity on the message embedding.
|
||||||
|
Current implementation: keyword-based fallback (no external embedding call)
|
||||||
|
so tests pass without a live OpenAI key.
|
||||||
|
"""
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryAssociative)
|
||||||
|
.where(MemoryAssociative.user_id == user_id)
|
||||||
|
.order_by(MemoryAssociative.updated_at.desc())
|
||||||
|
.limit(_ASSOCIATIVE_TOP_K)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.content_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_episodic(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryEpisodic)
|
||||||
|
.where(MemoryEpisodic.user_id == user_id)
|
||||||
|
.order_by(MemoryEpisodic.created_at.desc())
|
||||||
|
.limit(_EPISODIC_RECENT_N)
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.summary_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
|
||||||
|
result = await self._db.execute(
|
||||||
|
select(MemoryProactive)
|
||||||
|
.where(
|
||||||
|
MemoryProactive.user_id == user_id,
|
||||||
|
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
|
||||||
|
)
|
||||||
|
.order_by(MemoryProactive.confidence.desc())
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
out: list[str] = []
|
||||||
|
for row in rows:
|
||||||
|
plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
|
||||||
|
if plaintext is not None:
|
||||||
|
out.append(plaintext)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
# ── Encryption helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _encrypt(fernet: Fernet, plaintext: str) -> str:
|
||||||
|
return fernet.encrypt(plaintext.encode()).decode()
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
|
||||||
|
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
|
||||||
|
try:
|
||||||
|
return fernet.decrypt(ciphertext.encode()).decode()
|
||||||
|
except (InvalidToken, Exception) as exc:
|
||||||
|
logger.warning("memory: decrypt failed: %s", exc)
|
||||||
|
return None
|
||||||
141
app/core/output_formatter.py
Normal file
141
app/core/output_formatter.py
Normal file
@@ -0,0 +1,141 @@
|
|||||||
|
"""Output Formatter — transforms deep-agent event streams into WS frame sequences.
|
||||||
|
|
||||||
|
Consumes ``(event_type, data)`` tuples yielded by ``deep_agent.run_*_stream()``:
|
||||||
|
* ``("token", str)`` — supervisor text token
|
||||||
|
* ``("tool_end", dict)`` — sub-agent finished: ``{name, result}``
|
||||||
|
* ``("mutations", list)`` — collected CRUD mutations for ``stream_end``
|
||||||
|
|
||||||
|
HomeFormatter:
|
||||||
|
* Streams text tokens as-is → emits ``WsStreamText``
|
||||||
|
(text may contain inline ``<type>[id,...]</type>`` entity tags
|
||||||
|
for the frontend to parse and render as interactive components)
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
|
||||||
|
FloatingFormatter:
|
||||||
|
* Sniffs first ``tool_end`` name → emits ``WsFloatingDomain``
|
||||||
|
* Streams text tokens → emits ``WsStreamText``
|
||||||
|
* Attaches mutations → injects into ``WsStreamEnd``
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from collections.abc import AsyncGenerator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from app.schemas import (
|
||||||
|
WsFloatingDomain,
|
||||||
|
WsStreamEnd,
|
||||||
|
WsStreamStart,
|
||||||
|
WsStreamText,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Map sub-agent tool name → floating domain / entity type
|
||||||
|
_AGENT_DOMAIN: dict[str, str] = {
|
||||||
|
"task_agent": "tasks",
|
||||||
|
"timeline_agent": "timelines",
|
||||||
|
"note_agent": "notes",
|
||||||
|
"project_agent": "projects",
|
||||||
|
}
|
||||||
|
|
||||||
|
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
|
||||||
|
|
||||||
|
|
||||||
|
class HomeFormatter:
|
||||||
|
"""Consumes a deep-agent event stream and yields WS frames for the Home view.
|
||||||
|
|
||||||
|
Text tokens are forwarded as-is via ``WsStreamText``. The supervisor
|
||||||
|
embeds ``<type>[id1,id2]</type>`` entity tags inline — the frontend
|
||||||
|
is responsible for parsing those and rendering interactive components.
|
||||||
|
Mutations are attached to ``WsStreamEnd``.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "token":
|
||||||
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class FloatingFormatter:
|
||||||
|
"""Consumes a deep-agent event stream and yields WS frames for the Floating view.
|
||||||
|
|
||||||
|
Sniffs the first ``tool_end`` event name to derive the domain (e.g.
|
||||||
|
``task_agent`` → ``"tasks"``), then streams text tokens as plain
|
||||||
|
``WsStreamText``. No block parsing for floating context.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, request_id: str) -> None:
|
||||||
|
self.request_id = request_id
|
||||||
|
self._mutations: list[dict] = []
|
||||||
|
|
||||||
|
async def format(
|
||||||
|
self,
|
||||||
|
event_stream: AsyncGenerator[tuple[str, Any], None],
|
||||||
|
) -> AsyncGenerator[WsFrame, None]:
|
||||||
|
domain_sent = False
|
||||||
|
|
||||||
|
async for event_type, data in event_stream:
|
||||||
|
if event_type == "tool_end" and not domain_sent:
|
||||||
|
# Sniff domain from the first sub-agent that completes
|
||||||
|
name = data.get("name", "")
|
||||||
|
domain = _AGENT_DOMAIN.get(name, "tasks")
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain=domain, # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
|
||||||
|
elif event_type == "token":
|
||||||
|
if not domain_sent:
|
||||||
|
# First token arrived before any tool_end — default domain
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
domain_sent = True
|
||||||
|
if data:
|
||||||
|
yield WsStreamText(request_id=self.request_id, chunk=data)
|
||||||
|
|
||||||
|
elif event_type == "mutations":
|
||||||
|
self._mutations = data or []
|
||||||
|
|
||||||
|
# If no events triggered domain_sent (edge case), still emit structure
|
||||||
|
if not domain_sent:
|
||||||
|
yield WsFloatingDomain(
|
||||||
|
request_id=self.request_id,
|
||||||
|
domain="tasks", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
yield WsStreamStart(request_id=self.request_id)
|
||||||
|
|
||||||
|
yield WsStreamEnd(
|
||||||
|
request_id=self.request_id,
|
||||||
|
mutations=[
|
||||||
|
{"action": m["action"], "table": m["table"], "data": m["data"]}
|
||||||
|
for m in self._mutations
|
||||||
|
],
|
||||||
|
)
|
||||||
100
app/core/ws_context.py
Normal file
100
app/core/ws_context.py
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
"""WebSocket client executor context.
|
||||||
|
|
||||||
|
Holds a per-request async callback that tools call to execute CRUD
|
||||||
|
operations on the Electron client's local SQLite / LanceDB databases.
|
||||||
|
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from contextvars import ContextVar
|
||||||
|
from typing import Any, Callable, Coroutine
|
||||||
|
from uuid import uuid4
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Holds the execute callback for the current WS session.
|
||||||
|
# Set by the chat WS handler before the deep agent runs; cleared after.
|
||||||
|
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
|
||||||
|
"_client_executor"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Optional collector that captures raw execute_on_client results.
|
||||||
|
# Set by the deep agent tool loop to capture CRUD mutations.
|
||||||
|
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
|
||||||
|
"_tool_result_collector", default=None
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_tool_result_collector(lst: list[dict]) -> None:
|
||||||
|
"""Register *lst* as the collector for this async context."""
|
||||||
|
_tool_result_collector.set(lst)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_tool_result_collector() -> None:
|
||||||
|
"""Clear the collector (best-effort)."""
|
||||||
|
_tool_result_collector.set(None)
|
||||||
|
|
||||||
|
|
||||||
|
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
|
||||||
|
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
|
||||||
|
_client_executor.set(fn)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_client_executor() -> None:
|
||||||
|
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
|
||||||
|
try:
|
||||||
|
_client_executor.set(None) # type: ignore[arg-type]
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
async def execute_on_client(
|
||||||
|
action: str,
|
||||||
|
table: str | None = None,
|
||||||
|
data: dict[str, Any] | None = None,
|
||||||
|
filters: dict[str, Any] | None = None,
|
||||||
|
vector: list[float] | None = None,
|
||||||
|
limit: int | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send a CRUD/vector operation to the Electron client and return the result.
|
||||||
|
|
||||||
|
Builds a ``tool_call`` payload, invokes the per-session WS callback,
|
||||||
|
and returns the ``tool_result`` dict from Electron.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
|
||||||
|
"""
|
||||||
|
callback = _client_executor.get(None)
|
||||||
|
if callback is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"execute_on_client() called outside a WebSocket session — "
|
||||||
|
"no client executor is set."
|
||||||
|
)
|
||||||
|
|
||||||
|
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
|
||||||
|
if table is not None:
|
||||||
|
payload["table"] = table
|
||||||
|
if data is not None:
|
||||||
|
payload["data"] = data
|
||||||
|
if filters is not None:
|
||||||
|
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
|
||||||
|
if vector is not None:
|
||||||
|
payload["vector"] = vector
|
||||||
|
if limit is not None:
|
||||||
|
payload["limit"] = limit
|
||||||
|
|
||||||
|
logger.info("execute_on_client: sending payload action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
|
result = await callback(payload)
|
||||||
|
if result is None:
|
||||||
|
logger.error("execute_on_client: callback returned None for action=%s table=%s id=%s", action, table, payload["id"])
|
||||||
|
else:
|
||||||
|
logger.info("execute_on_client: got result type=%s keys=%s", type(result).__name__, list(result.keys()) if isinstance(result, dict) else "N/A")
|
||||||
|
collector = _tool_result_collector.get(None)
|
||||||
|
if collector is not None:
|
||||||
|
collector.append({
|
||||||
|
"action": action,
|
||||||
|
"table": table,
|
||||||
|
"data": result,
|
||||||
|
})
|
||||||
|
return result
|
||||||
@@ -1,7 +1,15 @@
|
|||||||
"""Database engine, session factory, and declarative base.
|
"""Database engine, session factory, and base model.
|
||||||
|
|
||||||
All services use the async SQLAlchemy API via ``get_session()``.
|
All app code uses the async SQLAlchemy API. Alembic migrations use the
|
||||||
Alembic migrations use the synchronous psycopg2 URL (see alembic/env.py).
|
synchronous psycopg2 URL for the CLI (see alembic/env.py).
|
||||||
|
|
||||||
|
Usage in routes:
|
||||||
|
from app.db import get_session
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
async def my_route(db: AsyncSession = Depends(get_session)):
|
||||||
|
result = await db.execute(select(User).where(User.email == email))
|
||||||
|
user = result.scalar_one_or_none()
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -11,7 +19,7 @@ from collections.abc import AsyncGenerator
|
|||||||
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
from shared.config import settings
|
from app.config.settings import settings
|
||||||
|
|
||||||
engine = create_async_engine(
|
engine = create_async_engine(
|
||||||
settings.DATABASE_URL,
|
settings.DATABASE_URL,
|
||||||
@@ -1,11 +1,20 @@
|
|||||||
"""Cloud provider integration utilities.
|
"""Cloud provider integration utilities.
|
||||||
|
|
||||||
Adapted for Batch Agent Service: import from shared.config instead of app.config.
|
|
||||||
|
|
||||||
Provides:
|
Provides:
|
||||||
* Shared message dataclasses (EmailMessage, ChatMessage)
|
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
|
||||||
* get_provider() — factory for Gmail/MS Graph clients
|
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
|
||||||
* encrypt_token() / decrypt_token() — Fernet-based OAuth token encryption
|
* ``get_provider()`` — factory that returns the correct client given a
|
||||||
|
provider name and decrypted OAuth credentials dict.
|
||||||
|
* ``encrypt_token()`` / ``decrypt_token()`` — Fernet-based at-rest
|
||||||
|
encryption for OAuth tokens stored in ``cloud_agent_configs``.
|
||||||
|
|
||||||
|
Encryption rationale
|
||||||
|
--------------------
|
||||||
|
Unlike user content (which is E2E-encrypted client-side and **never**
|
||||||
|
decrypted server-side), OAuth tokens *must* be decrypted server-side
|
||||||
|
because the backend makes provider API calls on behalf of the user.
|
||||||
|
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var — it
|
||||||
|
is never returned to clients.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -18,7 +27,7 @@ from typing import TYPE_CHECKING
|
|||||||
|
|
||||||
from cryptography.fernet import Fernet, InvalidToken
|
from cryptography.fernet import Fernet, InvalidToken
|
||||||
|
|
||||||
from shared.config import settings
|
from app.config.settings import settings
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from app.integrations.gmail import GmailClient
|
from app.integrations.gmail import GmailClient
|
||||||
@@ -26,9 +35,13 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Shared message types ──────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class EmailMessage:
|
class EmailMessage:
|
||||||
|
"""A single email message fetched from Gmail or Outlook."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
subject: str
|
subject: str
|
||||||
sender: str
|
sender: str
|
||||||
@@ -38,6 +51,7 @@ class EmailMessage:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def as_text(self) -> str:
|
def as_text(self) -> str:
|
||||||
|
"""Return a human-readable text representation for LLM extraction."""
|
||||||
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
|
||||||
return (
|
return (
|
||||||
@@ -50,6 +64,8 @@ class EmailMessage:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ChatMessage:
|
class ChatMessage:
|
||||||
|
"""A single Teams chat or channel message fetched from MS Graph."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
content: str
|
content: str
|
||||||
sender: str
|
sender: str
|
||||||
@@ -58,6 +74,7 @@ class ChatMessage:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def as_text(self) -> str:
|
def as_text(self) -> str:
|
||||||
|
"""Return a human-readable text representation for LLM extraction."""
|
||||||
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
date_str = self.date.strftime("%Y-%m-%d %H:%M")
|
||||||
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
channel_str = f" [channel: {self.channel}]" if self.channel else ""
|
||||||
return (
|
return (
|
||||||
@@ -67,7 +84,15 @@ class ChatMessage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Fernet helpers ────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def _get_fernet() -> Fernet:
|
def _get_fernet() -> Fernet:
|
||||||
|
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
|
||||||
|
|
||||||
|
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set — callers
|
||||||
|
must ensure this is configured before persisting OAuth tokens.
|
||||||
|
"""
|
||||||
key = settings.OAUTH_ENCRYPTION_KEY
|
key = settings.OAUTH_ENCRYPTION_KEY
|
||||||
if not key:
|
if not key:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@@ -78,6 +103,15 @@ def _get_fernet() -> Fernet:
|
|||||||
|
|
||||||
|
|
||||||
def encrypt_token(token_info: dict) -> str:
|
def encrypt_token(token_info: dict) -> str:
|
||||||
|
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
|
||||||
|
|
||||||
|
Stores the full ``{access_token, refresh_token, token_uri, client_id,
|
||||||
|
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||||
|
ValueError: ``token_info`` is not a non-empty dict.
|
||||||
|
"""
|
||||||
if not isinstance(token_info, dict) or not token_info:
|
if not isinstance(token_info, dict) or not token_info:
|
||||||
raise ValueError("token_info must be a non-empty dict")
|
raise ValueError("token_info must be a non-empty dict")
|
||||||
plaintext = json.dumps(token_info).encode("utf-8")
|
plaintext = json.dumps(token_info).encode("utf-8")
|
||||||
@@ -85,6 +119,13 @@ def encrypt_token(token_info: dict) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def decrypt_token(encrypted: str) -> dict:
|
def decrypt_token(encrypted: str) -> dict:
|
||||||
|
"""Decrypt a Fernet-encrypted token string and return the credential dict.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
|
||||||
|
ValueError: The encrypted string is invalid or was encrypted with a
|
||||||
|
different key.
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
|
||||||
return json.loads(plaintext)
|
return json.loads(plaintext)
|
||||||
@@ -92,10 +133,25 @@ def decrypt_token(encrypted: str) -> dict:
|
|||||||
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
# ── Provider factory ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
def get_provider(
|
def get_provider(
|
||||||
provider: str,
|
provider: str,
|
||||||
credentials_info: dict,
|
credentials_info: dict,
|
||||||
) -> "GmailClient | MSGraphClient":
|
) -> "GmailClient | MSGraphClient":
|
||||||
|
"""Return the correct provider client for *provider*.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
provider:
|
||||||
|
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
|
||||||
|
credentials_info:
|
||||||
|
Decrypted OAuth credential dict (Google or Microsoft shape).
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: Unknown provider name.
|
||||||
|
"""
|
||||||
if provider == "gmail":
|
if provider == "gmail":
|
||||||
from app.integrations.gmail import GmailClient
|
from app.integrations.gmail import GmailClient
|
||||||
return GmailClient(credentials_info)
|
return GmailClient(credentials_info)
|
||||||
@@ -1,7 +1,26 @@
|
|||||||
"""Gmail API client for cloud agent integration.
|
"""Gmail API client for cloud agent integration.
|
||||||
|
|
||||||
Adapted for Batch Agent Service: import from app.integrations instead of
|
Wraps the Google Gmail REST API to fetch email messages matching a
|
||||||
app.integrations (same relative path within the service).
|
``filter_config`` dict. Uses the official ``google-api-python-client``
|
||||||
|
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
|
||||||
|
blocking the event loop.
|
||||||
|
|
||||||
|
Token refresh is handled transparently: when the stored access token has
|
||||||
|
expired, ``google.auth.transport.requests.Request`` will use the refresh
|
||||||
|
token to obtain a fresh one. The caller is responsible for persisting
|
||||||
|
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
|
||||||
|
(see ``agent_runner.run_cloud_agent``).
|
||||||
|
|
||||||
|
Credential dict shape (Google OAuth2):
|
||||||
|
{
|
||||||
|
"token": "<access_token>",
|
||||||
|
"refresh_token": "<refresh_token>",
|
||||||
|
"token_uri": "https://oauth2.googleapis.com/token",
|
||||||
|
"client_id": "<client_id>",
|
||||||
|
"client_secret": "<client_secret>",
|
||||||
|
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
|
||||||
|
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -19,8 +38,13 @@ from app.integrations import EmailMessage
|
|||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Gmail search date format — e.g. "after:2025/01/01"
|
||||||
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
_GMAIL_DATE_FMT = "%Y/%m/%d"
|
||||||
|
|
||||||
|
# Maximum characters of body text forwarded to the LLM.
|
||||||
_BODY_TRUNCATE = 8_000
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
# Maximum messages retrieved per run (prevents runaway quota usage).
|
||||||
_MAX_MESSAGES = 200
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
|
||||||
@@ -28,9 +52,20 @@ def _build_gmail_query(
|
|||||||
filter_config: dict[str, Any] | None,
|
filter_config: dict[str, Any] | None,
|
||||||
since: datetime | None,
|
since: datetime | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Build a Gmail search query string from *filter_config* and *since*.
|
||||||
|
|
||||||
|
Supported ``filter_config`` keys:
|
||||||
|
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
|
||||||
|
senders (list[str]): Sender addresses or domains to include
|
||||||
|
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
|
||||||
|
|
||||||
|
A hard ``since`` date (from last run) always overrides ``date_range.from``
|
||||||
|
when it is earlier.
|
||||||
|
"""
|
||||||
parts: list[str] = []
|
parts: list[str] = []
|
||||||
cfg = filter_config or {}
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
# Labels — joined with OR when multiple given.
|
||||||
labels: list[str] = cfg.get("labels", [])
|
labels: list[str] = cfg.get("labels", [])
|
||||||
if labels:
|
if labels:
|
||||||
if len(labels) == 1:
|
if len(labels) == 1:
|
||||||
@@ -39,14 +74,17 @@ def _build_gmail_query(
|
|||||||
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
|
||||||
parts.append(f"({label_expr})")
|
parts.append(f"({label_expr})")
|
||||||
|
|
||||||
|
# Senders — each prefixed with "from:".
|
||||||
senders: list[str] = cfg.get("senders", [])
|
senders: list[str] = cfg.get("senders", [])
|
||||||
for sender in senders:
|
for sender in senders:
|
||||||
parts.append(f"from:{sender}")
|
parts.append(f"from:{sender}")
|
||||||
|
|
||||||
|
# Date range.
|
||||||
date_range: dict = cfg.get("date_range", {})
|
date_range: dict = cfg.get("date_range", {})
|
||||||
from_str: str | None = date_range.get("from")
|
from_str: str | None = date_range.get("from")
|
||||||
to_str: str | None = date_range.get("to")
|
to_str: str | None = date_range.get("to")
|
||||||
|
|
||||||
|
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
|
||||||
effective_since: datetime | None = since
|
effective_since: datetime | None = since
|
||||||
if from_str:
|
if from_str:
|
||||||
try:
|
try:
|
||||||
@@ -72,12 +110,18 @@ def _build_gmail_query(
|
|||||||
|
|
||||||
|
|
||||||
def _strip_html(raw_html: str) -> str:
|
def _strip_html(raw_html: str) -> str:
|
||||||
|
"""Remove HTML tags and decode entities to get plain text."""
|
||||||
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
no_tags = re.sub(r"<[^>]+>", " ", raw_html)
|
||||||
decoded = html.unescape(no_tags)
|
decoded = html.unescape(no_tags)
|
||||||
return re.sub(r"\s+", " ", decoded).strip()
|
return re.sub(r"\s+", " ", decoded).strip()
|
||||||
|
|
||||||
|
|
||||||
def _parse_body(payload: dict[str, Any]) -> str:
|
def _parse_body(payload: dict[str, Any]) -> str:
|
||||||
|
"""Recursively extract the plain-text body from a Gmail message payload.
|
||||||
|
|
||||||
|
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
|
||||||
|
Returns an empty string if no body can be extracted.
|
||||||
|
"""
|
||||||
mime_type: str = payload.get("mimeType", "")
|
mime_type: str = payload.get("mimeType", "")
|
||||||
body: dict = payload.get("body", {})
|
body: dict = payload.get("body", {})
|
||||||
parts: list[dict] = payload.get("parts", [])
|
parts: list[dict] = payload.get("parts", [])
|
||||||
@@ -95,6 +139,7 @@ def _parse_body(payload: dict[str, Any]) -> str:
|
|||||||
return _strip_html(raw)
|
return _strip_html(raw)
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
|
# Multipart — prefer text/plain part, fall back to text/html.
|
||||||
plain_fallback = ""
|
plain_fallback = ""
|
||||||
for part in parts:
|
for part in parts:
|
||||||
part_mime = part.get("mimeType", "")
|
part_mime = part.get("mimeType", "")
|
||||||
@@ -110,6 +155,7 @@ def _parse_body(payload: dict[str, Any]) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _parse_date(raw: str) -> datetime:
|
def _parse_date(raw: str) -> datetime:
|
||||||
|
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
|
||||||
try:
|
try:
|
||||||
parsed = email.utils.parsedate_to_datetime(raw)
|
parsed = email.utils.parsedate_to_datetime(raw)
|
||||||
if parsed.tzinfo is None:
|
if parsed.tzinfo is None:
|
||||||
@@ -120,6 +166,16 @@ def _parse_date(raw: str) -> datetime:
|
|||||||
|
|
||||||
|
|
||||||
class GmailClient:
|
class GmailClient:
|
||||||
|
"""Fetch email messages from a Gmail account via the Gmail REST API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
credentials_info:
|
||||||
|
Decrypted OAuth2 credential dict. Must contain at minimum
|
||||||
|
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
|
||||||
|
``client_id`` + ``client_secret``.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
from google.oauth2.credentials import Credentials
|
from google.oauth2.credentials import Credentials
|
||||||
|
|
||||||
@@ -144,20 +200,38 @@ class GmailClient:
|
|||||||
expiry=expiry,
|
expiry=expiry,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# ── Public API ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def fetch_messages(
|
async def fetch_messages(
|
||||||
self,
|
self,
|
||||||
filter_config: dict[str, Any] | None = None,
|
filter_config: dict[str, Any] | None = None,
|
||||||
since: datetime | None = None,
|
since: datetime | None = None,
|
||||||
) -> list[EmailMessage]:
|
) -> list[EmailMessage]:
|
||||||
|
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
|
||||||
|
|
||||||
|
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
|
||||||
|
to avoid blocking the async event loop.
|
||||||
|
|
||||||
|
Token refresh is performed automatically when the access token has
|
||||||
|
expired. After the call, ``self.refreshed_credentials`` may be
|
||||||
|
consulted to detect whether new credentials should be persisted.
|
||||||
|
"""
|
||||||
query = _build_gmail_query(filter_config, since)
|
query = _build_gmail_query(filter_config, since)
|
||||||
logger.debug("gmail: executing search query %r", query)
|
logger.debug("gmail: executing search query %r", query)
|
||||||
return await asyncio.to_thread(self._fetch_sync, query)
|
return await asyncio.to_thread(self._fetch_sync, query)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def refreshed_credentials(self) -> dict[str, Any] | None:
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
"""Return updated credential dict if the access token was refreshed.
|
||||||
|
|
||||||
|
If the credentials were refreshed during ``fetch_messages()``, returns
|
||||||
|
a new dict that should be re-encrypted and written back to the DB.
|
||||||
|
Returns ``None`` if no refresh occurred.
|
||||||
|
"""
|
||||||
creds = self._credentials
|
creds = self._credentials
|
||||||
if not creds.valid and creds.expired:
|
if not creds.valid and creds.expired:
|
||||||
return None
|
return None
|
||||||
|
# Check whether the token changed from what was stored.
|
||||||
if creds.token != self._credentials_info.get("token"):
|
if creds.token != self._credentials_info.get("token"):
|
||||||
result = {
|
result = {
|
||||||
"token": creds.token,
|
"token": creds.token,
|
||||||
@@ -172,11 +246,15 @@ class GmailClient:
|
|||||||
return result
|
return result
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# ── Internal sync worker ───────────────────────────────────────────────
|
||||||
|
|
||||||
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
def _fetch_sync(self, query: str) -> list[EmailMessage]:
|
||||||
|
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
|
||||||
import googleapiclient.discovery
|
import googleapiclient.discovery
|
||||||
import googleapiclient.errors
|
import googleapiclient.errors
|
||||||
from google.auth.transport.requests import Request
|
from google.auth.transport.requests import Request
|
||||||
|
|
||||||
|
# Refresh token if needed before building the service.
|
||||||
if self._credentials.expired and self._credentials.refresh_token:
|
if self._credentials.expired and self._credentials.refresh_token:
|
||||||
try:
|
try:
|
||||||
self._credentials.refresh(Request())
|
self._credentials.refresh(Request())
|
||||||
@@ -186,8 +264,9 @@ class GmailClient:
|
|||||||
service = googleapiclient.discovery.build(
|
service = googleapiclient.discovery.build(
|
||||||
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
"gmail", "v1", credentials=self._credentials, cache_discovery=False
|
||||||
)
|
)
|
||||||
user_api = service.users()
|
user_api = service.users() # type: ignore[attr-defined]
|
||||||
|
|
||||||
|
# ── List matching message IDs ──────────────────────────────────────
|
||||||
ids: list[str] = []
|
ids: list[str] = []
|
||||||
page_token: str | None = None
|
page_token: str | None = None
|
||||||
while len(ids) < _MAX_MESSAGES:
|
while len(ids) < _MAX_MESSAGES:
|
||||||
@@ -214,10 +293,12 @@ class GmailClient:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not ids:
|
if not ids:
|
||||||
|
logger.debug("gmail: no messages matched query %r", query)
|
||||||
return []
|
return []
|
||||||
|
|
||||||
logger.info("gmail: fetching %d message(s)", len(ids))
|
logger.info("gmail: fetching %d message(s)", len(ids))
|
||||||
|
|
||||||
|
# ── Fetch individual message details ──────────────────────────────
|
||||||
messages: list[EmailMessage] = []
|
messages: list[EmailMessage] = []
|
||||||
for msg_id in ids:
|
for msg_id in ids:
|
||||||
try:
|
try:
|
||||||
@@ -245,8 +326,10 @@ class GmailClient:
|
|||||||
date=date,
|
date=date,
|
||||||
labels=labels,
|
labels=labels,
|
||||||
))
|
))
|
||||||
|
except googleapiclient.errors.HttpError as exc:
|
||||||
|
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
logger.warning("gmail: skipping message %s: %s", msg_id, exc)
|
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc)
|
||||||
|
|
||||||
logger.info("gmail: returned %d message(s)", len(messages))
|
logger.info("gmail: returned %d message(s)", len(messages))
|
||||||
return messages
|
return messages
|
||||||
@@ -1,30 +1,52 @@
|
|||||||
"""Microsoft Graph API client for Outlook and Teams.
|
"""Microsoft Graph API client for Outlook and Teams cloud agent integration.
|
||||||
|
|
||||||
Adapted for Batch Agent Service: import settings from shared.config.
|
Handles two data sources:
|
||||||
|
|
||||||
|
* **Outlook email** (``provider="outlook"``) — ``fetch_emails()`` calls
|
||||||
|
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
|
||||||
|
* **Teams messages** (``provider="teams"``) — ``fetch_messages()`` calls
|
||||||
|
``/me/chats/getAllMessages`` filtered by date.
|
||||||
|
|
||||||
|
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
|
||||||
|
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
|
||||||
|
dependency) is used for all API calls.
|
||||||
|
|
||||||
|
Credential dict shape (Microsoft OAuth2 / MSAL):
|
||||||
|
{
|
||||||
|
"access_token": "<access_token>",
|
||||||
|
"refresh_token": "<refresh_token>",
|
||||||
|
"token_type": "Bearer",
|
||||||
|
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
|
||||||
|
"expires_in": 3600
|
||||||
|
}
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import re
|
import re
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timedelta, timezone
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
|
|
||||||
from shared.config import settings
|
from app.config.settings import settings
|
||||||
from app.integrations import ChatMessage, EmailMessage
|
from app.integrations import ChatMessage, EmailMessage
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
_GRAPH_BASE = "https://graph.microsoft.com/v1.0"
|
||||||
|
|
||||||
|
# Max items fetched per run.
|
||||||
_MAX_EMAILS = 200
|
_MAX_EMAILS = 200
|
||||||
_MAX_MESSAGES = 200
|
_MAX_MESSAGES = 200
|
||||||
|
|
||||||
|
# Max characters of body forwarded to the LLM.
|
||||||
_BODY_TRUNCATE = 8_000
|
_BODY_TRUNCATE = 8_000
|
||||||
|
|
||||||
|
|
||||||
def _strip_html(raw: str) -> str:
|
def _strip_html(raw: str) -> str:
|
||||||
|
"""Strip HTML tags and collapse whitespace."""
|
||||||
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
no_tags = re.sub(r"<[^>]+>", " ", raw)
|
||||||
import html as _html
|
import html as _html
|
||||||
decoded = _html.unescape(no_tags)
|
decoded = _html.unescape(no_tags)
|
||||||
@@ -32,6 +54,7 @@ def _strip_html(raw: str) -> str:
|
|||||||
|
|
||||||
|
|
||||||
def _odata_datetime(dt: datetime) -> str:
|
def _odata_datetime(dt: datetime) -> str:
|
||||||
|
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
|
||||||
utc = dt.astimezone(timezone.utc)
|
utc = dt.astimezone(timezone.utc)
|
||||||
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||||
|
|
||||||
@@ -40,14 +63,29 @@ def _build_email_filter(
|
|||||||
filter_config: dict[str, Any] | None,
|
filter_config: dict[str, Any] | None,
|
||||||
since: datetime | None,
|
since: datetime | None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
|
||||||
|
|
||||||
|
Supported ``filter_config`` keys:
|
||||||
|
senders (list[str]): Sender email addresses.
|
||||||
|
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
|
||||||
|
folders (list[str]): Folder display names (not directly filterable
|
||||||
|
via OData, so ignored here — callers iterate
|
||||||
|
folder IDs separately if needed; listed for
|
||||||
|
completeness).
|
||||||
|
|
||||||
|
A hard ``since`` date always overrides ``date_range.from`` when it is
|
||||||
|
earlier.
|
||||||
|
"""
|
||||||
clauses: list[str] = []
|
clauses: list[str] = []
|
||||||
cfg = filter_config or {}
|
cfg = filter_config or {}
|
||||||
|
|
||||||
|
# Senders.
|
||||||
senders: list[str] = cfg.get("senders", [])
|
senders: list[str] = cfg.get("senders", [])
|
||||||
if senders:
|
if senders:
|
||||||
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
|
||||||
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
clauses.append("(" + " or ".join(sender_clauses) + ")")
|
||||||
|
|
||||||
|
# Date range.
|
||||||
date_range: dict = cfg.get("date_range", {})
|
date_range: dict = cfg.get("date_range", {})
|
||||||
from_str: str | None = date_range.get("from")
|
from_str: str | None = date_range.get("from")
|
||||||
|
|
||||||
@@ -79,16 +117,33 @@ def _build_email_filter(
|
|||||||
|
|
||||||
|
|
||||||
class MSGraphClient:
|
class MSGraphClient:
|
||||||
|
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
credentials_info:
|
||||||
|
Decrypted MSAL credential dict.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
def __init__(self, credentials_info: dict[str, Any]) -> None:
|
||||||
self._credentials_info = credentials_info
|
self._credentials_info = credentials_info
|
||||||
self._access_token: str = credentials_info.get("access_token", "")
|
self._access_token: str = credentials_info.get("access_token", "")
|
||||||
self._original_access_token: str = self._access_token
|
self._original_access_token: str = self._access_token
|
||||||
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
self._refresh_token: str | None = credentials_info.get("refresh_token")
|
||||||
|
|
||||||
|
# ── Token management ───────────────────────────────────────────────────
|
||||||
|
|
||||||
def _auth_headers(self) -> dict[str, str]:
|
def _auth_headers(self) -> dict[str, str]:
|
||||||
return {"Authorization": f"Bearer {self._access_token}"}
|
return {"Authorization": f"Bearer {self._access_token}"}
|
||||||
|
|
||||||
async def _refresh_access_token(self) -> None:
|
async def _refresh_access_token(self) -> None:
|
||||||
|
"""Use MSAL to exchange the refresh token for a fresh access token.
|
||||||
|
|
||||||
|
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
RuntimeError: MSAL reports an auth error.
|
||||||
|
"""
|
||||||
import msal
|
import msal
|
||||||
|
|
||||||
app = msal.ConfidentialClientApplication(
|
app = msal.ConfidentialClientApplication(
|
||||||
@@ -109,6 +164,7 @@ class MSGraphClient:
|
|||||||
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
raise RuntimeError(f"MS Graph token refresh failed: {error}")
|
||||||
|
|
||||||
self._access_token = result["access_token"]
|
self._access_token = result["access_token"]
|
||||||
|
# MSAL may issue a new refresh token.
|
||||||
if "refresh_token" in result:
|
if "refresh_token" in result:
|
||||||
self._refresh_token = result["refresh_token"]
|
self._refresh_token = result["refresh_token"]
|
||||||
self._credentials_info["refresh_token"] = result["refresh_token"]
|
self._credentials_info["refresh_token"] = result["refresh_token"]
|
||||||
@@ -116,10 +172,16 @@ class MSGraphClient:
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def refreshed_credentials(self) -> dict[str, Any] | None:
|
def refreshed_credentials(self) -> dict[str, Any] | None:
|
||||||
|
"""Return updated credential dict if the access token was refreshed.
|
||||||
|
|
||||||
|
Returns ``None`` if no change was made.
|
||||||
|
"""
|
||||||
if self._access_token != self._original_access_token:
|
if self._access_token != self._original_access_token:
|
||||||
return {**self._credentials_info, "access_token": self._access_token}
|
return {**self._credentials_info, "access_token": self._access_token}
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
# ── HTTP helpers ───────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def _get(
|
async def _get(
|
||||||
self,
|
self,
|
||||||
client: httpx.AsyncClient,
|
client: httpx.AsyncClient,
|
||||||
@@ -128,8 +190,10 @@ class MSGraphClient:
|
|||||||
*,
|
*,
|
||||||
retry_on_401: bool = True,
|
retry_on_401: bool = True,
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
|
"""GET *url* with auth; refresh token on 401 and retry once."""
|
||||||
resp = await client.get(url, params=params, headers=self._auth_headers())
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
if resp.status_code == 401 and retry_on_401 and self._refresh_token:
|
||||||
|
logger.debug("ms_graph: 401 on %s — refreshing token", url)
|
||||||
await self._refresh_access_token()
|
await self._refresh_access_token()
|
||||||
resp = await client.get(url, params=params, headers=self._auth_headers())
|
resp = await client.get(url, params=params, headers=self._auth_headers())
|
||||||
if resp.status_code == 429:
|
if resp.status_code == 429:
|
||||||
@@ -137,11 +201,22 @@ class MSGraphClient:
|
|||||||
resp.raise_for_status()
|
resp.raise_for_status()
|
||||||
return resp.json()
|
return resp.json()
|
||||||
|
|
||||||
|
# ── Public API ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
async def fetch_emails(
|
async def fetch_emails(
|
||||||
self,
|
self,
|
||||||
filter_config: dict[str, Any] | None = None,
|
filter_config: dict[str, Any] | None = None,
|
||||||
since: datetime | None = None,
|
since: datetime | None = None,
|
||||||
) -> list[EmailMessage]:
|
) -> list[EmailMessage]:
|
||||||
|
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
filter_config:
|
||||||
|
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
|
||||||
|
since:
|
||||||
|
Hard lower-bound on email date (from last agent run).
|
||||||
|
"""
|
||||||
odata_filter = _build_email_filter(filter_config, since)
|
odata_filter = _build_email_filter(filter_config, since)
|
||||||
params: dict[str, Any] = {
|
params: dict[str, Any] = {
|
||||||
"$top": 50,
|
"$top": 50,
|
||||||
@@ -162,7 +237,7 @@ class MSGraphClient:
|
|||||||
if len(emails) >= _MAX_EMAILS:
|
if len(emails) >= _MAX_EMAILS:
|
||||||
break
|
break
|
||||||
url = data.get("@odata.nextLink", "")
|
url = data.get("@odata.nextLink", "")
|
||||||
params = {}
|
params = {} # nextLink already contains encoded params.
|
||||||
|
|
||||||
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
|
||||||
return emails
|
return emails
|
||||||
@@ -172,6 +247,13 @@ class MSGraphClient:
|
|||||||
filter_config: dict[str, Any] | None = None,
|
filter_config: dict[str, Any] | None = None,
|
||||||
since: datetime | None = None,
|
since: datetime | None = None,
|
||||||
) -> list[ChatMessage]:
|
) -> list[ChatMessage]:
|
||||||
|
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
|
||||||
|
|
||||||
|
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
|
||||||
|
The ``filter_config.channels`` key is checked as a text-filter on
|
||||||
|
the channel name post-fetch (the API doesn't support channel OData
|
||||||
|
filter directly on ``getAllMessages``).
|
||||||
|
"""
|
||||||
cfg = filter_config or {}
|
cfg = filter_config or {}
|
||||||
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
|
||||||
params: dict[str, Any] = {"$top": 50}
|
params: dict[str, Any] = {"$top": 50}
|
||||||
@@ -186,9 +268,11 @@ class MSGraphClient:
|
|||||||
try:
|
try:
|
||||||
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
|
||||||
except httpx.HTTPStatusError as exc:
|
except httpx.HTTPStatusError as exc:
|
||||||
|
# getAllMessages requires specific licensing; degrade gracefully.
|
||||||
if exc.response.status_code in (403, 404):
|
if exc.response.status_code in (403, 404):
|
||||||
logger.warning(
|
logger.warning(
|
||||||
"ms_graph: /me/chats/getAllMessages not available (%d)",
|
"ms_graph: /me/chats/getAllMessages not available (%d) — "
|
||||||
|
"check Teams license or permissions",
|
||||||
exc.response.status_code,
|
exc.response.status_code,
|
||||||
)
|
)
|
||||||
break
|
break
|
||||||
@@ -208,6 +292,8 @@ class MSGraphClient:
|
|||||||
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
# ── Parsers ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
def _parse_email(item: dict[str, Any]) -> EmailMessage:
|
||||||
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
subject: str = item.get("subject", "(no subject)") or "(no subject)"
|
||||||
71
app/main.py
Normal file
71
app/main.py
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
from contextlib import asynccontextmanager
|
||||||
|
import logging
|
||||||
|
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
|
||||||
|
)
|
||||||
|
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
|
||||||
|
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
|
||||||
|
|
||||||
|
from app.api.middleware.rate_limit import TierRateLimitMiddleware
|
||||||
|
from app.api.middleware.sanitizer import SanitizerMiddleware
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
@asynccontextmanager
|
||||||
|
async def lifespan(app: FastAPI):
|
||||||
|
# Startup: initialise DB connection pool
|
||||||
|
yield
|
||||||
|
|
||||||
|
# Shutdown: dispose SQLAlchemy connection pool
|
||||||
|
from app.db import engine
|
||||||
|
await engine.dispose()
|
||||||
|
|
||||||
|
|
||||||
|
def create_app() -> FastAPI:
|
||||||
|
app = FastAPI(
|
||||||
|
title="Adiuva Cloud API",
|
||||||
|
version="0.1.0",
|
||||||
|
docs_url="/docs" if settings.ENV == "dev" else None,
|
||||||
|
redoc_url=None,
|
||||||
|
lifespan=lifespan,
|
||||||
|
)
|
||||||
|
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=settings.CORS_ORIGINS,
|
||||||
|
allow_credentials=True,
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
|
||||||
|
# Request flow: TierRateLimit → Sanitizer → CORS → Router
|
||||||
|
# Response flow: Router → CORS → Sanitizer → TierRateLimit
|
||||||
|
app.add_middleware(SanitizerMiddleware)
|
||||||
|
app.add_middleware(TierRateLimitMiddleware)
|
||||||
|
|
||||||
|
from app.api.routes import agent_setup, agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
|
||||||
|
|
||||||
|
app.include_router(auth.router, prefix="/api/v1")
|
||||||
|
app.include_router(chat.router, prefix="/api/v1")
|
||||||
|
app.include_router(storage.router, prefix="/api/v1")
|
||||||
|
app.include_router(vectors.router, prefix="/api/v1")
|
||||||
|
app.include_router(backup.router, prefix="/api/v1")
|
||||||
|
app.include_router(plugins.router, prefix="/api/v1")
|
||||||
|
app.include_router(billing.router, prefix="/api/v1")
|
||||||
|
app.include_router(agents.router, prefix="/api/v1")
|
||||||
|
app.include_router(agent_setup.router, prefix="/api/v1")
|
||||||
|
app.include_router(device_ws.router, prefix="/api/v1")
|
||||||
|
|
||||||
|
@app.get("/api/v1/health", tags=["health"])
|
||||||
|
async def health() -> dict:
|
||||||
|
return {"status": "ok", "version": app.version}
|
||||||
|
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
app = create_app()
|
||||||
7
app/marketplace/__init__.py
Normal file
7
app/marketplace/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
"""Plugin marketplace package.
|
||||||
|
|
||||||
|
Three service classes introduced in Step 10:
|
||||||
|
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
|
||||||
|
- ``ReviewQueue`` — approval workflow + security checklist
|
||||||
|
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
|
||||||
|
"""
|
||||||
212
app/marketplace/plugin_registry.py
Normal file
212
app/marketplace/plugin_registry.py
Normal file
@@ -0,0 +1,212 @@
|
|||||||
|
"""Plugin catalog registry backed by PostgreSQL.
|
||||||
|
|
||||||
|
Maintains the authoritative list of plugins, their review status, and
|
||||||
|
aggregate install counts. All data is persisted in the ``plugins`` table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy import select, func
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.models import Plugin
|
||||||
|
from app.schemas import PluginListResponse, PluginManifest
|
||||||
|
|
||||||
|
_PAGE_SIZE = 20
|
||||||
|
|
||||||
|
|
||||||
|
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
|
||||||
|
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
|
||||||
|
try:
|
||||||
|
permissions = json.loads(p.permissions) if p.permissions else []
|
||||||
|
except (json.JSONDecodeError, TypeError):
|
||||||
|
permissions = []
|
||||||
|
return PluginManifest(
|
||||||
|
id=p.id,
|
||||||
|
name=p.name,
|
||||||
|
description=p.description,
|
||||||
|
version=p.version,
|
||||||
|
author=p.author_name,
|
||||||
|
permissions=permissions,
|
||||||
|
category=p.category,
|
||||||
|
price_cents=p.price_cents,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PluginRegistry:
|
||||||
|
"""PostgreSQL-backed plugin catalog.
|
||||||
|
|
||||||
|
All methods accept an ``AsyncSession`` parameter so the calling route
|
||||||
|
controls the session lifecycle.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Queries ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def list_plugins(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
category: str | None = None,
|
||||||
|
query: str | None = None,
|
||||||
|
page: int = 1,
|
||||||
|
sort: Literal["rating", "installs", "newest"] = "newest",
|
||||||
|
) -> PluginListResponse:
|
||||||
|
"""Return a page of approved plugins, optionally filtered and sorted."""
|
||||||
|
base = select(Plugin).where(Plugin.status == "approved")
|
||||||
|
|
||||||
|
if category:
|
||||||
|
base = base.where(Plugin.category == category)
|
||||||
|
if query:
|
||||||
|
pattern = f"%{query}%"
|
||||||
|
base = base.where(
|
||||||
|
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Count
|
||||||
|
count_q = select(func.count()).select_from(base.subquery())
|
||||||
|
total = (await db.execute(count_q)).scalar_one()
|
||||||
|
|
||||||
|
# Sort
|
||||||
|
if sort == "installs":
|
||||||
|
base = base.order_by(Plugin.install_count.desc())
|
||||||
|
elif sort == "rating":
|
||||||
|
base = base.order_by(Plugin.avg_rating.desc())
|
||||||
|
else: # newest
|
||||||
|
base = base.order_by(Plugin.created_at.desc())
|
||||||
|
|
||||||
|
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
|
||||||
|
rows = (await db.execute(base)).scalars().all()
|
||||||
|
|
||||||
|
return PluginListResponse(
|
||||||
|
plugins=[_plugin_to_manifest(r) for r in rows],
|
||||||
|
total=total,
|
||||||
|
page=page,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
|
||||||
|
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
p = result.scalar_one_or_none()
|
||||||
|
if p is None:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"manifest": _plugin_to_manifest(p),
|
||||||
|
"status": p.status,
|
||||||
|
"install_count": p.install_count,
|
||||||
|
"avg_rating": p.avg_rating,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Mutations ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def submit_plugin(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
manifest: PluginManifest,
|
||||||
|
package_s3_key: str,
|
||||||
|
) -> str:
|
||||||
|
"""Add *manifest* to the catalog with ``status='pending_review'``.
|
||||||
|
|
||||||
|
Returns the plugin_id. If a plugin with the same id already exists
|
||||||
|
it is overwritten (re-submission after rejection).
|
||||||
|
"""
|
||||||
|
plugin_id = manifest.id
|
||||||
|
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = existing.scalar_one_or_none()
|
||||||
|
|
||||||
|
if row is not None:
|
||||||
|
row.name = manifest.name
|
||||||
|
row.description = manifest.description
|
||||||
|
row.version = manifest.version
|
||||||
|
row.author_name = manifest.author
|
||||||
|
row.category = manifest.category
|
||||||
|
row.price_cents = manifest.price_cents
|
||||||
|
row.permissions = json.dumps(manifest.permissions)
|
||||||
|
row.status = "pending_review"
|
||||||
|
row.s3_package_key = package_s3_key
|
||||||
|
row.rejection_reason = None
|
||||||
|
else:
|
||||||
|
row = Plugin(
|
||||||
|
id=plugin_id,
|
||||||
|
name=manifest.name,
|
||||||
|
description=manifest.description,
|
||||||
|
version=manifest.version,
|
||||||
|
author_name=manifest.author,
|
||||||
|
category=manifest.category,
|
||||||
|
price_cents=manifest.price_cents,
|
||||||
|
permissions=json.dumps(manifest.permissions),
|
||||||
|
status="pending_review",
|
||||||
|
s3_package_key=package_s3_key,
|
||||||
|
install_count=0,
|
||||||
|
avg_rating=0.0,
|
||||||
|
)
|
||||||
|
db.add(row)
|
||||||
|
await db.commit()
|
||||||
|
return plugin_id
|
||||||
|
|
||||||
|
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'approved'``.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
row.status = "approved"
|
||||||
|
row.rejection_reason = None
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
|
||||||
|
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
|
||||||
|
|
||||||
|
Raises ``KeyError`` if the plugin is not found.
|
||||||
|
"""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is None:
|
||||||
|
raise KeyError(f"Plugin not found: {plugin_id}")
|
||||||
|
row.status = "rejected"
|
||||||
|
row.rejection_reason = reason
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Increment the install count for *plugin_id* (no-op if not found)."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.install_count = row.install_count + 1
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
|
||||||
|
"""Decrement the install count for *plugin_id*, floored at 0."""
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
row = result.scalar_one_or_none()
|
||||||
|
if row is not None:
|
||||||
|
row.install_count = max(0, row.install_count - 1)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
# ── Internal helpers used by ReviewQueue ─────────────────────────
|
||||||
|
|
||||||
|
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
|
"""Return all entries with status='pending_review'."""
|
||||||
|
result = await db.execute(
|
||||||
|
select(Plugin).where(Plugin.status == "pending_review")
|
||||||
|
)
|
||||||
|
rows = result.scalars().all()
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"manifest": _plugin_to_manifest(r),
|
||||||
|
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
|
||||||
|
}
|
||||||
|
for r in rows
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
registry = PluginRegistry()
|
||||||
125
app/marketplace/plugin_review.py
Normal file
125
app/marketplace/plugin_review.py
Normal file
@@ -0,0 +1,125 @@
|
|||||||
|
"""Plugin review workflow backed by PostgreSQL.
|
||||||
|
|
||||||
|
Manages the approval queue for newly submitted plugins and enforces a
|
||||||
|
security checklist before any plugin is made visible in the marketplace.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.plugin_review import review_queue
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import re
|
||||||
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import PluginReview as PluginReviewModel
|
||||||
|
from app.schemas import PluginManifest
|
||||||
|
|
||||||
|
# ── Security policy ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
|
||||||
|
{
|
||||||
|
"read:tasks",
|
||||||
|
"write:tasks",
|
||||||
|
"read:projects",
|
||||||
|
"write:projects",
|
||||||
|
"read:notes",
|
||||||
|
"write:notes",
|
||||||
|
"read:timelines",
|
||||||
|
"write:timelines",
|
||||||
|
"read:calendar",
|
||||||
|
"write:calendar",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
|
||||||
|
|
||||||
|
|
||||||
|
def validate_manifest(manifest: PluginManifest) -> None:
|
||||||
|
"""Enforce the plugin security checklist.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``ValueError`` on the first violation found. Callers should catch
|
||||||
|
this and return HTTP 422 / reject the submission.
|
||||||
|
|
||||||
|
Checks:
|
||||||
|
1. Plugin id matches ``^[a-z0-9-]+$``
|
||||||
|
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
|
||||||
|
3. No manifest field contains raw binary data
|
||||||
|
"""
|
||||||
|
if not _PLUGIN_ID_RE.match(manifest.id):
|
||||||
|
raise ValueError(
|
||||||
|
f"Invalid plugin id format: '{manifest.id}'. "
|
||||||
|
"Only lowercase letters, digits, and hyphens are allowed."
|
||||||
|
)
|
||||||
|
|
||||||
|
for perm in manifest.permissions:
|
||||||
|
if perm not in ALLOWED_PERMISSIONS:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown permission: '{perm}'. "
|
||||||
|
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
for field_name, value in manifest.model_dump().items():
|
||||||
|
if isinstance(value, (bytes, bytearray)):
|
||||||
|
raise ValueError(
|
||||||
|
f"Binary content is not allowed in manifest field '{field_name}'."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ReviewQueue:
|
||||||
|
"""Approval queue for pending plugin submissions.
|
||||||
|
|
||||||
|
Delegates status changes to the shared ``PluginRegistry`` singleton.
|
||||||
|
Review records are persisted in the ``plugin_reviews`` table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
|
||||||
|
"""Return all plugins currently awaiting review.
|
||||||
|
|
||||||
|
Each item is ``{plugin_id, manifest, submitted_at}``.
|
||||||
|
"""
|
||||||
|
entries = await registry.get_pending_entries(db)
|
||||||
|
return [
|
||||||
|
{
|
||||||
|
"plugin_id": e["manifest"].id,
|
||||||
|
"manifest": e["manifest"],
|
||||||
|
"submitted_at": e["submitted_at"],
|
||||||
|
}
|
||||||
|
for e in entries
|
||||||
|
]
|
||||||
|
|
||||||
|
async def submit_review(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
plugin_id: str,
|
||||||
|
reviewer_id: str,
|
||||||
|
decision: Literal["approved", "rejected"],
|
||||||
|
notes: str = "",
|
||||||
|
) -> None:
|
||||||
|
"""Record a review decision and update the plugin's status.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``KeyError`` if *plugin_id* is not found in the registry.
|
||||||
|
"""
|
||||||
|
if decision == "approved":
|
||||||
|
await registry.approve_plugin(db, plugin_id)
|
||||||
|
else:
|
||||||
|
await registry.reject_plugin(db, plugin_id, reason=notes)
|
||||||
|
|
||||||
|
review = PluginReviewModel(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
reviewer_id=reviewer_id,
|
||||||
|
decision=decision,
|
||||||
|
notes=notes,
|
||||||
|
)
|
||||||
|
db.add(review)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
review_queue = ReviewQueue()
|
||||||
233
app/marketplace/revenue_share.py
Normal file
233
app/marketplace/revenue_share.py
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
|
||||||
|
|
||||||
|
Records every plugin installation as a revenue event and facilitates
|
||||||
|
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
|
||||||
|
in the ``revenue_events`` table.
|
||||||
|
|
||||||
|
Module-level singleton::
|
||||||
|
|
||||||
|
from app.marketplace.revenue_share import revenue_share
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import stripe as stripe_lib
|
||||||
|
from sqlalchemy import extract, func, select
|
||||||
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.marketplace.plugin_registry import registry
|
||||||
|
from app.models import Plugin, RevenueEvent
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Revenue split constants ───────────────────────────────────────────
|
||||||
|
|
||||||
|
DEVELOPER_SHARE: float = 0.70
|
||||||
|
PLATFORM_SHARE: float = 0.30
|
||||||
|
|
||||||
|
|
||||||
|
class RevenueShare:
|
||||||
|
"""Records installation revenue events and coordinates developer payouts.
|
||||||
|
|
||||||
|
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
|
||||||
|
is not configured, consistent with the rest of the billing layer.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# ── Helpers ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe_configured() -> bool:
|
||||||
|
return bool(settings.STRIPE_SECRET_KEY)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _stripe() -> Any:
|
||||||
|
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
|
||||||
|
return stripe_lib
|
||||||
|
|
||||||
|
# ── Core operations ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def record_install(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
plugin_id: str,
|
||||||
|
user_id: str,
|
||||||
|
amount_cents: int,
|
||||||
|
) -> None:
|
||||||
|
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
|
||||||
|
|
||||||
|
For free plugins (``amount_cents == 0``) no payment is initiated but
|
||||||
|
the event is still recorded for analytics.
|
||||||
|
|
||||||
|
For paid plugins the developer receives 70 % via a Stripe Connect
|
||||||
|
destination charge. If Stripe is not configured or the charge fails
|
||||||
|
the installation still succeeds (the event is recorded and the install
|
||||||
|
count is incremented) — a warning is logged for monitoring.
|
||||||
|
"""
|
||||||
|
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
|
||||||
|
stripe_transfer_id: str | None = None
|
||||||
|
|
||||||
|
if amount_cents > 0 and self._stripe_configured():
|
||||||
|
# Look up the plugin's author Stripe account from the DB
|
||||||
|
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
plugin_row = result.scalar_one_or_none()
|
||||||
|
developer_stripe_account: str | None = None
|
||||||
|
if plugin_row and plugin_row.author_id:
|
||||||
|
# Future: look up user.stripe_connect_account_id
|
||||||
|
developer_stripe_account = None # no real account yet
|
||||||
|
|
||||||
|
if developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
transfer = s.Transfer.create(
|
||||||
|
amount=developer_share_cents,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Revenue share for plugin {plugin_id}",
|
||||||
|
metadata={"plugin_id": plugin_id, "user_id": user_id},
|
||||||
|
)
|
||||||
|
stripe_transfer_id = transfer["id"]
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning(
|
||||||
|
"Stripe Connect transfer failed for plugin %s: %s",
|
||||||
|
plugin_id,
|
||||||
|
exc,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.debug(
|
||||||
|
"No Stripe account on file for plugin %s developer; "
|
||||||
|
"skipping transfer.",
|
||||||
|
plugin_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
event = RevenueEvent(
|
||||||
|
plugin_id=plugin_id,
|
||||||
|
user_id=user_id,
|
||||||
|
amount_cents=amount_cents,
|
||||||
|
developer_share_cents=developer_share_cents,
|
||||||
|
stripe_transfer_id=stripe_transfer_id,
|
||||||
|
)
|
||||||
|
db.add(event)
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
await registry.record_install(db, plugin_id)
|
||||||
|
|
||||||
|
async def get_earnings(
|
||||||
|
self,
|
||||||
|
db: AsyncSession,
|
||||||
|
developer_id: str,
|
||||||
|
period: str | None = None,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Return aggregated earnings for *developer_id*.
|
||||||
|
|
||||||
|
``period`` is an optional ``YYYY-MM`` string to restrict the window.
|
||||||
|
|
||||||
|
Returns::
|
||||||
|
|
||||||
|
{
|
||||||
|
"developer_id": str,
|
||||||
|
"period": str | None,
|
||||||
|
"total_installs": int,
|
||||||
|
"total_revenue_cents": int,
|
||||||
|
"developer_share_cents": int,
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
# Find plugin ids belonging to this developer (by author_name match)
|
||||||
|
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
|
||||||
|
plugin_result = await db.execute(plugin_q)
|
||||||
|
developer_plugin_ids = [row[0] for row in plugin_result.all()]
|
||||||
|
|
||||||
|
if not developer_plugin_ids:
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": 0,
|
||||||
|
"total_revenue_cents": 0,
|
||||||
|
"developer_share_cents": 0,
|
||||||
|
}
|
||||||
|
|
||||||
|
query = select(
|
||||||
|
func.count().label("total_installs"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
|
||||||
|
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
|
||||||
|
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
|
||||||
|
|
||||||
|
if period:
|
||||||
|
# Filter by YYYY-MM: extract year and month from created_at
|
||||||
|
try:
|
||||||
|
year, month = period.split("-")
|
||||||
|
query = query.where(
|
||||||
|
extract("year", RevenueEvent.created_at) == int(year),
|
||||||
|
extract("month", RevenueEvent.created_at) == int(month),
|
||||||
|
)
|
||||||
|
except ValueError:
|
||||||
|
pass # invalid period format — return all
|
||||||
|
|
||||||
|
result = await db.execute(query)
|
||||||
|
row = result.one()
|
||||||
|
|
||||||
|
return {
|
||||||
|
"developer_id": developer_id,
|
||||||
|
"period": period,
|
||||||
|
"total_installs": row.total_installs,
|
||||||
|
"total_revenue_cents": row.total_revenue,
|
||||||
|
"developer_share_cents": row.dev_share,
|
||||||
|
}
|
||||||
|
|
||||||
|
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
|
||||||
|
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
|
||||||
|
|
||||||
|
Marks processed events with ``paid_at`` timestamp.
|
||||||
|
Stubs gracefully when Stripe is not configured.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
year, month = period.split("-")
|
||||||
|
year_int, month_int = int(year), int(month)
|
||||||
|
except ValueError:
|
||||||
|
logger.warning("Invalid period format: %s", period)
|
||||||
|
return
|
||||||
|
|
||||||
|
result = await db.execute(
|
||||||
|
select(RevenueEvent).where(
|
||||||
|
RevenueEvent.plugin_id == plugin_id,
|
||||||
|
RevenueEvent.paid_at.is_(None),
|
||||||
|
extract("year", RevenueEvent.created_at) == year_int,
|
||||||
|
extract("month", RevenueEvent.created_at) == month_int,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
unpaid = list(result.scalars().all())
|
||||||
|
|
||||||
|
total_dev_share = sum(e.developer_share_cents for e in unpaid)
|
||||||
|
if total_dev_share <= 0 or not unpaid:
|
||||||
|
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
|
||||||
|
return
|
||||||
|
|
||||||
|
if self._stripe_configured():
|
||||||
|
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
|
||||||
|
plugin_row = plugin_result.scalar_one_or_none()
|
||||||
|
developer_stripe_account: str | None = None # Future: fetch from DB
|
||||||
|
if plugin_row and developer_stripe_account:
|
||||||
|
try:
|
||||||
|
s = self._stripe()
|
||||||
|
s.Transfer.create(
|
||||||
|
amount=total_dev_share,
|
||||||
|
currency="eur",
|
||||||
|
destination=developer_stripe_account,
|
||||||
|
description=f"Payout for plugin {plugin_id} period {period}",
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
|
||||||
|
return
|
||||||
|
|
||||||
|
paid_ts = datetime.now(timezone.utc)
|
||||||
|
for event in unpaid:
|
||||||
|
event.paid_at = paid_ts
|
||||||
|
await db.commit()
|
||||||
|
|
||||||
|
|
||||||
|
# Module-level singleton
|
||||||
|
revenue_share = RevenueShare()
|
||||||
@@ -1,14 +1,23 @@
|
|||||||
"""SQLAlchemy ORM models for all persistent tables.
|
"""SQLAlchemy ORM models for all persistent tables.
|
||||||
|
|
||||||
Centralized here so that Alembic migrations and all services share
|
Only auth, billing, storage metadata, and marketplace data live here.
|
||||||
the same model definitions. Each service only queries the tables it owns.
|
User content (notes, tasks, etc.) is NEVER persisted server-side —
|
||||||
|
it lives in E2E-encrypted blobs in S3, referenced by storage_records.
|
||||||
|
|
||||||
Ownership:
|
Table inventory:
|
||||||
Auth Service → users, refresh_tokens, subscriptions
|
users — account credentials + tier
|
||||||
Chat Service → memory_core, memory_associative, memory_episodic, memory_proactive
|
refresh_tokens — hashed refresh token store
|
||||||
Batch Agent → local_agent_configs, cloud_agent_configs, agent_run_logs
|
subscriptions — Stripe subscription records
|
||||||
Billing Service → subscriptions (shared write with Auth)
|
storage_records — S3 blob metadata (no plaintext)
|
||||||
(excluded MVP) → storage_records, backup_metadata, plugins, plugin_*, revenue_events
|
backup_metadata — encrypted backup manifests
|
||||||
|
plugins — marketplace plugin catalog
|
||||||
|
plugin_installations — per-user install records
|
||||||
|
plugin_reviews — admin review decisions
|
||||||
|
revenue_events — Stripe Connect 70/30 split ledger
|
||||||
|
memory_core — per-user persistent key/value preferences (encrypted)
|
||||||
|
memory_associative — per-user semantic memory with embeddings (encrypted)
|
||||||
|
memory_episodic — per-user session summaries (encrypted)
|
||||||
|
memory_proactive — per-user behavioral patterns (encrypted)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -33,7 +42,7 @@ from sqlalchemy import (
|
|||||||
)
|
)
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
from shared.db import Base
|
from app.db import Base
|
||||||
|
|
||||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
@@ -56,7 +65,7 @@ AgentStatusEnum = Enum("running", "success", "error", "partial", name="agent_run
|
|||||||
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
CloudProviderEnum = Enum("gmail", "teams", "outlook", name="cloud_provider")
|
||||||
|
|
||||||
|
|
||||||
# ── Auth models ───────────────────────────────────────────────────────────
|
# ── Models ────────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
@@ -71,6 +80,8 @@ class User(Base):
|
|||||||
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
password_hash: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
tier: Mapped[str] = mapped_column(TierEnum, nullable=False, default="free")
|
||||||
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
stripe_customer_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
|
# Per-user Fernet key (base64-urlsafe, 44 chars). Generated on registration.
|
||||||
|
# Used to encrypt/decrypt all memory rows for this user.
|
||||||
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
encryption_key: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now()
|
||||||
@@ -126,9 +137,6 @@ class Subscription(Base):
|
|||||||
user: Mapped[User] = relationship(back_populates="subscription")
|
user: Mapped[User] = relationship(back_populates="subscription")
|
||||||
|
|
||||||
|
|
||||||
# ── Storage models (excluded from MVP, kept for Alembic) ──────────────
|
|
||||||
|
|
||||||
|
|
||||||
class StorageRecord(Base):
|
class StorageRecord(Base):
|
||||||
__tablename__ = "storage_records"
|
__tablename__ = "storage_records"
|
||||||
|
|
||||||
@@ -169,9 +177,6 @@ class BackupMetadata(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Plugin models (excluded from MVP, kept for Alembic) ───────────────
|
|
||||||
|
|
||||||
|
|
||||||
class Plugin(Base):
|
class Plugin(Base):
|
||||||
__tablename__ = "plugins"
|
__tablename__ = "plugins"
|
||||||
|
|
||||||
@@ -179,13 +184,14 @@ class Plugin(Base):
|
|||||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
description: Mapped[str] = mapped_column(Text, nullable=False, default="")
|
||||||
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
version: Mapped[str] = mapped_column(String(50), nullable=False, default="1.0.0")
|
||||||
|
# nullable until developer account system is built
|
||||||
author_id: Mapped[str | None] = mapped_column(
|
author_id: Mapped[str | None] = mapped_column(
|
||||||
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="SET NULL"), nullable=True
|
||||||
)
|
)
|
||||||
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
author_name: Mapped[str] = mapped_column(String(255), nullable=False, default="")
|
||||||
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
category: Mapped[str] = mapped_column(String(100), nullable=False, default="")
|
||||||
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
price_cents: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]")
|
permissions: Mapped[str] = mapped_column(Text, nullable=False, default="[]") # JSON list
|
||||||
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
status: Mapped[str] = mapped_column(PluginStatusEnum, nullable=False, default="pending_review")
|
||||||
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
s3_package_key: Mapped[str | None] = mapped_column(String(500), nullable=True)
|
||||||
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
install_count: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||||
@@ -276,9 +282,6 @@ class RevenueEvent(Base):
|
|||||||
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
plugin: Mapped[Plugin] = relationship(back_populates="revenue_events")
|
||||||
|
|
||||||
|
|
||||||
# ── Agent models ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class LocalAgentConfig(Base):
|
class LocalAgentConfig(Base):
|
||||||
__tablename__ = "local_agent_configs"
|
__tablename__ = "local_agent_configs"
|
||||||
|
|
||||||
@@ -353,6 +356,8 @@ class AgentRunLog(Base):
|
|||||||
id: Mapped[str] = mapped_column(
|
id: Mapped[str] = mapped_column(
|
||||||
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
Uuid(as_uuid=False), primary_key=True, default=_uuid
|
||||||
)
|
)
|
||||||
|
# Plain string — not a FK because it references either local_agent_configs or cloud_agent_configs
|
||||||
|
# depending on agent_type. Query by (agent_id, agent_type) to locate the source config.
|
||||||
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
|
||||||
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
|
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
|
||||||
user_id: Mapped[str] = mapped_column(
|
user_id: Mapped[str] = mapped_column(
|
||||||
@@ -381,11 +386,15 @@ class AgentRunLog(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Memory models ─────────────────────────────────────────────────────────
|
# ── Memory models ─────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
class MemoryCore(Base):
|
class MemoryCore(Base):
|
||||||
"""Per-user persistent key/value preferences, encrypted at rest."""
|
"""Per-user persistent key/value preferences, encrypted at rest.
|
||||||
|
|
||||||
|
Examples: preferred_language, timezone, work_style.
|
||||||
|
Decrypted in-memory only using User.encryption_key.
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "memory_core"
|
__tablename__ = "memory_core"
|
||||||
|
|
||||||
@@ -402,7 +411,11 @@ class MemoryCore(Base):
|
|||||||
|
|
||||||
|
|
||||||
class MemoryAssociative(Base):
|
class MemoryAssociative(Base):
|
||||||
"""Per-user semantic memory: encrypted content + pgvector embedding."""
|
"""Per-user semantic memory: encrypted content + pgvector embedding for similarity search.
|
||||||
|
|
||||||
|
Production: ``embedding`` column is ``vector(1536)`` via pgvector.
|
||||||
|
Tests (SQLite): stored as JSON list.
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "memory_associative"
|
__tablename__ = "memory_associative"
|
||||||
|
|
||||||
@@ -412,6 +425,7 @@ class MemoryAssociative(Base):
|
|||||||
nullable=False, index=True,
|
nullable=False, index=True,
|
||||||
)
|
)
|
||||||
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
content_encrypted: Mapped[str] = mapped_column(Text, nullable=False)
|
||||||
|
# JSON-encoded float list in SQLite tests; vector(1536) in Postgres via migration.
|
||||||
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
embedding: Mapped[list | None] = mapped_column(JSON, nullable=True)
|
||||||
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
entity_type: Mapped[str | None] = mapped_column(String(100), nullable=True)
|
||||||
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
entity_id: Mapped[str | None] = mapped_column(String(255), nullable=True)
|
||||||
@@ -421,7 +435,10 @@ class MemoryAssociative(Base):
|
|||||||
|
|
||||||
|
|
||||||
class MemoryEpisodic(Base):
|
class MemoryEpisodic(Base):
|
||||||
"""Per-user session summaries, encrypted at rest."""
|
"""Per-user session summaries, encrypted at rest.
|
||||||
|
|
||||||
|
One row per session interaction; used to recall recent conversations.
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "memory_episodic"
|
__tablename__ = "memory_episodic"
|
||||||
|
|
||||||
@@ -438,7 +455,11 @@ class MemoryEpisodic(Base):
|
|||||||
|
|
||||||
|
|
||||||
class MemoryProactive(Base):
|
class MemoryProactive(Base):
|
||||||
"""Per-user inferred behavioral patterns, encrypted at rest."""
|
"""Per-user inferred behavioral patterns, encrypted at rest.
|
||||||
|
|
||||||
|
Confidence in [0.0, 1.0]; only patterns above threshold are injected.
|
||||||
|
Source: 'inferred' (from episodes) or 'explicit' (user-stated).
|
||||||
|
"""
|
||||||
|
|
||||||
__tablename__ = "memory_proactive"
|
__tablename__ = "memory_proactive"
|
||||||
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
"""Pydantic schemas — API request/response contracts.
|
"""Pydantic schemas — API request/response contracts.
|
||||||
|
|
||||||
Shared across all services. Mirrors the TypeScript types from
|
Mirrors the TypeScript types from the Electron app (src/shared/api-types.ts).
|
||||||
the Electron app (src/shared/api-types.ts).
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
@@ -87,7 +86,7 @@ class StorageRecordUpdate(BaseModel):
|
|||||||
|
|
||||||
class VectorItem(BaseModel):
|
class VectorItem(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
blob: bytes
|
blob: bytes # encrypted vector + metadata — backend never decrypts
|
||||||
checksum: str
|
checksum: str
|
||||||
|
|
||||||
|
|
||||||
@@ -96,7 +95,7 @@ class VectorUpsertRequest(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class VectorSearchRequest(BaseModel):
|
class VectorSearchRequest(BaseModel):
|
||||||
query_blob: bytes
|
query_blob: bytes # encrypted query — backend never decrypts
|
||||||
top_k: int = 10
|
top_k: int = 10
|
||||||
|
|
||||||
|
|
||||||
@@ -143,6 +142,9 @@ class WsFrameType(str, Enum):
|
|||||||
tool_result = "tool_result"
|
tool_result = "tool_result"
|
||||||
final = "final"
|
final = "final"
|
||||||
ping = "ping"
|
ping = "ping"
|
||||||
|
agent_run = "agent_run"
|
||||||
|
agent_data = "agent_data"
|
||||||
|
agent_complete = "agent_complete"
|
||||||
device_hello = "device_hello"
|
device_hello = "device_hello"
|
||||||
# ── v3 frame types ─────────────────────────────────────────────────
|
# ── v3 frame types ─────────────────────────────────────────────────
|
||||||
home_request = "home_request"
|
home_request = "home_request"
|
||||||
@@ -154,10 +156,6 @@ class WsFrameType(str, Enum):
|
|||||||
data_request = "data_request"
|
data_request = "data_request"
|
||||||
data_response = "data_response"
|
data_response = "data_response"
|
||||||
mutation = "mutation"
|
mutation = "mutation"
|
||||||
# ── v4 journey frame types ────────────────────────────────────────
|
|
||||||
journey_start = "journey_start"
|
|
||||||
journey_message = "journey_message"
|
|
||||||
journey_reply = "journey_reply"
|
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -210,10 +208,36 @@ class WsDeviceHello(BaseModel):
|
|||||||
agent_ids: list[str] = Field(default_factory=list)
|
agent_ids: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentRun(BaseModel):
|
||||||
|
"""Server → Client: trigger an agent run on the connected device."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_run] = WsFrameType.agent_run
|
||||||
|
run_id: str
|
||||||
|
agent_id: str
|
||||||
|
config: dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentData(BaseModel):
|
||||||
|
"""Client → Server: files read by the local agent."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_data] = WsFrameType.agent_data
|
||||||
|
run_id: str
|
||||||
|
files: list[dict[str, Any]]
|
||||||
|
|
||||||
|
|
||||||
|
class WsAgentComplete(BaseModel):
|
||||||
|
"""Client → Server: Electron signals it has finished reading files."""
|
||||||
|
|
||||||
|
type: Literal[WsFrameType.agent_complete] = WsFrameType.agent_complete
|
||||||
|
run_id: str
|
||||||
|
files_read: int
|
||||||
|
errors: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
# ── WebSocket v3 Frame Models ─────────────────────────────────────────
|
||||||
|
|
||||||
class WsFloatingScope(BaseModel):
|
class WsFloatingScope(BaseModel):
|
||||||
"""Scope for a floating request."""
|
"""Scope for a floating request — narrows the agent to a specific entity."""
|
||||||
|
|
||||||
type: Literal["task", "project", "note", "timeline"]
|
type: Literal["task", "project", "note", "timeline"]
|
||||||
id: str | None = None
|
id: str | None = None
|
||||||
@@ -255,14 +279,7 @@ class WsStreamEnd(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
type: Literal[WsFrameType.stream_end] = WsFrameType.stream_end
|
||||||
request_id: str
|
request_id: str
|
||||||
|
mutations: list[dict[str, Any]] = Field(default_factory=list)
|
||||||
|
|
||||||
class WsDomain(BaseModel):
|
|
||||||
"""Structured floating domain payload for UI routing decisions."""
|
|
||||||
|
|
||||||
type: Literal["task", "timeline", "project", "node"]
|
|
||||||
id: str | None = None
|
|
||||||
section: Literal["task", "timeline", "note"] | None = None
|
|
||||||
|
|
||||||
|
|
||||||
class WsFloatingDomain(BaseModel):
|
class WsFloatingDomain(BaseModel):
|
||||||
@@ -270,7 +287,7 @@ class WsFloatingDomain(BaseModel):
|
|||||||
|
|
||||||
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
|
||||||
request_id: str
|
request_id: str
|
||||||
domain: WsDomain
|
domain: Literal["tasks", "timelines", "notes", "projects"]
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Catalog ─────────────────────────────────────────────────────
|
# ── Agent Catalog ─────────────────────────────────────────────────────
|
||||||
@@ -279,28 +296,84 @@ class AgentCatalogItem(BaseModel):
|
|||||||
type: str
|
type: str
|
||||||
name: str
|
name: str
|
||||||
description: str
|
description: str
|
||||||
|
config_schema: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckRequest(BaseModel):
|
# ── Local Agent Config ────────────────────────────────────────────────
|
||||||
active_agents: int = Field(ge=0, default=0)
|
|
||||||
|
class LocalAgentConfigCreate(BaseModel):
|
||||||
|
name: str
|
||||||
|
device_id: str
|
||||||
|
directory_paths: list[str]
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
file_extensions: list[str]
|
||||||
|
schedule_cron: str
|
||||||
|
|
||||||
|
|
||||||
class AgentCreationCheckResponse(BaseModel):
|
class LocalAgentConfigUpdate(BaseModel):
|
||||||
allowed: bool
|
name: str | None = None
|
||||||
tier: BillingTier
|
device_id: str | None = None
|
||||||
active_agents: int
|
directory_paths: list[str] | None = None
|
||||||
limit: int
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
file_extensions: list[str] | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
class AgentTriggerRequest(BaseModel):
|
class LocalAgentConfigResponse(BaseModel):
|
||||||
directory: str = Field(min_length=1)
|
id: str
|
||||||
device_id: str = Field(default="")
|
name: str
|
||||||
agent_id: str | None = None
|
device_id: str
|
||||||
what_to_extract: list[str] = Field(min_length=1)
|
directory_paths: list[str]
|
||||||
actions_by_type: dict[str, list[str]] | None = None
|
data_types: list[str]
|
||||||
batch_interval: str = Field(min_length=1)
|
prompt_template: str
|
||||||
custom_agent_prompt: str = Field(min_length=1)
|
file_extensions: list[str]
|
||||||
active_agents: int = Field(ge=0, default=0)
|
schedule_cron: str
|
||||||
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud Agent Config ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class CloudAgentConfigCreate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
oauth_token_encrypted: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigUpdate(BaseModel):
|
||||||
|
provider: Literal["gmail", "teams", "outlook"] | None = None
|
||||||
|
name: str | None = None
|
||||||
|
data_types: list[str] | None = None
|
||||||
|
prompt_template: str | None = None
|
||||||
|
oauth_token_encrypted: str | None = None
|
||||||
|
schedule_cron: str | None = None
|
||||||
|
filter_config: dict[str, Any] | None = None
|
||||||
|
enabled: bool | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class CloudAgentConfigResponse(BaseModel):
|
||||||
|
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
provider: Literal["gmail", "teams", "outlook"]
|
||||||
|
name: str
|
||||||
|
data_types: list[str]
|
||||||
|
prompt_template: str
|
||||||
|
schedule_cron: str
|
||||||
|
filter_config: dict[str, Any] | None
|
||||||
|
enabled: bool
|
||||||
|
last_run_at: int | None
|
||||||
|
created_at: int
|
||||||
|
updated_at: int
|
||||||
|
|
||||||
|
|
||||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||||
@@ -315,3 +388,22 @@ class AgentRunLogResponse(BaseModel):
|
|||||||
errors: list[str]
|
errors: list[str]
|
||||||
started_at: int
|
started_at: int
|
||||||
completed_at: int | None
|
completed_at: int | None
|
||||||
|
|
||||||
|
|
||||||
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
class JourneyStartRequest(BaseModel):
|
||||||
|
agent_type: Literal["local", "cloud"]
|
||||||
|
agent_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyMessageRequest(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
|
||||||
|
|
||||||
|
class JourneyResponse(BaseModel):
|
||||||
|
session_id: str
|
||||||
|
message: str
|
||||||
|
done: bool
|
||||||
|
prompt_template: str | None = None
|
||||||
1
app/storage/__init__.py
Normal file
1
app/storage/__init__.py
Normal file
@@ -0,0 +1 @@
|
|||||||
|
"""Cloud storage layer — E2E encrypted blobs and vectors."""
|
||||||
106
app/storage/blob_store.py
Normal file
106
app/storage/blob_store.py
Normal file
@@ -0,0 +1,106 @@
|
|||||||
|
"""S3-backed store for E2E-encrypted blobs.
|
||||||
|
|
||||||
|
Keys are structured as ``{user_id}/{table}/{record_id}``.
|
||||||
|
The backend never inspects blob content — it stores and retrieves opaque bytes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import boto3
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
|
||||||
|
|
||||||
|
class BlobStore:
|
||||||
|
"""Thin wrapper around boto3 S3.
|
||||||
|
|
||||||
|
All blobs must be E2E encrypted by the client before upload.
|
||||||
|
The backend adds SSE-S3 as an extra layer of at-rest encryption
|
||||||
|
but cannot decrypt the inner client-side payload.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _client(self) -> Any:
|
||||||
|
kwargs: dict[str, Any] = {
|
||||||
|
"region_name": settings.S3_REGION,
|
||||||
|
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
|
||||||
|
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
|
||||||
|
}
|
||||||
|
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
|
||||||
|
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
|
||||||
|
return boto3.client("s3", **kwargs)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _key(user_id: str, table: str, record_id: str) -> str:
|
||||||
|
return f"{user_id}/{table}/{record_id}"
|
||||||
|
|
||||||
|
async def upload(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
table: str,
|
||||||
|
record_id: str,
|
||||||
|
blob: bytes,
|
||||||
|
checksum: str,
|
||||||
|
) -> str:
|
||||||
|
"""Store *blob* in S3 and return the S3 key.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Owner of the blob (used as key prefix).
|
||||||
|
table: Logical table name (e.g. ``"tasks"``).
|
||||||
|
record_id: Record UUID.
|
||||||
|
blob: Raw bytes (pre-encrypted by client).
|
||||||
|
checksum: SHA-256 hex digest supplied by the client; stored as
|
||||||
|
object metadata for download-time verification.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The S3 key under which the blob was stored.
|
||||||
|
"""
|
||||||
|
key = self._key(user_id, table, record_id)
|
||||||
|
self._client().put_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=key,
|
||||||
|
Body=blob,
|
||||||
|
ServerSideEncryption="AES256", # SSE-S3 at rest
|
||||||
|
Metadata={"checksum": checksum},
|
||||||
|
)
|
||||||
|
return key
|
||||||
|
|
||||||
|
async def download(self, user_id: str, s3_key: str) -> bytes:
|
||||||
|
"""Retrieve the blob stored at *s3_key*.
|
||||||
|
|
||||||
|
*user_id* is retained in the signature so higher-level code can
|
||||||
|
enforce ownership without re-parsing the key.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
|
||||||
|
object does not exist.
|
||||||
|
"""
|
||||||
|
response = self._client().get_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
return response["Body"].read()
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, s3_key: str) -> None:
|
||||||
|
"""Delete the object at *s3_key*.
|
||||||
|
|
||||||
|
S3 ``delete_object`` is idempotent — it succeeds even if the key does
|
||||||
|
not exist.
|
||||||
|
"""
|
||||||
|
self._client().delete_object(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Key=s3_key,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def list_keys(self, user_id: str, table: str) -> list[str]:
|
||||||
|
"""Return all S3 keys for a given user + table combination.
|
||||||
|
|
||||||
|
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
|
||||||
|
"""
|
||||||
|
prefix = f"{user_id}/{table}/"
|
||||||
|
response = self._client().list_objects_v2(
|
||||||
|
Bucket=settings.S3_BUCKET,
|
||||||
|
Prefix=prefix,
|
||||||
|
)
|
||||||
|
return [obj["Key"] for obj in response.get("Contents", [])]
|
||||||
32
app/storage/encryption.py
Normal file
32
app/storage/encryption.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Integrity verification only — the backend NEVER decrypts user data."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import hashlib
|
||||||
|
import hmac
|
||||||
|
|
||||||
|
from fastapi import HTTPException
|
||||||
|
|
||||||
|
|
||||||
|
def verify_checksum(blob: bytes, checksum: str) -> bool:
|
||||||
|
"""Return ``True`` if SHA-256(blob) matches *checksum*.
|
||||||
|
|
||||||
|
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
|
||||||
|
timing-based side-channel attacks.
|
||||||
|
"""
|
||||||
|
computed = hashlib.sha256(blob).hexdigest()
|
||||||
|
return hmac.compare_digest(computed, checksum)
|
||||||
|
|
||||||
|
|
||||||
|
def reject_if_tampered(blob: bytes, checksum: str) -> None:
|
||||||
|
"""Raise ``HTTP 400`` if the blob does not match its checksum.
|
||||||
|
|
||||||
|
Call this before storing or forwarding any client-provided blob.
|
||||||
|
The backend never holds decryption keys — this check only verifies
|
||||||
|
that the opaque bytes arrived intact.
|
||||||
|
"""
|
||||||
|
if not verify_checksum(blob, checksum):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=400,
|
||||||
|
detail="Checksum mismatch: blob integrity check failed",
|
||||||
|
)
|
||||||
205
app/storage/vector_store.py
Normal file
205
app/storage/vector_store.py
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
|
||||||
|
|
||||||
|
Vectors are pre-encrypted blobs from the client. The backend stores them
|
||||||
|
alongside a deterministic 32-dim float representation derived from the blob's
|
||||||
|
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
|
||||||
|
is a known trade-off documented in the backend plan.
|
||||||
|
|
||||||
|
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
|
||||||
|
``user_id`` payload field on a shared collection.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import base64
|
||||||
|
import hashlib
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from pinecone import Pinecone
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
|
||||||
|
|
||||||
|
from app.config.settings import settings
|
||||||
|
from app.schemas import VectorItem, VectorSearchResult
|
||||||
|
|
||||||
|
_QDRANT_COLLECTION = "adiuva_vectors"
|
||||||
|
|
||||||
|
|
||||||
|
def _blob_to_vector(blob: bytes) -> list[float]:
|
||||||
|
"""Derive a 32-dim float vector from *blob* for storage purposes only.
|
||||||
|
|
||||||
|
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
|
||||||
|
normalises each byte to the range [-1.0, 1.0]. This vector carries no
|
||||||
|
semantic meaning on encrypted data.
|
||||||
|
"""
|
||||||
|
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
|
||||||
|
|
||||||
|
|
||||||
|
class VectorStore:
|
||||||
|
"""Thin wrapper around Pinecone or Qdrant.
|
||||||
|
|
||||||
|
The backend to use is selected at runtime:
|
||||||
|
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
|
||||||
|
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _use_pinecone(self) -> bool:
|
||||||
|
return bool(settings.PINECONE_API_KEY)
|
||||||
|
|
||||||
|
# ── Pinecone helpers ──────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _pinecone_index(self) -> Any:
|
||||||
|
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
|
||||||
|
return pc.Index(settings.PINECONE_INDEX)
|
||||||
|
|
||||||
|
# ── Qdrant helpers ────────────────────────────────────────────────
|
||||||
|
|
||||||
|
def _qdrant_client(self) -> Any:
|
||||||
|
return QdrantClient(
|
||||||
|
url=settings.QDRANT_URL,
|
||||||
|
api_key=settings.QDRANT_API_KEY or None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Public API ────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
"""Store encrypted vectors in the backend.
|
||||||
|
|
||||||
|
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
|
||||||
|
so it can be returned verbatim during search.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Used as Pinecone namespace or Qdrant payload field.
|
||||||
|
vectors: List of encrypted vector items from the client.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_upsert(user_id, vectors)
|
||||||
|
else:
|
||||||
|
await self._qdrant_upsert(user_id, vectors)
|
||||||
|
|
||||||
|
async def search(
|
||||||
|
self,
|
||||||
|
user_id: str,
|
||||||
|
query_blob: bytes,
|
||||||
|
top_k: int,
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
"""Query the vector store and return encrypted result blobs.
|
||||||
|
|
||||||
|
The query vector is derived from *query_blob* using the same
|
||||||
|
deterministic mapping as upsert.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Scopes the search to this user's namespace.
|
||||||
|
query_blob: Encrypted query from the client.
|
||||||
|
top_k: Maximum number of results to return.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
return await self._pinecone_search(user_id, query_blob, top_k)
|
||||||
|
return await self._qdrant_search(user_id, query_blob, top_k)
|
||||||
|
|
||||||
|
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
"""Remove vectors by ID, scoped to *user_id*.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_id: Namespace / payload filter to prevent cross-user deletion.
|
||||||
|
vector_ids: List of vector IDs to remove.
|
||||||
|
"""
|
||||||
|
if self._use_pinecone():
|
||||||
|
await self._pinecone_delete(user_id, vector_ids)
|
||||||
|
else:
|
||||||
|
await self._qdrant_delete(user_id, vector_ids)
|
||||||
|
|
||||||
|
# ── Pinecone implementation ───────────────────────────────────────
|
||||||
|
|
||||||
|
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
records = [
|
||||||
|
{
|
||||||
|
"id": v.id,
|
||||||
|
"values": _blob_to_vector(v.blob),
|
||||||
|
"metadata": {
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
index.upsert(vectors=records, namespace=user_id)
|
||||||
|
|
||||||
|
async def _pinecone_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
response = index.query(
|
||||||
|
vector=query_vector,
|
||||||
|
top_k=top_k,
|
||||||
|
namespace=user_id,
|
||||||
|
include_metadata=True,
|
||||||
|
)
|
||||||
|
results: list[VectorSearchResult] = []
|
||||||
|
for match in response.get("matches", []):
|
||||||
|
blob_bytes = base64.b64decode(match["metadata"]["blob"])
|
||||||
|
results.append(
|
||||||
|
VectorSearchResult(
|
||||||
|
id=match["id"],
|
||||||
|
score=match["score"],
|
||||||
|
blob=blob_bytes,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return results
|
||||||
|
|
||||||
|
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
index = self._pinecone_index()
|
||||||
|
index.delete(ids=vector_ids, namespace=user_id)
|
||||||
|
|
||||||
|
# ── Qdrant implementation ─────────────────────────────────────────
|
||||||
|
|
||||||
|
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
points = [
|
||||||
|
PointStruct(
|
||||||
|
id=v.id,
|
||||||
|
vector=_blob_to_vector(v.blob),
|
||||||
|
payload={
|
||||||
|
"blob": base64.b64encode(v.blob).decode(),
|
||||||
|
"checksum": v.checksum,
|
||||||
|
"user_id": user_id,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
for v in vectors
|
||||||
|
]
|
||||||
|
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
|
||||||
|
|
||||||
|
async def _qdrant_search(
|
||||||
|
self, user_id: str, query_blob: bytes, top_k: int
|
||||||
|
) -> list[VectorSearchResult]:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
query_vector = _blob_to_vector(query_blob)
|
||||||
|
hits = client.search(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
query_vector=query_vector,
|
||||||
|
query_filter=Filter(
|
||||||
|
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
|
||||||
|
),
|
||||||
|
limit=top_k,
|
||||||
|
)
|
||||||
|
return [
|
||||||
|
VectorSearchResult(
|
||||||
|
id=str(hit.id),
|
||||||
|
score=hit.score,
|
||||||
|
blob=base64.b64decode(hit.payload["blob"]),
|
||||||
|
)
|
||||||
|
for hit in hits
|
||||||
|
]
|
||||||
|
|
||||||
|
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
|
||||||
|
client = self._qdrant_client()
|
||||||
|
client.delete(
|
||||||
|
collection_name=_QDRANT_COLLECTION,
|
||||||
|
points_selector=PointIdsList(points=vector_ids),
|
||||||
|
)
|
||||||
@@ -1,34 +1,27 @@
|
|||||||
# ── Adiuva Microservices ─────────────────────────────────────────────
|
|
||||||
# docker compose up --build
|
|
||||||
# docker compose up --build auth ws-gateway chat # subset
|
|
||||||
|
|
||||||
services:
|
services:
|
||||||
|
app:
|
||||||
# ═══════════════════════════════════════════════════════════════════
|
build: .
|
||||||
# Infrastructure
|
|
||||||
# ═══════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
traefik:
|
|
||||||
image: traefik:v3.1
|
|
||||||
ports:
|
ports:
|
||||||
- "80:80"
|
- "8080:8000"
|
||||||
- "443:443"
|
env_file:
|
||||||
- "8080:8080" # dashboard (dev only)
|
- path: .env
|
||||||
|
required: false
|
||||||
environment:
|
environment:
|
||||||
CF_DNS_API_TOKEN: ${CF_DNS_API_TOKEN:-}
|
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
||||||
|
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
|
||||||
volumes:
|
volumes:
|
||||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
- copilot_tokens:/root/.config/litellm/github_copilot
|
||||||
- ./traefik/traefik.yml:/etc/traefik/traefik.yml:ro
|
depends_on:
|
||||||
- ./traefik/dynamic:/etc/traefik/dynamic:ro
|
db:
|
||||||
- traefik_acme:/etc/traefik/acme
|
condition: service_healthy
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
db:
|
db:
|
||||||
image: pgvector/pgvector:pg16
|
image: pgvector/pgvector:pg16
|
||||||
environment:
|
environment:
|
||||||
POSTGRES_USER: ${POSTGRES_USER:-postgres}
|
POSTGRES_USER: postgres
|
||||||
POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-postgres}
|
POSTGRES_PASSWORD: postgres
|
||||||
POSTGRES_DB: ${POSTGRES_DB:-adiuva}
|
POSTGRES_DB: adiuva
|
||||||
volumes:
|
volumes:
|
||||||
- postgres_data:/var/lib/postgresql/data
|
- postgres_data:/var/lib/postgresql/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
@@ -38,161 +31,42 @@ services:
|
|||||||
retries: 5
|
retries: 5
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
redis:
|
# Optional Redis for future rate-limit or caching needs
|
||||||
image: redis:7-alpine
|
# redis:
|
||||||
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
# image: redis:7-alpine
|
||||||
|
# restart: unless-stopped
|
||||||
|
|
||||||
|
# ── Local S3-compatible storage (MinIO) ──
|
||||||
|
minio:
|
||||||
|
image: minio/minio:latest
|
||||||
|
command: server /data --console-address ":9001"
|
||||||
|
ports:
|
||||||
|
- "9000:9000"
|
||||||
|
- "9001:9001"
|
||||||
|
environment:
|
||||||
|
MINIO_ROOT_USER: minioadmin
|
||||||
|
MINIO_ROOT_PASSWORD: minioadmin
|
||||||
volumes:
|
volumes:
|
||||||
- redis_data:/data
|
- minio_data:/data
|
||||||
healthcheck:
|
healthcheck:
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
test: ["CMD", "mc", "ready", "local"]
|
||||||
interval: 5s
|
interval: 5s
|
||||||
timeout: 3s
|
timeout: 5s
|
||||||
retries: 5
|
retries: 5
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
# ── Optional infrastructure (uncomment as needed) ────────────────
|
# ── Local vector store (Qdrant) ──
|
||||||
|
qdrant:
|
||||||
# minio:
|
image: qdrant/qdrant:latest
|
||||||
# image: minio/minio:latest
|
ports:
|
||||||
# command: server /data --console-address ":9001"
|
- "6333:6333"
|
||||||
# ports:
|
- "6334:6334"
|
||||||
# - "9000:9000"
|
volumes:
|
||||||
# - "9001:9001"
|
- qdrant_data:/qdrant/storage
|
||||||
# environment:
|
|
||||||
# MINIO_ROOT_USER: minioadmin
|
|
||||||
# MINIO_ROOT_PASSWORD: minioadmin
|
|
||||||
# volumes:
|
|
||||||
# - minio_data:/data
|
|
||||||
# healthcheck:
|
|
||||||
# test: ["CMD", "mc", "ready", "local"]
|
|
||||||
# interval: 5s
|
|
||||||
# timeout: 5s
|
|
||||||
# retries: 5
|
|
||||||
# restart: unless-stopped
|
|
||||||
|
|
||||||
# qdrant:
|
|
||||||
# image: qdrant/qdrant:latest
|
|
||||||
# ports:
|
|
||||||
# - "6333:6333"
|
|
||||||
# - "6334:6334"
|
|
||||||
# volumes:
|
|
||||||
# - qdrant_data:/qdrant/storage
|
|
||||||
# restart: unless-stopped
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════
|
|
||||||
# Migrations (run once, then exit)
|
|
||||||
# ═══════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
migrate:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: Dockerfile
|
|
||||||
command: ["python", "-m", "alembic", "upgrade", "head"]
|
|
||||||
env_file:
|
|
||||||
- path: .env
|
|
||||||
required: false
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
restart: "no"
|
|
||||||
|
|
||||||
# ═══════════════════════════════════════════════════════════════════
|
|
||||||
# Application Services
|
|
||||||
# ═══════════════════════════════════════════════════════════════════
|
|
||||||
|
|
||||||
auth:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: services/auth/Dockerfile
|
|
||||||
env_file:
|
|
||||||
- path: .env
|
|
||||||
required: false
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
|
||||||
REDIS_URL: redis://redis:6379/0
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
migrate:
|
|
||||||
condition: service_completed_successfully
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
ws-gateway:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: services/ws-gateway/Dockerfile
|
|
||||||
env_file:
|
|
||||||
- path: .env
|
|
||||||
required: false
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
|
||||||
REDIS_URL: redis://redis:6379/0
|
|
||||||
depends_on:
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
auth:
|
|
||||||
condition: service_started
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
chat:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: services/chat/Dockerfile
|
|
||||||
env_file:
|
|
||||||
- path: .env
|
|
||||||
required: false
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
|
||||||
REDIS_URL: redis://redis:6379/0
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
migrate:
|
|
||||||
condition: service_completed_successfully
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
batch-agent:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: services/batch-agent/Dockerfile
|
|
||||||
env_file:
|
|
||||||
- path: .env
|
|
||||||
required: false
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
|
||||||
REDIS_URL: redis://redis:6379/0
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
migrate:
|
|
||||||
condition: service_completed_successfully
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
billing:
|
|
||||||
build:
|
|
||||||
context: .
|
|
||||||
dockerfile: services/billing/Dockerfile
|
|
||||||
env_file:
|
|
||||||
- path: .env
|
|
||||||
required: false
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
migrate:
|
|
||||||
condition: service_completed_successfully
|
|
||||||
restart: unless-stopped
|
restart: unless-stopped
|
||||||
|
|
||||||
volumes:
|
volumes:
|
||||||
postgres_data:
|
postgres_data:
|
||||||
redis_data:
|
minio_data:
|
||||||
traefik_acme:
|
qdrant_data:
|
||||||
# minio_data:
|
copilot_tokens:
|
||||||
# qdrant_data:
|
|
||||||
|
|||||||
@@ -1,941 +0,0 @@
|
|||||||
# Adiuva — Architettura Microservizi (MVP)
|
|
||||||
|
|
||||||
## Panoramica
|
|
||||||
|
|
||||||
Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
|
|
||||||
|
|
||||||
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐
|
|
||||||
│ Cloudflare │
|
|
||||||
│ (DNS + CDN) │
|
|
||||||
└──────┬───────┘
|
|
||||||
│ HTTPS / WSS
|
|
||||||
┌──────▼───────┐
|
|
||||||
│ Traefik │
|
|
||||||
│ API Gateway │
|
|
||||||
│ (routing, │
|
|
||||||
│ TLS, rate │
|
|
||||||
│ limiting) │
|
|
||||||
└──────┬───────┘
|
|
||||||
│
|
|
||||||
┌──────────┬───────────┼───────────┐
|
|
||||||
│ │ │ │
|
|
||||||
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
|
|
||||||
│ Auth │ │ Chat │ │ Agent │ │Billing │
|
|
||||||
│ Service │ │Service│ │ Service │ │Service │
|
|
||||||
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
|
|
||||||
│ │ │ │
|
|
||||||
┌─────▼──────────▼──────────▼───────────▼────┐
|
|
||||||
│ Infrastruttura │
|
|
||||||
│ PostgreSQL │ Redis │ Qdrant │
|
|
||||||
└─────────────────────────────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 1. Suddivisione dei Servizi
|
|
||||||
|
|
||||||
### 1.1 Auth Service (`auth-service`)
|
|
||||||
|
|
||||||
**Responsabilità**: Registrazione, login, refresh token, profilo utente, encryption key.
|
|
||||||
|
|
||||||
| Endpoint originale | Metodo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/auth/register` | POST |
|
|
||||||
| `/api/v1/auth/login` | POST |
|
|
||||||
| `/api/v1/auth/refresh` | POST |
|
|
||||||
| `/api/v1/auth/me` | GET / PUT |
|
|
||||||
|
|
||||||
**Database**: Tabelle `users`, `refresh_tokens` (PostgreSQL condiviso, schema `auth`).
|
|
||||||
|
|
||||||
**Modifica chiave — JWT con RS256**:
|
|
||||||
Il monolite usa un `SECRET_KEY` simmetrico (HS256). Con i microservizi, passare a **RS256** (asimmetrico):
|
|
||||||
- L'Auth Service firma i JWT con la **chiave privata**.
|
|
||||||
- Tutti gli altri servizi verificano i JWT con la **chiave pubblica** senza mai contattare l'Auth Service.
|
|
||||||
- La chiave pubblica viene esposta via `GET /api/v1/auth/.well-known/jwks.json` oppure montata come volume condiviso.
|
|
||||||
|
|
||||||
```python
|
|
||||||
# auth-service/app/auth/jwt.py
|
|
||||||
from cryptography.hazmat.primitives.asymmetric import rsa
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PRIVATE_KEY = ... # Da env/secret
|
|
||||||
PUBLIC_KEY = ... # Derivata o da env
|
|
||||||
|
|
||||||
def create_access_token(user_id: str, tier: str) -> str:
|
|
||||||
return jwt.encode(
|
|
||||||
{"sub": user_id, "tier": tier, "exp": ...},
|
|
||||||
PRIVATE_KEY,
|
|
||||||
algorithm="RS256",
|
|
||||||
)
|
|
||||||
```
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/auth.py (usato da tutti gli altri servizi)
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PUBLIC_KEY = ... # Volume montato o fetched da JWKS endpoint
|
|
||||||
|
|
||||||
def verify_token(token: str) -> dict:
|
|
||||||
return jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
|
||||||
```
|
|
||||||
|
|
||||||
**Scaling**: 2 repliche sufficienti, stateless. Rate-limit dedicato su `/login` e `/register`.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.2 Chat Service (`chat-service`) ⭐ Real-time
|
|
||||||
|
|
||||||
**Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
|
|
||||||
|
|
||||||
Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
|
|
||||||
|
|
||||||
| Endpoint | Tipo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/ws/device` | WebSocket (connessione persistente) |
|
|
||||||
| `/api/v1/chat` | POST (REST fallback) |
|
|
||||||
|
|
||||||
**Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
|
|
||||||
|
|
||||||
**Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
|
|
||||||
|
|
||||||
**Scaling**: 2–N repliche. Sticky cookies per le WS + Redis per cross-instance.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.3 Agent Service (`agent-service`) ⭐ Batch
|
|
||||||
|
|
||||||
**Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
|
|
||||||
|
|
||||||
Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
|
|
||||||
|
|
||||||
| Endpoint | Tipo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/agents/catalog` | GET |
|
|
||||||
| `/api/v1/agents/can-create` | POST |
|
|
||||||
| `/api/v1/agents/trigger` | POST |
|
|
||||||
| `/api/v1/agents/journey/start` | POST (o WS relay) |
|
|
||||||
| `/api/v1/agents/journey/message` | POST (o WS relay) |
|
|
||||||
|
|
||||||
**Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
|
|
||||||
|
|
||||||
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────┐ ┌──────────────┐ ┌──────────┐
|
|
||||||
│ Agent Service│ │ Redis │ │ Chat │
|
|
||||||
│ (batch run) │ │ │ │ Service │
|
|
||||||
│ │ │ │ │ (ha WS) │
|
|
||||||
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
|
|
||||||
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
|
|
||||||
│ from │ │ │ │ al │
|
|
||||||
│ device │ │ │ │ device│
|
|
||||||
│ │ │ │ │ via WS│
|
|
||||||
│ │ SUBSCRIBE │ │ PUBLISH │ │
|
|
||||||
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
|
|
||||||
│ risultato │ │ │ │ reply │
|
|
||||||
└──────────────┘ └──────────────┘ └──────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
**Scaling**: 1–N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
|
|
||||||
|
|
||||||
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.4 Billing Service (`billing-service`)
|
|
||||||
|
|
||||||
**Responsabilità**: Stripe checkout, webhook, subscription management.
|
|
||||||
|
|
||||||
| Endpoint originale | Metodo |
|
|
||||||
|---|---|
|
|
||||||
| `/api/v1/billing/checkout` | POST |
|
|
||||||
| `/api/v1/billing/webhook` | POST |
|
|
||||||
| `/api/v1/billing/subscription` | GET / DELETE |
|
|
||||||
|
|
||||||
**Database**: Tabelle `subscriptions` (schema `billing`).
|
|
||||||
|
|
||||||
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
|
|
||||||
|
|
||||||
**Scaling**: 1 replica sufficiente. Basso traffico.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
### 1.5 Servizi esclusi dall'MVP
|
|
||||||
|
|
||||||
I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
|
|
||||||
|
|
||||||
| Servizio | Responsabilità | Note |
|
|
||||||
|---|---|---|
|
|
||||||
| **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
|
|
||||||
| **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 2. Tier Check — Dove e Come
|
|
||||||
|
|
||||||
Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
|
|
||||||
|
|
||||||
### Strategia: Tier nel JWT
|
|
||||||
|
|
||||||
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
|
|
||||||
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"sub": "user_123",
|
|
||||||
"tier": "pro",
|
|
||||||
"exp": 1742515200,
|
|
||||||
"iat": 1742511600
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
Ogni servizio:
|
|
||||||
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
|
|
||||||
2. Legge `payload["tier"]` — **zero chiamate extra**
|
|
||||||
3. Applica le sue regole di enforcement localmente
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/auth.py — dependency FastAPI condivisa
|
|
||||||
from fastapi import Depends, HTTPException, Request
|
|
||||||
from jose import jwt
|
|
||||||
|
|
||||||
PUBLIC_KEY = ...
|
|
||||||
|
|
||||||
class CurrentUser:
|
|
||||||
def __init__(self, user_id: str, tier: str):
|
|
||||||
self.user_id = user_id
|
|
||||||
self.tier = tier
|
|
||||||
|
|
||||||
async def get_current_user(request: Request) -> CurrentUser:
|
|
||||||
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
|
|
||||||
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
|
|
||||||
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
|
|
||||||
|
|
||||||
def require_tier(*allowed_tiers: str):
|
|
||||||
"""Dependency che blocca se il tier non è tra quelli ammessi."""
|
|
||||||
async def check(user: CurrentUser = Depends(get_current_user)):
|
|
||||||
if user.tier not in allowed_tiers:
|
|
||||||
raise HTTPException(403, "Tier insufficient")
|
|
||||||
return user
|
|
||||||
return check
|
|
||||||
```
|
|
||||||
|
|
||||||
### Cosa succede quando il tier cambia (upgrade/downgrade)?
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
|
|
||||||
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
|
|
||||||
│ │ │ Service │ (Redis pub/sub) │ Service │
|
|
||||||
└──────────┘ └──────────┘ └────┬─────┘
|
|
||||||
│
|
|
||||||
UPDATE users
|
|
||||||
SET tier = 'power'
|
|
||||||
│
|
|
||||||
Al prossimo /refresh
|
|
||||||
il JWT conterrà tier='power'
|
|
||||||
```
|
|
||||||
|
|
||||||
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 15–30 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
|
|
||||||
|
|
||||||
### Dove si applica in ciascun servizio
|
|
||||||
|
|
||||||
| Servizio | Enforcement |
|
|
||||||
|---|---|
|
|
||||||
| **Auth Service** | Nessuno (è lui che scrive il tier) |
|
|
||||||
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
|
|
||||||
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
|
|
||||||
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
|
|
||||||
|
|
||||||
### Rate-limit distribuito via Redis
|
|
||||||
|
|
||||||
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
|
|
||||||
|
|
||||||
```python
|
|
||||||
# shared/middleware/rate_limit.py
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
|
|
||||||
class DistributedRateLimiter:
|
|
||||||
def __init__(self, redis: aioredis.Redis):
|
|
||||||
self._redis = redis
|
|
||||||
|
|
||||||
async def check(self, user_id: str, tier: str, service: str) -> bool:
|
|
||||||
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
|
|
||||||
max_req = limits.get(tier, 20)
|
|
||||||
key = f"rate:{service}:{user_id}"
|
|
||||||
|
|
||||||
pipe = self._redis.pipeline()
|
|
||||||
pipe.incr(key)
|
|
||||||
pipe.expire(key, 60)
|
|
||||||
count, _ = await pipe.execute()
|
|
||||||
|
|
||||||
return count <= max_req
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
|
|
||||||
|
|
||||||
`DeviceConnectionManager` è un **singleton in-memory**:
|
|
||||||
|
|
||||||
```python
|
|
||||||
class DeviceConnectionManager:
|
|
||||||
def __init__(self):
|
|
||||||
self._connections: dict[str, DeviceConnection] = {} # ← In-memory!
|
|
||||||
```
|
|
||||||
|
|
||||||
Con N istanze del Chat Service, il device si connette a **una sola** istanza. Quando un'altra istanza deve inviare un `tool_call` a quel device (es. un agent trigger da un'API call), non trova la connessione.
|
|
||||||
|
|
||||||
### La soluzione: Redis Pub/Sub + Registry
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────────────────────────────────────────────────────────┐
|
|
||||||
│ Redis │
|
|
||||||
│ │
|
|
||||||
│ Hash: ws:connections │
|
|
||||||
│ user_123 → instance_A │
|
|
||||||
│ user_456 → instance_B │
|
|
||||||
│ │
|
|
||||||
│ Pub/Sub channels: │
|
|
||||||
│ tool_call:{user_id} → tool call payloads │
|
|
||||||
│ tool_result:{call_id} → tool result payloads │
|
|
||||||
│ stream:{user_id} → text_chunk streaming │
|
|
||||||
└──────────────────────────────────────────────────────────────┘
|
|
||||||
|
|
||||||
Instance A (ha WS di user_123) Instance B (deve chiamare tool su user_123)
|
|
||||||
┌───────────────────────┐ ┌───────────────────────┐
|
|
||||||
│ 1. Sottoscrive a │ │ 1. Lookup Redis Hash │
|
|
||||||
│ tool_call:user_123│ │ → user_123 è su A │
|
|
||||||
│ │ │ │
|
|
||||||
│ 2. Riceve tool_call │◄─────────│ 2. PUBLISH │
|
|
||||||
│ da Redis channel │ │ tool_call:user_123 │
|
|
||||||
│ │ │ {id, action, ...} │
|
|
||||||
│ 3. Invia al device │ │ │
|
|
||||||
│ via WS │ │ 4. SUBSCRIBE │
|
|
||||||
│ │ │ tool_result:{id} │
|
|
||||||
│ 4. Device risponde │ │ │
|
|
||||||
│ tool_result │──────────│► 5. Riceve risultato │
|
|
||||||
│ │ │ │
|
|
||||||
│ 5. PUBLISH │ │ │
|
|
||||||
│ tool_result:{id} │ │ │
|
|
||||||
└───────────────────────┘ └───────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### Implementazione: `RedisDeviceManager`
|
|
||||||
|
|
||||||
```python
|
|
||||||
# chat-service/app/core/device_manager.py
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import redis.asyncio as aioredis
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from fastapi import WebSocket
|
|
||||||
|
|
||||||
INSTANCE_ID = os.environ.get("INSTANCE_ID", os.urandom(8).hex())
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class LocalConnection:
|
|
||||||
ws: WebSocket
|
|
||||||
device_id: str
|
|
||||||
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
|
|
||||||
class RedisDeviceManager:
|
|
||||||
"""Device manager backed by Redis for cross-instance communication."""
|
|
||||||
|
|
||||||
def __init__(self, redis_url: str = "redis://redis:6379"):
|
|
||||||
self._redis = aioredis.from_url(redis_url)
|
|
||||||
self._pubsub = self._redis.pubsub()
|
|
||||||
self._local: dict[str, LocalConnection] = {} # Solo connessioni locali
|
|
||||||
self._remote_futures: dict[str, asyncio.Future[dict]] = {}
|
|
||||||
|
|
||||||
async def start(self):
|
|
||||||
"""Avvia il listener Redis per tool_call in arrivo."""
|
|
||||||
asyncio.create_task(self._listen_tool_calls())
|
|
||||||
|
|
||||||
# ── Registrazione ──
|
|
||||||
|
|
||||||
async def register(self, user_id: str, device_id: str, ws: WebSocket):
|
|
||||||
# Registra localmente
|
|
||||||
self._local[user_id] = LocalConnection(ws=ws, device_id=device_id)
|
|
||||||
# Registra in Redis quale istanza ha la connessione
|
|
||||||
await self._redis.hset("ws:connections", user_id, INSTANCE_ID)
|
|
||||||
# Sottoscrivi ai tool_call per questo utente
|
|
||||||
await self._pubsub.subscribe(f"tool_call:{user_id}")
|
|
||||||
|
|
||||||
async def unregister(self, user_id: str):
|
|
||||||
conn = self._local.pop(user_id, None)
|
|
||||||
if conn:
|
|
||||||
for fut in conn.pending_calls.values():
|
|
||||||
if not fut.done():
|
|
||||||
fut.cancel()
|
|
||||||
await self._redis.hdel("ws:connections", user_id)
|
|
||||||
await self._pubsub.unsubscribe(f"tool_call:{user_id}")
|
|
||||||
|
|
||||||
# ── Presenza ──
|
|
||||||
|
|
||||||
async def is_online(self, user_id: str) -> bool:
|
|
||||||
return await self._redis.hexists("ws:connections", user_id)
|
|
||||||
|
|
||||||
# ── Tool-call round-trip (cross-instance) ──
|
|
||||||
|
|
||||||
async def execute_tool_call(self, user_id: str, payload: dict) -> dict:
|
|
||||||
"""
|
|
||||||
Invia un tool_call al device dell'utente.
|
|
||||||
Funziona sia che la WS sia locale che su un'altra istanza.
|
|
||||||
"""
|
|
||||||
call_id = payload["id"]
|
|
||||||
|
|
||||||
# Caso 1: connessione locale → invio diretto
|
|
||||||
if user_id in self._local:
|
|
||||||
conn = self._local[user_id]
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
fut: asyncio.Future[dict] = loop.create_future()
|
|
||||||
conn.pending_calls[call_id] = fut
|
|
||||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **payload}))
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
|
|
||||||
# Caso 2: connessione remota → Redis pub/sub
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
fut = loop.create_future()
|
|
||||||
self._remote_futures[call_id] = fut
|
|
||||||
|
|
||||||
# Sottoscrivi al canale di risposta
|
|
||||||
result_channel = f"tool_result:{call_id}"
|
|
||||||
await self._pubsub.subscribe(result_channel)
|
|
||||||
|
|
||||||
# Pubblica il tool_call
|
|
||||||
await self._redis.publish(
|
|
||||||
f"tool_call:{user_id}",
|
|
||||||
json.dumps(payload),
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
return await asyncio.wait_for(fut, timeout=30.0)
|
|
||||||
finally:
|
|
||||||
self._remote_futures.pop(call_id, None)
|
|
||||||
await self._pubsub.unsubscribe(result_channel)
|
|
||||||
|
|
||||||
# ── Risoluzione tool_result (da WS locale) ──
|
|
||||||
|
|
||||||
def resolve_local(self, user_id: str, call_id: str, result: dict):
|
|
||||||
conn = self._local.get(user_id)
|
|
||||||
if conn:
|
|
||||||
fut = conn.pending_calls.pop(call_id, None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(result)
|
|
||||||
|
|
||||||
async def resolve_and_publish(self, user_id: str, call_id: str, result: dict):
|
|
||||||
"""Chiamato quando il device locale invia un tool_result."""
|
|
||||||
self.resolve_local(user_id, call_id, result)
|
|
||||||
# Pubblica anche su Redis per l'istanza remota che aspetta
|
|
||||||
await self._redis.publish(
|
|
||||||
f"tool_result:{call_id}",
|
|
||||||
json.dumps(result),
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Listener Redis ──
|
|
||||||
|
|
||||||
async def _listen_tool_calls(self):
|
|
||||||
"""Loop che ascolta i tool_call in arrivo da altre istanze."""
|
|
||||||
async for message in self._pubsub.listen():
|
|
||||||
if message["type"] != "message":
|
|
||||||
continue
|
|
||||||
channel = message["channel"]
|
|
||||||
if isinstance(channel, bytes):
|
|
||||||
channel = channel.decode()
|
|
||||||
|
|
||||||
data = json.loads(message["data"])
|
|
||||||
|
|
||||||
if channel.startswith("tool_call:"):
|
|
||||||
# Un'altra istanza vuole che inviamo un tool_call al nostro device
|
|
||||||
user_id = channel.split(":", 1)[1]
|
|
||||||
conn = self._local.get(user_id)
|
|
||||||
if conn:
|
|
||||||
await conn.ws.send_text(json.dumps({"type": "tool_call", **data}))
|
|
||||||
|
|
||||||
elif channel.startswith("tool_result:"):
|
|
||||||
# Risposta a un tool_call che abbiamo inviato tramite Redis
|
|
||||||
call_id = channel.split(":", 1)[1]
|
|
||||||
fut = self._remote_futures.pop(call_id, None)
|
|
||||||
if fut and not fut.done():
|
|
||||||
fut.set_result(data)
|
|
||||||
|
|
||||||
# ── Stream cross-instance ──
|
|
||||||
|
|
||||||
async def publish_stream_chunk(self, user_id: str, chunk: dict):
|
|
||||||
"""Pubblica un chunk di streaming su Redis (per REST→WS relay)."""
|
|
||||||
await self._redis.publish(f"stream:{user_id}", json.dumps(chunk))
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 4. Struttura Directory Proposta (MVP)
|
|
||||||
|
|
||||||
```
|
|
||||||
adiuva-api/
|
|
||||||
├── docker-compose.yml # Orchestrazione completa
|
|
||||||
├── docker-compose.dev.yml # Override per sviluppo locale
|
|
||||||
├── shared/ # Codice condiviso (montato come volume)
|
|
||||||
│ ├── auth.py # JWT verification (chiave pubblica)
|
|
||||||
│ ├── schemas.py # Pydantic schemas condivisi
|
|
||||||
│ ├── middleware/
|
|
||||||
│ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
|
|
||||||
│ │ └── sanitizer.py
|
|
||||||
│ └── models/
|
|
||||||
│ └── base.py # SQLAlchemy base condivisa
|
|
||||||
│
|
|
||||||
├── auth-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # users, refresh_tokens
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ └── auth.py
|
|
||||||
│ └── services/
|
|
||||||
│ ├── jwt_service.py # RS256 signing
|
|
||||||
│ └── user_service.py
|
|
||||||
│
|
|
||||||
├── chat-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # memory_*
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ ├── device_ws.py # WS connection owner
|
|
||||||
│ │ └── chat.py # REST fallback
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── device_manager.py # RedisDeviceManager
|
|
||||||
│ │ ├── deep_agent.py # Home + floating chat
|
|
||||||
│ │ ├── memory_middleware.py
|
|
||||||
│ │ ├── ws_context.py
|
|
||||||
│ │ ├── output_formatter.py
|
|
||||||
│ │ └── llm.py
|
|
||||||
│ └── agents/ # Tool definitions (used by deep_agent)
|
|
||||||
│ ├── task_agent.py
|
|
||||||
│ ├── project_agent.py
|
|
||||||
│ ├── note_agent.py
|
|
||||||
│ └── timeline_agent.py
|
|
||||||
│
|
|
||||||
├── agent-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # agent_run_logs, local/cloud_agent_configs
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ ├── agents.py # catalog, can-create, trigger
|
|
||||||
│ │ └── agent_setup.py # journey start/message
|
|
||||||
│ ├── core/
|
|
||||||
│ │ ├── agent_runner.py # Batch classify → process
|
|
||||||
│ │ ├── agent_registry.py
|
|
||||||
│ │ ├── redis_executor.py # execute_on_client via Redis pub/sub
|
|
||||||
│ │ └── llm.py
|
|
||||||
│ └── agents/
|
|
||||||
│ ├── task_agent.py # Tool definitions (batch context)
|
|
||||||
│ ├── project_agent.py
|
|
||||||
│ ├── note_agent.py
|
|
||||||
│ ├── timeline_agent.py
|
|
||||||
│ └── filesystem_agent.py
|
|
||||||
│
|
|
||||||
├── billing-service/
|
|
||||||
│ ├── Dockerfile
|
|
||||||
│ ├── requirements.txt
|
|
||||||
│ └── app/
|
|
||||||
│ ├── main.py
|
|
||||||
│ ├── config.py
|
|
||||||
│ ├── db.py
|
|
||||||
│ ├── models.py # subscriptions
|
|
||||||
│ ├── routes/
|
|
||||||
│ │ └── billing.py
|
|
||||||
│ └── services/
|
|
||||||
│ ├── stripe_service.py
|
|
||||||
│ └── tier_manager.py
|
|
||||||
│
|
|
||||||
└── infra/
|
|
||||||
├── traefik/
|
|
||||||
│ └── traefik.yml
|
|
||||||
├── keys/
|
|
||||||
│ ├── jwt_private.pem # Solo auth-service
|
|
||||||
│ └── jwt_public.pem # Tutti i servizi
|
|
||||||
└── alembic/ # Migrazioni condivise o per-servizio
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 5. Docker Compose — Configurazione MVP
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# docker-compose.yml
|
|
||||||
|
|
||||||
services:
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# API Gateway
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
traefik:
|
|
||||||
image: traefik:v3.2
|
|
||||||
command:
|
|
||||||
- "--api.insecure=true"
|
|
||||||
- "--providers.docker=true"
|
|
||||||
- "--providers.docker.exposedbydefault=false"
|
|
||||||
- "--entrypoints.web.address=:80"
|
|
||||||
- "--entrypoints.websecure.address=:443"
|
|
||||||
- "--entrypoints.web.http.redirections.entrypoint.to=websecure"
|
|
||||||
ports:
|
|
||||||
- "80:80"
|
|
||||||
- "443:443"
|
|
||||||
- "8080:8080" # Dashboard Traefik (disabilitare in prod)
|
|
||||||
volumes:
|
|
||||||
- /var/run/docker.sock:/var/run/docker.sock:ro
|
|
||||||
- ./infra/certs:/certs:ro
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Auth Service (2 repliche)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
auth-service:
|
|
||||||
build: ./auth-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
|
|
||||||
SERVICE_NAME: auth
|
|
||||||
secrets:
|
|
||||||
- jwt_private_key
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
|
|
||||||
- "traefik.http.services.auth.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Chat Service — Real-time WS + Chat (scalabile)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
chat-service:
|
|
||||||
build: ./chat-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: chat
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
# REST chat endpoint
|
|
||||||
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
|
|
||||||
- "traefik.http.services.chat.loadbalancer.server.port=8000"
|
|
||||||
# WebSocket route con sticky session
|
|
||||||
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
|
|
||||||
- "traefik.http.routers.ws.service=chat-ws"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.server.port=8000"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.name=ws_affinity"
|
|
||||||
- "traefik.http.services.chat-ws.loadbalancer.sticky.cookie.httpOnly=true"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Agent Service — Batch processing (scalabile indipendentemente)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
agent-service:
|
|
||||||
build: ./agent-service
|
|
||||||
deploy:
|
|
||||||
replicas: 2
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: agent
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
|
|
||||||
- "traefik.http.services.agents.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Billing Service (1 replica)
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
billing-service:
|
|
||||||
build: ./billing-service
|
|
||||||
deploy:
|
|
||||||
replicas: 1
|
|
||||||
env_file: .env
|
|
||||||
environment:
|
|
||||||
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
|
|
||||||
REDIS_URL: redis://redis:6379
|
|
||||||
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
|
|
||||||
SERVICE_NAME: billing
|
|
||||||
secrets:
|
|
||||||
- jwt_public_key
|
|
||||||
labels:
|
|
||||||
- "traefik.enable=true"
|
|
||||||
- "traefik.http.routers.billing.rule=PathPrefix(`/api/v1/billing`)"
|
|
||||||
- "traefik.http.services.billing.loadbalancer.server.port=8000"
|
|
||||||
depends_on:
|
|
||||||
db:
|
|
||||||
condition: service_healthy
|
|
||||||
redis:
|
|
||||||
condition: service_healthy
|
|
||||||
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
# Infrastruttura
|
|
||||||
# ══════════════════════════════════════════════════════════
|
|
||||||
db:
|
|
||||||
image: pgvector/pgvector:pg16
|
|
||||||
environment:
|
|
||||||
POSTGRES_USER: postgres
|
|
||||||
POSTGRES_PASSWORD: postgres
|
|
||||||
POSTGRES_DB: adiuva
|
|
||||||
volumes:
|
|
||||||
- postgres_data:/var/lib/postgresql/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD-SHELL", "pg_isready -U postgres"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 5s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
redis:
|
|
||||||
image: redis:7-alpine
|
|
||||||
command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
|
|
||||||
volumes:
|
|
||||||
- redis_data:/data
|
|
||||||
healthcheck:
|
|
||||||
test: ["CMD", "redis-cli", "ping"]
|
|
||||||
interval: 5s
|
|
||||||
timeout: 3s
|
|
||||||
retries: 5
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
qdrant:
|
|
||||||
image: qdrant/qdrant:latest
|
|
||||||
volumes:
|
|
||||||
- qdrant_data:/qdrant/storage
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
secrets:
|
|
||||||
jwt_private_key:
|
|
||||||
file: ./infra/keys/jwt_private.pem
|
|
||||||
jwt_public_key:
|
|
||||||
file: ./infra/keys/jwt_public.pem
|
|
||||||
|
|
||||||
volumes:
|
|
||||||
postgres_data:
|
|
||||||
redis_data:
|
|
||||||
qdrant_data:
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 6. Configurazione Cloudflare + VPS
|
|
||||||
|
|
||||||
### 6.1 DNS
|
|
||||||
|
|
||||||
```
|
|
||||||
api.tuodominio.com → A record → IP del VPS
|
|
||||||
→ Proxy: ON (orange cloud)
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.2 Cloudflare Settings
|
|
||||||
|
|
||||||
| Setting | Valore | Motivo |
|
|
||||||
|---------|--------|--------|
|
|
||||||
| SSL/TLS mode | **Full (Strict)** | Cloudflare ↔ VPS con certificato valido |
|
|
||||||
| WebSocket | **ON** | Necessario per `/api/v1/ws/device` |
|
|
||||||
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
|
|
||||||
| Under Attack Mode | Off (attivare se necessario) | |
|
|
||||||
|
|
||||||
### 6.3 TLS sul VPS
|
|
||||||
|
|
||||||
Due opzioni:
|
|
||||||
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
|
|
||||||
- **Opzione B**: Let's Encrypt via Traefik (con DNS challenge Cloudflare)
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# traefik.yml — con Cloudflare Origin Certificate
|
|
||||||
entryPoints:
|
|
||||||
websecure:
|
|
||||||
address: ":443"
|
|
||||||
|
|
||||||
tls:
|
|
||||||
certificates:
|
|
||||||
- certFile: /certs/origin.pem
|
|
||||||
keyFile: /certs/origin-key.pem
|
|
||||||
```
|
|
||||||
|
|
||||||
### 6.4 Rete VPS
|
|
||||||
|
|
||||||
```bash
|
|
||||||
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443
|
|
||||||
# https://www.cloudflare.com/ips/
|
|
||||||
ufw default deny incoming
|
|
||||||
ufw allow from 173.245.48.0/20 to any port 443
|
|
||||||
ufw allow from 103.21.244.0/22 to any port 443
|
|
||||||
# ... (tutti gli IP range di Cloudflare)
|
|
||||||
ufw allow ssh
|
|
||||||
ufw enable
|
|
||||||
```
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 7. Comunicazione Inter-Servizio
|
|
||||||
|
|
||||||
### 7.1 Redis Pub/Sub — Event Bus
|
|
||||||
|
|
||||||
```
|
|
||||||
┌──────────┐ tier_changed:user_123 ┌──────────┐
|
|
||||||
│ Billing │ ────────────────────────► │ Auth │
|
|
||||||
│ Service │ │ Service │
|
|
||||||
└──────────┘ └──────────┘
|
|
||||||
|
|
||||||
┌──────────┐ tool_call:user_123 ┌──────────┐
|
|
||||||
│ Agent │ ────────────────────────► │ Chat │
|
|
||||||
│ Service │ │ Service │
|
|
||||||
│ (batch) │ ◄────────────────────────│ (ha WS) │
|
|
||||||
└──────────┘ tool_result:{call_id} └──────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
### 7.2 Health Checks e Service Discovery
|
|
||||||
|
|
||||||
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
|
|
||||||
- **Redis pub/sub** (tool-call cross-instance, tier events)
|
|
||||||
- **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
|
|
||||||
- **PostgreSQL** (dati persistenti condivisi)
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 8. Piano di Migrazione Incrementale (MVP)
|
|
||||||
|
|
||||||
### Fase 1 — Preparazione (nel monolite attuale)
|
|
||||||
1. Aggiungere Redis al `docker-compose.yml` attuale
|
|
||||||
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
|
|
||||||
3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
|
|
||||||
4. Estrarre `shared/` con auth verification, schemas, middleware
|
|
||||||
|
|
||||||
### Fase 2 — Auth Service (primo split)
|
|
||||||
1. Estrarre `auth.py` routes + models in `auth-service/`
|
|
||||||
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
|
|
||||||
3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
|
|
||||||
4. Il monolite continua a servire tutto il resto
|
|
||||||
|
|
||||||
### Fase 3 — Billing Service
|
|
||||||
1. Estrarre billing routes, Stripe service, tier manager
|
|
||||||
2. Configurare Redis pub/sub per `tier_changed` events
|
|
||||||
3. Routare via Traefik
|
|
||||||
|
|
||||||
### Fase 4 — Split Chat + Agent (il più delicato)
|
|
||||||
1. Il monolite residuo contiene WS + chat + agents
|
|
||||||
2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
|
|
||||||
3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
|
|
||||||
4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
|
|
||||||
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
|
|
||||||
|
|
||||||
### Fase 5 — Scaling test
|
|
||||||
1. Scalare Chat Service a 2 repliche, verificare sticky sessions
|
|
||||||
2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
|
|
||||||
3. Monitoring (Prometheus + Grafana) per ogni servizio
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 9. Monitoraggio e Logging
|
|
||||||
|
|
||||||
```yaml
|
|
||||||
# Aggiungere al docker-compose.yml
|
|
||||||
|
|
||||||
prometheus:
|
|
||||||
image: prom/prometheus:latest
|
|
||||||
volumes:
|
|
||||||
- ./infra/prometheus/prometheus.yml:/etc/prometheus/prometheus.yml
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
grafana:
|
|
||||||
image: grafana/grafana:latest
|
|
||||||
ports:
|
|
||||||
- "3000:3000"
|
|
||||||
volumes:
|
|
||||||
- grafana_data:/var/lib/grafana
|
|
||||||
restart: unless-stopped
|
|
||||||
|
|
||||||
loki:
|
|
||||||
image: grafana/loki:latest
|
|
||||||
restart: unless-stopped
|
|
||||||
```
|
|
||||||
|
|
||||||
Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) raccolti da Loki.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## 10. Sizing VPS Minimo Consigliato (MVP)
|
|
||||||
|
|
||||||
| Componente | CPU | RAM | Note |
|
|
||||||
|---|---|---|---|
|
|
||||||
| Traefik | 0.25 | 128MB | |
|
|
||||||
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
|
|
||||||
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
|
|
||||||
| Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
|
|
||||||
| Billing Service | 0.25 | 128MB | |
|
|
||||||
| PostgreSQL | 1.0 | 1GB | |
|
|
||||||
| Redis | 0.25 | 256MB | |
|
|
||||||
| Qdrant | 0.5 | 512MB | |
|
|
||||||
| **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
|
|
||||||
|
|
||||||
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
|
|
||||||
|
|
||||||
---
|
|
||||||
|
|
||||||
## Riepilogo Architettura MVP
|
|
||||||
|
|
||||||
| Servizio | Repliche | Proprietario di |
|
|
||||||
|---|---|---|
|
|
||||||
| **Traefik** | 1 | Routing, TLS, sticky sessions |
|
|
||||||
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
|
|
||||||
| **Chat Service** | 2–N | WebSocket, home/floating chat, streaming |
|
|
||||||
| **Agent Service** | 2–N | Batch processing, directory scan, agent setup |
|
|
||||||
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
|
|
||||||
|
|
||||||
| Decisione | Scelta | Motivazione |
|
|
||||||
|---|---|---|
|
|
||||||
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
|
|
||||||
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
|
|
||||||
| Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
|
|
||||||
| WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
|
|
||||||
| Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
|
|
||||||
| Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
|
|
||||||
| Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
|
|
||||||
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
|
|
||||||
| TLS | Cloudflare Origin Certificate | Zero maintenance |
|
|
||||||
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
|
|
||||||
| Storage / Plugin | Post-MVP | Non critici per il lancio |
|
|
||||||
36
requirements.txt
Normal file
36
requirements.txt
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
fastapi>=0.115.0
|
||||||
|
uvicorn[standard]>=0.34.0
|
||||||
|
gunicorn>=22.0.0
|
||||||
|
langchain>=0.3.0
|
||||||
|
langchain-openai>=0.3.0
|
||||||
|
langchain-litellm>=0.1.0
|
||||||
|
langgraph>=0.3.0
|
||||||
|
litellm>=1.50.0
|
||||||
|
pydantic>=2.10.0
|
||||||
|
pydantic-settings>=2.7.0
|
||||||
|
python-jose[cryptography]>=3.3.0
|
||||||
|
stripe>=11.0.0
|
||||||
|
boto3>=1.35.0
|
||||||
|
slowapi>=0.1.9
|
||||||
|
sqlalchemy>=2.0.0
|
||||||
|
asyncpg>=0.30.0
|
||||||
|
alembic>=1.14.0
|
||||||
|
bcrypt>=4.2.0
|
||||||
|
python-dotenv>=1.0.0
|
||||||
|
httpx>=0.28.0
|
||||||
|
websockets>=14.0
|
||||||
|
psycopg2-binary>=2.9.0
|
||||||
|
pytest>=8.0.0
|
||||||
|
pytest-asyncio>=0.24.0
|
||||||
|
aiosqlite>=0.20.0
|
||||||
|
moto[s3]>=5.0.0
|
||||||
|
pinecone>=5.0.0
|
||||||
|
qdrant-client>=1.7.0
|
||||||
|
croniter>=3.0.0
|
||||||
|
google-api-python-client>=2.130.0
|
||||||
|
google-auth>=2.29.0
|
||||||
|
google-auth-oauthlib>=1.2.0
|
||||||
|
google-auth-httplib2>=0.2.0
|
||||||
|
msal>=1.28.0
|
||||||
|
cryptography>=42.0.0
|
||||||
|
ruff>=0.8.0
|
||||||
@@ -1,19 +0,0 @@
|
|||||||
# ── Auth Service ──────────────────────────────────────────────────────────────
|
|
||||||
# This file contains env vars specific to the Auth Service.
|
|
||||||
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
|
|
||||||
# or from docker-compose environment.
|
|
||||||
|
|
||||||
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
|
|
||||||
# Generate keypair:
|
|
||||||
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
|
||||||
# openssl rsa -in private.pem -pubout -out public.pem
|
|
||||||
#
|
|
||||||
# Paste PEM content with literal \n for newlines:
|
|
||||||
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
|
|
||||||
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
|
|
||||||
|
|
||||||
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
|
|
||||||
JWT_PRIVATE_KEY=
|
|
||||||
|
|
||||||
# PUBLIC KEY — used to VERIFY JWTs.
|
|
||||||
JWT_PUBLIC_KEY=
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
# ── builder ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS builder
|
|
||||||
|
|
||||||
WORKDIR /build
|
|
||||||
|
|
||||||
# Install shared + service deps in one layer
|
|
||||||
COPY services/auth/requirements.txt ./requirements.txt
|
|
||||||
RUN pip install --upgrade pip && \
|
|
||||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
|
||||||
|
|
||||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS runtime
|
|
||||||
|
|
||||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY --from=builder /install /usr/local
|
|
||||||
|
|
||||||
# Copy shared module (available to all services)
|
|
||||||
COPY shared/ shared/
|
|
||||||
|
|
||||||
# Copy service source
|
|
||||||
COPY services/auth/app/ app/
|
|
||||||
|
|
||||||
RUN chown -R appuser:appgroup /app
|
|
||||||
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
CMD ["gunicorn", "app.main:app", \
|
|
||||||
"-k", "uvicorn.workers.UvicornWorker", \
|
|
||||||
"--bind", "0.0.0.0:8000", \
|
|
||||||
"--workers", "2", \
|
|
||||||
"--timeout", "30"]
|
|
||||||
@@ -1,16 +0,0 @@
|
|||||||
# Auth Service
|
|
||||||
|
|
||||||
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
|
|
||||||
|
|
||||||
## Tables owned
|
|
||||||
- `users`
|
|
||||||
- `refresh_tokens`
|
|
||||||
- `subscriptions` (read; Billing Service writes)
|
|
||||||
|
|
||||||
## Endpoints
|
|
||||||
- `POST /auth/register`
|
|
||||||
- `POST /auth/login`
|
|
||||||
- `POST /auth/refresh`
|
|
||||||
- `GET /auth/me`
|
|
||||||
- `PUT /auth/me`
|
|
||||||
- `GET /auth/verify` (ForwardAuth for Traefik)
|
|
||||||
@@ -1,34 +0,0 @@
|
|||||||
"""Auth Service — local configuration.
|
|
||||||
|
|
||||||
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
|
|
||||||
These are NOT in shared/config.py to prevent other services from accessing them.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from pydantic import field_validator
|
|
||||||
from pydantic_settings import BaseSettings, SettingsConfigDict
|
|
||||||
|
|
||||||
|
|
||||||
class AuthSettings(BaseSettings):
|
|
||||||
# RS256 private key (PEM format). Used to SIGN JWTs.
|
|
||||||
# Only the Auth Service has this. Generate with:
|
|
||||||
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
|
|
||||||
# Then set the env var (newlines as \n):
|
|
||||||
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
|
|
||||||
JWT_PRIVATE_KEY: str = ""
|
|
||||||
|
|
||||||
# RS256 public key (PEM format). Used to VERIFY JWTs.
|
|
||||||
# Derived from the private key:
|
|
||||||
# openssl rsa -in private.pem -pubout -out public.pem
|
|
||||||
JWT_PUBLIC_KEY: str = ""
|
|
||||||
|
|
||||||
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def _expand_pem_newlines(cls, v: str) -> str:
|
|
||||||
if isinstance(v, str) and r"\n" in v:
|
|
||||||
return v.replace(r"\n", "\n")
|
|
||||||
return v
|
|
||||||
|
|
||||||
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
|
|
||||||
|
|
||||||
|
|
||||||
auth_settings = AuthSettings()
|
|
||||||
@@ -1,62 +0,0 @@
|
|||||||
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
|
|
||||||
|
|
||||||
Standalone FastAPI service extracted from the adiuva-api monolith.
|
|
||||||
Owns: users, refresh_tokens, subscriptions (read).
|
|
||||||
"""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Ensure the repo root is on sys.path so "shared" is importable.
|
|
||||||
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
|
|
||||||
# In local dev, we need to add the repo root (two levels up from this file).
|
|
||||||
_repo_root = str(Path(__file__).resolve().parents[3])
|
|
||||||
if _repo_root not in sys.path:
|
|
||||||
sys.path.insert(0, _repo_root)
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI):
|
|
||||||
yield
|
|
||||||
from shared.db import engine
|
|
||||||
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
|
|
||||||
def create_app() -> FastAPI:
|
|
||||||
app = FastAPI(
|
|
||||||
title="Adiuva Auth Service",
|
|
||||||
version="0.1.0",
|
|
||||||
docs_url="/docs" if settings.ENV == "dev" else None,
|
|
||||||
redoc_url=None,
|
|
||||||
lifespan=lifespan,
|
|
||||||
)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=settings.CORS_ORIGINS,
|
|
||||||
allow_credentials=True,
|
|
||||||
allow_methods=["*"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
from app.routes import router
|
|
||||||
from app.verify import router as verify_router
|
|
||||||
|
|
||||||
app.include_router(router, prefix="/api/v1")
|
|
||||||
app.include_router(verify_router, prefix="/api/v1")
|
|
||||||
|
|
||||||
@app.get("/api/v1/health", tags=["health"])
|
|
||||||
async def health() -> dict:
|
|
||||||
return {"status": "ok", "service": "auth", "version": app.version}
|
|
||||||
|
|
||||||
return app
|
|
||||||
|
|
||||||
|
|
||||||
app = create_app()
|
|
||||||
@@ -1,66 +0,0 @@
|
|||||||
"""ForwardAuth verification endpoint for Traefik.
|
|
||||||
|
|
||||||
Traefik calls GET /api/v1/auth/verify on every request to a protected
|
|
||||||
service. This endpoint validates the JWT from the Authorization header
|
|
||||||
and returns identity headers that Traefik injects into downstream requests.
|
|
||||||
|
|
||||||
Downstream services NEVER validate JWTs themselves — they trust the
|
|
||||||
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Request, Response
|
|
||||||
from fastapi import status as http_status
|
|
||||||
from jose import JWTError, jwt
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
from shared.db import async_session
|
|
||||||
from shared.models import Subscription
|
|
||||||
|
|
||||||
from app.config import auth_settings
|
|
||||||
|
|
||||||
router = APIRouter(tags=["auth"])
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/auth/verify")
|
|
||||||
async def verify(request: Request) -> Response:
|
|
||||||
"""Validate JWT and return identity headers for Traefik ForwardAuth.
|
|
||||||
|
|
||||||
Returns 200 with X-User-* headers on success, 401 on failure.
|
|
||||||
Traefik copies response headers to the downstream request.
|
|
||||||
"""
|
|
||||||
auth_header = request.headers.get("Authorization", "")
|
|
||||||
if not auth_header.startswith("Bearer "):
|
|
||||||
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
|
||||||
|
|
||||||
token = auth_header[7:] # strip "Bearer "
|
|
||||||
|
|
||||||
try:
|
|
||||||
payload = jwt.decode(
|
|
||||||
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
|
|
||||||
)
|
|
||||||
user_id: str | None = payload.get("sub")
|
|
||||||
email: str | None = payload.get("email")
|
|
||||||
if not user_id or not email:
|
|
||||||
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
|
||||||
except JWTError:
|
|
||||||
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
|
|
||||||
|
|
||||||
# Live tier lookup from subscriptions table
|
|
||||||
async with async_session() as db:
|
|
||||||
result = await db.execute(
|
|
||||||
select(Subscription.tier).where(Subscription.user_id == user_id)
|
|
||||||
)
|
|
||||||
default_tier = "power" if settings.ENV == "dev" else "free"
|
|
||||||
tier: str = result.scalar_one_or_none() or default_tier
|
|
||||||
|
|
||||||
return Response(
|
|
||||||
status_code=http_status.HTTP_200_OK,
|
|
||||||
headers={
|
|
||||||
"X-User-Id": user_id,
|
|
||||||
"X-User-Email": email,
|
|
||||||
"X-User-Tier": tier,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
@@ -1,11 +0,0 @@
|
|||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
gunicorn>=22.0.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
pydantic-settings>=2.7.0
|
|
||||||
python-jose[cryptography]>=3.3.0
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
bcrypt>=4.2.0
|
|
||||||
cryptography>=42.0.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
# ── builder ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS builder
|
|
||||||
|
|
||||||
WORKDIR /build
|
|
||||||
|
|
||||||
COPY services/batch-agent/requirements.txt ./requirements.txt
|
|
||||||
RUN pip install --upgrade pip && \
|
|
||||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
|
||||||
|
|
||||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS runtime
|
|
||||||
|
|
||||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY --from=builder /install /usr/local
|
|
||||||
|
|
||||||
# Shared module
|
|
||||||
COPY shared/ shared/
|
|
||||||
|
|
||||||
# Service source
|
|
||||||
COPY services/batch-agent/app/ app/
|
|
||||||
|
|
||||||
RUN chown -R appuser:appgroup /app
|
|
||||||
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
|
|
||||||
CMD ["gunicorn", "app.main:app", \
|
|
||||||
"-k", "uvicorn.workers.UvicornWorker", \
|
|
||||||
"--bind", "0.0.0.0:8000", \
|
|
||||||
"--workers", "2", \
|
|
||||||
"--timeout", "300"]
|
|
||||||
@@ -1,23 +0,0 @@
|
|||||||
# Batch Agent Service
|
|
||||||
|
|
||||||
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
|
|
||||||
|
|
||||||
## Tables owned
|
|
||||||
- `local_agent_configs`
|
|
||||||
- `cloud_agent_configs`
|
|
||||||
- `agent_run_logs`
|
|
||||||
|
|
||||||
## Endpoints
|
|
||||||
- `GET /agents/catalog`
|
|
||||||
- `POST /agents/can-create`
|
|
||||||
- `POST /agents/trigger`
|
|
||||||
- `GET /agents/{id}/history`
|
|
||||||
|
|
||||||
## Redis channels
|
|
||||||
- Subscribe: `batch:request:{user_id}`
|
|
||||||
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
|
|
||||||
- BRPOP: `tool:result:{call_id}` (30s timeout)
|
|
||||||
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
|
|
||||||
|
|
||||||
## TODO
|
|
||||||
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.
|
|
||||||
@@ -1,910 +0,0 @@
|
|||||||
"""Agent run orchestrator — adapted for Batch Agent Service.
|
|
||||||
|
|
||||||
Key changes from monolith app/core/agent_runner.py:
|
|
||||||
- No DeviceConnectionManager — tool calls go through Redis ws_context.
|
|
||||||
- set_current_user / clear_current_user replace set_client_executor.
|
|
||||||
- run_local_agent accepts a serialized dict (from Redis / REST) instead
|
|
||||||
of SQLAlchemy model objects.
|
|
||||||
- _finalize_run writes to PostgreSQL via shared.db.async_session.
|
|
||||||
- Cloud agent import path changed to app.integrations.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timedelta, timezone
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
||||||
from sqlalchemy import select
|
|
||||||
|
|
||||||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
|
||||||
from shared.agents.note_agent import NOTE_TOOLS
|
|
||||||
from shared.agents.project_agent import PROJECT_TOOLS
|
|
||||||
from shared.agents.task_agent import TASK_TOOLS
|
|
||||||
from shared.agents.timeline_agent import TIMELINE_TOOLS
|
|
||||||
from shared.llm import get_llm
|
|
||||||
from shared.ws_context import execute_on_client, set_current_user, clear_current_user
|
|
||||||
import app.tracing as tracing
|
|
||||||
from shared.db import async_session
|
|
||||||
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
|
||||||
from shared.redis import redis_client, ws_out_channel
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ── Concurrency guard ─────────────────────────────────────────────────────
|
|
||||||
_running_agents: set[str] = set()
|
|
||||||
|
|
||||||
|
|
||||||
def is_agent_running(agent_id: str) -> bool:
|
|
||||||
return agent_id in _running_agents
|
|
||||||
|
|
||||||
|
|
||||||
# ── Timeouts ───────────────────────────────────────────────────────────────
|
|
||||||
_TOOL_CALL_TIMEOUT: int = 30
|
|
||||||
_MAX_PROCESSING_STEPS: int = 12
|
|
||||||
_MAX_SCAN_DEPTH: int = 5
|
|
||||||
|
|
||||||
# ── Data-type to tool mapping ─────────────────────────────────────────────
|
|
||||||
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
|
|
||||||
"tasks": TASK_TOOLS,
|
|
||||||
"notes": NOTE_TOOLS,
|
|
||||||
"timelines": TIMELINE_TOOLS,
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── Step 1: Classification prompt ─────────────────────────────────────────
|
|
||||||
|
|
||||||
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
|
|
||||||
"tasks": (
|
|
||||||
"Action items, to-dos, deliverables — anything that describes work to be done, "
|
|
||||||
"assigned to someone, or tracked with a due date or status."
|
|
||||||
),
|
|
||||||
"notes": (
|
|
||||||
"Documentation, meeting notes, summaries, reference material — "
|
|
||||||
"written content meant to be read and referenced rather than acted on."
|
|
||||||
),
|
|
||||||
"timelines": (
|
|
||||||
"Project milestones, deadlines, scheduled events — "
|
|
||||||
"specific dates that mark a point in the progress of a project."
|
|
||||||
),
|
|
||||||
"projects": (
|
|
||||||
"High-level project entities — only relevant if the file clearly introduces "
|
|
||||||
"a new project or updates the scope of an existing one."
|
|
||||||
),
|
|
||||||
}
|
|
||||||
|
|
||||||
_STEP1_SYSTEM_PROMPT = """\
|
|
||||||
You are a file classifier for a freelance project management tool.
|
|
||||||
|
|
||||||
Your job is to match a file to an existing project and identify which data domains to extract.
|
|
||||||
|
|
||||||
## Project matching rules (STRICT — follow in order)
|
|
||||||
|
|
||||||
1. Search the file content for any mention of a project name, client name, acronym, or topic
|
|
||||||
that overlaps with the existing projects listed below.
|
|
||||||
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
|
|
||||||
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
|
|
||||||
when the file has zero meaningful connection to any listed project.
|
|
||||||
4. When in doubt, pick the closest match from the list.
|
|
||||||
|
|
||||||
## Response format
|
|
||||||
|
|
||||||
Respond ONLY with a JSON object — no markdown, no explanation:
|
|
||||||
|
|
||||||
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
|
|
||||||
|
|
||||||
## Domain definitions (only consider domains in the allowed list)
|
|
||||||
|
|
||||||
{domain_definitions}
|
|
||||||
|
|
||||||
## Existing projects
|
|
||||||
|
|
||||||
{projects_list}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ── Step 2: Processing prompt ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
_PROCESSING_SYSTEM_PROMPT = """\
|
|
||||||
You are a data extraction assistant for a freelance project management tool.
|
|
||||||
|
|
||||||
Your task: extract structured data from the file content and persist it using the available tools.
|
|
||||||
|
|
||||||
## Mandatory process — follow this order for EVERY item you extract
|
|
||||||
|
|
||||||
1. READ the existing records listed below for the relevant domain.
|
|
||||||
2. SEARCH for a match by title, topic, or semantic similarity.
|
|
||||||
3. If a match exists → call the update_* tool with the existing record's id.
|
|
||||||
4. If no match exists → call the create_* tool and set isAiSuggested=1.
|
|
||||||
|
|
||||||
NEVER call create_* without first checking the existing records.
|
|
||||||
NEVER duplicate a record that already exists under a different wording.
|
|
||||||
|
|
||||||
## Existing records (source of truth)
|
|
||||||
|
|
||||||
{existing_context}
|
|
||||||
|
|
||||||
## Context
|
|
||||||
|
|
||||||
Project: {project_context}
|
|
||||||
Domains to extract: {data_types}
|
|
||||||
|
|
||||||
{custom_prompt_section}
|
|
||||||
"""
|
|
||||||
|
|
||||||
# ── Cloud processing prompt ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
_CLOUD_PROCESSING_PROMPT = """\
|
|
||||||
You are a data extraction and management assistant for a freelance project
|
|
||||||
management tool.
|
|
||||||
|
|
||||||
Available tools:
|
|
||||||
Filesystem : read_file_content, list_directory, get_file_metadata
|
|
||||||
Tasks : list_tasks, create_task, update_task, add_task_comment
|
|
||||||
Notes : list_notes, get_note, create_note, update_note
|
|
||||||
Timelines : list_timelines, create_timeline, update_timeline
|
|
||||||
Projects : list_all_projects, get_project, create_project, update_project
|
|
||||||
|
|
||||||
Your task:
|
|
||||||
1. Read the full content of each file below using read_file_content.
|
|
||||||
2. For each piece of information found, ALWAYS try to match and update an
|
|
||||||
existing record before creating a new one.
|
|
||||||
3. ONLY act on these entity types: {data_types}.
|
|
||||||
4. Do NOT invent data. Only extract what is clearly present in the files.
|
|
||||||
5. If a file contains no relevant data for the target entity types, skip it.
|
|
||||||
|
|
||||||
{project_context}
|
|
||||||
|
|
||||||
Files to process:
|
|
||||||
{file_list}
|
|
||||||
|
|
||||||
{custom_prompt_section}
|
|
||||||
|
|
||||||
After processing all files, respond with a brief summary of what you updated
|
|
||||||
and what you created.
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
# ── LLM tool-calling loop ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _as_text(content: Any) -> str:
|
|
||||||
if content is None:
|
|
||||||
return ""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts: list[str] = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, str):
|
|
||||||
parts.append(item)
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
text = item.get("text")
|
|
||||||
if isinstance(text, str):
|
|
||||||
parts.append(text)
|
|
||||||
return "".join(parts)
|
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_agent_with_tools(
|
|
||||||
*,
|
|
||||||
system_prompt: str,
|
|
||||||
user_message: str,
|
|
||||||
tools: list[Any],
|
|
||||||
max_steps: int,
|
|
||||||
langfuse_handler: Any | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Run an LLM agent with tool-calling, returning the final text response."""
|
|
||||||
callbacks = [langfuse_handler] if langfuse_handler else None
|
|
||||||
llm = get_llm(callbacks=callbacks)
|
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
|
||||||
messages: list[Any] = [
|
|
||||||
SystemMessage(content=system_prompt),
|
|
||||||
HumanMessage(content=user_message),
|
|
||||||
]
|
|
||||||
|
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
||||||
|
|
||||||
for _ in range(max_steps):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
return _as_text(response.content)
|
|
||||||
|
|
||||||
for call in response.tool_calls:
|
|
||||||
call_id = str(call.get("id", ""))
|
|
||||||
call_name = str(call.get("name", ""))
|
|
||||||
call_args = call.get("args", {})
|
|
||||||
logger.info(
|
|
||||||
"agent_runner: tool_call name=%s args=%s",
|
|
||||||
call_name,
|
|
||||||
json.dumps(call_args, ensure_ascii=True)[:800],
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_fn = tool_map.get(call_name)
|
|
||||||
if tool_fn is None:
|
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
|
||||||
else:
|
|
||||||
tool_output = await tool_fn.ainvoke(call_args)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"agent_runner: tool_result name=%s output=%s",
|
|
||||||
call_name,
|
|
||||||
str(tool_output)[:200],
|
|
||||||
)
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
||||||
|
|
||||||
final = await llm.ainvoke(messages)
|
|
||||||
return _as_text(final.content)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tool list builder ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _build_processing_tools(data_types: list[str]) -> list[Any]:
|
|
||||||
tools: list[Any] = list(FILESYSTEM_TOOLS)
|
|
||||||
for dt in data_types:
|
|
||||||
dt_tools = _DATA_TYPE_TOOLS.get(dt)
|
|
||||||
if dt_tools:
|
|
||||||
tools.extend(dt_tools)
|
|
||||||
return tools
|
|
||||||
|
|
||||||
|
|
||||||
# ── Code-based directory scanner ─────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _scan_directories(
|
|
||||||
paths: list[str],
|
|
||||||
extensions: list[str],
|
|
||||||
last_run_at: datetime | None,
|
|
||||||
) -> list[str]:
|
|
||||||
all_files: list[str] = []
|
|
||||||
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
|
|
||||||
|
|
||||||
async def _walk(path: str, depth: int) -> None:
|
|
||||||
if depth > _MAX_SCAN_DEPTH:
|
|
||||||
return
|
|
||||||
try:
|
|
||||||
result = await execute_on_client(action="list_directory", data={"path": path})
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
|
|
||||||
return
|
|
||||||
for entry in result.get("entries", []):
|
|
||||||
entry_path = entry.get("path", "")
|
|
||||||
if not entry_path:
|
|
||||||
continue
|
|
||||||
if entry.get("type") == "directory":
|
|
||||||
await _walk(entry_path, depth + 1)
|
|
||||||
elif entry.get("type") == "file":
|
|
||||||
if ext_set:
|
|
||||||
dot_pos = entry_path.rfind(".")
|
|
||||||
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
|
|
||||||
if file_ext not in ext_set:
|
|
||||||
continue
|
|
||||||
all_files.append(entry_path)
|
|
||||||
|
|
||||||
for root in paths:
|
|
||||||
await _walk(root, depth=0)
|
|
||||||
|
|
||||||
if last_run_at is None:
|
|
||||||
return all_files
|
|
||||||
|
|
||||||
last_run_ms = int(last_run_at.timestamp() * 1000)
|
|
||||||
filtered: list[str] = []
|
|
||||||
for file_path in all_files:
|
|
||||||
try:
|
|
||||||
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
|
|
||||||
modified_at = meta.get("modifiedAt")
|
|
||||||
if modified_at is None:
|
|
||||||
filtered.append(file_path)
|
|
||||||
continue
|
|
||||||
if isinstance(modified_at, (int, float)):
|
|
||||||
mod_ms = int(modified_at)
|
|
||||||
else:
|
|
||||||
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
|
|
||||||
if mod_ms > last_run_ms:
|
|
||||||
filtered.append(file_path)
|
|
||||||
except Exception:
|
|
||||||
filtered.append(file_path)
|
|
||||||
|
|
||||||
return filtered
|
|
||||||
|
|
||||||
|
|
||||||
# ── Code-based entity fetchers ────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_projects() -> list[dict]:
|
|
||||||
try:
|
|
||||||
result = await execute_on_client(action="select", table="projects")
|
|
||||||
return result.get("rows", [])
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("agent_runner: failed to fetch projects: %s", exc)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
_DOMAIN_TABLE: dict[str, str] = {
|
|
||||||
"tasks": "tasks",
|
|
||||||
"notes": "notes",
|
|
||||||
"timelines": "timelines",
|
|
||||||
"projects": "projects",
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
|
|
||||||
table = _DOMAIN_TABLE.get(domain)
|
|
||||||
if not table:
|
|
||||||
return []
|
|
||||||
filters: dict[str, Any] = {}
|
|
||||||
if project_id != "standalone" and domain != "projects":
|
|
||||||
filters["projectId"] = project_id
|
|
||||||
try:
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="select",
|
|
||||||
table=table,
|
|
||||||
filters=filters if filters else None,
|
|
||||||
)
|
|
||||||
return result.get("rows", [])
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
|
|
||||||
return []
|
|
||||||
|
|
||||||
|
|
||||||
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
|
|
||||||
if not rows:
|
|
||||||
return f"No existing {domain}."
|
|
||||||
lines: list[str] = []
|
|
||||||
for r in rows:
|
|
||||||
if domain == "tasks":
|
|
||||||
desc = r.get("description") or ""
|
|
||||||
desc_part = f" — {desc[:120]}" if desc else ""
|
|
||||||
assignee = r.get("assignee") or r.get("assignees") or ""
|
|
||||||
due = r.get("dueDate") or r.get("due_date") or ""
|
|
||||||
meta = ", ".join(filter(None, [
|
|
||||||
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
|
|
||||||
f"assignee: {assignee}" if assignee else "",
|
|
||||||
f"due: {due}" if due else "",
|
|
||||||
]))
|
|
||||||
lines.append(
|
|
||||||
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
|
|
||||||
f" ({meta}, id: {r['id']})"
|
|
||||||
)
|
|
||||||
elif domain == "notes":
|
|
||||||
snippet = (r.get("content") or "")[:200].replace("\n", " ")
|
|
||||||
snippet_part = f"\n Preview: {snippet}" if snippet else ""
|
|
||||||
lines.append(
|
|
||||||
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
|
|
||||||
)
|
|
||||||
elif domain == "timelines":
|
|
||||||
lines.append(
|
|
||||||
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
|
|
||||||
)
|
|
||||||
elif domain == "projects":
|
|
||||||
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
|
|
||||||
summary_part = f" — {summary}" if summary else ""
|
|
||||||
lines.append(
|
|
||||||
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
|
|
||||||
f" (id: {r['id']})"
|
|
||||||
)
|
|
||||||
return f"Existing {domain}:\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Step 1: LLM file classifier ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _classify_file(
|
|
||||||
file_path: str,
|
|
||||||
file_content: str,
|
|
||||||
projects: list[dict],
|
|
||||||
config_data_types: list[str],
|
|
||||||
langfuse_handler: Any | None = None,
|
|
||||||
custom_system_prompt: str | None = None,
|
|
||||||
) -> tuple[str, list[str], str | None]:
|
|
||||||
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
|
|
||||||
|
|
||||||
if not file_content.strip():
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
valid_project_ids = {p["id"] for p in projects}
|
|
||||||
|
|
||||||
def _fmt_project(p: dict) -> str:
|
|
||||||
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
|
|
||||||
summary_part = f" — {summary[:100]}" if summary else ""
|
|
||||||
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
|
|
||||||
|
|
||||||
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
|
|
||||||
|
|
||||||
domain_definitions = "\n".join(
|
|
||||||
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
|
|
||||||
for d in config_data_types
|
|
||||||
if d in _DOMAIN_DESCRIPTIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
if custom_system_prompt:
|
|
||||||
# Fixture-provided prompt takes absolute priority
|
|
||||||
system = custom_system_prompt.format_map(
|
|
||||||
{"domain_definitions": domain_definitions, "projects_list": projects_list}
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
system = tracing.compile_prompt(
|
|
||||||
"batch_file_classifier",
|
|
||||||
fallback=_STEP1_SYSTEM_PROMPT,
|
|
||||||
variables={
|
|
||||||
"domain_definitions": domain_definitions,
|
|
||||||
"projects_list": projects_list,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
llm = get_llm(callbacks=[langfuse_handler] if langfuse_handler else None)
|
|
||||||
try:
|
|
||||||
response = await llm.ainvoke([
|
|
||||||
SystemMessage(content=system),
|
|
||||||
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
|
|
||||||
])
|
|
||||||
raw = _as_text(response.content).strip()
|
|
||||||
if raw.startswith("```"):
|
|
||||||
raw = raw.split("```")[1]
|
|
||||||
if raw.startswith("json"):
|
|
||||||
raw = raw[4:]
|
|
||||||
parsed = json.loads(raw.strip())
|
|
||||||
raw_project_id: str = str(parsed.get("project_id") or "new")
|
|
||||||
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
|
|
||||||
new_project_name: str | None = (
|
|
||||||
str(parsed["new_project_name"]).strip() or None
|
|
||||||
if project_id == "new" and parsed.get("new_project_name")
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
domains: list[str] = [
|
|
||||||
d for d in parsed.get("domains", [])
|
|
||||||
if d in config_data_types
|
|
||||||
]
|
|
||||||
if not domains:
|
|
||||||
domains = list(config_data_types)
|
|
||||||
return project_id, domains, new_project_name
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"agent_runner: step1 classification failed for %r: %s", file_path, exc
|
|
||||||
)
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
|
|
||||||
# ── Local agent runner (two-step per file) ────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def run_local_agent(user_id: str, trigger_data: dict[str, Any], *, langfuse_handler: Any | None = None) -> None:
|
|
||||||
"""Execute a local directory agent run.
|
|
||||||
|
|
||||||
In the microservice world, trigger_data is a serialized dict from
|
|
||||||
the REST route (forwarded via Redis), containing the agent config
|
|
||||||
fields and run_context.
|
|
||||||
|
|
||||||
set_current_user() must be called BEFORE this function.
|
|
||||||
"""
|
|
||||||
run_context: dict = trigger_data.get("run_context", {})
|
|
||||||
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
|
|
||||||
run_id = run_context.get("run_id")
|
|
||||||
|
|
||||||
_running_agents.add(agent_id)
|
|
||||||
|
|
||||||
# Extract config from trigger payload
|
|
||||||
directory_paths: list[str] = trigger_data.get("directory_paths", [])
|
|
||||||
if not directory_paths:
|
|
||||||
directory = trigger_data.get("directory", "")
|
|
||||||
if directory:
|
|
||||||
directory_paths = [directory]
|
|
||||||
|
|
||||||
data_types: list[str] = trigger_data.get("data_types", [])
|
|
||||||
file_extensions: list[str] = trigger_data.get("file_extensions", [])
|
|
||||||
prompt_template: str = trigger_data.get("prompt_template", "")
|
|
||||||
last_run_at_raw = trigger_data.get("last_run_at")
|
|
||||||
last_run_at: datetime | None = None
|
|
||||||
if last_run_at_raw:
|
|
||||||
if isinstance(last_run_at_raw, str):
|
|
||||||
last_run_at = datetime.fromisoformat(last_run_at_raw)
|
|
||||||
elif isinstance(last_run_at_raw, (int, float)):
|
|
||||||
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
|
|
||||||
|
|
||||||
errors: list[str] = []
|
|
||||||
items_processed = 0
|
|
||||||
items_created = 0
|
|
||||||
|
|
||||||
custom_section = (
|
|
||||||
f"User instructions:\n{prompt_template}"
|
|
||||||
if prompt_template
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create or load run log
|
|
||||||
run_log_id = run_id
|
|
||||||
if not run_log_id:
|
|
||||||
async with async_session() as db:
|
|
||||||
run_log = AgentRunLog(
|
|
||||||
agent_id=agent_id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=user_id,
|
|
||||||
status="running",
|
|
||||||
)
|
|
||||||
db.add(run_log)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(run_log)
|
|
||||||
run_log_id = run_log.id
|
|
||||||
|
|
||||||
try:
|
|
||||||
# ── Scan directories ─────────────────────────────────────────
|
|
||||||
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
|
|
||||||
file_paths = await _scan_directories(
|
|
||||||
paths=directory_paths,
|
|
||||||
extensions=file_extensions,
|
|
||||||
last_run_at=last_run_at,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
|
|
||||||
)
|
|
||||||
|
|
||||||
if not file_paths:
|
|
||||||
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
|
|
||||||
return
|
|
||||||
|
|
||||||
# ── Fetch all projects once ──────────────────────────────────
|
|
||||||
projects = await _fetch_projects()
|
|
||||||
|
|
||||||
for file_path in file_paths:
|
|
||||||
try:
|
|
||||||
file_result = await execute_on_client(
|
|
||||||
action="read_file_content", data={"path": file_path}
|
|
||||||
)
|
|
||||||
file_content: str = file_result.get("content", "")
|
|
||||||
if not file_content:
|
|
||||||
continue
|
|
||||||
|
|
||||||
items_processed += 1
|
|
||||||
|
|
||||||
# Step 1 — classify file
|
|
||||||
project_id, domains, new_project_name = await _classify_file(
|
|
||||||
file_path=file_path,
|
|
||||||
file_content=file_content,
|
|
||||||
projects=projects,
|
|
||||||
config_data_types=data_types,
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Step 2 — resolve project_id, fetch entities, process
|
|
||||||
if project_id == "new":
|
|
||||||
proj_name = new_project_name or "Untitled Project"
|
|
||||||
try:
|
|
||||||
proj_result = await execute_on_client(
|
|
||||||
action="insert",
|
|
||||||
table="projects",
|
|
||||||
data={"name": proj_name, "clientId": None},
|
|
||||||
)
|
|
||||||
created = proj_result.get("row", {})
|
|
||||||
effective_project_id = created.get("id", "standalone")
|
|
||||||
if "id" in created:
|
|
||||||
projects.append(created)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
|
|
||||||
effective_project_id = "standalone"
|
|
||||||
proj_name = "unknown"
|
|
||||||
project_context = (
|
|
||||||
f"Project: {proj_name} (id: {effective_project_id}). "
|
|
||||||
"Always set projectId to this id on every record you create."
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
effective_project_id = project_id
|
|
||||||
proj = next((p for p in projects if p["id"] == project_id), None)
|
|
||||||
proj_name = proj.get("name", project_id) if proj else project_id
|
|
||||||
project_context = (
|
|
||||||
f"Project: {proj_name} (id: {project_id}). "
|
|
||||||
"Always set projectId to this id on every record you create."
|
|
||||||
)
|
|
||||||
|
|
||||||
domains = [d for d in domains if d != "projects"]
|
|
||||||
|
|
||||||
existing_blocks: list[str] = []
|
|
||||||
for domain in domains:
|
|
||||||
rows = await _fetch_domain_entities(domain, effective_project_id)
|
|
||||||
existing_blocks.append(_format_entities_for_context(domain, rows))
|
|
||||||
|
|
||||||
existing_context = "\n\n".join(existing_blocks)
|
|
||||||
|
|
||||||
system_prompt = tracing.compile_prompt(
|
|
||||||
"batch_processing",
|
|
||||||
fallback=_PROCESSING_SYSTEM_PROMPT,
|
|
||||||
variables={
|
|
||||||
"existing_context": existing_context,
|
|
||||||
"project_context": project_context,
|
|
||||||
"data_types": ", ".join(domains),
|
|
||||||
"custom_prompt_section": custom_section,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
processing_tools = _build_processing_tools(domains)
|
|
||||||
|
|
||||||
result_text = await _run_agent_with_tools(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
user_message=(
|
|
||||||
f"Process this file and extract relevant information.\n\n"
|
|
||||||
f"File: {file_path}\n\nContent:\n{file_content}"
|
|
||||||
),
|
|
||||||
tools=processing_tools,
|
|
||||||
max_steps=_MAX_PROCESSING_STEPS,
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
)
|
|
||||||
logger.info(
|
|
||||||
"agent_runner: run=%s file=%r result=%s",
|
|
||||||
run_log_id, file_path, result_text[:200],
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
errors.append(f"Error processing '{file_path}': {exc}")
|
|
||||||
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
errors.append(f"Agent run failed: {exc}")
|
|
||||||
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
|
|
||||||
finally:
|
|
||||||
_running_agents.discard(agent_id)
|
|
||||||
|
|
||||||
# ── Finalise ────────────────────────────────────────────────────
|
|
||||||
if errors and items_processed == 0:
|
|
||||||
final_status = "error"
|
|
||||||
elif errors:
|
|
||||||
final_status = "partial"
|
|
||||||
else:
|
|
||||||
final_status = "success"
|
|
||||||
|
|
||||||
await _finalize_run(
|
|
||||||
run_log_id,
|
|
||||||
status=final_status,
|
|
||||||
items_processed=items_processed,
|
|
||||||
items_created=items_created,
|
|
||||||
errors=errors,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Notify Electron that the run is complete via Redis
|
|
||||||
if run_context:
|
|
||||||
try:
|
|
||||||
channel = ws_out_channel(user_id)
|
|
||||||
await redis_client.publish(channel, json.dumps({
|
|
||||||
"type": "run_complete",
|
|
||||||
"run_context": run_context,
|
|
||||||
"status": final_status,
|
|
||||||
}))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Cloud agent runner ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
|
|
||||||
|
|
||||||
|
|
||||||
async def run_cloud_agent(user_id: str, config_id: str, *, langfuse_handler: Any | None = None) -> None:
|
|
||||||
"""Execute a cloud connector agent run.
|
|
||||||
|
|
||||||
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
|
|
||||||
messages from the provider, and runs LLM extraction.
|
|
||||||
|
|
||||||
set_current_user() must be called BEFORE this function.
|
|
||||||
"""
|
|
||||||
from app.integrations import decrypt_token, encrypt_token, get_provider
|
|
||||||
|
|
||||||
async with async_session() as db:
|
|
||||||
result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
|
||||||
)
|
|
||||||
config = result.scalar_one_or_none()
|
|
||||||
if config is None:
|
|
||||||
logger.error("agent_runner: cloud config %s not found", config_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
# Create run log
|
|
||||||
run_log = AgentRunLog(
|
|
||||||
agent_id=config.id,
|
|
||||||
agent_type="cloud",
|
|
||||||
user_id=user_id,
|
|
||||||
status="running",
|
|
||||||
)
|
|
||||||
db.add(run_log)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(run_log)
|
|
||||||
run_log_id = run_log.id
|
|
||||||
|
|
||||||
# ── Decrypt OAuth token ────────────────────────────────────────
|
|
||||||
if not config.oauth_token_encrypted:
|
|
||||||
await _finalize_run(
|
|
||||||
run_log_id,
|
|
||||||
status="error",
|
|
||||||
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
credentials_info = decrypt_token(config.oauth_token_encrypted)
|
|
||||||
except ValueError as exc:
|
|
||||||
await _finalize_run(
|
|
||||||
run_log_id,
|
|
||||||
status="error",
|
|
||||||
errors=[f"Failed to decrypt OAuth token: {exc}"],
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
# ── Instantiate provider ──────────────────────────────────────
|
|
||||||
try:
|
|
||||||
provider = get_provider(config.provider, credentials_info)
|
|
||||||
except ValueError as exc:
|
|
||||||
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
|
|
||||||
return
|
|
||||||
|
|
||||||
# ── Fetch messages ────────────────────────────────────────────
|
|
||||||
since: datetime | None = config.last_run_at
|
|
||||||
if since is None:
|
|
||||||
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
|
|
||||||
if since.tzinfo is None:
|
|
||||||
since = since.replace(tzinfo=timezone.utc)
|
|
||||||
|
|
||||||
errors: list[str] = []
|
|
||||||
items_processed = 0
|
|
||||||
|
|
||||||
try:
|
|
||||||
if config.provider == "gmail":
|
|
||||||
raw_messages = await provider.fetch_messages(
|
|
||||||
filter_config=config.filter_config,
|
|
||||||
since=since,
|
|
||||||
)
|
|
||||||
elif config.provider == "outlook":
|
|
||||||
raw_messages = await provider.fetch_emails(
|
|
||||||
filter_config=config.filter_config,
|
|
||||||
since=since,
|
|
||||||
)
|
|
||||||
elif config.provider == "teams":
|
|
||||||
raw_messages = await provider.fetch_messages(
|
|
||||||
filter_config=config.filter_config,
|
|
||||||
since=since,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
raw_messages = []
|
|
||||||
except RuntimeError as exc:
|
|
||||||
await _finalize_run(
|
|
||||||
run_log_id,
|
|
||||||
status="error",
|
|
||||||
errors=[f"Provider fetch failed: {exc}"],
|
|
||||||
update_config_last_run=True,
|
|
||||||
config_id=config.id,
|
|
||||||
config_type="cloud",
|
|
||||||
)
|
|
||||||
return
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"agent_runner: cloud agent %s fetched %d item(s) from %s",
|
|
||||||
config.id, len(raw_messages), config.provider,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Extract + insert via LLM ─────────────────────────────────
|
|
||||||
try:
|
|
||||||
processing_tools = _build_processing_tools(config.data_types)
|
|
||||||
custom_section = (
|
|
||||||
f"User instructions:\n{config.prompt_template}"
|
|
||||||
if config.prompt_template
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
|
|
||||||
for msg in raw_messages:
|
|
||||||
content_text = msg.as_text
|
|
||||||
if not content_text:
|
|
||||||
continue
|
|
||||||
items_processed += 1
|
|
||||||
|
|
||||||
processing_prompt = tracing.compile_prompt(
|
|
||||||
"batch_cloud_processing",
|
|
||||||
fallback=_CLOUD_PROCESSING_PROMPT,
|
|
||||||
variables={
|
|
||||||
"data_types": ", ".join(config.data_types),
|
|
||||||
"project_context": "Determine the appropriate project from the message context.",
|
|
||||||
"file_list": f"Message from {config.provider} (id: {msg.id})",
|
|
||||||
"custom_prompt_section": custom_section,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
await _run_agent_with_tools(
|
|
||||||
system_prompt=processing_prompt,
|
|
||||||
user_message=f"Process this message content:\n\n{content_text[:8000]}",
|
|
||||||
tools=processing_tools,
|
|
||||||
max_steps=_MAX_PROCESSING_STEPS,
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
|
|
||||||
except Exception as exc:
|
|
||||||
errors.append(f"Agent run failed: {exc}")
|
|
||||||
|
|
||||||
# ── Persist refreshed token ───────────────────────────────────
|
|
||||||
refreshed = getattr(provider, "refreshed_credentials", None)
|
|
||||||
if refreshed:
|
|
||||||
try:
|
|
||||||
new_encrypted = encrypt_token(refreshed)
|
|
||||||
async with async_session() as db:
|
|
||||||
cfg_result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
|
|
||||||
)
|
|
||||||
cfg_row = cfg_result.scalar_one_or_none()
|
|
||||||
if cfg_row:
|
|
||||||
cfg_row.oauth_token_encrypted = new_encrypted
|
|
||||||
await db.commit()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
|
|
||||||
|
|
||||||
# ── Finalise ──────────────────────────────────────────────────
|
|
||||||
if errors and items_processed == 0:
|
|
||||||
final_status = "error"
|
|
||||||
elif errors:
|
|
||||||
final_status = "partial"
|
|
||||||
else:
|
|
||||||
final_status = "success"
|
|
||||||
|
|
||||||
await _finalize_run(
|
|
||||||
run_log_id,
|
|
||||||
status=final_status,
|
|
||||||
items_processed=items_processed,
|
|
||||||
items_created=0,
|
|
||||||
errors=errors,
|
|
||||||
update_config_last_run=True,
|
|
||||||
config_id=config.id,
|
|
||||||
config_type="cloud",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Internal helper ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _finalize_run(
|
|
||||||
run_log_id: int | str,
|
|
||||||
*,
|
|
||||||
status: str,
|
|
||||||
items_processed: int = 0,
|
|
||||||
items_created: int = 0,
|
|
||||||
errors: list[str] | None = None,
|
|
||||||
update_config_last_run: bool = False,
|
|
||||||
config_id: str | None = None,
|
|
||||||
config_type: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Persist the run outcome and optionally update last_run_at on the config."""
|
|
||||||
now = datetime.now(timezone.utc)
|
|
||||||
try:
|
|
||||||
async with async_session() as db:
|
|
||||||
result = await db.execute(
|
|
||||||
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
|
|
||||||
)
|
|
||||||
managed = result.scalar_one_or_none()
|
|
||||||
if managed is None:
|
|
||||||
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
|
|
||||||
return
|
|
||||||
|
|
||||||
managed.status = status
|
|
||||||
managed.items_processed = items_processed
|
|
||||||
managed.items_created = items_created
|
|
||||||
managed.errors = errors or []
|
|
||||||
managed.completed_at = now
|
|
||||||
|
|
||||||
if update_config_last_run and config_id:
|
|
||||||
if config_type == "local":
|
|
||||||
cfg_result = await db.execute(
|
|
||||||
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
|
||||||
)
|
|
||||||
cfg = cfg_result.scalar_one_or_none()
|
|
||||||
if cfg:
|
|
||||||
cfg.last_run_at = now
|
|
||||||
elif config_type == "cloud":
|
|
||||||
cfg_result = await db.execute(
|
|
||||||
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
|
|
||||||
)
|
|
||||||
cfg = cfg_result.scalar_one_or_none()
|
|
||||||
if cfg:
|
|
||||||
cfg.last_run_at = now
|
|
||||||
|
|
||||||
await db.commit()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Batch Agent Service domain agents and filesystem tools."""
|
|
||||||
@@ -1,83 +0,0 @@
|
|||||||
"""Filesystem agent — tools for reading local directories and files on Electron.
|
|
||||||
|
|
||||||
Adapted for Batch Agent Service: import from app.ws_context.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.tools import tool
|
|
||||||
|
|
||||||
from shared.ws_context import execute_on_client
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def list_directory(path: str) -> str:
|
|
||||||
"""List files and folders in a local directory on the user's device.
|
|
||||||
|
|
||||||
Returns a formatted listing of entries with name, type (file/directory),
|
|
||||||
and full path.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="list_directory",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
entries: list[dict[str, Any]] = result.get("entries", [])
|
|
||||||
if not entries:
|
|
||||||
return f"Directory '{path}' is empty or does not exist."
|
|
||||||
lines: list[str] = []
|
|
||||||
for entry in entries:
|
|
||||||
entry_type = entry.get("type", "unknown")
|
|
||||||
entry_name = entry.get("name", "")
|
|
||||||
entry_path = entry.get("path", "")
|
|
||||||
lines.append(f"- [{entry_type}] {entry_name} ({entry_path})")
|
|
||||||
return f"Directory listing for '{path}' ({len(entries)} entries):\n" + "\n".join(lines)
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def read_file_content(path: str) -> str:
|
|
||||||
"""Read the text content of a local file on the user's device.
|
|
||||||
|
|
||||||
Returns the file content as a string. Large files may be truncated
|
|
||||||
by the Electron client.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="read_file_content",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
content: str = result.get("content", "")
|
|
||||||
if not content:
|
|
||||||
return f"File '{path}' is empty or could not be read."
|
|
||||||
return content
|
|
||||||
|
|
||||||
|
|
||||||
@tool
|
|
||||||
async def get_file_metadata(path: str) -> str:
|
|
||||||
"""Get metadata for a local file: size, creation date, modification date, extension.
|
|
||||||
|
|
||||||
Returns a formatted summary of the file's metadata.
|
|
||||||
"""
|
|
||||||
result = await execute_on_client(
|
|
||||||
action="get_file_metadata",
|
|
||||||
data={"path": path},
|
|
||||||
)
|
|
||||||
size = result.get("size", "unknown")
|
|
||||||
created = result.get("createdAt", "unknown")
|
|
||||||
modified = result.get("modifiedAt", "unknown")
|
|
||||||
extension = result.get("extension", "unknown")
|
|
||||||
name = result.get("name", path)
|
|
||||||
return (
|
|
||||||
f"File: {name}\n"
|
|
||||||
f" Extension: {extension}\n"
|
|
||||||
f" Size: {size} bytes\n"
|
|
||||||
f" Created: {created}\n"
|
|
||||||
f" Modified: {modified}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
FILESYSTEM_TOOLS: list[Any] = [
|
|
||||||
list_directory,
|
|
||||||
read_file_content,
|
|
||||||
get_file_metadata,
|
|
||||||
]
|
|
||||||
@@ -1,395 +0,0 @@
|
|||||||
"""Chatbot Journey — guided conversation to build an agent prompt_template.
|
|
||||||
|
|
||||||
Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
|
|
||||||
and app.llm instead of monolith paths. Session state is in-memory (could
|
|
||||||
be moved to Redis for horizontal scaling in the future).
|
|
||||||
|
|
||||||
Journey flow:
|
|
||||||
1. Redis consumer dispatches ``journey_start`` with basic agent config.
|
|
||||||
2. Server creates an in-memory session, runs the setup LLM with
|
|
||||||
file-system tools to explore the directory, returns first question.
|
|
||||||
3. ``journey_message`` frames drive the conversation.
|
|
||||||
4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
|
|
||||||
5. Server parses the block and returns ``journey_reply`` with ``done=True``.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
|
|
||||||
|
|
||||||
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
|
|
||||||
from shared.llm import get_llm
|
|
||||||
import app.tracing as tracing
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ── Session TTL ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
|
|
||||||
|
|
||||||
# Sentinel strings used to delimit the LLM-produced prompt_template.
|
|
||||||
_TEMPLATE_START = "PROMPT_TEMPLATE_START"
|
|
||||||
_TEMPLATE_END = "PROMPT_TEMPLATE_END"
|
|
||||||
|
|
||||||
_MIN_TURNS_BEFORE_NUDGE: int = 3
|
|
||||||
_MAX_TURNS: int = 15
|
|
||||||
_MAX_TOOL_STEPS: int = 6
|
|
||||||
|
|
||||||
# ── In-memory session store ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class JourneySession:
|
|
||||||
session_id: str
|
|
||||||
user_id: str
|
|
||||||
agent_type: str # "local" | "cloud"
|
|
||||||
directory: str
|
|
||||||
data_types: list[str]
|
|
||||||
history: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
system_prompt: str = ""
|
|
||||||
created_at: float = field(default_factory=time.monotonic)
|
|
||||||
|
|
||||||
def is_expired(self) -> bool:
|
|
||||||
return (time.monotonic() - self.created_at) > _SESSION_TTL_SECONDS
|
|
||||||
|
|
||||||
|
|
||||||
# session_id → session
|
|
||||||
_sessions: dict[str, JourneySession] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
|
|
||||||
"""Retrieve session; return None on missing, expired, or wrong owner."""
|
|
||||||
s = _sessions.get(session_id)
|
|
||||||
if s is None or s.is_expired():
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
return None
|
|
||||||
if s.user_id != user_id:
|
|
||||||
return None
|
|
||||||
return s
|
|
||||||
|
|
||||||
|
|
||||||
# ── System prompt builder ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_SYSTEM_PROMPT_TEMPLATE = """\
|
|
||||||
You are a friendly assistant helping a freelancer configure a data-extraction agent.
|
|
||||||
Your job is to understand exactly what data the user wants to extract from their
|
|
||||||
local directory and produce a concise prompt_template that a separate AI will use
|
|
||||||
as its instruction set.
|
|
||||||
|
|
||||||
You have access to file-system tools to explore the user's directory:
|
|
||||||
- list_directory: to see folder structure
|
|
||||||
- read_file_content: to peek at file contents
|
|
||||||
- get_file_metadata: to check file info
|
|
||||||
|
|
||||||
The user's configured directory is: {directory}
|
|
||||||
Target data types: {data_types}
|
|
||||||
|
|
||||||
IMPORTANT — project assignment is handled automatically. You MUST NOT ask the user
|
|
||||||
about projects, projectId, or how to link records to projects. Never include
|
|
||||||
projectId logic or project creation instructions in the generated prompt_template.
|
|
||||||
|
|
||||||
Start by exploring the directory to understand its structure. Then ask concise,
|
|
||||||
focused questions one at a time. Cover only the topics relevant to the target
|
|
||||||
data types listed above:
|
|
||||||
|
|
||||||
1. Content type and format — confirmed by your exploration.
|
|
||||||
2. For TASKS (if in scope): field mapping for title, status, priority, content,
|
|
||||||
dueDate (where is the date found? what's the fallback when absent?),
|
|
||||||
and assignee (is there a person name to assign?).
|
|
||||||
3. For NOTES when TASKS are also in scope: note vs task distinction —
|
|
||||||
what makes something a note rather than a task?
|
|
||||||
4. For TIMELINES (if in scope): the date source — what marks a milestone or event?
|
|
||||||
5. Exclusions and special handling applicable to the target data types.
|
|
||||||
|
|
||||||
Keep asking focused questions until you are at least 90% confident. Then stop and
|
|
||||||
output the final prompt_template immediately, wrapped between these exact markers
|
|
||||||
on their own lines:
|
|
||||||
|
|
||||||
{template_start}
|
|
||||||
<the complete extraction prompt here>
|
|
||||||
{template_end}
|
|
||||||
|
|
||||||
The prompt_template must be concise (bullet points, ~15–25 lines maximum).
|
|
||||||
Specify only:
|
|
||||||
- Scope: what files/content qualify and what entity types to create.
|
|
||||||
- Field mapping rules per entity type (camelCase fields: title, status, priority,
|
|
||||||
dueDate, content, assignee, etc.).
|
|
||||||
- dueDate rule (if tasks in scope): source and fallback behaviour.
|
|
||||||
- Note vs task rule (if both in scope): the criterion that separates them.
|
|
||||||
- Timeline date rule (if timelines in scope): what constitutes a timeline event.
|
|
||||||
- Exclusion/filtering rules.
|
|
||||||
- 2–3 concrete mapping examples based on what you discovered.
|
|
||||||
|
|
||||||
{existing_section}Begin by exploring the directory, then ask your first question.\
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
def _build_system_prompt(
|
|
||||||
directory: str,
|
|
||||||
data_types: list[str],
|
|
||||||
existing_template: str | None = None,
|
|
||||||
) -> str:
|
|
||||||
existing_section = (
|
|
||||||
f"\nThe user already has the following prompt_template — refine it based on their answers:\n"
|
|
||||||
f"---\n{existing_template}\n---\n"
|
|
||||||
if existing_template
|
|
||||||
else ""
|
|
||||||
)
|
|
||||||
# Use Langfuse compile_prompt ({{variable}} syntax) with Python .format() fallback
|
|
||||||
return tracing.compile_prompt(
|
|
||||||
"journey_system",
|
|
||||||
fallback=_SYSTEM_PROMPT_TEMPLATE,
|
|
||||||
variables={
|
|
||||||
"directory": directory,
|
|
||||||
"data_types": ", ".join(data_types),
|
|
||||||
"existing_section": existing_section,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Template extraction ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _extract_template(text: str) -> str | None:
|
|
||||||
"""Return the text between PROMPT_TEMPLATE_START and PROMPT_TEMPLATE_END, or None."""
|
|
||||||
if _TEMPLATE_START not in text or _TEMPLATE_END not in text:
|
|
||||||
return None
|
|
||||||
start_idx = text.index(_TEMPLATE_START) + len(_TEMPLATE_START)
|
|
||||||
end_idx = text.index(_TEMPLATE_END)
|
|
||||||
return text[start_idx:end_idx].strip() or None
|
|
||||||
|
|
||||||
|
|
||||||
# ── LLM call with tool support ───────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _as_text(content: Any) -> str:
|
|
||||||
if content is None:
|
|
||||||
return ""
|
|
||||||
if isinstance(content, str):
|
|
||||||
return content
|
|
||||||
if isinstance(content, list):
|
|
||||||
parts: list[str] = []
|
|
||||||
for item in content:
|
|
||||||
if isinstance(item, str):
|
|
||||||
parts.append(item)
|
|
||||||
elif isinstance(item, dict):
|
|
||||||
text = item.get("text")
|
|
||||||
if isinstance(text, str):
|
|
||||||
parts.append(text)
|
|
||||||
return "".join(parts)
|
|
||||||
return str(content)
|
|
||||||
|
|
||||||
|
|
||||||
async def _call_llm_with_tools(
|
|
||||||
system_prompt: str,
|
|
||||||
history: list[dict[str, Any]],
|
|
||||||
tools: list[Any],
|
|
||||||
langfuse_handler: Any | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Build LangChain messages from history and invoke the LLM with tools.
|
|
||||||
|
|
||||||
Handles tool-calling loops: if the LLM calls tools, execute them and
|
|
||||||
continue until a final text response is produced.
|
|
||||||
"""
|
|
||||||
messages: list[Any] = [SystemMessage(content=system_prompt)]
|
|
||||||
for turn in history:
|
|
||||||
if turn["role"] == "user":
|
|
||||||
messages.append(HumanMessage(content=turn["content"]))
|
|
||||||
else:
|
|
||||||
messages.append(AIMessage(content=turn["content"]))
|
|
||||||
|
|
||||||
callbacks = [langfuse_handler] if langfuse_handler else None
|
|
||||||
llm = get_llm(model=None, temperature=0.4, callbacks=callbacks)
|
|
||||||
llm_with_tools = llm.bind_tools(tools)
|
|
||||||
tool_map = {tool_def.name: tool_def for tool_def in tools}
|
|
||||||
|
|
||||||
for _ in range(_MAX_TOOL_STEPS):
|
|
||||||
response: AIMessage = await llm_with_tools.ainvoke(messages)
|
|
||||||
messages.append(response)
|
|
||||||
|
|
||||||
if not response.tool_calls:
|
|
||||||
return _as_text(response.content)
|
|
||||||
|
|
||||||
for call in response.tool_calls:
|
|
||||||
call_name = str(call.get("name", ""))
|
|
||||||
call_args = call.get("args", {})
|
|
||||||
logger.info(
|
|
||||||
"journey: tool_call name=%s args=%s",
|
|
||||||
call_name,
|
|
||||||
json.dumps(call_args, ensure_ascii=True)[:500],
|
|
||||||
)
|
|
||||||
|
|
||||||
tool_fn = tool_map.get(call_name)
|
|
||||||
if tool_fn is None:
|
|
||||||
tool_output = f"Unknown tool: {call_name}"
|
|
||||||
else:
|
|
||||||
tool_output = await tool_fn.ainvoke(call_args)
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"journey: tool_result name=%s output=%s",
|
|
||||||
call_name,
|
|
||||||
str(tool_output)[:800],
|
|
||||||
)
|
|
||||||
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
|
|
||||||
|
|
||||||
# Fallback: exceeded max tool steps.
|
|
||||||
final = await llm.ainvoke(messages)
|
|
||||||
return _as_text(final.content)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Journey handlers (called from redis_consumer) ────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_journey_start(
|
|
||||||
user_id: str,
|
|
||||||
frame: dict[str, Any],
|
|
||||||
*,
|
|
||||||
langfuse_handler: Any | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Handle a ``journey_start`` request.
|
|
||||||
|
|
||||||
Creates a session, runs the setup LLM with directory exploration,
|
|
||||||
and returns the ``journey_reply`` payload.
|
|
||||||
"""
|
|
||||||
agent_type = frame.get("agent_type", "local")
|
|
||||||
directory = frame.get("directory", "")
|
|
||||||
data_types = frame.get("data_types", [])
|
|
||||||
existing_template = frame.get("existing_template")
|
|
||||||
|
|
||||||
session_id = frame.get("session_id") or str(uuid.uuid4())
|
|
||||||
system_prompt = _build_system_prompt(directory, data_types, existing_template)
|
|
||||||
|
|
||||||
session = JourneySession(
|
|
||||||
session_id=session_id,
|
|
||||||
user_id=user_id,
|
|
||||||
agent_type=agent_type,
|
|
||||||
directory=directory,
|
|
||||||
data_types=data_types,
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
seed_history: list[dict[str, Any]] = [
|
|
||||||
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
|
|
||||||
]
|
|
||||||
ai_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
history=seed_history,
|
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
session.history.extend(seed_history)
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
_sessions[session_id] = session
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"journey: session %s started for user %s (directory=%s)",
|
|
||||||
session_id,
|
|
||||||
user_id,
|
|
||||||
directory,
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt_template = _extract_template(ai_reply)
|
|
||||||
done = prompt_template is not None
|
|
||||||
|
|
||||||
display_message = ai_reply
|
|
||||||
if done:
|
|
||||||
display_message = (
|
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
|
||||||
or "Here is your agent configuration. You can save it or continue refining."
|
|
||||||
)
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": display_message,
|
|
||||||
"done": done,
|
|
||||||
"prompt_template": prompt_template,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
async def handle_journey_message(
|
|
||||||
user_id: str,
|
|
||||||
frame: dict[str, Any],
|
|
||||||
*,
|
|
||||||
langfuse_handler: Any | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Handle a ``journey_message`` request.
|
|
||||||
|
|
||||||
Appends the user message, calls the LLM, and returns the
|
|
||||||
``journey_reply`` payload.
|
|
||||||
"""
|
|
||||||
session_id = frame.get("session_id", "")
|
|
||||||
message = frame.get("message", "")
|
|
||||||
|
|
||||||
session = get_journey_session(session_id, user_id)
|
|
||||||
if session is None:
|
|
||||||
return {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": "Journey session not found or expired. Please start a new setup.",
|
|
||||||
"done": True,
|
|
||||||
"prompt_template": None,
|
|
||||||
}
|
|
||||||
|
|
||||||
session.history.append({"role": "user", "content": message})
|
|
||||||
|
|
||||||
ai_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=session.system_prompt,
|
|
||||||
history=session.history,
|
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
)
|
|
||||||
|
|
||||||
session.history.append({"role": "assistant", "content": ai_reply})
|
|
||||||
|
|
||||||
prompt_template = _extract_template(ai_reply)
|
|
||||||
done = prompt_template is not None
|
|
||||||
|
|
||||||
if not done:
|
|
||||||
turns = sum(1 for t in session.history if t["role"] == "user")
|
|
||||||
if turns >= _MAX_TURNS:
|
|
||||||
nudge_content = (
|
|
||||||
"[System: You have enough information. Please generate the final "
|
|
||||||
f"prompt_template now, wrapped in {_TEMPLATE_START} / {_TEMPLATE_END} markers.]"
|
|
||||||
)
|
|
||||||
session.history.append({"role": "user", "content": nudge_content})
|
|
||||||
|
|
||||||
nudge_reply = await _call_llm_with_tools(
|
|
||||||
system_prompt=session.system_prompt,
|
|
||||||
history=session.history,
|
|
||||||
tools=list(FILESYSTEM_TOOLS),
|
|
||||||
langfuse_handler=langfuse_handler,
|
|
||||||
)
|
|
||||||
session.history.append({"role": "assistant", "content": nudge_reply})
|
|
||||||
|
|
||||||
prompt_template = _extract_template(nudge_reply)
|
|
||||||
if prompt_template is not None:
|
|
||||||
done = True
|
|
||||||
ai_reply = nudge_reply
|
|
||||||
|
|
||||||
display_message = ai_reply
|
|
||||||
if done:
|
|
||||||
display_message = (
|
|
||||||
ai_reply[: ai_reply.index(_TEMPLATE_START)].strip()
|
|
||||||
if _TEMPLATE_START in ai_reply
|
|
||||||
else "Here is your agent configuration. You can save it or continue refining."
|
|
||||||
)
|
|
||||||
_sessions.pop(session_id, None)
|
|
||||||
logger.info("journey: session %s completed for user %s", session_id, user_id)
|
|
||||||
|
|
||||||
return {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": display_message,
|
|
||||||
"done": done,
|
|
||||||
"prompt_template": prompt_template,
|
|
||||||
}
|
|
||||||
@@ -1,76 +0,0 @@
|
|||||||
"""LLM factory — centralised model instantiation via LiteLLM.
|
|
||||||
|
|
||||||
Identical to services/chat/app/llm.py. Uses shared.config.settings.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import os
|
|
||||||
import warnings
|
|
||||||
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
import litellm
|
|
||||||
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from langchain_litellm import ChatLiteLLM
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
litellm.drop_params = True
|
|
||||||
|
|
||||||
warnings.filterwarnings(
|
|
||||||
"ignore",
|
|
||||||
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
|
|
||||||
category=UserWarning,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _api_key_for_model(model: str) -> str | None:
|
|
||||||
if model.startswith("anthropic/"):
|
|
||||||
return settings.ANTHROPIC_API_KEY or None
|
|
||||||
if model.startswith("gemini/") or model.startswith("google/"):
|
|
||||||
return settings.GOOGLE_API_KEY or None
|
|
||||||
if model.startswith("cerebras/"):
|
|
||||||
return settings.CEREBRAS_API_KEY or None
|
|
||||||
if model.startswith("github/"):
|
|
||||||
return settings.GITHUB_TOKEN or None
|
|
||||||
if model.startswith("github_copilot/"):
|
|
||||||
return None
|
|
||||||
return settings.OPENAI_API_KEY or None
|
|
||||||
|
|
||||||
|
|
||||||
def get_llm(
|
|
||||||
*,
|
|
||||||
model: str | None = None,
|
|
||||||
temperature: float = 0,
|
|
||||||
callbacks: list | None = None,
|
|
||||||
) -> ChatOpenAI | ChatLiteLLM:
|
|
||||||
model = model or settings.LLM_MODEL
|
|
||||||
|
|
||||||
if settings.GITHUB_COPILOT_TOKEN_DIR:
|
|
||||||
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
|
|
||||||
|
|
||||||
if settings.GITHUB_TOKEN:
|
|
||||||
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
|
|
||||||
|
|
||||||
if "/" in model:
|
|
||||||
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
|
|
||||||
|
|
||||||
return ChatOpenAI(
|
|
||||||
model=model,
|
|
||||||
temperature=temperature,
|
|
||||||
api_key=_api_key_for_model(model),
|
|
||||||
callbacks=callbacks,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
async def embed(text: str) -> list[float]:
|
|
||||||
model = settings.LLM_EMBED_MODEL
|
|
||||||
|
|
||||||
if model.startswith("github_copilot/") or "/" in model:
|
|
||||||
response = await litellm.aembedding(model=model, input=[text])
|
|
||||||
return response.data[0]["embedding"]
|
|
||||||
|
|
||||||
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
|
|
||||||
response = await client.embeddings.create(model=model, input=text)
|
|
||||||
return response.data[0].embedding
|
|
||||||
@@ -1,79 +0,0 @@
|
|||||||
"""Batch Agent Service — FastAPI application.
|
|
||||||
|
|
||||||
Owns: agent_runner (local directory + cloud connectors), journey builder,
|
|
||||||
filesystem_agent, integrations (Gmail, MS Graph).
|
|
||||||
|
|
||||||
Communicates with WS Gateway via Redis:
|
|
||||||
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
|
|
||||||
- Publishes to ws:out:{user_id} (journey replies + tool calls)
|
|
||||||
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
|
|
||||||
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Ensure the repo root is on sys.path so ``shared`` is importable when
|
|
||||||
# running locally (in Docker the COPY already places it at /app/shared/).
|
|
||||||
_repo_root = str(Path(__file__).resolve().parents[3])
|
|
||||||
if _repo_root not in sys.path:
|
|
||||||
sys.path.insert(0, _repo_root)
|
|
||||||
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from app.redis_consumer import start_consumer
|
|
||||||
from app.routes import router
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
||||||
# Initialise Langfuse tracing (no-op if keys are missing)
|
|
||||||
from app.tracing import init_langfuse
|
|
||||||
init_langfuse()
|
|
||||||
|
|
||||||
logger.info("batch-agent: starting Redis consumer")
|
|
||||||
task = asyncio.create_task(start_consumer())
|
|
||||||
yield
|
|
||||||
task.cancel()
|
|
||||||
try:
|
|
||||||
await task
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
from app.tracing import shutdown as shutdown_langfuse
|
|
||||||
shutdown_langfuse()
|
|
||||||
|
|
||||||
from shared.db import engine
|
|
||||||
await engine.dispose()
|
|
||||||
|
|
||||||
from shared.redis import redis_client
|
|
||||||
await redis_client.aclose()
|
|
||||||
|
|
||||||
logger.info("batch-agent: Redis consumer stopped")
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_methods=["GET", "POST"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health() -> dict[str, str]:
|
|
||||||
return {"status": "ok", "service": "batch-agent"}
|
|
||||||
@@ -1,183 +0,0 @@
|
|||||||
"""Redis consumer for the Batch Agent Service.
|
|
||||||
|
|
||||||
Subscribes to batch:request:* (pattern) and dispatches:
|
|
||||||
- journey_start → handle_journey_start
|
|
||||||
- journey_message → handle_journey_message
|
|
||||||
- agent_trigger → run_local_agent / run_cloud_agent
|
|
||||||
|
|
||||||
Results are published back to ws:out:{user_id} via Redis.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from shared.redis import redis_client, batch_request_channel, ws_out_channel
|
|
||||||
|
|
||||||
import app.tracing as tracing
|
|
||||||
from shared.ws_context import set_current_user, clear_current_user
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
|
|
||||||
"""Publish a frame to the user's WS outbound channel."""
|
|
||||||
channel = ws_out_channel(user_id)
|
|
||||||
await redis_client.publish(channel, json.dumps(payload))
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
|
|
||||||
"""Handle a journey_start request from WS Gateway."""
|
|
||||||
from app.journey import handle_journey_start
|
|
||||||
|
|
||||||
session_id = data.get("session_id", "")
|
|
||||||
set_current_user(user_id)
|
|
||||||
try:
|
|
||||||
with tracing.trace_span(
|
|
||||||
name="journey_start",
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
input=data.get("directory", ""),
|
|
||||||
metadata={"data_types": data.get("data_types", [])},
|
|
||||||
tags=["journey"],
|
|
||||||
) as span:
|
|
||||||
langfuse_handler = tracing.get_langfuse_callback()
|
|
||||||
reply = await handle_journey_start(user_id, data, langfuse_handler=langfuse_handler)
|
|
||||||
tracing.link_prompt_to_trace(span, "journey_system")
|
|
||||||
span.update(output=reply.get("message", "")[:500])
|
|
||||||
await _publish_to_user(user_id, reply)
|
|
||||||
tracing.flush()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
|
|
||||||
await _publish_to_user(user_id, {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": f"Journey setup failed: {exc}",
|
|
||||||
"done": True,
|
|
||||||
"prompt_template": None,
|
|
||||||
})
|
|
||||||
finally:
|
|
||||||
clear_current_user()
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
|
|
||||||
"""Handle a journey_message from WS Gateway."""
|
|
||||||
from app.journey import handle_journey_message
|
|
||||||
|
|
||||||
session_id = data.get("session_id", "")
|
|
||||||
set_current_user(user_id)
|
|
||||||
try:
|
|
||||||
with tracing.trace_span(
|
|
||||||
name="journey_message",
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
input=data.get("message", "")[:200],
|
|
||||||
tags=["journey"],
|
|
||||||
) as span:
|
|
||||||
langfuse_handler = tracing.get_langfuse_callback()
|
|
||||||
reply = await handle_journey_message(user_id, data, langfuse_handler=langfuse_handler)
|
|
||||||
tracing.link_prompt_to_trace(span, "journey_system")
|
|
||||||
span.update(output=reply.get("message", "")[:500])
|
|
||||||
await _publish_to_user(user_id, reply)
|
|
||||||
tracing.flush()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
|
|
||||||
await _publish_to_user(user_id, {
|
|
||||||
"type": "journey_reply",
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": f"Journey processing failed: {exc}",
|
|
||||||
"done": True,
|
|
||||||
"prompt_template": None,
|
|
||||||
})
|
|
||||||
finally:
|
|
||||||
clear_current_user()
|
|
||||||
|
|
||||||
|
|
||||||
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
|
|
||||||
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
|
|
||||||
from app.agent_runner import run_local_agent
|
|
||||||
|
|
||||||
run_context = data.get("run_context", {})
|
|
||||||
agent_id = run_context.get("agent_id", "")
|
|
||||||
set_current_user(user_id)
|
|
||||||
try:
|
|
||||||
with tracing.trace_span(
|
|
||||||
name="agent_trigger",
|
|
||||||
user_id=user_id,
|
|
||||||
trace_id=run_context.get("run_id"),
|
|
||||||
input={"agent_id": agent_id, "directory": data.get("directory", "")},
|
|
||||||
metadata={"data_types": data.get("data_types", [])},
|
|
||||||
tags=["batch", "agent_run"],
|
|
||||||
) as span:
|
|
||||||
langfuse_handler = tracing.get_langfuse_callback()
|
|
||||||
await run_local_agent(user_id, data, langfuse_handler=langfuse_handler)
|
|
||||||
tracing.link_prompt_to_trace(span, "batch_processing")
|
|
||||||
span.update(output={"status": "completed"})
|
|
||||||
tracing.flush()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
|
|
||||||
await _publish_to_user(user_id, {
|
|
||||||
"type": "run_complete",
|
|
||||||
"status": "error",
|
|
||||||
"run_context": run_context,
|
|
||||||
})
|
|
||||||
finally:
|
|
||||||
clear_current_user()
|
|
||||||
|
|
||||||
|
|
||||||
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
|
|
||||||
"""Route a batch request to the correct handler."""
|
|
||||||
msg_type = message_data.get("type", "")
|
|
||||||
|
|
||||||
if msg_type == "journey_start":
|
|
||||||
await _handle_journey_start(user_id, message_data)
|
|
||||||
elif msg_type == "journey_message":
|
|
||||||
await _handle_journey_message(user_id, message_data)
|
|
||||||
elif msg_type == "agent_trigger":
|
|
||||||
await _handle_agent_trigger(user_id, message_data)
|
|
||||||
elif msg_type == "device_online":
|
|
||||||
logger.info("batch-agent: device_online user=%s device=%s", user_id, message_data.get("device_id", "?"))
|
|
||||||
else:
|
|
||||||
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
|
|
||||||
|
|
||||||
|
|
||||||
async def start_consumer() -> None:
|
|
||||||
"""Subscribe to batch:request:* and dispatch incoming frames."""
|
|
||||||
pubsub = redis_client.pubsub()
|
|
||||||
await pubsub.psubscribe("batch:request:*")
|
|
||||||
logger.info("batch-agent: subscribed to batch:request:*")
|
|
||||||
|
|
||||||
try:
|
|
||||||
async for message in pubsub.listen():
|
|
||||||
if message["type"] != "pmessage":
|
|
||||||
continue
|
|
||||||
|
|
||||||
channel: str = message["channel"]
|
|
||||||
if isinstance(channel, bytes):
|
|
||||||
channel = channel.decode()
|
|
||||||
|
|
||||||
# Extract user_id from channel: batch:request:{user_id}
|
|
||||||
parts = channel.split(":", 2)
|
|
||||||
if len(parts) < 3:
|
|
||||||
continue
|
|
||||||
user_id = parts[2]
|
|
||||||
|
|
||||||
raw = message["data"]
|
|
||||||
if isinstance(raw, bytes):
|
|
||||||
raw = raw.decode()
|
|
||||||
|
|
||||||
try:
|
|
||||||
data = json.loads(raw)
|
|
||||||
except json.JSONDecodeError:
|
|
||||||
logger.warning("batch-agent: invalid JSON on channel %s", channel)
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Dispatch in a separate task to avoid blocking the consumer
|
|
||||||
asyncio.create_task(_dispatch(user_id, data))
|
|
||||||
except asyncio.CancelledError:
|
|
||||||
logger.info("batch-agent: consumer shutting down")
|
|
||||||
finally:
|
|
||||||
await pubsub.punsubscribe("batch:request:*")
|
|
||||||
@@ -1,208 +0,0 @@
|
|||||||
"""Agent REST routes — catalog, billing checks, trigger.
|
|
||||||
|
|
||||||
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
|
|
||||||
Agent trigger dispatches via Redis to the consumer instead of spawning
|
|
||||||
an in-process background task.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import uuid
|
|
||||||
from datetime import datetime, timezone
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Header, HTTPException, status
|
|
||||||
from sqlalchemy import func, select
|
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
|
||||||
|
|
||||||
from shared.db import async_session
|
|
||||||
from shared.models import AgentRunLog
|
|
||||||
from shared.redis import redis_client, batch_request_channel
|
|
||||||
|
|
||||||
from app.agent_runner import is_agent_running
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/agents", tags=["agents"])
|
|
||||||
|
|
||||||
# ── Tier feature limits ───────────────────────────────────────────────
|
|
||||||
# Mirrors app/billing/tier_manager.py FEATURES dict.
|
|
||||||
FEATURES: dict[str, dict] = {
|
|
||||||
"free": {"batch_active": 1, "batch_runs_per_day": 3},
|
|
||||||
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
|
|
||||||
"power": {"batch_active": 20, "batch_runs_per_day": 100},
|
|
||||||
"team": {"batch_active": -1, "batch_runs_per_day": -1},
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _dt_ms(dt: datetime) -> int:
|
|
||||||
return int(dt.timestamp() * 1000)
|
|
||||||
|
|
||||||
|
|
||||||
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|
||||||
return int(dt.timestamp() * 1000) if dt else None
|
|
||||||
|
|
||||||
|
|
||||||
def _to_data_types(values: list[str]) -> list[str]:
|
|
||||||
normalize = {
|
|
||||||
"task": "tasks", "tasks": "tasks",
|
|
||||||
"note": "notes", "notes": "notes",
|
|
||||||
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
|
||||||
"project": "projects", "projects": "projects",
|
|
||||||
}
|
|
||||||
seen: set[str] = set()
|
|
||||||
result: list[str] = []
|
|
||||||
for v in values:
|
|
||||||
mapped = normalize.get(v)
|
|
||||||
if mapped and mapped not in seen:
|
|
||||||
seen.add(mapped)
|
|
||||||
result.append(mapped)
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
|
||||||
if limit != -1 and current_count >= limit:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
|
||||||
)
|
|
||||||
return limit
|
|
||||||
|
|
||||||
|
|
||||||
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
|
|
||||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
|
||||||
if limit == -1:
|
|
||||||
return
|
|
||||||
today_start = datetime.now(timezone.utc).replace(
|
|
||||||
hour=0, minute=0, second=0, microsecond=0
|
|
||||||
)
|
|
||||||
async with async_session() as db:
|
|
||||||
result = await db.execute(
|
|
||||||
select(func.count(AgentRunLog.id)).where(
|
|
||||||
AgentRunLog.user_id == user_id,
|
|
||||||
AgentRunLog.started_at >= today_start,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
runs_today: int = result.scalar_one()
|
|
||||||
|
|
||||||
if runs_today >= limit:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
||||||
detail=f"Daily batch run limit ({limit}) reached for your tier.",
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Catalog ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/catalog")
|
|
||||||
async def get_agent_catalog(
|
|
||||||
x_user_id: str = Header(..., alias="X-User-Id"),
|
|
||||||
) -> list[dict]:
|
|
||||||
return [
|
|
||||||
{
|
|
||||||
"type": "local_directory",
|
|
||||||
"name": "Local Directory Monitor",
|
|
||||||
"description": "Watches local directories, extracts data from files using AI",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "gmail",
|
|
||||||
"name": "Gmail Connector",
|
|
||||||
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "teams",
|
|
||||||
"name": "Microsoft Teams Connector",
|
|
||||||
"description": "Monitors Teams messages, extracts action items",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"type": "outlook",
|
|
||||||
"name": "Outlook Connector",
|
|
||||||
"description": "Scans Outlook inbox, extracts tasks/notes",
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
# ── Can-create check ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/can-create")
|
|
||||||
async def can_create_agent(
|
|
||||||
body: dict,
|
|
||||||
x_user_id: str = Header(..., alias="X-User-Id"),
|
|
||||||
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
|
||||||
) -> dict:
|
|
||||||
active_agents = body.get("active_agents", 0)
|
|
||||||
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
|
|
||||||
allowed = limit == -1 or active_agents < limit
|
|
||||||
return {
|
|
||||||
"allowed": allowed,
|
|
||||||
"tier": x_user_tier,
|
|
||||||
"active_agents": active_agents,
|
|
||||||
"limit": limit,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Trigger ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
|
|
||||||
async def trigger_agent_run(
|
|
||||||
body: dict,
|
|
||||||
x_user_id: str = Header(..., alias="X-User-Id"),
|
|
||||||
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
|
||||||
) -> dict:
|
|
||||||
"""Trigger a local agent run — creates run log and dispatches via Redis."""
|
|
||||||
active_agents = body.get("active_agents", 0)
|
|
||||||
_enforce_agent_limit(x_user_tier, active_agents)
|
|
||||||
await _enforce_run_frequency(x_user_tier, x_user_id)
|
|
||||||
|
|
||||||
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
|
|
||||||
|
|
||||||
if is_agent_running(stable_agent_id):
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_409_CONFLICT,
|
|
||||||
detail="Agent is already running.",
|
|
||||||
)
|
|
||||||
|
|
||||||
# Create run log in DB
|
|
||||||
async with async_session() as db:
|
|
||||||
run_log = AgentRunLog(
|
|
||||||
agent_id=stable_agent_id,
|
|
||||||
agent_type="local",
|
|
||||||
user_id=x_user_id,
|
|
||||||
status="running",
|
|
||||||
)
|
|
||||||
db.add(run_log)
|
|
||||||
await db.commit()
|
|
||||||
await db.refresh(run_log)
|
|
||||||
run_log_id = run_log.id
|
|
||||||
|
|
||||||
run_context = {
|
|
||||||
"type": "agent_batch",
|
|
||||||
"run_id": run_log_id,
|
|
||||||
"agent_id": stable_agent_id,
|
|
||||||
}
|
|
||||||
|
|
||||||
# Dispatch to the Redis consumer for processing
|
|
||||||
trigger_data = {
|
|
||||||
"type": "agent_trigger",
|
|
||||||
"directory": body.get("directory", ""),
|
|
||||||
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
|
|
||||||
"data_types": _to_data_types(body.get("what_to_extract", [])),
|
|
||||||
"file_extensions": body.get("file_extensions", []),
|
|
||||||
"prompt_template": body.get("custom_agent_prompt", ""),
|
|
||||||
"device_id": body.get("device_id", ""),
|
|
||||||
"run_context": run_context,
|
|
||||||
}
|
|
||||||
|
|
||||||
channel = batch_request_channel(x_user_id)
|
|
||||||
await redis_client.publish(channel, json.dumps(trigger_data))
|
|
||||||
|
|
||||||
return {
|
|
||||||
"id": run_log_id,
|
|
||||||
"agent_id": stable_agent_id,
|
|
||||||
"agent_type": "local",
|
|
||||||
"status": "running",
|
|
||||||
"items_processed": 0,
|
|
||||||
"items_created": 0,
|
|
||||||
"errors": [],
|
|
||||||
"started_at": _dt_ms(run_log.started_at),
|
|
||||||
"completed_at": None,
|
|
||||||
}
|
|
||||||
@@ -1,336 +0,0 @@
|
|||||||
"""Langfuse tracing & prompt management for the Batch Agent Service (v4 SDK).
|
|
||||||
|
|
||||||
Provides:
|
|
||||||
- ``init_langfuse()`` — initialise the singleton client at startup
|
|
||||||
- ``trace_span()`` — context manager that creates a trace + span
|
|
||||||
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
|
|
||||||
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
|
|
||||||
- ``flush()`` / ``shutdown()`` — lifecycle management
|
|
||||||
|
|
||||||
All functions gracefully degrade to no-ops when Langfuse is not configured,
|
|
||||||
so the service works identically with or without observability keys.
|
|
||||||
|
|
||||||
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ── State ────────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_initialised: bool = False
|
|
||||||
_disabled: bool = False
|
|
||||||
|
|
||||||
|
|
||||||
def _is_configured() -> bool:
|
|
||||||
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
|
|
||||||
|
|
||||||
|
|
||||||
def init_langfuse() -> None:
|
|
||||||
"""Initialise the Langfuse singleton. Call once at startup."""
|
|
||||||
global _initialised, _disabled
|
|
||||||
|
|
||||||
if _initialised or _disabled:
|
|
||||||
return
|
|
||||||
|
|
||||||
if not _is_configured():
|
|
||||||
_disabled = True
|
|
||||||
logger.info("tracing: Langfuse keys not set — tracing disabled")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import Langfuse
|
|
||||||
|
|
||||||
Langfuse(
|
|
||||||
secret_key=settings.LANGFUSE_SECRET_KEY,
|
|
||||||
public_key=settings.LANGFUSE_PUBLIC_KEY,
|
|
||||||
host=settings.LANGFUSE_HOST,
|
|
||||||
)
|
|
||||||
_initialised = True
|
|
||||||
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
|
|
||||||
except Exception as exc:
|
|
||||||
_disabled = True
|
|
||||||
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_client() -> Any | None:
|
|
||||||
"""Return the singleton Langfuse client, or *None* if disabled."""
|
|
||||||
if _disabled:
|
|
||||||
return None
|
|
||||||
if not _initialised:
|
|
||||||
init_langfuse()
|
|
||||||
if _disabled:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
from langfuse import get_client
|
|
||||||
return get_client()
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
class _NullSpan:
|
|
||||||
"""Drop-in replacement when Langfuse is disabled."""
|
|
||||||
|
|
||||||
def update(self, **_: Any) -> None: ...
|
|
||||||
def set_trace_io(self, **_: Any) -> None: ...
|
|
||||||
def score_trace(self, **_: Any) -> None: ...
|
|
||||||
|
|
||||||
|
|
||||||
# ── Trace context manager ───────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def trace_span(
|
|
||||||
*,
|
|
||||||
name: str,
|
|
||||||
user_id: str,
|
|
||||||
session_id: str | None = None,
|
|
||||||
trace_id: str | None = None,
|
|
||||||
input: Any = None,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
tags: list[str] | None = None,
|
|
||||||
):
|
|
||||||
"""Context manager that creates a Langfuse trace/span.
|
|
||||||
|
|
||||||
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
|
|
||||||
A ``CallbackHandler`` created inside this block auto-inherits the trace
|
|
||||||
context, so there is no need to pass trace IDs manually.
|
|
||||||
"""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None:
|
|
||||||
yield _NullSpan()
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import Langfuse, propagate_attributes
|
|
||||||
|
|
||||||
trace_ctx: dict[str, str] = {}
|
|
||||||
if trace_id is not None:
|
|
||||||
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
|
|
||||||
|
|
||||||
with lf.start_as_current_observation(
|
|
||||||
as_type="span",
|
|
||||||
name=name,
|
|
||||||
input=input,
|
|
||||||
metadata=metadata or {},
|
|
||||||
**({"trace_context": trace_ctx} if trace_ctx else {}),
|
|
||||||
) as span:
|
|
||||||
with propagate_attributes(
|
|
||||||
user_id=user_id,
|
|
||||||
session_id=session_id,
|
|
||||||
tags=tags or [],
|
|
||||||
):
|
|
||||||
yield span
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
|
|
||||||
yield _NullSpan()
|
|
||||||
|
|
||||||
|
|
||||||
# ── LangChain callback handler ──────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def get_langfuse_callback() -> Any | None:
|
|
||||||
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
|
|
||||||
|
|
||||||
Must be called inside a ``trace_span()`` block for proper linking.
|
|
||||||
Returns *None* when Langfuse is disabled.
|
|
||||||
"""
|
|
||||||
if _disabled and not _initialised:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse.langchain import CallbackHandler
|
|
||||||
return CallbackHandler()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# ── Prompt management ────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt(
|
|
||||||
name: str,
|
|
||||||
*,
|
|
||||||
version: int | None = None,
|
|
||||||
label: str | None = None,
|
|
||||||
fallback: str | None = None,
|
|
||||||
cache_ttl_seconds: int = 300,
|
|
||||||
) -> str | None:
|
|
||||||
"""Fetch a managed prompt from Langfuse by name (without variable compilation).
|
|
||||||
|
|
||||||
Returns the raw prompt string, or *fallback* if the prompt is not
|
|
||||||
found or Langfuse is disabled.
|
|
||||||
"""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None:
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
try:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"name": name,
|
|
||||||
"cache_ttl_seconds": cache_ttl_seconds,
|
|
||||||
}
|
|
||||||
if version is not None:
|
|
||||||
kwargs["version"] = version
|
|
||||||
if label is not None:
|
|
||||||
kwargs["label"] = label
|
|
||||||
prompt = lf.get_prompt(**kwargs)
|
|
||||||
return prompt.prompt
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
|
|
||||||
return fallback
|
|
||||||
|
|
||||||
|
|
||||||
def compile_prompt(
|
|
||||||
name: str,
|
|
||||||
*,
|
|
||||||
fallback: str,
|
|
||||||
variables: dict[str, str],
|
|
||||||
version: int | None = None,
|
|
||||||
label: str | None = None,
|
|
||||||
cache_ttl_seconds: int = 300,
|
|
||||||
) -> str:
|
|
||||||
"""Fetch a managed prompt from Langfuse and compile it with ``{{variables}}``.
|
|
||||||
|
|
||||||
If the prompt exists in Langfuse, uses the SDK's ``.compile(**variables)``
|
|
||||||
which replaces ``{{key}}`` placeholders. If Langfuse is disabled or the
|
|
||||||
prompt is not found, falls back to ``fallback.format(**variables)`` (Python
|
|
||||||
``{key}`` placeholders).
|
|
||||||
|
|
||||||
This means:
|
|
||||||
- Langfuse prompts use ``{{variable}}`` syntax.
|
|
||||||
- Hardcoded fallback strings use Python ``{variable}`` syntax.
|
|
||||||
"""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None:
|
|
||||||
return fallback.format(**variables)
|
|
||||||
|
|
||||||
try:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"name": name,
|
|
||||||
"cache_ttl_seconds": cache_ttl_seconds,
|
|
||||||
}
|
|
||||||
if version is not None:
|
|
||||||
kwargs["version"] = version
|
|
||||||
if label is not None:
|
|
||||||
kwargs["label"] = label
|
|
||||||
prompt = lf.get_prompt(**kwargs)
|
|
||||||
return prompt.compile(**variables)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: compile_prompt(%s) failed, using fallback: %s", name, exc)
|
|
||||||
return fallback.format(**variables)
|
|
||||||
|
|
||||||
|
|
||||||
def get_prompt_object(
|
|
||||||
name: str,
|
|
||||||
*,
|
|
||||||
version: int | None = None,
|
|
||||||
label: str | None = None,
|
|
||||||
cache_ttl_seconds: int = 300,
|
|
||||||
) -> Any | None:
|
|
||||||
"""Fetch the raw Langfuse prompt *object* (not the compiled string).
|
|
||||||
|
|
||||||
Returns ``None`` when Langfuse is disabled or the prompt is not found.
|
|
||||||
Use this when you need to pass the prompt to ``start_observation(prompt=...)``
|
|
||||||
for linking the prompt to a trace in the Langfuse UI.
|
|
||||||
"""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
kwargs: dict[str, Any] = {
|
|
||||||
"name": name,
|
|
||||||
"cache_ttl_seconds": cache_ttl_seconds,
|
|
||||||
}
|
|
||||||
if version is not None:
|
|
||||||
kwargs["version"] = version
|
|
||||||
if label is not None:
|
|
||||||
kwargs["label"] = label
|
|
||||||
return lf.get_prompt(**kwargs)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: get_prompt_object(%s) failed: %s", name, exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def link_prompt_to_trace(
|
|
||||||
span: Any,
|
|
||||||
prompt_name: str,
|
|
||||||
*,
|
|
||||||
version: int | None = None,
|
|
||||||
label: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Link a Langfuse managed prompt to a span/observation.
|
|
||||||
|
|
||||||
Uses the SDK v4 ``prompt=`` parameter so that the prompt version
|
|
||||||
appears linked in the Langfuse UI with metrics tracking.
|
|
||||||
"""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None or isinstance(span, _NullSpan):
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
prompt = get_prompt_object(prompt_name, version=version, label=label)
|
|
||||||
if prompt is not None:
|
|
||||||
span.update(prompt=prompt)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Scoring helper ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def score_trace(
|
|
||||||
trace_id: str,
|
|
||||||
name: str,
|
|
||||||
value: float,
|
|
||||||
*,
|
|
||||||
comment: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: score_trace failed: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Shutdown ─────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def flush() -> None:
|
|
||||||
"""Flush pending Langfuse events."""
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is not None:
|
|
||||||
try:
|
|
||||||
lf.flush()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: flush failed: %s", exc)
|
|
||||||
|
|
||||||
|
|
||||||
def shutdown() -> None:
|
|
||||||
"""Flush and close the Langfuse client."""
|
|
||||||
global _initialised, _disabled
|
|
||||||
lf = _get_client()
|
|
||||||
if lf is not None:
|
|
||||||
try:
|
|
||||||
lf.flush()
|
|
||||||
lf.shutdown()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("tracing: shutdown failed: %s", exc)
|
|
||||||
_initialised = False
|
|
||||||
_disabled = False
|
|
||||||
@@ -1 +0,0 @@
|
|||||||
"""Batch Agent E2E evaluation harness."""
|
|
||||||
@@ -1,5 +0,0 @@
|
|||||||
"""Allow running the eval package as ``python -m eval``."""
|
|
||||||
|
|
||||||
from eval.cli import main
|
|
||||||
|
|
||||||
main()
|
|
||||||
@@ -1,285 +0,0 @@
|
|||||||
"""CLI entry point for the batch agent evaluation harness.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
# From services/batch-agent/:
|
|
||||||
python -m eval run # all agent fixtures, default model
|
|
||||||
python -m eval run --fixture=classify-invoices # single fixture
|
|
||||||
python -m eval run --models=gpt-4o,gpt-5.3-codex # multiple models
|
|
||||||
python -m eval run --mode=step1 # only step1 fixtures
|
|
||||||
python -m eval run --no-judge # skip LLM judge scoring
|
|
||||||
|
|
||||||
python -m eval interactive # interactive journey session
|
|
||||||
python -m eval interactive --fixture=journey-invoice-setup
|
|
||||||
python -m eval interactive --model=gpt-4o
|
|
||||||
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
|
|
||||||
|
|
||||||
python -m eval list # list all fixtures
|
|
||||||
python -m eval sync # sync fixtures to Langfuse datasets
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import asyncio
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
# Ensure the service root and repo root are in sys.path.
|
|
||||||
# Service root must come BEFORE repo root so its ``app/`` package
|
|
||||||
# shadows the monolith ``app/`` in the repo root.
|
|
||||||
_SERVICE_ROOT = Path(__file__).resolve().parent.parent
|
|
||||||
_REPO_ROOT = _SERVICE_ROOT.parent.parent
|
|
||||||
_sr = str(_SERVICE_ROOT)
|
|
||||||
_rr = str(_REPO_ROOT)
|
|
||||||
if _rr not in sys.path:
|
|
||||||
sys.path.insert(0, _rr)
|
|
||||||
# Always force service root to position 0 (python -m may have already
|
|
||||||
# added CWD further down the list, which loses to repo root).
|
|
||||||
if _sr in sys.path:
|
|
||||||
sys.path.remove(_sr)
|
|
||||||
sys.path.insert(0, _sr)
|
|
||||||
|
|
||||||
from eval.config import discover_fixtures, discover_journey_fixtures
|
|
||||||
from eval.runner import run_fixture_eval, print_results
|
|
||||||
from eval.interactive import run_interactive
|
|
||||||
from eval import langfuse_eval
|
|
||||||
|
|
||||||
|
|
||||||
def _setup_logging(verbose: bool) -> None:
|
|
||||||
level = logging.DEBUG if verbose else logging.INFO
|
|
||||||
logging.basicConfig(
|
|
||||||
level=level,
|
|
||||||
format="%(asctime)s %(name)-20s %(levelname)-5s %(message)s",
|
|
||||||
datefmt="%H:%M:%S",
|
|
||||||
)
|
|
||||||
# Quiet noisy libraries
|
|
||||||
for name in ("httpx", "httpcore", "openai", "litellm", "urllib3"):
|
|
||||||
logging.getLogger(name).setLevel(logging.WARNING)
|
|
||||||
|
|
||||||
|
|
||||||
def _parse_args() -> argparse.Namespace:
|
|
||||||
parser = argparse.ArgumentParser(
|
|
||||||
description="Batch Agent E2E evaluation harness",
|
|
||||||
prog="python -m eval",
|
|
||||||
)
|
|
||||||
sub = parser.add_subparsers(dest="command", required=True)
|
|
||||||
|
|
||||||
# ── run ───────────────────────────────────────────────────────
|
|
||||||
run_cmd = sub.add_parser("run", help="Run evaluations")
|
|
||||||
run_cmd.add_argument(
|
|
||||||
"--fixture", "-f",
|
|
||||||
help="Run only the named fixture (default: all)",
|
|
||||||
)
|
|
||||||
run_cmd.add_argument(
|
|
||||||
"--models", "-m",
|
|
||||||
default="github_copilot/gpt-5.3-codex",
|
|
||||||
help="Comma-separated list of models to test (default: github_copilot/gpt-5.3-codex)",
|
|
||||||
)
|
|
||||||
run_cmd.add_argument(
|
|
||||||
"--mode",
|
|
||||||
default=None,
|
|
||||||
choices=["step1", "step2", "full"],
|
|
||||||
help="Only run fixtures with this mode (default: all)",
|
|
||||||
)
|
|
||||||
run_cmd.add_argument(
|
|
||||||
"--no-judge",
|
|
||||||
action="store_true",
|
|
||||||
help="Skip LLM-as-judge scoring",
|
|
||||||
)
|
|
||||||
run_cmd.add_argument(
|
|
||||||
"--judge-model",
|
|
||||||
default="gpt-4o",
|
|
||||||
help="Model for LLM judge (default: gpt-4o)",
|
|
||||||
)
|
|
||||||
run_cmd.add_argument(
|
|
||||||
"--fixtures-dir",
|
|
||||||
default=None,
|
|
||||||
help="Path to fixtures directory (default: eval/fixtures/)",
|
|
||||||
)
|
|
||||||
run_cmd.add_argument("-v", "--verbose", action="store_true")
|
|
||||||
|
|
||||||
# ── list ──────────────────────────────────────────────────────
|
|
||||||
list_cmd = sub.add_parser("list", help="List available fixtures")
|
|
||||||
list_cmd.add_argument("--fixtures-dir", default=None)
|
|
||||||
list_cmd.add_argument("-v", "--verbose", action="store_true")
|
|
||||||
|
|
||||||
# ── sync ──────────────────────────────────────────────────────
|
|
||||||
sync_cmd = sub.add_parser("sync", help="Sync fixtures to Langfuse datasets")
|
|
||||||
sync_cmd.add_argument("--fixture", "-f", default=None, help="Sync only the named fixture")
|
|
||||||
sync_cmd.add_argument("--fixtures-dir", default=None)
|
|
||||||
sync_cmd.add_argument("-v", "--verbose", action="store_true")
|
|
||||||
|
|
||||||
# ── interactive ───────────────────────────────────────────────
|
|
||||||
inter_cmd = sub.add_parser("interactive", help="Interactive journey session (human-in-the-loop)")
|
|
||||||
inter_cmd.add_argument(
|
|
||||||
"--fixture", "-f",
|
|
||||||
help="Journey fixture to use (default: pick interactively)",
|
|
||||||
)
|
|
||||||
inter_cmd.add_argument(
|
|
||||||
"--model", "-m",
|
|
||||||
default="github_copilot/gpt-5.3-codex",
|
|
||||||
help="Model for the journey AI (default: github_copilot/gpt-5.3-codex)",
|
|
||||||
)
|
|
||||||
inter_cmd.add_argument(
|
|
||||||
"--judge-model",
|
|
||||||
default="gpt-4o",
|
|
||||||
help="Model for LLM judge (default: gpt-4o)",
|
|
||||||
)
|
|
||||||
inter_cmd.add_argument(
|
|
||||||
"--fixtures-dir",
|
|
||||||
default=None,
|
|
||||||
help="Path to fixtures directory (default: eval/fixtures/)",
|
|
||||||
)
|
|
||||||
inter_cmd.add_argument(
|
|
||||||
"--data-dir",
|
|
||||||
default=None,
|
|
||||||
help="Override sample data directory (e.g. path to private test files not in git)",
|
|
||||||
)
|
|
||||||
inter_cmd.add_argument("-v", "--verbose", action="store_true")
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def _fixtures_dir(arg: str | None) -> Path | None:
|
|
||||||
if arg:
|
|
||||||
return Path(arg)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
async def _cmd_run(args: argparse.Namespace) -> None:
|
|
||||||
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
|
||||||
if not fixtures:
|
|
||||||
print("No fixtures found. Create YAML files in eval/fixtures/.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.fixture:
|
|
||||||
fixtures = [f for f in fixtures if f.name == args.fixture]
|
|
||||||
if not fixtures:
|
|
||||||
print(f"Fixture '{args.fixture}' not found.")
|
|
||||||
return
|
|
||||||
|
|
||||||
models = [m.strip() for m in args.models.split(",")]
|
|
||||||
|
|
||||||
all_results = []
|
|
||||||
for fixture in fixtures:
|
|
||||||
if args.mode and fixture.mode != args.mode:
|
|
||||||
continue
|
|
||||||
results = await run_fixture_eval(
|
|
||||||
fixture,
|
|
||||||
models=models,
|
|
||||||
use_llm_judge=not args.no_judge,
|
|
||||||
judge_model=args.judge_model,
|
|
||||||
)
|
|
||||||
all_results.extend(results)
|
|
||||||
|
|
||||||
print_results(all_results)
|
|
||||||
|
|
||||||
|
|
||||||
def _cmd_list(args: argparse.Namespace) -> None:
|
|
||||||
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
|
||||||
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
|
||||||
|
|
||||||
if not fixtures and not journey_fixtures:
|
|
||||||
print("No fixtures found.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if fixtures:
|
|
||||||
print(f"\n{'[Agent Fixtures]'}")
|
|
||||||
print(f"{'Name':<30} {'Mode':<6} {'Types':<25} {'Expected'}")
|
|
||||||
print("-" * 90)
|
|
||||||
for f in fixtures:
|
|
||||||
types = ", ".join(f.data_types)
|
|
||||||
n_expected = len(f.expected) + len(f.expected_classification)
|
|
||||||
print(f"{f.name:<30} {f.mode:<6} {types:<25} {n_expected}")
|
|
||||||
|
|
||||||
if journey_fixtures:
|
|
||||||
print(f"\n{'[Journey Fixtures]'}")
|
|
||||||
print(f"{'Name':<30} {'Types':<25} {'Messages':<10} {'Criteria'}")
|
|
||||||
print("-" * 90)
|
|
||||||
for f in journey_fixtures:
|
|
||||||
types = ", ".join(f.data_types)
|
|
||||||
print(f"{f.name:<30} {types:<25} {len(f.user_messages):<10} {len(f.expected_template_criteria)}")
|
|
||||||
|
|
||||||
print()
|
|
||||||
|
|
||||||
|
|
||||||
def _cmd_sync(args: argparse.Namespace) -> None:
|
|
||||||
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
|
|
||||||
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
|
||||||
|
|
||||||
if args.fixture:
|
|
||||||
fixtures = [f for f in fixtures if f.name == args.fixture]
|
|
||||||
journey_fixtures = [f for f in journey_fixtures if f.name == args.fixture]
|
|
||||||
|
|
||||||
if not fixtures and not journey_fixtures:
|
|
||||||
print("No fixtures to sync.")
|
|
||||||
return
|
|
||||||
|
|
||||||
for fixture in fixtures:
|
|
||||||
name = langfuse_eval.sync_fixture_to_dataset(fixture)
|
|
||||||
if name:
|
|
||||||
print(f"Synced: {fixture.name} → {name}")
|
|
||||||
else:
|
|
||||||
print(f"Skipped: {fixture.name} (Langfuse not configured)")
|
|
||||||
|
|
||||||
for fixture in journey_fixtures:
|
|
||||||
name = langfuse_eval.sync_journey_fixture_to_dataset(fixture)
|
|
||||||
if name:
|
|
||||||
print(f"Synced: {fixture.name} → {name}")
|
|
||||||
else:
|
|
||||||
print(f"Skipped: {fixture.name} (Langfuse not configured)")
|
|
||||||
|
|
||||||
|
|
||||||
async def _cmd_interactive(args: argparse.Namespace) -> None:
|
|
||||||
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
|
|
||||||
if not journey_fixtures:
|
|
||||||
print("No journey fixtures found. Create YAML files with type: journey in eval/fixtures/.")
|
|
||||||
return
|
|
||||||
|
|
||||||
if args.fixture:
|
|
||||||
fixtures = [f for f in journey_fixtures if f.name == args.fixture]
|
|
||||||
if not fixtures:
|
|
||||||
print(f"Journey fixture '{args.fixture}' not found.")
|
|
||||||
return
|
|
||||||
fixture = fixtures[0]
|
|
||||||
elif len(journey_fixtures) == 1:
|
|
||||||
fixture = journey_fixtures[0]
|
|
||||||
else:
|
|
||||||
# Let user pick
|
|
||||||
print("\nAvailable journey fixtures:")
|
|
||||||
for i, f in enumerate(journey_fixtures, 1):
|
|
||||||
print(f" {i}. {f.name} — {f.description[:60]}")
|
|
||||||
print()
|
|
||||||
try:
|
|
||||||
choice = int(input("Pick a fixture number: ").strip()) - 1
|
|
||||||
fixture = journey_fixtures[choice]
|
|
||||||
except (ValueError, IndexError, EOFError, KeyboardInterrupt):
|
|
||||||
print("Invalid choice.")
|
|
||||||
return
|
|
||||||
|
|
||||||
await run_interactive(
|
|
||||||
fixture,
|
|
||||||
model=args.model,
|
|
||||||
judge_model=args.judge_model,
|
|
||||||
data_dir=Path(args.data_dir).resolve() if args.data_dir else None,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> None:
|
|
||||||
args = _parse_args()
|
|
||||||
_setup_logging(args.verbose)
|
|
||||||
|
|
||||||
if args.command == "run":
|
|
||||||
asyncio.run(_cmd_run(args))
|
|
||||||
elif args.command == "interactive":
|
|
||||||
asyncio.run(_cmd_interactive(args))
|
|
||||||
elif args.command == "list":
|
|
||||||
_cmd_list(args)
|
|
||||||
elif args.command == "sync":
|
|
||||||
_cmd_sync(args)
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
main()
|
|
||||||
@@ -1,220 +0,0 @@
|
|||||||
"""Eval configuration — YAML fixture loader and dataclasses.
|
|
||||||
|
|
||||||
Fixtures come in two families:
|
|
||||||
|
|
||||||
1. **Agent fixtures** — test the batch agent pipeline.
|
|
||||||
Three modes controlled by ``mode``:
|
|
||||||
|
|
||||||
``step1`` — classification prompt only.
|
|
||||||
``step2`` — processing prompt only.
|
|
||||||
``full`` — both steps in sequence.
|
|
||||||
|
|
||||||
2. **Journey fixtures** — test the prompt-template builder conversation
|
|
||||||
(unchanged).
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any, Literal
|
|
||||||
|
|
||||||
import yaml
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
EvalMode = Literal["step1", "step2", "full"]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExpectedRecord:
|
|
||||||
"""A single expected extraction result.
|
|
||||||
|
|
||||||
Only the fields specified are checked — unspecified fields are ignored.
|
|
||||||
"""
|
|
||||||
|
|
||||||
table: str # tasks | notes | timelines | projects
|
|
||||||
fields: dict[str, Any] # field_name → expected_value
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExpectedClassification:
|
|
||||||
"""Expected output of step-1 classification for one file."""
|
|
||||||
|
|
||||||
file: str # relative path to the sample file
|
|
||||||
project_id: str # expected matched project id, or "new"
|
|
||||||
domains: list[str] # expected domain list
|
|
||||||
new_project_name: str | None = None
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EvalFixture:
|
|
||||||
"""A complete test scenario loaded from YAML.
|
|
||||||
|
|
||||||
``mode`` determines which pipeline steps are exercised:
|
|
||||||
|
|
||||||
- **step1**: only ``_classify_file``
|
|
||||||
- **step2**: only the processing LLM + tool loop
|
|
||||||
- **full**: both steps in sequence (``run_local_agent``)
|
|
||||||
"""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
mode: EvalMode
|
|
||||||
directory: str # relative path to sample files
|
|
||||||
data_types: list[str]
|
|
||||||
file_extensions: list[str]
|
|
||||||
models: list[str] # if empty, use CLI default
|
|
||||||
fixture_path: Path = field(default_factory=lambda: Path("."))
|
|
||||||
|
|
||||||
# ── Step-1 inputs (classification) ───────────────────────────
|
|
||||||
domain_definitions: str = ""
|
|
||||||
projects_list: list[dict[str, Any]] = field(default_factory=list)
|
|
||||||
custom_step1_prompt: str = ""
|
|
||||||
|
|
||||||
# ── Step-2 inputs (processing) ───────────────────────────────
|
|
||||||
existing_context: str = ""
|
|
||||||
project_context: str = ""
|
|
||||||
custom_prompt_section: str = ""
|
|
||||||
|
|
||||||
# ── Seed records for mock executor ───────────────────────────
|
|
||||||
seed_records: dict[str, list[dict]] = field(default_factory=dict)
|
|
||||||
|
|
||||||
# ── Expected outputs ─────────────────────────────────────────
|
|
||||||
expected_classification: list[ExpectedClassification] = field(default_factory=list)
|
|
||||||
expected: list[ExpectedRecord] = field(default_factory=list)
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fixture_dir(self) -> Path:
|
|
||||||
"""Absolute path to the sample files directory."""
|
|
||||||
return self.fixture_path.parent / self.directory
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_yaml(cls, path: Path) -> "EvalFixture":
|
|
||||||
"""Load a fixture from a YAML file."""
|
|
||||||
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
mode: EvalMode = raw.get("mode", "full")
|
|
||||||
|
|
||||||
# Parse expected records (step2/full)
|
|
||||||
expected: list[ExpectedRecord] = []
|
|
||||||
for table, records in (raw.get("expected") or {}).items():
|
|
||||||
for rec in records:
|
|
||||||
expected.append(ExpectedRecord(table=table, fields=rec))
|
|
||||||
|
|
||||||
# Parse expected classification (step1/full)
|
|
||||||
expected_classification: list[ExpectedClassification] = []
|
|
||||||
for item in raw.get("expected_classification") or []:
|
|
||||||
expected_classification.append(ExpectedClassification(
|
|
||||||
file=item["file"],
|
|
||||||
project_id=item["project_id"],
|
|
||||||
domains=item.get("domains", []),
|
|
||||||
new_project_name=item.get("new_project_name"),
|
|
||||||
))
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
name=raw["name"],
|
|
||||||
description=raw.get("description", ""),
|
|
||||||
mode=mode,
|
|
||||||
directory=raw.get("directory", "sample_files"),
|
|
||||||
data_types=raw.get("data_types", ["tasks"]),
|
|
||||||
file_extensions=raw.get("file_extensions", []),
|
|
||||||
models=raw.get("models", []),
|
|
||||||
fixture_path=path,
|
|
||||||
# Step-1 inputs
|
|
||||||
domain_definitions=raw.get("domain_definitions", ""),
|
|
||||||
projects_list=raw.get("projects_list", []),
|
|
||||||
# Step-2 inputs
|
|
||||||
existing_context=raw.get("existing_context", ""),
|
|
||||||
project_context=raw.get("project_context", ""),
|
|
||||||
custom_prompt_section=raw.get("custom_prompt_section", ""),
|
|
||||||
# Shared
|
|
||||||
seed_records=raw.get("seed_records", {}),
|
|
||||||
expected_classification=expected_classification,
|
|
||||||
expected=expected,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def discover_fixtures(fixtures_dir: Path | None = None) -> list[EvalFixture]:
|
|
||||||
"""Find and load all YAML fixtures in the fixtures directory."""
|
|
||||||
if fixtures_dir is None:
|
|
||||||
fixtures_dir = Path(__file__).parent / "fixtures"
|
|
||||||
|
|
||||||
fixtures: list[EvalFixture] = []
|
|
||||||
if not fixtures_dir.is_dir():
|
|
||||||
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
|
|
||||||
return fixtures
|
|
||||||
|
|
||||||
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
|
|
||||||
try:
|
|
||||||
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
|
||||||
if raw.get("type") == "journey":
|
|
||||||
continue # Skip journey fixtures
|
|
||||||
fixtures.append(EvalFixture.from_yaml(yaml_path))
|
|
||||||
logger.info("eval: loaded fixture %s from %s", fixtures[-1].name, yaml_path.name)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("eval: failed to load fixture %s: %s", yaml_path.name, exc)
|
|
||||||
|
|
||||||
return fixtures
|
|
||||||
|
|
||||||
|
|
||||||
# ── Journey fixtures ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class JourneyFixture:
|
|
||||||
"""A journey test scenario — tests the prompt_template builder conversation."""
|
|
||||||
|
|
||||||
name: str
|
|
||||||
description: str
|
|
||||||
directory: str # relative path to sample files
|
|
||||||
data_types: list[str]
|
|
||||||
expected_template_criteria: list[str] # what the template should contain/satisfy
|
|
||||||
user_messages: list[str] = field(default_factory=list) # for automated journey runs (unused in interactive mode)
|
|
||||||
models: list[str] = field(default_factory=list)
|
|
||||||
fixture_path: Path = field(default_factory=lambda: Path("."))
|
|
||||||
|
|
||||||
@property
|
|
||||||
def fixture_dir(self) -> Path:
|
|
||||||
"""Absolute path to the sample files directory."""
|
|
||||||
return self.fixture_path.parent / self.directory
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def from_yaml(cls, path: Path) -> "JourneyFixture":
|
|
||||||
"""Load a journey fixture from a YAML file."""
|
|
||||||
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
|
|
||||||
|
|
||||||
return cls(
|
|
||||||
name=raw["name"],
|
|
||||||
description=raw.get("description", ""),
|
|
||||||
directory=raw.get("directory", "sample_files"),
|
|
||||||
data_types=raw.get("data_types", ["tasks"]),
|
|
||||||
user_messages=raw.get("user_messages", []),
|
|
||||||
expected_template_criteria=raw.get("expected_template_criteria", []),
|
|
||||||
models=raw.get("models", []),
|
|
||||||
fixture_path=path,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def discover_journey_fixtures(fixtures_dir: Path | None = None) -> list[JourneyFixture]:
|
|
||||||
"""Find and load all journey YAML fixtures in the fixtures directory."""
|
|
||||||
if fixtures_dir is None:
|
|
||||||
fixtures_dir = Path(__file__).parent / "fixtures"
|
|
||||||
|
|
||||||
fixtures: list[JourneyFixture] = []
|
|
||||||
if not fixtures_dir.is_dir():
|
|
||||||
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
|
|
||||||
return fixtures
|
|
||||||
|
|
||||||
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
|
|
||||||
try:
|
|
||||||
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
|
|
||||||
if raw.get("type") != "journey":
|
|
||||||
continue
|
|
||||||
fixtures.append(JourneyFixture.from_yaml(yaml_path))
|
|
||||||
logger.info("eval: loaded journey fixture %s from %s", fixtures[-1].name, yaml_path.name)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("eval: failed to load journey fixture %s: %s", yaml_path.name, exc)
|
|
||||||
|
|
||||||
return fixtures
|
|
||||||
@@ -1,40 +0,0 @@
|
|||||||
# Fixture: classify-invoices (step1)
|
|
||||||
# Tests _STEP1_SYSTEM_PROMPT — file classification and project matching.
|
|
||||||
# Verifies that the LLM correctly matches files to existing projects
|
|
||||||
# and identifies the right data domains.
|
|
||||||
|
|
||||||
name: classify-invoices
|
|
||||||
mode: step1
|
|
||||||
description: >
|
|
||||||
Test file classification on Italian freelance invoices and meeting notes.
|
|
||||||
Verifies project matching and domain identification.
|
|
||||||
|
|
||||||
directory: sample_files/invoices
|
|
||||||
data_types: [tasks, notes, timelines]
|
|
||||||
file_extensions: [txt, md]
|
|
||||||
|
|
||||||
# ── Step-1 prompt variables ──────────────────────────────────────
|
|
||||||
domain_definitions: |
|
|
||||||
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
|
|
||||||
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
|
|
||||||
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
|
|
||||||
|
|
||||||
projects_list:
|
|
||||||
- id: "proj-web-redesign"
|
|
||||||
name: "Redesign Sito Web Corporate"
|
|
||||||
status: "active"
|
|
||||||
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
|
||||||
- id: "proj-ecommerce"
|
|
||||||
name: "E-Commerce FashionStore"
|
|
||||||
status: "active"
|
|
||||||
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
|
||||||
|
|
||||||
# ── Expected classification results ─────────────────────────────
|
|
||||||
expected_classification:
|
|
||||||
- file: "sample_files/invoices/fattura_042.txt"
|
|
||||||
project_id: "proj-web-redesign"
|
|
||||||
domains: [tasks, notes, timelines]
|
|
||||||
|
|
||||||
- file: "sample_files/invoices/meeting_ecommerce.md"
|
|
||||||
project_id: "proj-ecommerce"
|
|
||||||
domains: [tasks, notes, timelines]
|
|
||||||
@@ -1,108 +0,0 @@
|
|||||||
# Fixture: full-invoices (full)
|
|
||||||
# Tests both _STEP1_SYSTEM_PROMPT and _PROCESSING_SYSTEM_PROMPT in sequence
|
|
||||||
# via run_local_agent(). Verifies end-to-end classification + extraction.
|
|
||||||
|
|
||||||
name: full-invoices
|
|
||||||
mode: full
|
|
||||||
description: >
|
|
||||||
End-to-end test: classify Italian invoices/meeting notes into the
|
|
||||||
correct project, then extract tasks, notes, and timeline events.
|
|
||||||
|
|
||||||
directory: sample_files/invoices
|
|
||||||
data_types: [tasks, notes, timelines]
|
|
||||||
file_extensions: [txt, md]
|
|
||||||
|
|
||||||
# ── Step-1 prompt variables ──────────────────────────────────────
|
|
||||||
domain_definitions: |
|
|
||||||
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
|
|
||||||
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
|
|
||||||
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
|
|
||||||
|
|
||||||
projects_list:
|
|
||||||
- id: "proj-web-redesign"
|
|
||||||
name: "Redesign Sito Web Corporate"
|
|
||||||
status: "active"
|
|
||||||
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
|
||||||
- id: "proj-ecommerce"
|
|
||||||
name: "E-Commerce FashionStore"
|
|
||||||
status: "active"
|
|
||||||
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
|
||||||
|
|
||||||
# ── Step-2 prompt variables ──────────────────────────────────────
|
|
||||||
existing_context: |
|
|
||||||
Existing tasks:
|
|
||||||
(none)
|
|
||||||
|
|
||||||
Existing notes:
|
|
||||||
(none)
|
|
||||||
|
|
||||||
Existing timelines:
|
|
||||||
(none)
|
|
||||||
|
|
||||||
project_context: ""
|
|
||||||
|
|
||||||
custom_prompt_section: |
|
|
||||||
User instructions:
|
|
||||||
Estrai i dati dai file come segue:
|
|
||||||
- TASK: ogni azione da fare, deliverable, o item con scadenza.
|
|
||||||
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
|
|
||||||
Mappa "media priorità" → priority: medium.
|
|
||||||
Mappa "bassa priorità" → priority: low.
|
|
||||||
Se un item è marcato come "completato" o [x], impostalo status: done.
|
|
||||||
Altrimenti status: todo.
|
|
||||||
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
|
|
||||||
- TIMELINE: date di scadenza, milestone, meeting futuri.
|
|
||||||
Imposta sempre isAiSuggested=1.
|
|
||||||
|
|
||||||
# ── Seed records (pre-existing DB state) ─────────────────────────
|
|
||||||
seed_records:
|
|
||||||
projects:
|
|
||||||
- id: "proj-web-redesign"
|
|
||||||
name: "Redesign Sito Web Corporate"
|
|
||||||
status: "active"
|
|
||||||
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
|
|
||||||
- id: "proj-ecommerce"
|
|
||||||
name: "E-Commerce FashionStore"
|
|
||||||
status: "active"
|
|
||||||
aiSummary: "Next.js e-commerce platform for FashionStore srl"
|
|
||||||
tasks: []
|
|
||||||
notes: []
|
|
||||||
timelines: []
|
|
||||||
|
|
||||||
# ── Expected classification (step 1) ─────────────────────────────
|
|
||||||
expected_classification:
|
|
||||||
- file: "sample_files/invoices/fattura_042.txt"
|
|
||||||
project_id: "proj-web-redesign"
|
|
||||||
domains: [tasks, notes, timelines]
|
|
||||||
|
|
||||||
- file: "sample_files/invoices/meeting_ecommerce.md"
|
|
||||||
project_id: "proj-ecommerce"
|
|
||||||
domains: [tasks, notes, timelines]
|
|
||||||
|
|
||||||
# ── Expected extractions (step 2) ────────────────────────────────
|
|
||||||
expected:
|
|
||||||
tasks:
|
|
||||||
- title: "Sviluppo frontend React"
|
|
||||||
priority: "high"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Integrazione API backend"
|
|
||||||
priority: "medium"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Testing cross-browser e fix bug responsive"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Preparare wireframe homepage"
|
|
||||||
priority: "high"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Setup progetto Next.js e configurare CI/CD"
|
|
||||||
priority: "medium"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Ricerca plugin Stripe per gestione abbonamenti"
|
|
||||||
priority: "low"
|
|
||||||
status: "todo"
|
|
||||||
|
|
||||||
notes:
|
|
||||||
- title: "Meeting Kickoff Progetto E-Commerce"
|
|
||||||
|
|
||||||
timelines:
|
|
||||||
- title: "MVP E-Commerce pronto"
|
|
||||||
- title: "Meeting di revisione"
|
|
||||||
@@ -1,28 +0,0 @@
|
|||||||
# Journey Fixture: journey-invoice-setup
|
|
||||||
# Used by `python -m eval interactive` for human-in-the-loop testing
|
|
||||||
# of the journey chatbot's prompt-building conversation.
|
|
||||||
|
|
||||||
type: journey
|
|
||||||
name: journey-invoice-setup
|
|
||||||
description: >
|
|
||||||
Interactive test for the journey chatbot — explore a directory of
|
|
||||||
Italian invoices and meeting notes, answer the chatbot's questions,
|
|
||||||
and verify it produces a well-structured prompt_template for data
|
|
||||||
extraction.
|
|
||||||
|
|
||||||
directory: sample_files/invoices
|
|
||||||
data_types: [tasks, notes, timelines, projects]
|
|
||||||
|
|
||||||
# Criteria the generated prompt_template must satisfy
|
|
||||||
# Each is scored 0-1 by an LLM judge
|
|
||||||
expected_template_criteria:
|
|
||||||
- "Mentions creating tasks from action items and work descriptions"
|
|
||||||
- "Mentions creating notes from meeting summaries"
|
|
||||||
- "Mentions extracting timeline events from deadlines and meeting dates"
|
|
||||||
- "Mentions creating projects from relevant information"
|
|
||||||
- "Sets isAiSuggested=1 on all created records"
|
|
||||||
- "Does NOT include projectId assignment logic"
|
|
||||||
- "Uses camelCase field names (title, status, priority, dueDate, content)"
|
|
||||||
|
|
||||||
# Models to test (empty = use CLI --models default)
|
|
||||||
models: []
|
|
||||||
@@ -1,81 +0,0 @@
|
|||||||
# Fixture: process-invoices (step2)
|
|
||||||
# Tests _PROCESSING_SYSTEM_PROMPT — data extraction & tool calling.
|
|
||||||
# The classification step is skipped; prompt variables are injected directly.
|
|
||||||
|
|
||||||
name: process-invoices
|
|
||||||
mode: step2
|
|
||||||
description: >
|
|
||||||
Test data extraction from Italian freelance invoices.
|
|
||||||
Verifies correct record creation via tool calls with the right
|
|
||||||
fields, priorities, and status values.
|
|
||||||
|
|
||||||
directory: sample_files/invoices
|
|
||||||
data_types: [tasks, notes, timelines]
|
|
||||||
file_extensions: [txt, md]
|
|
||||||
|
|
||||||
# ── Step-2 prompt variables ──────────────────────────────────────
|
|
||||||
existing_context: |
|
|
||||||
Existing tasks:
|
|
||||||
(none)
|
|
||||||
|
|
||||||
Existing notes:
|
|
||||||
(none)
|
|
||||||
|
|
||||||
Existing timelines:
|
|
||||||
(none)
|
|
||||||
|
|
||||||
project_context: >
|
|
||||||
Project: Redesign Sito Web Corporate (id: proj-web-redesign).
|
|
||||||
Always set projectId to this id on every record you create.
|
|
||||||
|
|
||||||
custom_prompt_section: |
|
|
||||||
User instructions:
|
|
||||||
Estrai i dati dai file come segue:
|
|
||||||
- TASK: ogni azione da fare, deliverable, o item con scadenza.
|
|
||||||
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
|
|
||||||
Mappa "media priorità" → priority: medium.
|
|
||||||
Mappa "bassa priorità" → priority: low.
|
|
||||||
Se un item è marcato come "completato" o [x], impostalo status: done.
|
|
||||||
Altrimenti status: todo.
|
|
||||||
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
|
|
||||||
Il titolo deve essere descrittivo. Il content deve includere tutti i dettagli.
|
|
||||||
- TIMELINE: date di scadenza, milestone, meeting futuri.
|
|
||||||
Imposta sempre isAiSuggested=1.
|
|
||||||
|
|
||||||
# ── Seed records (pre-existing DB state) ─────────────────────────
|
|
||||||
seed_records:
|
|
||||||
projects:
|
|
||||||
- id: "proj-web-redesign"
|
|
||||||
name: "Redesign Sito Web Corporate"
|
|
||||||
status: "active"
|
|
||||||
tasks: []
|
|
||||||
notes: []
|
|
||||||
timelines: []
|
|
||||||
|
|
||||||
# ── Expected extractions ─────────────────────────────────────────
|
|
||||||
expected:
|
|
||||||
tasks:
|
|
||||||
- title: "Sviluppo frontend React"
|
|
||||||
priority: "high"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Integrazione API backend"
|
|
||||||
priority: "medium"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Testing cross-browser e fix bug responsive"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Preparare wireframe homepage"
|
|
||||||
priority: "high"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Setup progetto Next.js e configurare CI/CD"
|
|
||||||
priority: "medium"
|
|
||||||
status: "todo"
|
|
||||||
- title: "Ricerca plugin Stripe per gestione abbonamenti"
|
|
||||||
priority: "low"
|
|
||||||
status: "todo"
|
|
||||||
|
|
||||||
notes:
|
|
||||||
- title: "Meeting Kickoff Progetto E-Commerce"
|
|
||||||
|
|
||||||
timelines:
|
|
||||||
- title: "MVP E-Commerce pronto"
|
|
||||||
- title: "Meeting di revisione"
|
|
||||||
@@ -1,18 +0,0 @@
|
|||||||
FATTURA N. 2026-0042
|
|
||||||
Data: 15 Marzo 2026
|
|
||||||
Cliente: Studio Architettura Bianchi
|
|
||||||
|
|
||||||
Progetto: Redesign Sito Web Corporate
|
|
||||||
|
|
||||||
Descrizione lavori:
|
|
||||||
- Sviluppo frontend React (40 ore) — URGENTE, completare entro 20 marzo
|
|
||||||
- Integrazione API backend (20 ore) — priorità media
|
|
||||||
- Design UI/UX mockup homepage (8 ore) — completato
|
|
||||||
- Testing cross-browser e fix bug responsive (12 ore) — da iniziare
|
|
||||||
|
|
||||||
Totale: €4.800,00 + IVA
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Meeting di revisione previsto per il 18 marzo alle 10:00.
|
|
||||||
Il cliente ha richiesto modifiche al layout mobile della sezione contatti.
|
|
||||||
Attendere conferma budget aggiuntivo per sezione blog.
|
|
||||||
@@ -1,25 +0,0 @@
|
|||||||
# Meeting Notes - Kickoff Progetto E-Commerce
|
|
||||||
|
|
||||||
**Data:** 10 Marzo 2026
|
|
||||||
**Partecipanti:** Marco R., Giulia T., Cliente (FashionStore srl)
|
|
||||||
|
|
||||||
## Decisioni prese
|
|
||||||
|
|
||||||
1. **Piattaforma**: Next.js + Stripe per i pagamenti
|
|
||||||
2. **Timeline**: MVP pronto entro 30 aprile 2026
|
|
||||||
3. **Budget**: €12.000 totale, €4.000 anticipo già ricevuto
|
|
||||||
|
|
||||||
## Action items
|
|
||||||
|
|
||||||
- [ ] Marco: preparare wireframe homepage entro 14 marzo — ALTA PRIORITÀ
|
|
||||||
- [ ] Giulia: setup progetto Next.js e configurare CI/CD — media priorità
|
|
||||||
- [ ] Marco: ricerca plugin Stripe per gestione abbonamenti — bassa priorità
|
|
||||||
- [x] Giulia: inviare contratto firmato al cliente — COMPLETATO
|
|
||||||
|
|
||||||
## Note aggiuntive
|
|
||||||
|
|
||||||
Il cliente vuole un design minimalista, ispirato a Zara.com.
|
|
||||||
Colori primari: nero, bianco, oro.
|
|
||||||
Font: Inter per body, Playfair Display per headings.
|
|
||||||
|
|
||||||
Prossimo meeting: 24 marzo 2026 ore 15:00.
|
|
||||||
@@ -1,471 +0,0 @@
|
|||||||
"""Interactive journey session — human-in-the-loop CLI conversation.
|
|
||||||
|
|
||||||
Flow:
|
|
||||||
1. Show the system prompt used by the journey AI.
|
|
||||||
2. Start the journey (AI explores files, asks first question).
|
|
||||||
3. User types responses in the terminal — AI replies.
|
|
||||||
4. User types `/done` to end the conversation.
|
|
||||||
5. User writes a comment about the interaction quality.
|
|
||||||
6. LLM judge scores the conversation + generated template.
|
|
||||||
7. Results are reported to Langfuse.
|
|
||||||
|
|
||||||
Usage::
|
|
||||||
|
|
||||||
python -m eval interactive # pick a fixture interactively
|
|
||||||
python -m eval interactive --fixture=journey-invoice-setup
|
|
||||||
python -m eval interactive --model=gpt-4o
|
|
||||||
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
from eval.config import JourneyFixture, discover_journey_fixtures
|
|
||||||
from eval.mock_executor import MockExecutor
|
|
||||||
from eval import langfuse_eval
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
# ── Special commands ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_CMD_DONE = "/done"
|
|
||||||
_CMD_QUIT = "/quit"
|
|
||||||
_CMD_TEMPLATE = "/template"
|
|
||||||
_CMD_HELP = "/help"
|
|
||||||
|
|
||||||
_HELP_TEXT = f"""\
|
|
||||||
{_CMD_DONE} — End the conversation and proceed to evaluation
|
|
||||||
{_CMD_QUIT} — Abort without evaluation
|
|
||||||
{_CMD_TEMPLATE} — Show the generated template (if any)
|
|
||||||
{_CMD_HELP} — Show this help"""
|
|
||||||
|
|
||||||
# ── Terminal colours (ANSI) ──────────────────────────────────────────────
|
|
||||||
|
|
||||||
_C_RESET = "\033[0m"
|
|
||||||
_C_BOLD = "\033[1m"
|
|
||||||
_C_DIM = "\033[2m"
|
|
||||||
_C_CYAN = "\033[36m"
|
|
||||||
_C_GREEN = "\033[32m"
|
|
||||||
_C_YELLOW = "\033[33m"
|
|
||||||
_C_MAGENTA = "\033[35m"
|
|
||||||
_C_RED = "\033[31m"
|
|
||||||
_C_BLUE = "\033[34m"
|
|
||||||
|
|
||||||
|
|
||||||
def _print_header(text: str) -> None:
|
|
||||||
print(f"\n{_C_BOLD}{_C_CYAN}{'═' * 80}")
|
|
||||||
print(f" {text}")
|
|
||||||
print(f"{'═' * 80}{_C_RESET}\n")
|
|
||||||
|
|
||||||
|
|
||||||
def _print_ai(text: str) -> None:
|
|
||||||
print(f"\n{_C_GREEN}{_C_BOLD}AI:{_C_RESET} {text}\n")
|
|
||||||
|
|
||||||
|
|
||||||
def _print_system(text: str) -> None:
|
|
||||||
print(f"{_C_DIM}{text}{_C_RESET}")
|
|
||||||
|
|
||||||
|
|
||||||
def _print_score(label: str, score: float) -> None:
|
|
||||||
if score >= 0.7:
|
|
||||||
color = _C_GREEN
|
|
||||||
tag = "PASS"
|
|
||||||
elif score >= 0.4:
|
|
||||||
color = _C_YELLOW
|
|
||||||
tag = "PARTIAL"
|
|
||||||
else:
|
|
||||||
color = _C_RED
|
|
||||||
tag = "FAIL"
|
|
||||||
print(f" {color}{tag:>7}{_C_RESET} ({score:.1f}) {label}")
|
|
||||||
|
|
||||||
|
|
||||||
# ── Result type ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class InteractiveResult:
|
|
||||||
fixture_name: str
|
|
||||||
model: str
|
|
||||||
judge_model: str
|
|
||||||
prompt_template: str | None
|
|
||||||
conversation: list[dict[str, str]]
|
|
||||||
user_comment: str
|
|
||||||
done: bool
|
|
||||||
criteria_scores: dict[str, float]
|
|
||||||
overall_score: float
|
|
||||||
judge_reasoning: str
|
|
||||||
elapsed_seconds: float
|
|
||||||
|
|
||||||
def summary(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"fixture": self.fixture_name,
|
|
||||||
"model": self.model,
|
|
||||||
"judge_model": self.judge_model,
|
|
||||||
"done": self.done,
|
|
||||||
"turns": len([c for c in self.conversation if c["role"] == "user"]),
|
|
||||||
"overall_score": round(self.overall_score, 3),
|
|
||||||
"user_comment": self.user_comment,
|
|
||||||
"criteria_scores": {k: round(v, 3) for k, v in self.criteria_scores.items()},
|
|
||||||
"elapsed_s": round(self.elapsed_seconds, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── LLM judge ────────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_INTERACTIVE_JUDGE_SYSTEM = """\
|
|
||||||
You are an evaluation judge for AI-generated prompt templates produced during
|
|
||||||
an interactive conversation between a human and a journey chatbot.
|
|
||||||
|
|
||||||
The chatbot explored a directory and through multi-turn conversation with the
|
|
||||||
user produced a prompt_template — an instruction set for a data-extraction agent.
|
|
||||||
|
|
||||||
You have access to:
|
|
||||||
- The full conversation transcript
|
|
||||||
- The generated prompt_template (if any)
|
|
||||||
- The user's own comment about the interaction
|
|
||||||
- A list of quality criteria
|
|
||||||
|
|
||||||
Score each criterion from 0 to 1:
|
|
||||||
- 1.0: Fully satisfied
|
|
||||||
- 0.5: Partially satisfied
|
|
||||||
- 0.0: Not satisfied
|
|
||||||
|
|
||||||
Also provide an overall_quality score (0-1) evaluating the conversation flow,
|
|
||||||
how well the AI understood the user, and the template quality.
|
|
||||||
|
|
||||||
Respond with ONLY a JSON object:
|
|
||||||
{
|
|
||||||
"criteria_scores": {"criterion_1": 0.8, ...},
|
|
||||||
"overall_quality": 0.85,
|
|
||||||
"reasoning": "Brief explanation covering both conversation quality and template accuracy"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
async def _judge_interactive(
|
|
||||||
conversation: list[dict[str, str]],
|
|
||||||
prompt_template: str | None,
|
|
||||||
user_comment: str,
|
|
||||||
criteria: list[str],
|
|
||||||
*,
|
|
||||||
judge_model: str = "gpt-4o-mini",
|
|
||||||
) -> tuple[dict[str, float], float, str]:
|
|
||||||
"""Score an interactive session. Returns (criteria_scores, overall_quality, reasoning)."""
|
|
||||||
from shared.llm import get_llm
|
|
||||||
|
|
||||||
llm = get_llm(model=judge_model, temperature=0)
|
|
||||||
|
|
||||||
conv_text = "\n".join(
|
|
||||||
f"{'USER' if t['role'] == 'user' else 'AI'}: {t['content']}"
|
|
||||||
for t in conversation
|
|
||||||
)
|
|
||||||
criteria_text = "\n".join(f" {i+1}. {c}" for i, c in enumerate(criteria))
|
|
||||||
|
|
||||||
user_content = (
|
|
||||||
f"## Conversation transcript\n```\n{conv_text}\n```\n\n"
|
|
||||||
f"## Generated prompt_template\n```\n{prompt_template or '(none — conversation did not complete)'}\n```\n\n"
|
|
||||||
f"## User's comment\n{user_comment}\n\n"
|
|
||||||
f"## Criteria to evaluate\n{criteria_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await llm.ainvoke([
|
|
||||||
SystemMessage(content=_INTERACTIVE_JUDGE_SYSTEM),
|
|
||||||
HumanMessage(content=user_content),
|
|
||||||
])
|
|
||||||
raw = response.content.strip()
|
|
||||||
if raw.startswith("```"):
|
|
||||||
raw = raw.split("```")[1]
|
|
||||||
if raw.startswith("json"):
|
|
||||||
raw = raw[4:]
|
|
||||||
parsed = json.loads(raw.strip())
|
|
||||||
|
|
||||||
scores_raw = parsed.get("criteria_scores", parsed.get("scores", {}))
|
|
||||||
criteria_scores: dict[str, float] = {}
|
|
||||||
for i, criterion in enumerate(criteria):
|
|
||||||
key_candidates = [f"criterion_{i+1}", criterion, criterion[:50], str(i + 1)]
|
|
||||||
score = 0.0
|
|
||||||
for key in key_candidates:
|
|
||||||
if key in scores_raw:
|
|
||||||
score = float(scores_raw[key])
|
|
||||||
break
|
|
||||||
if score == 0.0 and i < len(scores_raw):
|
|
||||||
score = float(list(scores_raw.values())[i])
|
|
||||||
criteria_scores[criterion] = score
|
|
||||||
|
|
||||||
overall = float(parsed.get("overall_quality", 0.0))
|
|
||||||
reasoning = str(parsed.get("reasoning", ""))
|
|
||||||
return criteria_scores, overall, reasoning
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("interactive judge failed: %s", exc)
|
|
||||||
return {c: 0.0 for c in criteria}, 0.0, f"Judge error: {exc}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Interactive session ──────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def run_interactive(
|
|
||||||
fixture: JourneyFixture,
|
|
||||||
*,
|
|
||||||
model: str = "gpt-4o",
|
|
||||||
judge_model: str = "gpt-4o-mini",
|
|
||||||
data_dir: Path | None = None,
|
|
||||||
) -> InteractiveResult:
|
|
||||||
"""Run an interactive journey session in the terminal.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
data_dir :
|
|
||||||
If set, overrides the fixture's sample-file directory. The LLM
|
|
||||||
will explore this folder instead of the default
|
|
||||||
``fixtures/sample_files/…``. Useful for private test data that
|
|
||||||
shouldn't be committed to git.
|
|
||||||
"""
|
|
||||||
from shared.config import settings
|
|
||||||
from shared.ws_context import set_current_user, clear_current_user
|
|
||||||
from app.journey import (
|
|
||||||
handle_journey_start,
|
|
||||||
handle_journey_message,
|
|
||||||
_build_system_prompt,
|
|
||||||
)
|
|
||||||
|
|
||||||
# When --data-dir is given, the MockExecutor's root becomes
|
|
||||||
# data_dir's parent and the journey directory is data_dir's name.
|
|
||||||
# This way the LLM sees a meaningful directory name (not ".") and
|
|
||||||
# MockExecutor resolves paths correctly.
|
|
||||||
# Otherwise, use the fixture's YAML parent and its relative path.
|
|
||||||
if data_dir:
|
|
||||||
mock_root = data_dir.parent
|
|
||||||
journey_directory = data_dir.name
|
|
||||||
else:
|
|
||||||
mock_root = fixture.fixture_path.parent
|
|
||||||
journey_directory = fixture.directory
|
|
||||||
|
|
||||||
mock = MockExecutor(
|
|
||||||
fixture_dir=mock_root,
|
|
||||||
seed_records={},
|
|
||||||
)
|
|
||||||
|
|
||||||
original_model = settings.LLM_MODEL
|
|
||||||
settings.LLM_MODEL = model
|
|
||||||
eval_user_id = f"interactive-{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
# ── Show system prompt ───────────────────────────────────────
|
|
||||||
system_prompt = _build_system_prompt(journey_directory, fixture.data_types)
|
|
||||||
|
|
||||||
_print_header("SYSTEM PROMPT")
|
|
||||||
print(f"{_C_DIM}{system_prompt}{_C_RESET}")
|
|
||||||
|
|
||||||
_print_header(f"INTERACTIVE JOURNEY | fixture: {fixture.name} | model: {model}")
|
|
||||||
print(f" Data dir: {mock_root}")
|
|
||||||
print(f" Type your responses. Commands: {_CMD_DONE}, {_CMD_QUIT}, {_CMD_TEMPLATE}, {_CMD_HELP}")
|
|
||||||
print(f" Judge model: {judge_model}")
|
|
||||||
print(f" Criteria: {len(fixture.expected_template_criteria)}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
conversation: list[dict[str, str]] = []
|
|
||||||
prompt_template: str | None = None
|
|
||||||
done = False
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
try:
|
|
||||||
set_current_user(eval_user_id)
|
|
||||||
|
|
||||||
with mock.patch():
|
|
||||||
# ── Start ────────────────────────────────────────────
|
|
||||||
_print_system("Starting journey... (AI is exploring your files)")
|
|
||||||
|
|
||||||
start_frame: dict[str, Any] = {
|
|
||||||
"agent_type": "local",
|
|
||||||
"directory": journey_directory,
|
|
||||||
"data_types": fixture.data_types,
|
|
||||||
"session_id": f"interactive-{uuid.uuid4().hex[:8]}",
|
|
||||||
}
|
|
||||||
|
|
||||||
reply = await handle_journey_start(eval_user_id, start_frame)
|
|
||||||
session_id = reply["session_id"]
|
|
||||||
conversation.append({"role": "assistant", "content": reply["message"]})
|
|
||||||
_print_ai(reply["message"])
|
|
||||||
|
|
||||||
if reply["done"]:
|
|
||||||
prompt_template = reply.get("prompt_template")
|
|
||||||
done = True
|
|
||||||
_print_system("Journey completed on first reply (template generated).")
|
|
||||||
|
|
||||||
# ── Conversation loop ────────────────────────────────
|
|
||||||
while not done:
|
|
||||||
try:
|
|
||||||
user_input = input(f"{_C_BOLD}{_C_BLUE}YOU:{_C_RESET} ").strip()
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
print()
|
|
||||||
user_input = _CMD_QUIT
|
|
||||||
|
|
||||||
if not user_input:
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Handle commands
|
|
||||||
if user_input.lower() == _CMD_QUIT:
|
|
||||||
_print_system("Aborted — no evaluation will be performed.")
|
|
||||||
settings.LLM_MODEL = original_model
|
|
||||||
clear_current_user()
|
|
||||||
return InteractiveResult(
|
|
||||||
fixture_name=fixture.name, model=model, judge_model=judge_model,
|
|
||||||
prompt_template=None, conversation=conversation,
|
|
||||||
user_comment="(aborted)", done=False,
|
|
||||||
criteria_scores={}, overall_score=0.0,
|
|
||||||
judge_reasoning="Session aborted by user.",
|
|
||||||
elapsed_seconds=time.time() - start_time,
|
|
||||||
)
|
|
||||||
|
|
||||||
if user_input.lower() == _CMD_HELP:
|
|
||||||
print(_HELP_TEXT)
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.lower() == _CMD_TEMPLATE:
|
|
||||||
if prompt_template:
|
|
||||||
print(f"\n{_C_MAGENTA}{prompt_template}{_C_RESET}\n")
|
|
||||||
else:
|
|
||||||
_print_system("No template generated yet.")
|
|
||||||
continue
|
|
||||||
|
|
||||||
if user_input.lower() == _CMD_DONE:
|
|
||||||
_print_system("Ending conversation...")
|
|
||||||
break
|
|
||||||
|
|
||||||
# ── Send message to AI ───────────────────────────
|
|
||||||
conversation.append({"role": "user", "content": user_input})
|
|
||||||
_print_system("AI is thinking...")
|
|
||||||
|
|
||||||
msg_frame: dict[str, Any] = {
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": user_input,
|
|
||||||
}
|
|
||||||
reply = await handle_journey_message(eval_user_id, msg_frame)
|
|
||||||
conversation.append({"role": "assistant", "content": reply["message"]})
|
|
||||||
_print_ai(reply["message"])
|
|
||||||
|
|
||||||
if reply["done"]:
|
|
||||||
prompt_template = reply.get("prompt_template")
|
|
||||||
done = True
|
|
||||||
_print_system("Journey completed — template generated!")
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("interactive journey failed: %s", exc)
|
|
||||||
_print_system(f"Error: {exc}")
|
|
||||||
finally:
|
|
||||||
settings.LLM_MODEL = original_model
|
|
||||||
clear_current_user()
|
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
turns = len([c for c in conversation if c["role"] == "user"])
|
|
||||||
|
|
||||||
# ── Show template if generated ───────────────────────────────
|
|
||||||
if prompt_template:
|
|
||||||
_print_header("GENERATED TEMPLATE")
|
|
||||||
print(f"{_C_MAGENTA}{prompt_template}{_C_RESET}\n")
|
|
||||||
else:
|
|
||||||
_print_system("No template was generated during this session.")
|
|
||||||
|
|
||||||
# ── User comment ─────────────────────────────────────────────
|
|
||||||
_print_header("YOUR EVALUATION")
|
|
||||||
print(" Write your comment about this interaction (press Enter twice to finish):")
|
|
||||||
print()
|
|
||||||
comment_lines: list[str] = []
|
|
||||||
try:
|
|
||||||
while True:
|
|
||||||
line = input()
|
|
||||||
if line == "" and comment_lines and comment_lines[-1] == "":
|
|
||||||
comment_lines.pop() # remove trailing empty
|
|
||||||
break
|
|
||||||
comment_lines.append(line)
|
|
||||||
except (EOFError, KeyboardInterrupt):
|
|
||||||
pass
|
|
||||||
user_comment = "\n".join(comment_lines).strip() or "(no comment)"
|
|
||||||
|
|
||||||
# ── Judge ────────────────────────────────────────────────────
|
|
||||||
_print_header("LLM JUDGE EVALUATION")
|
|
||||||
_print_system(f"Scoring with {judge_model}...")
|
|
||||||
|
|
||||||
criteria_scores, overall_quality, judge_reasoning = await _judge_interactive(
|
|
||||||
conversation=conversation,
|
|
||||||
prompt_template=prompt_template,
|
|
||||||
user_comment=user_comment,
|
|
||||||
criteria=fixture.expected_template_criteria,
|
|
||||||
judge_model=judge_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Display scores ───────────────────────────────────────────
|
|
||||||
print()
|
|
||||||
for criterion, score in criteria_scores.items():
|
|
||||||
_print_score(criterion, score)
|
|
||||||
|
|
||||||
overall = (
|
|
||||||
sum(criteria_scores.values()) / len(criteria_scores)
|
|
||||||
if criteria_scores
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"\n {_C_BOLD}Criteria avg: {overall:.2f}{_C_RESET}")
|
|
||||||
print(f" {_C_BOLD}Overall quality: {overall_quality:.2f}{_C_RESET}")
|
|
||||||
print(f" {_C_BOLD}Turns: {turns}{_C_RESET}")
|
|
||||||
print(f" {_C_BOLD}Time: {elapsed:.1f}s{_C_RESET}")
|
|
||||||
print(f"\n {_C_DIM}Judge: {judge_reasoning}{_C_RESET}")
|
|
||||||
print(f" {_C_DIM}Your comment: {user_comment}{_C_RESET}\n")
|
|
||||||
|
|
||||||
result = InteractiveResult(
|
|
||||||
fixture_name=fixture.name,
|
|
||||||
model=model,
|
|
||||||
judge_model=judge_model,
|
|
||||||
prompt_template=prompt_template,
|
|
||||||
conversation=conversation,
|
|
||||||
user_comment=user_comment,
|
|
||||||
done=done,
|
|
||||||
criteria_scores=criteria_scores,
|
|
||||||
overall_score=overall_quality,
|
|
||||||
judge_reasoning=judge_reasoning,
|
|
||||||
elapsed_seconds=elapsed,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Report to Langfuse ───────────────────────────────────────
|
|
||||||
trace_id = langfuse_eval.log_eval_trace(
|
|
||||||
fixture_name=fixture.name,
|
|
||||||
model=model,
|
|
||||||
prompt_variant="interactive",
|
|
||||||
prompt_template=prompt_template or "(not generated)",
|
|
||||||
actual_mutations=[{
|
|
||||||
"conversation": conversation[:30],
|
|
||||||
"user_comment": user_comment,
|
|
||||||
}],
|
|
||||||
scores_summary=result.summary(),
|
|
||||||
langfuse_prompt_names=["journey_system"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if trace_id:
|
|
||||||
from eval.scorer import EvalScores
|
|
||||||
scores_obj = EvalScores(
|
|
||||||
fixture_name=fixture.name,
|
|
||||||
model=model,
|
|
||||||
prompt_variant="interactive",
|
|
||||||
precision=overall,
|
|
||||||
recall=float(done),
|
|
||||||
f1=overall,
|
|
||||||
llm_judge_score=overall_quality,
|
|
||||||
llm_judge_reasoning=judge_reasoning,
|
|
||||||
)
|
|
||||||
langfuse_eval.post_eval_scores(scores_obj, trace_id=trace_id)
|
|
||||||
_print_system(f"Results reported to Langfuse (trace: {trace_id})")
|
|
||||||
else:
|
|
||||||
_print_system("Langfuse not configured — results not reported.")
|
|
||||||
|
|
||||||
return result
|
|
||||||
@@ -1,385 +0,0 @@
|
|||||||
"""Journey eval runner — tests the prompt_template builder conversation.
|
|
||||||
|
|
||||||
For each (journey_fixture × model) combination:
|
|
||||||
1. Build a MockExecutor (for filesystem tools used during journey)
|
|
||||||
2. Patch execute_on_client
|
|
||||||
3. Override LLM_MODEL
|
|
||||||
4. Call handle_journey_start to kick off the conversation
|
|
||||||
5. Feed simulated user_messages via handle_journey_message
|
|
||||||
6. Collect the generated prompt_template
|
|
||||||
7. Score it against expected_template_criteria (via LLM judge)
|
|
||||||
8. Report to Langfuse
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import asyncio
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
from eval.config import JourneyFixture
|
|
||||||
from eval.mock_executor import MockExecutor
|
|
||||||
from eval import langfuse_eval
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Result type ──────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class JourneyEvalResult:
|
|
||||||
"""Result of one journey eval run."""
|
|
||||||
|
|
||||||
fixture_name: str
|
|
||||||
model: str
|
|
||||||
prompt_template: str | None # the generated template (None if journey failed)
|
|
||||||
conversation_turns: int
|
|
||||||
done: bool # whether journey reached completion
|
|
||||||
criteria_scores: dict[str, float] # criterion → 0-1 score
|
|
||||||
overall_score: float # average of criteria scores
|
|
||||||
judge_reasoning: str
|
|
||||||
elapsed_seconds: float
|
|
||||||
|
|
||||||
def summary(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"fixture": self.fixture_name,
|
|
||||||
"model": self.model,
|
|
||||||
"done": self.done,
|
|
||||||
"turns": self.conversation_turns,
|
|
||||||
"overall_score": round(self.overall_score, 3),
|
|
||||||
"criteria_scores": {k: round(v, 3) for k, v in self.criteria_scores.items()},
|
|
||||||
"elapsed_s": round(self.elapsed_seconds, 1),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── LLM judge for template quality ──────────────────────────────────────
|
|
||||||
|
|
||||||
_JOURNEY_JUDGE_SYSTEM = """\
|
|
||||||
You are an evaluation judge for AI-generated prompt templates.
|
|
||||||
|
|
||||||
A journey chatbot explored a user's directory structure and through
|
|
||||||
conversation produced a prompt_template — an instruction set for a
|
|
||||||
data-extraction agent.
|
|
||||||
|
|
||||||
Your task: evaluate the generated template against a list of criteria.
|
|
||||||
Score each criterion from 0 to 1:
|
|
||||||
- 1.0: Fully satisfied, clearly present in the template
|
|
||||||
- 0.5: Partially satisfied or ambiguously addressed
|
|
||||||
- 0.0: Not satisfied, missing from the template
|
|
||||||
|
|
||||||
Respond with ONLY a JSON object:
|
|
||||||
{
|
|
||||||
"scores": {"criterion_1": 0.8, "criterion_2": 1.0, ...},
|
|
||||||
"reasoning": "Brief explanation"
|
|
||||||
}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
async def _judge_template(
|
|
||||||
prompt_template: str,
|
|
||||||
criteria: list[str],
|
|
||||||
*,
|
|
||||||
judge_model: str = "gpt-4o-mini",
|
|
||||||
) -> tuple[dict[str, float], str]:
|
|
||||||
"""Use an LLM to evaluate a generated prompt_template against criteria.
|
|
||||||
|
|
||||||
Returns (criteria_scores, reasoning).
|
|
||||||
"""
|
|
||||||
from shared.llm import get_llm
|
|
||||||
|
|
||||||
llm = get_llm(model=judge_model, temperature=0)
|
|
||||||
|
|
||||||
criteria_text = "\n".join(f" {i+1}. {c}" for i, c in enumerate(criteria))
|
|
||||||
user_content = (
|
|
||||||
f"## Generated prompt_template\n```\n{prompt_template}\n```\n\n"
|
|
||||||
f"## Criteria to evaluate\n{criteria_text}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await llm.ainvoke([
|
|
||||||
SystemMessage(content=_JOURNEY_JUDGE_SYSTEM),
|
|
||||||
HumanMessage(content=user_content),
|
|
||||||
])
|
|
||||||
raw = response.content.strip()
|
|
||||||
if raw.startswith("```"):
|
|
||||||
raw = raw.split("```")[1]
|
|
||||||
if raw.startswith("json"):
|
|
||||||
raw = raw[4:]
|
|
||||||
parsed = json.loads(raw.strip())
|
|
||||||
|
|
||||||
scores_raw = parsed.get("scores", {})
|
|
||||||
# Map criterion keys back to the original criteria text
|
|
||||||
criteria_scores: dict[str, float] = {}
|
|
||||||
for i, criterion in enumerate(criteria):
|
|
||||||
# Try matching by index key or exact criterion text
|
|
||||||
key_candidates = [
|
|
||||||
f"criterion_{i+1}",
|
|
||||||
criterion,
|
|
||||||
criterion[:50],
|
|
||||||
str(i + 1),
|
|
||||||
]
|
|
||||||
score = 0.0
|
|
||||||
for key in key_candidates:
|
|
||||||
if key in scores_raw:
|
|
||||||
score = float(scores_raw[key])
|
|
||||||
break
|
|
||||||
# If no match found, try values in order
|
|
||||||
if score == 0.0 and i < len(scores_raw):
|
|
||||||
score = float(list(scores_raw.values())[i])
|
|
||||||
criteria_scores[criterion] = score
|
|
||||||
|
|
||||||
reasoning = str(parsed.get("reasoning", ""))
|
|
||||||
return criteria_scores, reasoning
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("journey_eval: LLM judge failed: %s", exc)
|
|
||||||
return {c: 0.0 for c in criteria}, f"Judge error: {exc}"
|
|
||||||
|
|
||||||
|
|
||||||
# ── Journey runner ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def run_single_journey_eval(
|
|
||||||
fixture: JourneyFixture,
|
|
||||||
model: str,
|
|
||||||
*,
|
|
||||||
judge_model: str = "gpt-4o-mini",
|
|
||||||
data_dir: Path | None = None,
|
|
||||||
) -> JourneyEvalResult:
|
|
||||||
"""Execute one journey eval: start \u2192 messages \u2192 score template."""
|
|
||||||
from shared.config import settings
|
|
||||||
|
|
||||||
# When data_dir is given, use its parent as MockExecutor root
|
|
||||||
# and its name as the journey directory so the LLM sees a
|
|
||||||
# meaningful path (not ".").
|
|
||||||
if data_dir:
|
|
||||||
mock_root = data_dir.parent
|
|
||||||
journey_directory = data_dir.name
|
|
||||||
else:
|
|
||||||
mock_root = fixture.fixture_path.parent
|
|
||||||
journey_directory = fixture.directory
|
|
||||||
|
|
||||||
mock = MockExecutor(
|
|
||||||
fixture_dir=mock_root,
|
|
||||||
seed_records={},
|
|
||||||
)
|
|
||||||
|
|
||||||
original_model = settings.LLM_MODEL
|
|
||||||
settings.LLM_MODEL = model
|
|
||||||
|
|
||||||
eval_user_id = f"eval-journey-{uuid.uuid4().hex[:8]}"
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"journey_eval: starting %s | model=%s",
|
|
||||||
fixture.name, model,
|
|
||||||
)
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
prompt_template: str | None = None
|
|
||||||
conversation: list[dict[str, str]] = []
|
|
||||||
done = False
|
|
||||||
|
|
||||||
try:
|
|
||||||
from shared.ws_context import set_current_user, clear_current_user
|
|
||||||
from app.journey import handle_journey_start, handle_journey_message, _sessions
|
|
||||||
|
|
||||||
set_current_user(eval_user_id)
|
|
||||||
with mock.patch():
|
|
||||||
# ── Start the journey ────────────────────────────────
|
|
||||||
start_frame: dict[str, Any] = {
|
|
||||||
"agent_type": "local",
|
|
||||||
"directory": journey_directory,
|
|
||||||
"data_types": fixture.data_types,
|
|
||||||
"session_id": f"eval-{uuid.uuid4().hex[:8]}",
|
|
||||||
}
|
|
||||||
|
|
||||||
reply = await handle_journey_start(eval_user_id, start_frame)
|
|
||||||
session_id = reply["session_id"]
|
|
||||||
conversation.append({"role": "assistant", "content": reply["message"]})
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"journey_eval: start reply (%d chars), done=%s",
|
|
||||||
len(reply["message"]), reply["done"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if reply["done"]:
|
|
||||||
prompt_template = reply.get("prompt_template")
|
|
||||||
done = True
|
|
||||||
else:
|
|
||||||
# ── Send user messages ───────────────────────────
|
|
||||||
for i, user_msg in enumerate(fixture.user_messages):
|
|
||||||
if done:
|
|
||||||
break
|
|
||||||
|
|
||||||
conversation.append({"role": "user", "content": user_msg})
|
|
||||||
|
|
||||||
msg_frame: dict[str, Any] = {
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": user_msg,
|
|
||||||
}
|
|
||||||
reply = await handle_journey_message(eval_user_id, msg_frame)
|
|
||||||
conversation.append({"role": "assistant", "content": reply["message"]})
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"journey_eval: turn %d reply (%d chars), done=%s",
|
|
||||||
i + 1, len(reply["message"]), reply["done"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if reply["done"]:
|
|
||||||
prompt_template = reply.get("prompt_template")
|
|
||||||
done = True
|
|
||||||
|
|
||||||
# If not done after all user messages, send a final nudge
|
|
||||||
if not done:
|
|
||||||
nudge = "Please generate the final prompt_template now. I'm satisfied with the configuration."
|
|
||||||
conversation.append({"role": "user", "content": nudge})
|
|
||||||
|
|
||||||
nudge_frame: dict[str, Any] = {
|
|
||||||
"session_id": session_id,
|
|
||||||
"message": nudge,
|
|
||||||
}
|
|
||||||
reply = await handle_journey_message(eval_user_id, nudge_frame)
|
|
||||||
conversation.append({"role": "assistant", "content": reply["message"]})
|
|
||||||
if reply["done"]:
|
|
||||||
prompt_template = reply.get("prompt_template")
|
|
||||||
done = True
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("journey_eval: pipeline failed for %s/%s: %s", fixture.name, model, exc)
|
|
||||||
finally:
|
|
||||||
settings.LLM_MODEL = original_model
|
|
||||||
from shared.ws_context import clear_current_user
|
|
||||||
clear_current_user()
|
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
turns = len([c for c in conversation if c["role"] == "user"])
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"journey_eval: completed in %.1fs — %d turns, done=%s, template=%s",
|
|
||||||
elapsed, turns, done, "yes" if prompt_template else "no",
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Score the template ───────────────────────────────────────
|
|
||||||
criteria_scores: dict[str, float] = {}
|
|
||||||
judge_reasoning = ""
|
|
||||||
|
|
||||||
if prompt_template and fixture.expected_template_criteria:
|
|
||||||
criteria_scores, judge_reasoning = await _judge_template(
|
|
||||||
prompt_template,
|
|
||||||
fixture.expected_template_criteria,
|
|
||||||
judge_model=judge_model,
|
|
||||||
)
|
|
||||||
elif not prompt_template:
|
|
||||||
criteria_scores = {c: 0.0 for c in fixture.expected_template_criteria}
|
|
||||||
judge_reasoning = "No prompt_template was generated — journey did not complete."
|
|
||||||
|
|
||||||
overall = (
|
|
||||||
sum(criteria_scores.values()) / len(criteria_scores)
|
|
||||||
if criteria_scores
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
|
|
||||||
result = JourneyEvalResult(
|
|
||||||
fixture_name=fixture.name,
|
|
||||||
model=model,
|
|
||||||
prompt_template=prompt_template,
|
|
||||||
conversation_turns=turns,
|
|
||||||
done=done,
|
|
||||||
criteria_scores=criteria_scores,
|
|
||||||
overall_score=overall,
|
|
||||||
judge_reasoning=judge_reasoning,
|
|
||||||
elapsed_seconds=elapsed,
|
|
||||||
)
|
|
||||||
|
|
||||||
# ── Report to Langfuse ───────────────────────────────────────
|
|
||||||
trace_id = langfuse_eval.log_eval_trace(
|
|
||||||
fixture_name=fixture.name,
|
|
||||||
model=model,
|
|
||||||
prompt_variant="journey",
|
|
||||||
prompt_template=prompt_template or "(not generated)",
|
|
||||||
actual_mutations=[{"conversation": conversation[:20]}],
|
|
||||||
scores_summary=result.summary(),
|
|
||||||
langfuse_prompt_names=["journey_system"],
|
|
||||||
)
|
|
||||||
|
|
||||||
if trace_id:
|
|
||||||
from eval.scorer import EvalScores
|
|
||||||
scores_obj = EvalScores(
|
|
||||||
fixture_name=fixture.name,
|
|
||||||
model=model,
|
|
||||||
prompt_variant="journey",
|
|
||||||
precision=overall,
|
|
||||||
recall=float(done),
|
|
||||||
f1=overall,
|
|
||||||
llm_judge_score=overall,
|
|
||||||
llm_judge_reasoning=judge_reasoning,
|
|
||||||
)
|
|
||||||
langfuse_eval.post_eval_scores(scores_obj, trace_id=trace_id)
|
|
||||||
|
|
||||||
return result
|
|
||||||
|
|
||||||
|
|
||||||
async def run_journey_fixture_eval(
|
|
||||||
fixture: JourneyFixture,
|
|
||||||
models: list[str],
|
|
||||||
*,
|
|
||||||
judge_model: str = "gpt-4o-mini",
|
|
||||||
data_dir: Path | None = None,
|
|
||||||
) -> list[JourneyEvalResult]:
|
|
||||||
"""Run all models for a journey fixture."""
|
|
||||||
langfuse_eval.sync_journey_fixture_to_dataset(fixture)
|
|
||||||
|
|
||||||
results: list[JourneyEvalResult] = []
|
|
||||||
for model in models:
|
|
||||||
result = await run_single_journey_eval(
|
|
||||||
fixture, model, judge_model=judge_model,
|
|
||||||
data_dir=data_dir,
|
|
||||||
)
|
|
||||||
results.append(result)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def print_journey_results(results: list[JourneyEvalResult]) -> None:
|
|
||||||
"""Print a formatted summary of journey eval results."""
|
|
||||||
if not results:
|
|
||||||
print("\nNo journey eval results.")
|
|
||||||
return
|
|
||||||
|
|
||||||
print("\n" + "=" * 95)
|
|
||||||
print(f"{'Fixture':<25} {'Model':<25} {'Done':>5} {'Turns':>6} {'Score':>7} {'Time':>7}")
|
|
||||||
print("-" * 95)
|
|
||||||
|
|
||||||
for r in results:
|
|
||||||
done_str = "yes" if r.done else "NO"
|
|
||||||
print(
|
|
||||||
f"{r.fixture_name:<25} {r.model:<25} {done_str:>5} "
|
|
||||||
f"{r.conversation_turns:>6} {r.overall_score:>7.2f} {r.elapsed_seconds:>6.1f}s"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("=" * 95)
|
|
||||||
|
|
||||||
# Criteria breakdown
|
|
||||||
for r in results:
|
|
||||||
if r.criteria_scores:
|
|
||||||
print(f"\n[{r.model}] Criteria scores:")
|
|
||||||
for criterion, score in r.criteria_scores.items():
|
|
||||||
indicator = "PASS" if score >= 0.7 else "PARTIAL" if score >= 0.4 else "FAIL"
|
|
||||||
print(f" {indicator:>7} ({score:.1f}) {criterion}")
|
|
||||||
|
|
||||||
if r.judge_reasoning:
|
|
||||||
print(f" Judge: {r.judge_reasoning}")
|
|
||||||
|
|
||||||
if r.prompt_template:
|
|
||||||
preview = r.prompt_template[:200].replace("\n", " ")
|
|
||||||
print(f" Template preview: {preview}...")
|
|
||||||
|
|
||||||
print()
|
|
||||||
@@ -1,327 +0,0 @@
|
|||||||
"""Langfuse evaluation integration — datasets, runs, and scoring.
|
|
||||||
|
|
||||||
Uses the Langfuse Python SDK v4 (OpenTelemetry-based) to:
|
|
||||||
|
|
||||||
1. **Sync fixtures → Langfuse datasets**: Each YAML fixture becomes a dataset,
|
|
||||||
each prompt variant + expected pair becomes a dataset item.
|
|
||||||
|
|
||||||
2. **Track eval runs**: Each (fixture × model × prompt_variant) execution
|
|
||||||
is recorded as a trace with linked scores.
|
|
||||||
|
|
||||||
3. **Post scores**: precision, recall, F1, field_accuracy, llm_judge are
|
|
||||||
posted as numeric scores on the trace.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from shared.config import settings
|
|
||||||
from eval.config import EvalFixture
|
|
||||||
from eval.scorer import EvalScores
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
def _get_langfuse():
|
|
||||||
"""Get or create a Langfuse client instance (SDK v4)."""
|
|
||||||
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
|
|
||||||
return None
|
|
||||||
try:
|
|
||||||
os.environ.setdefault("LANGFUSE_SECRET_KEY", settings.LANGFUSE_SECRET_KEY)
|
|
||||||
os.environ.setdefault("LANGFUSE_PUBLIC_KEY", settings.LANGFUSE_PUBLIC_KEY)
|
|
||||||
if settings.LANGFUSE_HOST:
|
|
||||||
os.environ.setdefault("LANGFUSE_HOST", settings.LANGFUSE_HOST)
|
|
||||||
from langfuse import get_client
|
|
||||||
return get_client()
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse_eval: failed to create client: %s", exc)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def sync_fixture_to_dataset(fixture: EvalFixture) -> str | None:
|
|
||||||
"""Create or update a Langfuse dataset from a fixture.
|
|
||||||
|
|
||||||
Each prompt variant becomes a separate dataset item with:
|
|
||||||
- input: {directory, data_types, prompt_template, seed_records}
|
|
||||||
- expected_output: {expected records}
|
|
||||||
|
|
||||||
Returns the dataset name, or None if Langfuse is unavailable.
|
|
||||||
"""
|
|
||||||
lf = _get_langfuse()
|
|
||||||
if lf is None:
|
|
||||||
logger.info("langfuse_eval: Langfuse not configured — skipping dataset sync")
|
|
||||||
return None
|
|
||||||
|
|
||||||
dataset_name = f"batch-eval-{fixture.name}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
lf.create_dataset(
|
|
||||||
name=dataset_name,
|
|
||||||
description=fixture.description,
|
|
||||||
metadata={
|
|
||||||
"data_types": ",".join(fixture.data_types),
|
|
||||||
"file_extensions": ",".join(fixture.file_extensions) if fixture.file_extensions else "",
|
|
||||||
},
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
# Dataset may already exist — that's fine
|
|
||||||
pass
|
|
||||||
|
|
||||||
# Build expected_output appropriate to the fixture's mode
|
|
||||||
expected_output: dict[str, Any] = {}
|
|
||||||
if fixture.mode in ("step1", "full") and fixture.expected_classification:
|
|
||||||
expected_output["classifications"] = [
|
|
||||||
{"file": ec.file, "project_id": ec.project_id, "domains": ec.domains}
|
|
||||||
for ec in fixture.expected_classification
|
|
||||||
]
|
|
||||||
if fixture.mode in ("step2", "full") and fixture.expected:
|
|
||||||
for rec in fixture.expected:
|
|
||||||
expected_output.setdefault(rec.table, []).append(rec.fields)
|
|
||||||
|
|
||||||
item_id = f"{fixture.name}--{fixture.mode}"
|
|
||||||
try:
|
|
||||||
lf.create_dataset_item(
|
|
||||||
dataset_name=dataset_name,
|
|
||||||
id=item_id,
|
|
||||||
input={
|
|
||||||
"directory": fixture.directory,
|
|
||||||
"data_types": fixture.data_types,
|
|
||||||
"mode": fixture.mode,
|
|
||||||
"seed_records": fixture.seed_records,
|
|
||||||
},
|
|
||||||
expected_output=expected_output,
|
|
||||||
metadata={"mode": fixture.mode},
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning(
|
|
||||||
"langfuse_eval: failed to upsert dataset item %s: %s", item_id, exc
|
|
||||||
)
|
|
||||||
|
|
||||||
lf.flush()
|
|
||||||
logger.info("langfuse_eval: synced fixture '%s' → dataset '%s'", fixture.name, dataset_name)
|
|
||||||
return dataset_name
|
|
||||||
|
|
||||||
|
|
||||||
def sync_journey_fixture_to_dataset(fixture) -> str | None:
|
|
||||||
"""Create or update a Langfuse dataset from a journey fixture.
|
|
||||||
|
|
||||||
Each journey fixture becomes a single dataset item with:
|
|
||||||
- input: {directory, data_types, user_messages}
|
|
||||||
- expected_output: {criteria}
|
|
||||||
"""
|
|
||||||
lf = _get_langfuse()
|
|
||||||
if lf is None:
|
|
||||||
logger.info("langfuse_eval: Langfuse not configured — skipping journey dataset sync")
|
|
||||||
return None
|
|
||||||
|
|
||||||
dataset_name = f"journey-eval-{fixture.name}"
|
|
||||||
|
|
||||||
try:
|
|
||||||
lf.create_dataset(
|
|
||||||
name=dataset_name,
|
|
||||||
description=fixture.description,
|
|
||||||
metadata={"type": "journey", "data_types": ",".join(fixture.data_types)},
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass # Dataset may already exist
|
|
||||||
|
|
||||||
item_id = f"{fixture.name}--journey"
|
|
||||||
try:
|
|
||||||
lf.create_dataset_item(
|
|
||||||
dataset_name=dataset_name,
|
|
||||||
id=item_id,
|
|
||||||
input={
|
|
||||||
"directory": fixture.directory,
|
|
||||||
"data_types": fixture.data_types,
|
|
||||||
"user_messages": fixture.user_messages,
|
|
||||||
},
|
|
||||||
expected_output={
|
|
||||||
"criteria": fixture.expected_template_criteria,
|
|
||||||
},
|
|
||||||
metadata={"type": "journey"},
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse_eval: failed to upsert journey dataset item %s: %s", item_id, exc)
|
|
||||||
|
|
||||||
lf.flush()
|
|
||||||
logger.info("langfuse_eval: synced journey fixture '%s' → dataset '%s'", fixture.name, dataset_name)
|
|
||||||
return dataset_name
|
|
||||||
|
|
||||||
|
|
||||||
def create_eval_run(
|
|
||||||
dataset_name: str,
|
|
||||||
run_name: str,
|
|
||||||
*,
|
|
||||||
metadata: dict[str, Any] | None = None,
|
|
||||||
) -> str:
|
|
||||||
"""Create a dataset run in Langfuse. Returns the run name.
|
|
||||||
|
|
||||||
Note: In SDK v4, dataset runs are created implicitly via
|
|
||||||
dataset.run_experiment(). This function is kept for backwards
|
|
||||||
compatibility but may not create a run.
|
|
||||||
"""
|
|
||||||
lf = _get_langfuse()
|
|
||||||
if lf is None:
|
|
||||||
return run_name
|
|
||||||
|
|
||||||
try:
|
|
||||||
if hasattr(lf, "create_dataset_run"):
|
|
||||||
lf.create_dataset_run(
|
|
||||||
dataset_name=dataset_name,
|
|
||||||
run_name=run_name,
|
|
||||||
metadata=metadata or {},
|
|
||||||
)
|
|
||||||
lf.flush()
|
|
||||||
else:
|
|
||||||
logger.debug("langfuse_eval: create_dataset_run not available in SDK v4")
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse_eval: failed to create run %s: %s", run_name, exc)
|
|
||||||
|
|
||||||
return run_name
|
|
||||||
|
|
||||||
|
|
||||||
def post_eval_scores(
|
|
||||||
scores: EvalScores,
|
|
||||||
*,
|
|
||||||
trace_id: str | None = None,
|
|
||||||
dataset_name: str | None = None,
|
|
||||||
run_name: str | None = None,
|
|
||||||
) -> None:
|
|
||||||
"""Post evaluation scores to Langfuse.
|
|
||||||
|
|
||||||
If trace_id is provided, scores are attached to that trace.
|
|
||||||
"""
|
|
||||||
lf = _get_langfuse()
|
|
||||||
if lf is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
score_data = [
|
|
||||||
("precision", scores.precision),
|
|
||||||
("recall", scores.recall),
|
|
||||||
("f1", scores.f1),
|
|
||||||
]
|
|
||||||
# Only post field_accuracy when there are field-level scores (step2/full)
|
|
||||||
if scores.field_scores:
|
|
||||||
score_data.append(("field_accuracy", scores.field_accuracy))
|
|
||||||
if scores.llm_judge_score is not None:
|
|
||||||
score_data.append(("llm_judge", scores.llm_judge_score))
|
|
||||||
|
|
||||||
for name, value in score_data:
|
|
||||||
try:
|
|
||||||
lf.create_score(
|
|
||||||
name=name,
|
|
||||||
value=value,
|
|
||||||
trace_id=trace_id,
|
|
||||||
data_type="NUMERIC",
|
|
||||||
comment=f"{scores.fixture_name} | {scores.model} | {scores.prompt_variant}",
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse_eval: failed to post score %s: %s", name, exc)
|
|
||||||
|
|
||||||
lf.flush()
|
|
||||||
logger.info(
|
|
||||||
"langfuse_eval: posted %d scores for %s/%s/%s",
|
|
||||||
len(score_data), scores.fixture_name, scores.model, scores.prompt_variant,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def log_eval_trace(
|
|
||||||
*,
|
|
||||||
fixture_name: str,
|
|
||||||
model: str,
|
|
||||||
prompt_variant: str,
|
|
||||||
prompt_template: str,
|
|
||||||
actual_mutations: list[dict],
|
|
||||||
scores_summary: dict[str, Any],
|
|
||||||
step1_results: list[dict] | None = None,
|
|
||||||
dataset_name: str | None = None,
|
|
||||||
run_name: str | None = None,
|
|
||||||
dataset_item_id: str | None = None,
|
|
||||||
langfuse_prompt_names: list[str] | None = None,
|
|
||||||
) -> str | None:
|
|
||||||
"""Create a Langfuse trace for one eval execution and link it to a dataset run.
|
|
||||||
|
|
||||||
Uses SDK v4 observation API (traces are created implicitly by root spans).
|
|
||||||
``langfuse_prompt_names`` can contain one or two prompt names to link
|
|
||||||
(e.g. ``["batch_file_classifier", "batch_processing"]`` for full mode).
|
|
||||||
Each prompt gets its own generation-type observation for per-version
|
|
||||||
metrics tracking.
|
|
||||||
|
|
||||||
Returns the trace_id, or None if Langfuse is unavailable.
|
|
||||||
"""
|
|
||||||
lf = _get_langfuse()
|
|
||||||
if lf is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
try:
|
|
||||||
from langfuse import propagate_attributes
|
|
||||||
|
|
||||||
# Fetch prompt objects for linking
|
|
||||||
prompt_objs: list[tuple[str, Any]] = []
|
|
||||||
for pname in (langfuse_prompt_names or []):
|
|
||||||
try:
|
|
||||||
obj = lf.get_prompt(name=pname, cache_ttl_seconds=300)
|
|
||||||
prompt_objs.append((pname, obj))
|
|
||||||
logger.info("langfuse_eval: linked prompt '%s' (type=%s)", pname, type(obj).__name__)
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse_eval: prompt '%s' not found — %s", pname, exc)
|
|
||||||
|
|
||||||
# Build trace output dict
|
|
||||||
trace_output: dict[str, Any] = {"scores": scores_summary}
|
|
||||||
if step1_results:
|
|
||||||
trace_output["classifications"] = step1_results
|
|
||||||
if actual_mutations:
|
|
||||||
trace_output["mutations"] = actual_mutations[:50]
|
|
||||||
|
|
||||||
with propagate_attributes(
|
|
||||||
trace_name=f"eval-{fixture_name}",
|
|
||||||
metadata={
|
|
||||||
"eval": "true",
|
|
||||||
"fixture": fixture_name,
|
|
||||||
"model": model,
|
|
||||||
"prompt_variant": prompt_variant,
|
|
||||||
},
|
|
||||||
tags=["eval", f"model:{model}", f"variant:{prompt_variant}"],
|
|
||||||
):
|
|
||||||
# Root span for the eval run
|
|
||||||
span = lf.start_observation(name=f"eval-{fixture_name}")
|
|
||||||
span.update(
|
|
||||||
input={
|
|
||||||
"prompt_template": prompt_template,
|
|
||||||
"model": model,
|
|
||||||
"prompt_variant": prompt_variant,
|
|
||||||
},
|
|
||||||
output=trace_output,
|
|
||||||
)
|
|
||||||
trace_id = span.trace_id
|
|
||||||
|
|
||||||
# Create a generation-type observation per linked prompt
|
|
||||||
for pname, pobj in prompt_objs:
|
|
||||||
gen = lf.start_observation(
|
|
||||||
name=f"prompt-{pname}",
|
|
||||||
prompt=pobj,
|
|
||||||
as_type="generation",
|
|
||||||
)
|
|
||||||
gen.end()
|
|
||||||
|
|
||||||
# Link to dataset run if available
|
|
||||||
if dataset_name and run_name and dataset_item_id:
|
|
||||||
try:
|
|
||||||
dataset = lf.get_dataset(dataset_name)
|
|
||||||
for item in dataset.items:
|
|
||||||
if item.id == dataset_item_id:
|
|
||||||
item.link(span, run_name)
|
|
||||||
break
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse_eval: failed to link trace to dataset run: %s", exc)
|
|
||||||
|
|
||||||
span.end()
|
|
||||||
|
|
||||||
lf.flush()
|
|
||||||
return trace_id
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("langfuse_eval: failed to create eval trace: %s", exc)
|
|
||||||
return None
|
|
||||||
@@ -1,258 +0,0 @@
|
|||||||
"""Mock executor — intercepts execute_on_client for offline E2E testing.
|
|
||||||
|
|
||||||
Patches ``execute_on_client`` at all usage sites so agent pipeline runs don't
|
|
||||||
require a live Electron client or Redis. Instead:
|
|
||||||
|
|
||||||
- **Filesystem actions** (list_directory, read_file_content, get_file_metadata)
|
|
||||||
are served from local fixture files on disk.
|
|
||||||
- **Read actions** (select, get) return preseeded records from an in-memory
|
|
||||||
store provided by the test fixture.
|
|
||||||
- **Write actions** (insert, update, delete) are captured as *mutations* and
|
|
||||||
stored for later comparison against expected results.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Any
|
|
||||||
from contextlib import contextmanager, asynccontextmanager
|
|
||||||
from unittest.mock import AsyncMock, patch
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Mutation:
|
|
||||||
"""A single recorded write operation."""
|
|
||||||
|
|
||||||
action: str # insert | update | delete
|
|
||||||
table: str
|
|
||||||
data: dict[str, Any]
|
|
||||||
timestamp: float = field(default_factory=time.time)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Fake DB helpers (used to bypass async_session in full mode) ───────
|
|
||||||
|
|
||||||
class _FakeRow:
|
|
||||||
"""Mimics an AgentRunLog row returned by SQLAlchemy."""
|
|
||||||
id = 0
|
|
||||||
status = "running"
|
|
||||||
items_processed = 0
|
|
||||||
items_created = 0
|
|
||||||
errors: list[str] = []
|
|
||||||
completed_at = None
|
|
||||||
|
|
||||||
def __setattr__(self, name: str, value: Any) -> None:
|
|
||||||
object.__setattr__(self, name, value)
|
|
||||||
|
|
||||||
|
|
||||||
class _FakeResult:
|
|
||||||
"""Mimics a SQLAlchemy ``Result`` with ``scalar_one_or_none``."""
|
|
||||||
def __init__(self, row: _FakeRow) -> None:
|
|
||||||
self._row = row
|
|
||||||
|
|
||||||
def scalar_one_or_none(self) -> _FakeRow:
|
|
||||||
return self._row
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MockExecutor:
|
|
||||||
"""In-memory executor that replaces Redis-based tool round-trip.
|
|
||||||
|
|
||||||
Parameters
|
|
||||||
----------
|
|
||||||
fixture_dir : Path
|
|
||||||
Directory containing sample files for filesystem tool calls.
|
|
||||||
seed_records : dict[str, list[dict]]
|
|
||||||
Pre-existing records per table, e.g. ``{"tasks": [...], "projects": [...]}``.
|
|
||||||
The executor returns these for ``select`` / ``get`` actions and auto-updates
|
|
||||||
them on ``insert`` / ``update`` / ``delete`` so subsequent selects reflect changes.
|
|
||||||
"""
|
|
||||||
|
|
||||||
fixture_dir: Path
|
|
||||||
seed_records: dict[str, list[dict]] = field(default_factory=dict)
|
|
||||||
mutations: list[Mutation] = field(default_factory=list)
|
|
||||||
_id_counter: int = field(default=1000, repr=False)
|
|
||||||
|
|
||||||
# ── Public API ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
def reset(self) -> None:
|
|
||||||
"""Clear recorded mutations (keep seed_records intact)."""
|
|
||||||
self.mutations.clear()
|
|
||||||
|
|
||||||
def get_mutations(self, *, table: str | None = None, action: str | None = None) -> list[Mutation]:
|
|
||||||
"""Filter mutations by table and/or action."""
|
|
||||||
result = self.mutations
|
|
||||||
if table:
|
|
||||||
result = [m for m in result if m.table == table]
|
|
||||||
if action:
|
|
||||||
result = [m for m in result if m.action == action]
|
|
||||||
return result
|
|
||||||
|
|
||||||
def created_records(self, table: str) -> list[dict]:
|
|
||||||
"""Return data dicts of all inserts into *table*."""
|
|
||||||
return [m.data for m in self.mutations if m.table == table and m.action == "insert"]
|
|
||||||
|
|
||||||
def updated_records(self, table: str) -> list[dict]:
|
|
||||||
"""Return data dicts of all updates to *table*."""
|
|
||||||
return [m.data for m in self.mutations if m.table == table and m.action == "update"]
|
|
||||||
|
|
||||||
# ── Context manager for patching ──────────────────────────────
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def patch(self):
|
|
||||||
"""Patch execute_on_client and DB session at all usage sites."""
|
|
||||||
mock_fn = AsyncMock(side_effect=self._handle)
|
|
||||||
targets = [
|
|
||||||
"shared.ws_context.execute_on_client",
|
|
||||||
"app.agent_runner.execute_on_client",
|
|
||||||
"app.agents.filesystem_agent.execute_on_client",
|
|
||||||
]
|
|
||||||
|
|
||||||
# Mock async_session so run_local_agent / _finalize_run skip real DB
|
|
||||||
fake_row = _FakeRow()
|
|
||||||
fake_db = AsyncMock()
|
|
||||||
fake_db.commit = AsyncMock()
|
|
||||||
fake_db.refresh = AsyncMock()
|
|
||||||
fake_db.execute = AsyncMock(return_value=_FakeResult(fake_row))
|
|
||||||
fake_db.add = lambda obj: None # noqa: ARG005
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def _fake_session():
|
|
||||||
yield fake_db
|
|
||||||
|
|
||||||
patches = [patch(t, new=mock_fn) for t in targets]
|
|
||||||
patches.append(patch("app.agent_runner.async_session", _fake_session))
|
|
||||||
for p in patches:
|
|
||||||
p.start()
|
|
||||||
try:
|
|
||||||
yield mock_fn
|
|
||||||
finally:
|
|
||||||
for p in patches:
|
|
||||||
p.stop()
|
|
||||||
|
|
||||||
# ── Internal dispatch ─────────────────────────────────────────
|
|
||||||
|
|
||||||
async def _handle(
|
|
||||||
self,
|
|
||||||
action: str,
|
|
||||||
table: str | None = None,
|
|
||||||
data: dict[str, Any] | None = None,
|
|
||||||
filters: dict[str, Any] | None = None,
|
|
||||||
vector: list[float] | None = None,
|
|
||||||
limit: int | None = None,
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
# Filesystem
|
|
||||||
if action == "list_directory":
|
|
||||||
return self._list_directory(data or {})
|
|
||||||
if action == "read_file_content":
|
|
||||||
return self._read_file(data or {})
|
|
||||||
if action == "get_file_metadata":
|
|
||||||
return self._get_file_metadata(data or {})
|
|
||||||
|
|
||||||
# CRUD
|
|
||||||
if action == "select":
|
|
||||||
return self._select(table or "", filters)
|
|
||||||
if action == "get":
|
|
||||||
return self._get(table or "", data or {})
|
|
||||||
if action == "insert":
|
|
||||||
return self._insert(table or "", data or {})
|
|
||||||
if action == "update":
|
|
||||||
return self._update(table or "", data or {})
|
|
||||||
if action == "delete":
|
|
||||||
return self._delete(table or "", data or {})
|
|
||||||
|
|
||||||
# Vector (no-op for eval)
|
|
||||||
if action in ("vector_upsert", "vector_search"):
|
|
||||||
return {"rows": []}
|
|
||||||
|
|
||||||
return {"error": f"Unknown action: {action}"}
|
|
||||||
|
|
||||||
# ── Filesystem handlers ───────────────────────────────────────
|
|
||||||
|
|
||||||
def _list_directory(self, data: dict) -> dict:
|
|
||||||
rel_path = data.get("path", "")
|
|
||||||
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
|
|
||||||
if not abs_path.is_dir():
|
|
||||||
return {"entries": []}
|
|
||||||
entries: list[dict] = []
|
|
||||||
for child in sorted(abs_path.iterdir()):
|
|
||||||
entry_type = "directory" if child.is_dir() else "file"
|
|
||||||
# Return paths relative to fixture_dir but with the original prefix
|
|
||||||
entry_path = rel_path.rstrip("/\\") + "/" + child.name
|
|
||||||
entries.append({
|
|
||||||
"name": child.name,
|
|
||||||
"path": entry_path,
|
|
||||||
"type": entry_type,
|
|
||||||
})
|
|
||||||
return {"entries": entries}
|
|
||||||
|
|
||||||
def _read_file(self, data: dict) -> dict:
|
|
||||||
rel_path = data.get("path", "")
|
|
||||||
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
|
|
||||||
if not abs_path.is_file():
|
|
||||||
return {"content": "", "error": f"File not found: {rel_path}"}
|
|
||||||
return {"content": abs_path.read_text(encoding="utf-8", errors="replace")}
|
|
||||||
|
|
||||||
def _get_file_metadata(self, data: dict) -> dict:
|
|
||||||
rel_path = data.get("path", "")
|
|
||||||
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
|
|
||||||
if not abs_path.exists():
|
|
||||||
return {"error": f"Not found: {rel_path}"}
|
|
||||||
stat = abs_path.stat()
|
|
||||||
return {
|
|
||||||
"path": rel_path,
|
|
||||||
"size": stat.st_size,
|
|
||||||
"modifiedAt": int(stat.st_mtime * 1000),
|
|
||||||
"createdAt": int(stat.st_ctime * 1000),
|
|
||||||
"isDirectory": abs_path.is_dir(),
|
|
||||||
}
|
|
||||||
|
|
||||||
# ── CRUD handlers ─────────────────────────────────────────────
|
|
||||||
|
|
||||||
def _select(self, table: str, filters: dict | None) -> dict:
|
|
||||||
rows = list(self.seed_records.get(table, []))
|
|
||||||
if filters:
|
|
||||||
rows = [
|
|
||||||
r for r in rows
|
|
||||||
if all(r.get(k) == v for k, v in filters.items() if v is not None)
|
|
||||||
]
|
|
||||||
return {"rows": rows}
|
|
||||||
|
|
||||||
def _get(self, table: str, data: dict) -> dict:
|
|
||||||
record_id = data.get("id", "")
|
|
||||||
rows = self.seed_records.get(table, [])
|
|
||||||
for r in rows:
|
|
||||||
if r.get("id") == record_id:
|
|
||||||
return {"row": r}
|
|
||||||
return {"row": None}
|
|
||||||
|
|
||||||
def _insert(self, table: str, data: dict) -> dict:
|
|
||||||
self._id_counter += 1
|
|
||||||
record = {**data, "id": str(self._id_counter)}
|
|
||||||
# Add to seed so subsequent selects can find it
|
|
||||||
self.seed_records.setdefault(table, []).append(record)
|
|
||||||
self.mutations.append(Mutation(action="insert", table=table, data=record))
|
|
||||||
return {"row": record}
|
|
||||||
|
|
||||||
def _update(self, table: str, data: dict) -> dict:
|
|
||||||
record_id = data.get("id", "")
|
|
||||||
rows = self.seed_records.get(table, [])
|
|
||||||
for r in rows:
|
|
||||||
if r.get("id") == record_id:
|
|
||||||
r.update({k: v for k, v in data.items() if v is not None and v != ""})
|
|
||||||
self.mutations.append(Mutation(action="update", table=table, data=dict(r)))
|
|
||||||
return {"row": r}
|
|
||||||
# Record not found — still log the mutation
|
|
||||||
self.mutations.append(Mutation(action="update", table=table, data=data))
|
|
||||||
return {"row": data}
|
|
||||||
|
|
||||||
def _delete(self, table: str, data: dict) -> dict:
|
|
||||||
record_id = data.get("id", "")
|
|
||||||
rows = self.seed_records.get(table, [])
|
|
||||||
self.seed_records[table] = [r for r in rows if r.get("id") != record_id]
|
|
||||||
self.mutations.append(Mutation(action="delete", table=table, data={"id": record_id}))
|
|
||||||
return {"deleted": True}
|
|
||||||
@@ -1,2 +0,0 @@
|
|||||||
# Extra dependencies for the eval harness (on top of the service requirements.txt)
|
|
||||||
pyyaml>=6.0.0
|
|
||||||
@@ -1,545 +0,0 @@
|
|||||||
"""Eval runner — orchestrates fixture → mock → agent pipeline → scoring.
|
|
||||||
|
|
||||||
Supports three eval modes:
|
|
||||||
|
|
||||||
- **step1**: Test classification prompt only (``_STEP1_SYSTEM_PROMPT``).
|
|
||||||
Calls the LLM with fixture-provided ``domain_definitions`` and
|
|
||||||
``projects_list`` and compares output against ``expected_classification``.
|
|
||||||
|
|
||||||
- **step2**: Test processing prompt only (``_PROCESSING_SYSTEM_PROMPT``).
|
|
||||||
Compiles the prompt with fixture-provided ``existing_context``,
|
|
||||||
``project_context``, ``data_types``, and ``custom_prompt_section``,
|
|
||||||
then runs the tool-calling loop. Mutations are scored against
|
|
||||||
``expected`` records.
|
|
||||||
|
|
||||||
- **full**: Run ``run_local_agent()`` end-to-end (both steps).
|
|
||||||
Scored on both classification and extraction.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import copy
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
import time
|
|
||||||
import uuid
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from eval.config import EvalFixture, ExpectedClassification
|
|
||||||
from eval.mock_executor import MockExecutor
|
|
||||||
from eval.scorer import (
|
|
||||||
EvalScores,
|
|
||||||
FieldScore,
|
|
||||||
compute_precision_recall,
|
|
||||||
llm_judge_score,
|
|
||||||
score_field_match,
|
|
||||||
)
|
|
||||||
from eval import langfuse_eval
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Step 1 runner ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_step1(
|
|
||||||
fixture: EvalFixture,
|
|
||||||
model: str,
|
|
||||||
mock: MockExecutor,
|
|
||||||
) -> list[dict[str, Any]]:
|
|
||||||
"""Run step-1 classification for every file in the fixture directory.
|
|
||||||
|
|
||||||
Scans the directory recursively, classifies each file, and returns
|
|
||||||
a list of result dicts:
|
|
||||||
``[{file, project_id, domains, new_project_name}, ...]``
|
|
||||||
"""
|
|
||||||
from app.agent_runner import _classify_file
|
|
||||||
|
|
||||||
# Build project name lookup for display
|
|
||||||
proj_names: dict[str, str] = {
|
|
||||||
p.get("id", ""): p.get("name", "") for p in fixture.projects_list
|
|
||||||
}
|
|
||||||
|
|
||||||
# Discover all files in the fixture directory
|
|
||||||
all_files = await _scan_fixture_files(mock, fixture.directory)
|
|
||||||
print(f"\n Scanning {len(all_files)} files in {fixture.directory}\n")
|
|
||||||
|
|
||||||
results: list[dict[str, Any]] = []
|
|
||||||
for i, file_path in enumerate(all_files, 1):
|
|
||||||
file_result = await mock._handle(
|
|
||||||
action="read_file_content",
|
|
||||||
data={"path": file_path},
|
|
||||||
)
|
|
||||||
file_content: str = file_result.get("content", "")
|
|
||||||
if not file_content.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
project_id, domains, new_name = await _classify_file(
|
|
||||||
file_path=file_path,
|
|
||||||
file_content=file_content,
|
|
||||||
projects=fixture.projects_list,
|
|
||||||
config_data_types=fixture.data_types,
|
|
||||||
custom_system_prompt=fixture.custom_step1_prompt or None,
|
|
||||||
)
|
|
||||||
|
|
||||||
short_name = file_path.rsplit("/", 1)[-1] if "/" in file_path else file_path
|
|
||||||
proj_label = proj_names.get(project_id, new_name or "?")
|
|
||||||
print(f" [{i}/{len(all_files)}] {short_name} → {project_id} ({proj_label}) {domains}")
|
|
||||||
|
|
||||||
results.append({
|
|
||||||
"file": file_path,
|
|
||||||
"project_id": project_id,
|
|
||||||
"domains": domains,
|
|
||||||
"new_project_name": new_name,
|
|
||||||
})
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
async def _scan_fixture_files(mock: MockExecutor, directory: str) -> list[str]:
|
|
||||||
"""Recursively list all files under *directory* via the mock executor."""
|
|
||||||
files: list[str] = []
|
|
||||||
|
|
||||||
async def _walk(path: str) -> None:
|
|
||||||
result = await mock._handle(action="list_directory", data={"path": path})
|
|
||||||
for entry in result.get("entries", []):
|
|
||||||
if entry.get("type") == "directory":
|
|
||||||
await _walk(entry["path"])
|
|
||||||
elif entry.get("type") == "file":
|
|
||||||
files.append(entry["path"])
|
|
||||||
|
|
||||||
await _walk(directory)
|
|
||||||
return sorted(files)
|
|
||||||
|
|
||||||
|
|
||||||
def _score_step1(
|
|
||||||
fixture: EvalFixture,
|
|
||||||
results: list[dict[str, Any]],
|
|
||||||
) -> tuple[float, float, float, str]:
|
|
||||||
"""Score step-1 results. Returns (precision, recall, f1, reasoning).
|
|
||||||
|
|
||||||
Files with expected classifications are scored (OK/FAIL).
|
|
||||||
Files without expectations are shown as informational (INFO).
|
|
||||||
"""
|
|
||||||
if not fixture.expected_classification:
|
|
||||||
return 0.0, 0.0, 0.0, "No expected classifications"
|
|
||||||
|
|
||||||
# Build project name lookup
|
|
||||||
proj_names: dict[str, str] = {
|
|
||||||
p.get("id", ""): p.get("name", "") for p in fixture.projects_list
|
|
||||||
}
|
|
||||||
proj_names["new"] = "(new project)"
|
|
||||||
|
|
||||||
def _proj_label(pid: str, new_name: str | None = None) -> str:
|
|
||||||
name = proj_names.get(pid, "?")
|
|
||||||
if pid == "new" and new_name:
|
|
||||||
return f"new → \"{new_name}\""
|
|
||||||
return f"{pid} ({name})" if name and name != "?" else pid
|
|
||||||
|
|
||||||
def _short_file(path: str) -> str:
|
|
||||||
"""Use just the filename for cleaner display."""
|
|
||||||
return path.rsplit("/", 1)[-1] if "/" in path else path
|
|
||||||
|
|
||||||
expected_files = {ec.file for ec in fixture.expected_classification}
|
|
||||||
total = len(fixture.expected_classification)
|
|
||||||
matched = 0
|
|
||||||
|
|
||||||
scored_lines: list[str] = []
|
|
||||||
info_lines: list[str] = []
|
|
||||||
|
|
||||||
# Score expected files
|
|
||||||
for ec in fixture.expected_classification:
|
|
||||||
actual = next((r for r in results if r["file"] == ec.file), None)
|
|
||||||
fname = _short_file(ec.file)
|
|
||||||
if actual is None:
|
|
||||||
scored_lines.append(f" MISS {fname}")
|
|
||||||
scored_lines.append(f" expected: {_proj_label(ec.project_id)}")
|
|
||||||
continue
|
|
||||||
|
|
||||||
pid_ok = actual["project_id"] == ec.project_id
|
|
||||||
domains_ok = set(actual["domains"]) == set(ec.domains) if ec.domains else True
|
|
||||||
|
|
||||||
if pid_ok and domains_ok:
|
|
||||||
matched += 1
|
|
||||||
scored_lines.append(f" OK {fname}")
|
|
||||||
scored_lines.append(f" project: {_proj_label(actual['project_id'])}")
|
|
||||||
scored_lines.append(f" domains: {actual['domains']}")
|
|
||||||
else:
|
|
||||||
scored_lines.append(f" FAIL {fname}")
|
|
||||||
if not pid_ok:
|
|
||||||
scored_lines.append(f" project: {_proj_label(actual['project_id'])} (expected: {_proj_label(ec.project_id)})")
|
|
||||||
else:
|
|
||||||
scored_lines.append(f" project: {_proj_label(actual['project_id'])}")
|
|
||||||
if not domains_ok:
|
|
||||||
scored_lines.append(f" domains: {actual['domains']} (expected: {ec.domains})")
|
|
||||||
else:
|
|
||||||
scored_lines.append(f" domains: {actual['domains']}")
|
|
||||||
|
|
||||||
# Show unscored files
|
|
||||||
for r in results:
|
|
||||||
if r["file"] not in expected_files:
|
|
||||||
fname = _short_file(r["file"])
|
|
||||||
proj = _proj_label(r["project_id"], r.get("new_project_name"))
|
|
||||||
info_lines.append(f" · {fname}")
|
|
||||||
info_lines.append(f" project: {proj} | domains: {r['domains']}")
|
|
||||||
|
|
||||||
precision = matched / total if total > 0 else 0.0
|
|
||||||
recall = precision
|
|
||||||
f1 = precision
|
|
||||||
|
|
||||||
parts: list[str] = []
|
|
||||||
if scored_lines:
|
|
||||||
parts.append(f"Scored ({matched}/{total}):")
|
|
||||||
parts.extend(scored_lines)
|
|
||||||
if info_lines:
|
|
||||||
parts.append(f"\nOther files ({len(info_lines) // 2}):")
|
|
||||||
parts.extend(info_lines)
|
|
||||||
|
|
||||||
return precision, recall, f1, "\n".join(parts)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Step 2 runner ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_step2(
|
|
||||||
fixture: EvalFixture,
|
|
||||||
model: str,
|
|
||||||
mock: MockExecutor,
|
|
||||||
) -> None:
|
|
||||||
"""Run step-2 processing for each file in the fixture directory.
|
|
||||||
|
|
||||||
Compiles ``_PROCESSING_SYSTEM_PROMPT`` with fixture-provided variables
|
|
||||||
and runs the tool-calling loop. Mutations are captured by the mock.
|
|
||||||
"""
|
|
||||||
from app.agent_runner import (
|
|
||||||
_PROCESSING_SYSTEM_PROMPT,
|
|
||||||
_build_processing_tools,
|
|
||||||
_run_agent_with_tools,
|
|
||||||
_MAX_PROCESSING_STEPS,
|
|
||||||
)
|
|
||||||
from app import tracing
|
|
||||||
|
|
||||||
# Compile the processing prompt with fixture variables
|
|
||||||
system_prompt = tracing.compile_prompt(
|
|
||||||
"batch_processing",
|
|
||||||
fallback=_PROCESSING_SYSTEM_PROMPT,
|
|
||||||
variables={
|
|
||||||
"existing_context": fixture.existing_context,
|
|
||||||
"project_context": fixture.project_context,
|
|
||||||
"data_types": ", ".join(fixture.data_types),
|
|
||||||
"custom_prompt_section": fixture.custom_prompt_section,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
tools = _build_processing_tools(fixture.data_types)
|
|
||||||
|
|
||||||
# Scan files in the fixture directory
|
|
||||||
file_entries = await mock._handle(
|
|
||||||
action="list_directory",
|
|
||||||
data={"path": fixture.directory},
|
|
||||||
)
|
|
||||||
for entry in file_entries.get("entries", []):
|
|
||||||
if entry.get("type") != "file":
|
|
||||||
continue
|
|
||||||
# Filter by extension if specified
|
|
||||||
if fixture.file_extensions:
|
|
||||||
ext = entry["name"].rsplit(".", 1)[-1] if "." in entry["name"] else ""
|
|
||||||
if ext not in fixture.file_extensions:
|
|
||||||
continue
|
|
||||||
|
|
||||||
file_result = await mock._handle(
|
|
||||||
action="read_file_content",
|
|
||||||
data={"path": entry["path"]},
|
|
||||||
)
|
|
||||||
file_content: str = file_result.get("content", "")
|
|
||||||
if not file_content.strip():
|
|
||||||
continue
|
|
||||||
|
|
||||||
await _run_agent_with_tools(
|
|
||||||
system_prompt=system_prompt,
|
|
||||||
user_message=(
|
|
||||||
f"Process this file and extract relevant information.\n\n"
|
|
||||||
f"File: {entry['path']}\n\nContent:\n{file_content}"
|
|
||||||
),
|
|
||||||
tools=tools,
|
|
||||||
max_steps=_MAX_PROCESSING_STEPS,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Full runner ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def _run_full(
|
|
||||||
fixture: EvalFixture,
|
|
||||||
model: str,
|
|
||||||
mock: MockExecutor,
|
|
||||||
user_id: str,
|
|
||||||
) -> None:
|
|
||||||
"""Run the full two-step pipeline via ``run_local_agent``."""
|
|
||||||
from app.agent_runner import run_local_agent
|
|
||||||
|
|
||||||
trigger_data: dict[str, Any] = {
|
|
||||||
"type": "agent_trigger",
|
|
||||||
"directory": fixture.directory,
|
|
||||||
"directory_paths": [fixture.directory],
|
|
||||||
"data_types": fixture.data_types,
|
|
||||||
"file_extensions": fixture.file_extensions,
|
|
||||||
"prompt_template": fixture.custom_prompt_section,
|
|
||||||
"device_id": "eval-harness",
|
|
||||||
"run_context": {
|
|
||||||
"agent_id": f"eval-{fixture.name}",
|
|
||||||
"run_id": None,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
with mock.patch():
|
|
||||||
await run_local_agent(user_id, trigger_data)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Scoring helpers ───────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _score_mutations(
|
|
||||||
fixture: EvalFixture,
|
|
||||||
mock: MockExecutor,
|
|
||||||
) -> tuple[list[FieldScore], float, float, float, int, int]:
|
|
||||||
"""Score mutations against expected records.
|
|
||||||
|
|
||||||
Returns (field_scores, precision, recall, f1, extra, missing).
|
|
||||||
"""
|
|
||||||
all_field_scores: list[FieldScore] = []
|
|
||||||
total_expected = 0
|
|
||||||
total_actual = 0
|
|
||||||
total_matched = 0
|
|
||||||
total_extra = 0
|
|
||||||
total_missing = 0
|
|
||||||
|
|
||||||
expected_by_table: dict[str, list[dict]] = {}
|
|
||||||
for rec in fixture.expected:
|
|
||||||
expected_by_table.setdefault(rec.table, []).append(rec.fields)
|
|
||||||
|
|
||||||
tables = set(expected_by_table.keys()) | {m.table for m in mock.mutations}
|
|
||||||
for table in tables:
|
|
||||||
expected_records = expected_by_table.get(table, [])
|
|
||||||
actual_records = mock.created_records(table) + mock.updated_records(table)
|
|
||||||
|
|
||||||
field_scores, extra, missing = score_field_match(expected_records, actual_records, table)
|
|
||||||
all_field_scores.extend(field_scores)
|
|
||||||
|
|
||||||
matched = sum(1 for s in field_scores if s.best_match is not None)
|
|
||||||
total_expected += len(expected_records)
|
|
||||||
total_actual += len(actual_records)
|
|
||||||
total_matched += matched
|
|
||||||
total_extra += extra
|
|
||||||
total_missing += missing
|
|
||||||
|
|
||||||
precision, recall, f1 = compute_precision_recall(total_expected, total_actual, total_matched)
|
|
||||||
return all_field_scores, precision, recall, f1, total_extra, total_missing
|
|
||||||
|
|
||||||
|
|
||||||
# ── Main entry point ──────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
async def run_single_eval(
|
|
||||||
fixture: EvalFixture,
|
|
||||||
model: str,
|
|
||||||
*,
|
|
||||||
use_llm_judge: bool = True,
|
|
||||||
judge_model: str = "gpt-4o-mini",
|
|
||||||
) -> EvalScores:
|
|
||||||
"""Execute one eval run for a fixture + model. Mode is read from the fixture."""
|
|
||||||
from shared.config import settings
|
|
||||||
from shared.ws_context import set_current_user, clear_current_user
|
|
||||||
|
|
||||||
seed = copy.deepcopy(fixture.seed_records)
|
|
||||||
mock = MockExecutor(
|
|
||||||
fixture_dir=fixture.fixture_path.parent,
|
|
||||||
seed_records=seed,
|
|
||||||
)
|
|
||||||
|
|
||||||
original_model = settings.LLM_MODEL
|
|
||||||
settings.LLM_MODEL = model
|
|
||||||
eval_user_id = str(uuid.uuid4())
|
|
||||||
|
|
||||||
logger.info(
|
|
||||||
"eval: starting %s | mode=%s | model=%s",
|
|
||||||
fixture.name, fixture.mode, model,
|
|
||||||
)
|
|
||||||
start_time = time.time()
|
|
||||||
|
|
||||||
step1_results: list[dict[str, Any]] = []
|
|
||||||
step1_reasoning = ""
|
|
||||||
|
|
||||||
try:
|
|
||||||
set_current_user(eval_user_id)
|
|
||||||
|
|
||||||
if fixture.mode == "step1":
|
|
||||||
with mock.patch():
|
|
||||||
step1_results = await _run_step1(fixture, model, mock)
|
|
||||||
|
|
||||||
elif fixture.mode == "step2":
|
|
||||||
with mock.patch():
|
|
||||||
await _run_step2(fixture, model, mock)
|
|
||||||
|
|
||||||
elif fixture.mode == "full":
|
|
||||||
with mock.patch():
|
|
||||||
# Step 1 — classification (independent from run_local_agent)
|
|
||||||
if fixture.expected_classification:
|
|
||||||
step1_results = await _run_step1(fixture, model, mock)
|
|
||||||
|
|
||||||
# Step 2 — full pipeline (run_local_agent handles both steps)
|
|
||||||
await _run_full(fixture, model, mock, eval_user_id)
|
|
||||||
|
|
||||||
except Exception as exc:
|
|
||||||
logger.error("eval: pipeline failed for %s/%s: %s", fixture.name, model, exc)
|
|
||||||
finally:
|
|
||||||
settings.LLM_MODEL = original_model
|
|
||||||
clear_current_user()
|
|
||||||
|
|
||||||
elapsed = time.time() - start_time
|
|
||||||
logger.info("eval: completed in %.1fs — %d mutations", elapsed, len(mock.mutations))
|
|
||||||
|
|
||||||
# ── Score ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
if fixture.mode == "step1":
|
|
||||||
s1_precision, s1_recall, s1_f1, step1_reasoning = _score_step1(fixture, step1_results)
|
|
||||||
scores = EvalScores(
|
|
||||||
fixture_name=fixture.name,
|
|
||||||
model=model,
|
|
||||||
prompt_variant=fixture.mode,
|
|
||||||
precision=s1_precision,
|
|
||||||
recall=s1_recall,
|
|
||||||
f1=s1_f1,
|
|
||||||
llm_judge_reasoning=step1_reasoning,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
# step2 or full — score mutations
|
|
||||||
field_scores, precision, recall, f1, extra, missing = _score_mutations(fixture, mock)
|
|
||||||
scores = EvalScores(
|
|
||||||
fixture_name=fixture.name,
|
|
||||||
model=model,
|
|
||||||
prompt_variant=fixture.mode,
|
|
||||||
field_scores=field_scores,
|
|
||||||
precision=precision,
|
|
||||||
recall=recall,
|
|
||||||
f1=f1,
|
|
||||||
extra_records=extra,
|
|
||||||
missing_records=missing,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Add step1 classification scores for full mode
|
|
||||||
if fixture.mode == "full" and fixture.expected_classification:
|
|
||||||
s1_p, s1_r, s1_f1, step1_reasoning = _score_step1(fixture, step1_results)
|
|
||||||
scores.llm_judge_reasoning = f"Step1 classification:\n{step1_reasoning}"
|
|
||||||
|
|
||||||
# Optional LLM judge for extraction quality
|
|
||||||
if use_llm_judge and fixture.expected:
|
|
||||||
all_expected = [r.fields for r in fixture.expected]
|
|
||||||
all_actual = [m.data for m in mock.mutations if m.action in ("insert", "update")]
|
|
||||||
judge_score, reasoning = await llm_judge_score(
|
|
||||||
all_expected, all_actual, judge_model=judge_model,
|
|
||||||
)
|
|
||||||
scores.llm_judge_score = judge_score
|
|
||||||
if step1_reasoning:
|
|
||||||
scores.llm_judge_reasoning += f"\n\nLLM judge:\n{reasoning}"
|
|
||||||
else:
|
|
||||||
scores.llm_judge_reasoning = reasoning
|
|
||||||
|
|
||||||
# ── Report to Langfuse ────────────────────────────────────────
|
|
||||||
prompt_names = {
|
|
||||||
"step1": ["batch_file_classifier"],
|
|
||||||
"step2": ["batch_processing"],
|
|
||||||
"full": ["batch_file_classifier", "batch_processing"],
|
|
||||||
}.get(fixture.mode, ["batch_processing"])
|
|
||||||
|
|
||||||
trace_id = langfuse_eval.log_eval_trace(
|
|
||||||
fixture_name=fixture.name,
|
|
||||||
model=model,
|
|
||||||
prompt_variant=fixture.mode,
|
|
||||||
prompt_template=fixture.custom_prompt_section or "(default)",
|
|
||||||
actual_mutations=[{"action": m.action, "table": m.table, "data": m.data} for m in mock.mutations],
|
|
||||||
scores_summary=scores.summary(),
|
|
||||||
step1_results=step1_results or None,
|
|
||||||
langfuse_prompt_names=prompt_names,
|
|
||||||
)
|
|
||||||
|
|
||||||
if trace_id:
|
|
||||||
langfuse_eval.post_eval_scores(scores, trace_id=trace_id)
|
|
||||||
|
|
||||||
# For full mode, post classification scores separately
|
|
||||||
if fixture.mode == "full" and fixture.expected_classification:
|
|
||||||
s1_p, s1_r, s1_f1, _ = _score_step1(fixture, step1_results)
|
|
||||||
for name, value in [
|
|
||||||
("classification_precision", s1_p),
|
|
||||||
("classification_recall", s1_r),
|
|
||||||
("classification_f1", s1_f1),
|
|
||||||
]:
|
|
||||||
try:
|
|
||||||
from langfuse import get_client
|
|
||||||
lf = get_client()
|
|
||||||
if lf:
|
|
||||||
lf.create_score(
|
|
||||||
name=name,
|
|
||||||
value=value,
|
|
||||||
trace_id=trace_id,
|
|
||||||
data_type="NUMERIC",
|
|
||||||
comment=f"{fixture.name} | {model} | full",
|
|
||||||
)
|
|
||||||
except Exception:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return scores
|
|
||||||
|
|
||||||
|
|
||||||
async def run_fixture_eval(
|
|
||||||
fixture: EvalFixture,
|
|
||||||
models: list[str],
|
|
||||||
*,
|
|
||||||
use_llm_judge: bool = True,
|
|
||||||
judge_model: str = "gpt-4o-mini",
|
|
||||||
) -> list[EvalScores]:
|
|
||||||
"""Run all models for a fixture."""
|
|
||||||
langfuse_eval.sync_fixture_to_dataset(fixture)
|
|
||||||
|
|
||||||
results: list[EvalScores] = []
|
|
||||||
for model in models:
|
|
||||||
scores = await run_single_eval(
|
|
||||||
fixture, model,
|
|
||||||
use_llm_judge=use_llm_judge,
|
|
||||||
judge_model=judge_model,
|
|
||||||
)
|
|
||||||
results.append(scores)
|
|
||||||
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def print_results(results: list[EvalScores]) -> None:
|
|
||||||
"""Print a formatted summary table of eval results."""
|
|
||||||
if not results:
|
|
||||||
print("\nNo eval results.")
|
|
||||||
return
|
|
||||||
|
|
||||||
W = 90
|
|
||||||
|
|
||||||
print("\n" + "=" * W)
|
|
||||||
print(f"{'Fixture':<25} {'Mode':<6} {'Model':<25} {'P':>6} {'R':>6} {'F1':>6} {'FA':>6} {'LLM':>6}")
|
|
||||||
print("-" * W)
|
|
||||||
|
|
||||||
for s in results:
|
|
||||||
llm_str = f"{s.llm_judge_score:.2f}" if s.llm_judge_score is not None else " --"
|
|
||||||
fa_str = f"{s.field_accuracy:.2f}" if s.field_scores else " --"
|
|
||||||
print(
|
|
||||||
f"{s.fixture_name:<25} {s.prompt_variant:<6} {s.model:<25} "
|
|
||||||
f"{s.precision:>6.2f} {s.recall:>6.2f} {s.f1:>6.2f} "
|
|
||||||
f"{fa_str:>6} {llm_str:>6}"
|
|
||||||
)
|
|
||||||
|
|
||||||
print("=" * W)
|
|
||||||
|
|
||||||
for s in results:
|
|
||||||
if s.llm_judge_reasoning:
|
|
||||||
print(f"\n{'─' * W}")
|
|
||||||
print(f" {s.fixture_name} | {s.model} | {s.prompt_variant}")
|
|
||||||
print(f"{'─' * W}")
|
|
||||||
print(s.llm_judge_reasoning)
|
|
||||||
|
|
||||||
print()
|
|
||||||
@@ -1,268 +0,0 @@
|
|||||||
"""Scoring functions for batch agent evaluation.
|
|
||||||
|
|
||||||
Two scoring strategies:
|
|
||||||
|
|
||||||
1. **FieldMatchScorer** — deterministic check: for each expected record,
|
|
||||||
find the best-matching actual record and compare specified fields.
|
|
||||||
Returns precision, recall, and per-field accuracy.
|
|
||||||
|
|
||||||
2. **LLMJudgeScorer** — uses a secondary LLM to semantically evaluate
|
|
||||||
whether the actual extractions satisfy the expected intent, even if
|
|
||||||
wording differs. Returns a 0-1 score + reasoning.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import json
|
|
||||||
import logging
|
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from difflib import SequenceMatcher
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from langchain_core.messages import HumanMessage, SystemMessage
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
# ── Result types ─────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class FieldScore:
|
|
||||||
"""Score for a single expected record against its best match."""
|
|
||||||
|
|
||||||
expected: dict[str, Any]
|
|
||||||
best_match: dict[str, Any] | None
|
|
||||||
matched_fields: dict[str, bool]
|
|
||||||
similarity: float # 0-1 overall similarity
|
|
||||||
|
|
||||||
@property
|
|
||||||
def field_accuracy(self) -> float:
|
|
||||||
if not self.matched_fields:
|
|
||||||
return 0.0
|
|
||||||
return sum(self.matched_fields.values()) / len(self.matched_fields)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class EvalScores:
|
|
||||||
"""Aggregated scores for one eval run."""
|
|
||||||
|
|
||||||
fixture_name: str
|
|
||||||
model: str
|
|
||||||
prompt_variant: str
|
|
||||||
field_scores: list[FieldScore] = field(default_factory=list)
|
|
||||||
precision: float = 0.0
|
|
||||||
recall: float = 0.0
|
|
||||||
f1: float = 0.0
|
|
||||||
llm_judge_score: float | None = None
|
|
||||||
llm_judge_reasoning: str = ""
|
|
||||||
extra_records: int = 0 # records created but not expected
|
|
||||||
missing_records: int = 0 # expected but not found
|
|
||||||
|
|
||||||
@property
|
|
||||||
def field_accuracy(self) -> float:
|
|
||||||
if not self.field_scores:
|
|
||||||
return 0.0
|
|
||||||
return sum(s.field_accuracy for s in self.field_scores) / len(self.field_scores)
|
|
||||||
|
|
||||||
def summary(self) -> dict[str, Any]:
|
|
||||||
return {
|
|
||||||
"fixture": self.fixture_name,
|
|
||||||
"model": self.model,
|
|
||||||
"prompt_variant": self.prompt_variant,
|
|
||||||
"precision": round(self.precision, 3),
|
|
||||||
"recall": round(self.recall, 3),
|
|
||||||
"f1": round(self.f1, 3),
|
|
||||||
"field_accuracy": round(self.field_accuracy, 3),
|
|
||||||
"llm_judge_score": round(self.llm_judge_score, 3) if self.llm_judge_score is not None else None,
|
|
||||||
"extra_records": self.extra_records,
|
|
||||||
"missing_records": self.missing_records,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Field Match Scorer ───────────────────────────────────────────────────
|
|
||||||
|
|
||||||
|
|
||||||
def _normalize(value: Any) -> str:
|
|
||||||
"""Normalize a value for comparison."""
|
|
||||||
if value is None:
|
|
||||||
return ""
|
|
||||||
return str(value).strip().lower()
|
|
||||||
|
|
||||||
|
|
||||||
def _text_similarity(a: str, b: str) -> float:
|
|
||||||
"""Fuzzy text similarity using SequenceMatcher."""
|
|
||||||
if not a and not b:
|
|
||||||
return 1.0
|
|
||||||
if not a or not b:
|
|
||||||
return 0.0
|
|
||||||
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
|
|
||||||
|
|
||||||
|
|
||||||
def _find_best_match(
|
|
||||||
expected: dict[str, Any],
|
|
||||||
actuals: list[dict[str, Any]],
|
|
||||||
) -> tuple[dict[str, Any] | None, float]:
|
|
||||||
"""Find the actual record most similar to expected, return (match, similarity)."""
|
|
||||||
if not actuals:
|
|
||||||
return None, 0.0
|
|
||||||
|
|
||||||
best_match = None
|
|
||||||
best_score = 0.0
|
|
||||||
|
|
||||||
# Primary matching key: title or name
|
|
||||||
expected_title = _normalize(expected.get("title", expected.get("name", "")))
|
|
||||||
|
|
||||||
for actual in actuals:
|
|
||||||
actual_title = _normalize(actual.get("title", actual.get("name", "")))
|
|
||||||
sim = _text_similarity(expected_title, actual_title)
|
|
||||||
if sim > best_score:
|
|
||||||
best_score = sim
|
|
||||||
best_match = actual
|
|
||||||
|
|
||||||
return best_match, best_score
|
|
||||||
|
|
||||||
|
|
||||||
def _compare_fields(
|
|
||||||
expected: dict[str, Any],
|
|
||||||
actual: dict[str, Any],
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Compare each expected field against the actual record."""
|
|
||||||
results: dict[str, bool] = {}
|
|
||||||
for key, expected_val in expected.items():
|
|
||||||
actual_val = actual.get(key)
|
|
||||||
# Exact match for non-string types
|
|
||||||
if not isinstance(expected_val, str):
|
|
||||||
results[key] = actual_val == expected_val
|
|
||||||
else:
|
|
||||||
# Fuzzy match for strings (threshold: 0.7)
|
|
||||||
results[key] = _text_similarity(
|
|
||||||
_normalize(expected_val), _normalize(actual_val)
|
|
||||||
) >= 0.7
|
|
||||||
return results
|
|
||||||
|
|
||||||
|
|
||||||
def score_field_match(
|
|
||||||
expected_records: list[dict[str, Any]],
|
|
||||||
actual_records: list[dict[str, Any]],
|
|
||||||
table: str,
|
|
||||||
) -> tuple[list[FieldScore], int, int]:
|
|
||||||
"""Score actual extractions against expected records for one table.
|
|
||||||
|
|
||||||
Returns (field_scores, extra_count, missing_count).
|
|
||||||
"""
|
|
||||||
field_scores: list[FieldScore] = []
|
|
||||||
matched_actuals: set[int] = set()
|
|
||||||
|
|
||||||
for exp in expected_records:
|
|
||||||
# Find best match among unmatched actuals
|
|
||||||
candidates = [
|
|
||||||
(i, a) for i, a in enumerate(actual_records) if i not in matched_actuals
|
|
||||||
]
|
|
||||||
if not candidates:
|
|
||||||
field_scores.append(FieldScore(
|
|
||||||
expected=exp, best_match=None, matched_fields={}, similarity=0.0,
|
|
||||||
))
|
|
||||||
continue
|
|
||||||
|
|
||||||
best_idx, best_match = None, None
|
|
||||||
best_sim = 0.0
|
|
||||||
for idx, actual in candidates:
|
|
||||||
_, sim = _find_best_match(exp, [actual])
|
|
||||||
if sim > best_sim:
|
|
||||||
best_sim = sim
|
|
||||||
best_idx = idx
|
|
||||||
best_match = actual
|
|
||||||
|
|
||||||
if best_sim >= 0.5 and best_match is not None:
|
|
||||||
matched_actuals.add(best_idx)
|
|
||||||
matched_fields = _compare_fields(exp, best_match)
|
|
||||||
field_scores.append(FieldScore(
|
|
||||||
expected=exp, best_match=best_match,
|
|
||||||
matched_fields=matched_fields, similarity=best_sim,
|
|
||||||
))
|
|
||||||
else:
|
|
||||||
field_scores.append(FieldScore(
|
|
||||||
expected=exp, best_match=None, matched_fields={}, similarity=0.0,
|
|
||||||
))
|
|
||||||
|
|
||||||
extra_count = len(actual_records) - len(matched_actuals)
|
|
||||||
missing_count = sum(1 for s in field_scores if s.best_match is None)
|
|
||||||
|
|
||||||
return field_scores, extra_count, missing_count
|
|
||||||
|
|
||||||
|
|
||||||
def compute_precision_recall(
|
|
||||||
expected_count: int,
|
|
||||||
actual_count: int,
|
|
||||||
matched_count: int,
|
|
||||||
) -> tuple[float, float, float]:
|
|
||||||
"""Compute precision, recall, F1."""
|
|
||||||
precision = matched_count / actual_count if actual_count > 0 else 0.0
|
|
||||||
recall = matched_count / expected_count if expected_count > 0 else 0.0
|
|
||||||
f1 = (
|
|
||||||
2 * precision * recall / (precision + recall)
|
|
||||||
if (precision + recall) > 0
|
|
||||||
else 0.0
|
|
||||||
)
|
|
||||||
return precision, recall, f1
|
|
||||||
|
|
||||||
|
|
||||||
# ── LLM Judge Scorer ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
_JUDGE_SYSTEM_PROMPT = """\
|
|
||||||
You are an evaluation judge for a data extraction system.
|
|
||||||
|
|
||||||
Your task is to compare the EXPECTED extractions against the ACTUAL extractions
|
|
||||||
produced by an AI agent, and assess quality on a 0-1 scale.
|
|
||||||
|
|
||||||
Scoring criteria:
|
|
||||||
- 1.0: All expected records found with correct fields, no significant extras
|
|
||||||
- 0.8: Most expected records found, minor field differences or extras
|
|
||||||
- 0.6: Core extractions present but some missing or incorrect
|
|
||||||
- 0.4: Partial match — several expected records missing or wrong
|
|
||||||
- 0.2: Poor quality — most expected records missing or incorrect
|
|
||||||
- 0.0: Complete failure — no meaningful overlap
|
|
||||||
|
|
||||||
Consider semantic equivalence: "Send invoice" and "Email the invoice" are matches.
|
|
||||||
Ignore field ordering and formatting differences.
|
|
||||||
|
|
||||||
Respond with ONLY a JSON object:
|
|
||||||
{"score": 0.85, "reasoning": "Brief explanation of the score"}
|
|
||||||
"""
|
|
||||||
|
|
||||||
|
|
||||||
async def llm_judge_score(
|
|
||||||
expected: list[dict[str, Any]],
|
|
||||||
actual: list[dict[str, Any]],
|
|
||||||
*,
|
|
||||||
judge_model: str = "gpt-4o-mini",
|
|
||||||
) -> tuple[float, str]:
|
|
||||||
"""Use an LLM to semantically evaluate extraction quality.
|
|
||||||
|
|
||||||
Returns (score, reasoning).
|
|
||||||
"""
|
|
||||||
from shared.llm import get_llm
|
|
||||||
|
|
||||||
llm = get_llm(model=judge_model, temperature=0)
|
|
||||||
|
|
||||||
user_content = (
|
|
||||||
f"## Expected extractions\n```json\n{json.dumps(expected, indent=2, default=str)}\n```\n\n"
|
|
||||||
f"## Actual extractions\n```json\n{json.dumps(actual, indent=2, default=str)}\n```"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
response = await llm.ainvoke([
|
|
||||||
SystemMessage(content=_JUDGE_SYSTEM_PROMPT),
|
|
||||||
HumanMessage(content=user_content),
|
|
||||||
])
|
|
||||||
raw = response.content.strip()
|
|
||||||
if raw.startswith("```"):
|
|
||||||
raw = raw.split("```")[1]
|
|
||||||
if raw.startswith("json"):
|
|
||||||
raw = raw[4:]
|
|
||||||
parsed = json.loads(raw.strip())
|
|
||||||
return float(parsed.get("score", 0.0)), str(parsed.get("reasoning", ""))
|
|
||||||
except Exception as exc:
|
|
||||||
logger.warning("eval: LLM judge failed: %s", exc)
|
|
||||||
return 0.0, f"Judge error: {exc}"
|
|
||||||
@@ -1,21 +0,0 @@
|
|||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
gunicorn>=22.0.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
pydantic-settings>=2.7.0
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
redis>=5.0.0
|
|
||||||
cryptography>=42.0.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
langchain-core>=0.3.0
|
|
||||||
langchain-openai>=0.3.0
|
|
||||||
langchain-litellm>=0.3.0
|
|
||||||
litellm>=1.50.0
|
|
||||||
openai>=1.50.0
|
|
||||||
httpx>=0.27.0
|
|
||||||
langfuse>=3.0.0
|
|
||||||
croniter>=2.0.0
|
|
||||||
google-api-python-client>=2.130.0
|
|
||||||
google-auth>=2.30.0
|
|
||||||
msal>=1.28.0
|
|
||||||
@@ -1,36 +0,0 @@
|
|||||||
# ── builder ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS builder
|
|
||||||
|
|
||||||
WORKDIR /build
|
|
||||||
|
|
||||||
COPY services/billing/requirements.txt ./requirements.txt
|
|
||||||
RUN pip install --upgrade pip && \
|
|
||||||
pip install --no-cache-dir --prefix=/install -r requirements.txt
|
|
||||||
|
|
||||||
# ── runtime ──────────────────────────────────────────────────────────────────
|
|
||||||
FROM python:3.12-slim AS runtime
|
|
||||||
|
|
||||||
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
|
|
||||||
|
|
||||||
WORKDIR /app
|
|
||||||
|
|
||||||
COPY --from=builder /install /usr/local
|
|
||||||
|
|
||||||
# Shared module
|
|
||||||
COPY shared/ shared/
|
|
||||||
|
|
||||||
# Service source
|
|
||||||
COPY services/billing/app/ app/
|
|
||||||
|
|
||||||
RUN chown -R appuser:appgroup /app
|
|
||||||
|
|
||||||
USER appuser
|
|
||||||
|
|
||||||
EXPOSE 8000
|
|
||||||
|
|
||||||
# Billing is lightweight — single worker is fine
|
|
||||||
CMD ["gunicorn", "app.main:app", \
|
|
||||||
"-k", "uvicorn.workers.UvicornWorker", \
|
|
||||||
"--bind", "0.0.0.0:8000", \
|
|
||||||
"--workers", "1", \
|
|
||||||
"--timeout", "30"]
|
|
||||||
@@ -1,15 +0,0 @@
|
|||||||
# Billing Service
|
|
||||||
|
|
||||||
Owns: Stripe integration, tier management, subscription CRUD.
|
|
||||||
|
|
||||||
## Tables owned (write)
|
|
||||||
- `subscriptions`
|
|
||||||
|
|
||||||
## Endpoints
|
|
||||||
- `POST /billing/checkout`
|
|
||||||
- `POST /billing/webhook` (Stripe, no JWT auth)
|
|
||||||
- `GET /billing/subscription`
|
|
||||||
- `DELETE /billing/subscription`
|
|
||||||
|
|
||||||
## Redis channels
|
|
||||||
- Publish: `tier:changed:{user_id}` on tier change
|
|
||||||
@@ -1,53 +0,0 @@
|
|||||||
"""Billing Service — FastAPI application.
|
|
||||||
|
|
||||||
Owns: Stripe checkout/webhook, subscription management, tier feature matrix,
|
|
||||||
quota enforcement.
|
|
||||||
|
|
||||||
Downstream services query this service (or read the user's tier from
|
|
||||||
the X-User-Tier header injected by Traefik) for billing decisions.
|
|
||||||
The webhook endpoint is exposed WITHOUT ForwardAuth so Stripe can reach it.
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import sys
|
|
||||||
from contextlib import asynccontextmanager
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import AsyncGenerator
|
|
||||||
|
|
||||||
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
|
|
||||||
_repo_root = str(Path(__file__).resolve().parents[3])
|
|
||||||
if _repo_root not in sys.path:
|
|
||||||
sys.path.insert(0, _repo_root)
|
|
||||||
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
|
|
||||||
from app.routes import router
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
|
||||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
|
||||||
logger.info("billing: service started")
|
|
||||||
yield
|
|
||||||
logger.info("billing: service stopped")
|
|
||||||
|
|
||||||
|
|
||||||
app = FastAPI(title="Adiuva Billing Service", lifespan=lifespan)
|
|
||||||
|
|
||||||
app.add_middleware(
|
|
||||||
CORSMiddleware,
|
|
||||||
allow_origins=["*"],
|
|
||||||
allow_methods=["GET", "POST", "DELETE"],
|
|
||||||
allow_headers=["*"],
|
|
||||||
)
|
|
||||||
|
|
||||||
app.include_router(router)
|
|
||||||
|
|
||||||
|
|
||||||
@app.get("/health")
|
|
||||||
async def health() -> dict[str, str]:
|
|
||||||
return {"status": "ok", "service": "billing"}
|
|
||||||
@@ -1,134 +0,0 @@
|
|||||||
"""Billing routes: Stripe checkout, webhook, subscription, tier query.
|
|
||||||
|
|
||||||
Adapted for the Billing microservice:
|
|
||||||
- Authenticated routes use Traefik-injected headers (X-User-Id, X-User-Tier)
|
|
||||||
- Webhook route has NO auth (Stripe signature verification only)
|
|
||||||
- Added /tier/{user_id} for internal service-to-service tier lookups
|
|
||||||
- Added /features/{tier} for feature matrix queries
|
|
||||||
"""
|
|
||||||
|
|
||||||
from __future__ import annotations
|
|
||||||
|
|
||||||
from typing import Any
|
|
||||||
|
|
||||||
from fastapi import APIRouter, Header, HTTPException, Request, status
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from shared.db import async_session
|
|
||||||
from shared.schemas import BillingTier
|
|
||||||
|
|
||||||
from app.stripe_service import stripe_service
|
|
||||||
from app.tier_manager import tier_manager, FEATURES, RATE_LIMITS
|
|
||||||
|
|
||||||
router = APIRouter(prefix="/billing", tags=["billing"])
|
|
||||||
|
|
||||||
|
|
||||||
# ── Request bodies ─────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
class _CheckoutRequest(BaseModel):
|
|
||||||
tier: BillingTier
|
|
||||||
|
|
||||||
|
|
||||||
# ── Checkout ───────────────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.post("/checkout")
|
|
||||||
async def create_checkout(
|
|
||||||
body: _CheckoutRequest,
|
|
||||||
x_user_id: str = Header(..., alias="X-User-Id"),
|
|
||||||
) -> dict[str, str]:
|
|
||||||
"""Create a Stripe checkout session for a tier upgrade."""
|
|
||||||
url = stripe_service.create_checkout_session(x_user_id, body.tier)
|
|
||||||
return {"checkout_url": url}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Webhook (NO auth — Stripe signature only) ─────────────────────────
|
|
||||||
|
|
||||||
@router.post("/webhook")
|
|
||||||
async def stripe_webhook(
|
|
||||||
request: Request,
|
|
||||||
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Handle Stripe webhook events.
|
|
||||||
|
|
||||||
This endpoint is exposed without ForwardAuth in Traefik config
|
|
||||||
so Stripe can reach it directly.
|
|
||||||
"""
|
|
||||||
payload = await request.body()
|
|
||||||
async with async_session() as db:
|
|
||||||
await stripe_service.handle_webhook(payload, stripe_signature, db)
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Subscription CRUD ─────────────────────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/subscription")
|
|
||||||
async def get_subscription(
|
|
||||||
x_user_id: str = Header(..., alias="X-User-Id"),
|
|
||||||
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
|
||||||
) -> dict[str, Any]:
|
|
||||||
"""Return the current subscription info for the authenticated user."""
|
|
||||||
async with async_session() as db:
|
|
||||||
sub = await stripe_service.get_subscription(x_user_id, db)
|
|
||||||
if sub is None:
|
|
||||||
return {
|
|
||||||
"tier": x_user_tier,
|
|
||||||
"status": "free",
|
|
||||||
"stripe_subscription_id": None,
|
|
||||||
"current_period_end": None,
|
|
||||||
}
|
|
||||||
return sub
|
|
||||||
|
|
||||||
|
|
||||||
@router.delete("/subscription")
|
|
||||||
async def cancel_subscription(
|
|
||||||
x_user_id: str = Header(..., alias="X-User-Id"),
|
|
||||||
) -> dict[str, bool]:
|
|
||||||
"""Cancel the active subscription."""
|
|
||||||
async with async_session() as db:
|
|
||||||
await stripe_service.cancel_subscription(x_user_id, db)
|
|
||||||
return {"ok": True}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Tier query (internal, service-to-service) ─────────────────────────
|
|
||||||
|
|
||||||
@router.get("/tier/{user_id}")
|
|
||||||
async def get_user_tier(user_id: str) -> dict[str, str]:
|
|
||||||
"""Return the billing tier for a given user_id.
|
|
||||||
|
|
||||||
Used by other services for tier lookups. Protected by Traefik
|
|
||||||
ForwardAuth — only internal services should call this.
|
|
||||||
"""
|
|
||||||
async with async_session() as db:
|
|
||||||
tier = await tier_manager.get_tier(user_id, db)
|
|
||||||
return {"user_id": user_id, "tier": tier}
|
|
||||||
|
|
||||||
|
|
||||||
# ── Feature matrix (public, cacheable) ────────────────────────────────
|
|
||||||
|
|
||||||
@router.get("/features/{tier}")
|
|
||||||
async def get_tier_features(tier: str) -> dict[str, Any]:
|
|
||||||
"""Return the feature matrix for a tier."""
|
|
||||||
if tier not in FEATURES:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
|
||||||
detail=f"Unknown tier: {tier}",
|
|
||||||
)
|
|
||||||
return {
|
|
||||||
"tier": tier,
|
|
||||||
"features": FEATURES[tier],
|
|
||||||
"rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]),
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
@router.get("/features")
|
|
||||||
async def get_all_features() -> dict[str, Any]:
|
|
||||||
"""Return the full feature matrix for all tiers."""
|
|
||||||
return {
|
|
||||||
"tiers": {
|
|
||||||
tier: {
|
|
||||||
"features": features,
|
|
||||||
"rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]),
|
|
||||||
}
|
|
||||||
for tier, features in FEATURES.items()
|
|
||||||
},
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
fastapi>=0.115.0
|
|
||||||
uvicorn[standard]>=0.34.0
|
|
||||||
gunicorn>=22.0.0
|
|
||||||
pydantic>=2.10.0
|
|
||||||
pydantic-settings>=2.7.0
|
|
||||||
sqlalchemy>=2.0.0
|
|
||||||
asyncpg>=0.30.0
|
|
||||||
python-dotenv>=1.0.0
|
|
||||||
stripe>=8.0.0
|
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user