30 Commits

Author SHA1 Message Date
Roberto Musso
2b7d302ef2 refactor: remove monolith app/, Dockerfile, requirements.txt
All business logic has been extracted into microservices:
  - services/auth/       (Step 1)
  - services/ws-gateway/ (Step 2)
  - services/chat/       (Step 2)
  - services/batch-agent/ (Step 3)
  - services/billing/    (Step 4)

Shared code lives in shared/.
Migrations remain in alembic/.
Tests in tests/ will need updating to target individual services.
2026-04-06 23:44:12 +02:00
Roberto Musso
7f6ea29525 feat(infra): Docker Compose orchestration + env updates (Step 5)
- Replace monolith docker-compose with full microservices stack
- Services: traefik, db, redis, migrate, auth, ws-gateway, chat, batch-agent, billing
- Traefik API gateway with ForwardAuth, ACME/Cloudflare DNS-01 (from Step 2)
- Centralized migrations via 'migrate' service (run-once)
- All services share .env via env_file + override DATABASE_URL/REDIS_URL
- Health checks on db and redis; service dependency ordering
- MinIO and Qdrant kept as optional (commented out)
- .env.example: add JWT_PRIVATE_KEY, CF_DNS_API_TOKEN, ACME_EMAIL, POSTGRES_ vars
2026-04-06 23:40:14 +02:00
Roberto Musso
48036397f1 fix(billing): auto-detect repo root for shared module import in local dev 2026-04-06 23:32:17 +02:00
Roberto Musso
57b5648915 feat(billing): extract Billing Service (Step 4)
- stripe_service: checkout sessions, webhook handling, subscription CRUD
- tier_manager: feature matrix (4 tiers), quota enforcement, rate limits
- routes: checkout, webhook (no auth), subscription, tier query, features
- Traefik header auth (X-User-Id) replaces get_current_user dependency
- /tier/{user_id} endpoint for internal service-to-service lookups
- /features and /features/{tier} for feature matrix queries
- Dockerfile: single worker, 30s timeout (lightweight service)
2026-04-06 23:07:46 +02:00
Roberto Musso
7e4374c69b feat(eval): add custom system prompt support for step-1 classification 2026-04-06 22:56:30 +02:00
Roberto Musso
fe0dd038ee fix: Langfuse SDK v4 migration, tracing improvements, and LLM config
- Langfuse SDK v4: fix prompt-to-trace linking (as_type=generation)
- tracing: compile_prompt with Langfuse managed prompt fallback
- journey: remove journey CLI subcommand (keep only interactive)
- LLM: add service-specific llm modules for batch-agent and chat
- gitignore: exclude eval private test data
- config: add LANGFUSE settings to shared config
2026-03-24 16:25:51 +01:00
Roberto Musso
d3f7099d93 refactor(eval): 3-mode eval harness (step1/step2/full) with Langfuse fixes
- Rewrite eval config with EvalMode (step1, step2, full) replacing prompt_variants
- Rewrite runner with _run_step1, _run_step2, _run_full dispatch
- CLI: replace --variants with --mode flag
- Add 3 fixture YAMLs: classify_invoices (step1), process_invoices (step2), full_invoices (full)
- Remove old freelance_invoices fixture
- Langfuse: mode-aware dataset items (classifications for step1, extraction for step2, both for full)
- Langfuse: link both prompts (batch_file_classifier + batch_processing) in full mode
- Langfuse: post separate classification_precision/recall/f1 scores for full mode
- Langfuse: skip misleading field_accuracy=0 when field_scores is empty (step1)
- Langfuse: include step1_results in trace output
- MockExecutor: mock async_session to bypass DB in full mode
- Journey fixture: remove user_messages (only interactive test kept)
2026-03-24 16:18:51 +01:00
Roberto Musso
63fa119543 feat(batch-agent): add journey eval to E2E harness
- journey_runner.py: orchestrates journey start → simulated user
  messages → template extraction → LLM judge scoring
- config.py: JourneyFixture dataclass with user_messages and
  expected_template_criteria, discover_journey_fixtures()
- langfuse_eval.py: sync_journey_fixture_to_dataset()
- cli.py: new 'journey' subcommand (python -m eval journey)
  with --fixture, --models, --judge-model flags
- fixtures/journey_invoice_setup.yaml: example journey fixture
  with 4 user messages and 8 quality criteria
2026-03-23 23:16:41 +01:00
Roberto Musso
d856dfd28c refactor: deduplicate shared code into shared/ module
Move duplicated files from chat + batch-agent into shared/:
- shared/ws_context.py — Redis-based tool call round-trip
- shared/llm.py — LiteLLM factory (get_llm, embed)
- shared/agents/ — 4 domain agents (task, note, project, timeline)

Update all service imports to use shared.* instead of app.*.
Delete 12 duplicated files across both services.
2026-03-23 23:01:45 +01:00
Roberto Musso
ccba54ac24 fix(tracing): use Langfuse compile_prompt with {{variable}} syntax
- tracing.py: add compile_prompt() that uses Langfuse .compile(**vars)
  for {{variable}} substitution, falls back to Python .format() for
  hardcoded {variable} templates
- agent_runner.py: replace _get_system_prompt().format() with
  tracing.compile_prompt() for batch_file_classifier, batch_processing,
  batch_cloud_processing prompts
- journey.py: replace get_prompt + .format() with compile_prompt()
  for journey_system prompt
- chat tracing.py: add compile_prompt() for parity (chat prompts
  currently have no variables, but ready for future use)
- Remove unused _get_system_prompt helper
2026-03-23 22:39:27 +01:00
Roberto Musso
55500cc818 feat(batch-agent): add Langfuse prompt management
- _get_system_prompt helper: fetches managed prompts from Langfuse
  with hardcoded fallback (same pattern as chat service)
- journey.py: journey_system prompt manageable via Langfuse
- agent_runner.py: batch_file_classifier, batch_processing,
  batch_cloud_processing prompts all manageable via Langfuse
- redis_consumer.py: link_prompt_to_trace for all three handlers
2026-03-23 22:30:36 +01:00
Roberto Musso
75a826c9d8 feat(batch-agent): add E2E evaluation harness with Langfuse integration
- eval/mock_executor.py: intercepts execute_on_client, serves fixture
  files from disk, records all mutations (insert/update/delete)
- eval/config.py: YAML fixture loader with prompt variants, expected
  results, seed records, model overrides
- eval/scorer.py: FieldMatchScorer (fuzzy title match, per-field
  accuracy, precision/recall/F1) + LLMJudgeScorer (semantic eval)
- eval/langfuse_eval.py: sync fixtures to Langfuse datasets, create
  dataset runs, post scores, link traces to runs
- eval/runner.py: orchestrates fixture → mock → agent pipeline →
  scoring → Langfuse reporting
- eval/cli.py: CLI (python -m eval run/list/sync) with --models,
  --variants, --fixture, --no-judge flags
- eval/fixtures/: example Italian freelance scenario with 3 prompt
  variants (baseline, detailed_italian, minimal)
2026-03-23 08:54:19 +01:00
Roberto Musso
971f1dd84f feat(batch-agent): integrate Langfuse tracing
- tracing.py: init/shutdown, trace_span, get_langfuse_callback, prompt mgmt
- main.py: init_langfuse at startup, shutdown on teardown
- redis_consumer.py: trace_span around journey_start/message/agent_trigger
- agent_runner.py: thread langfuse_handler through classify + processing LLM
- journey.py: thread langfuse_handler through _call_llm_with_tools
- llm.py: accept callbacks param, forward to LLM constructors
- requirements.txt: add langfuse>=3.0.0
2026-03-23 08:43:15 +01:00
Roberto Musso
333bba6fdd feat(batch-agent): extract Batch Agent Service (Step 3)
- agent_runner: local directory + cloud agent orchestration via Redis
- 5 domain agents: filesystem, task, note, project, timeline
- integrations: Gmail, MS Graph (Outlook + Teams)
- journey: guided chatbot conversation to build prompt_template
- routes: REST endpoints (catalog, can-create, trigger)
- redis_consumer: subscribes to batch:request:* pattern
- ws_context: Redis-based execute_on_client for tool round-trip
- Dockerfile with 300s timeout for long-running batch jobs
2026-03-23 07:19:02 +01:00
Roberto Musso
229e20d073 docs: add Langfuse integration TODO for batch-agent service 2026-03-23 00:25:42 +01:00
Roberto Musso
0b491b3643 fix: langfuse v4 SDK compatibility and pass user message as trace input 2026-03-23 00:23:59 +01:00
Roberto Musso
0d5fa3e569 feat(chat): integrate Langfuse tracing, prompt management & generation tracking
- shared/config.py: add LANGFUSE_SECRET_KEY, LANGFUSE_PUBLIC_KEY, LANGFUSE_HOST
- services/chat/app/tracing.py: new module — Langfuse client singleton,
  create_trace(), get_langfuse_callback(), get_prompt(), link_prompt_to_trace(),
  score_trace(), flush/shutdown helpers. Gracefully no-ops when keys are missing.
- services/chat/app/llm.py: add callbacks param to get_llm() for LangChain
  callback handler injection
- services/chat/app/deep_agent.py: accept langfuse_handler in all run_* and
  _run_single_agent* functions, pipe callbacks to LLM calls, fetch managed
  prompts from Langfuse with fallback to hardcoded system prompts
- services/chat/app/redis_consumer.py: create Langfuse trace per request
  (home_request/floating_request), pass callback handler to deep_agent,
  link prompt name to trace, attach output preview, flush after each request
- services/chat/app/main.py: shutdown Langfuse client in lifespan teardown
- services/chat/requirements.txt: add langfuse>=2.0.0

Langfuse prompt names: 'home_system', 'floating_system' — create these in
the Langfuse dashboard to manage prompts. Without them, hardcoded defaults
are used transparently.
2026-03-22 23:15:04 +01:00
Roberto Musso
aff68a9051 fix: shared config loads root .env as fallback for microservices 2026-03-22 22:42:54 +01:00
Roberto Musso
5e9ef2809e fix: add extra=ignore to monolith Settings for strangler fig compat 2026-03-22 22:28:50 +01:00
Roberto Musso
90018af311 feat: add WS Gateway and Chat Service (Step 2)
WS Gateway:
- WebSocket lifecycle handler with RS256 JWT auth
- Redis bridge: device registry, frame publishing, tool_result routing
- Inbound routing: tool_result→LPUSH, home/floating→chat pub/sub
- Outbound: subscribes to ws:out:{user_id}, forwards to Electron
- Single-worker Dockerfile (long-lived WS connections)

Chat Service:
- Redis consumer: subscribes to chat:request:* pattern
- Redis-based ws_context: tool_call→publish, BRPOP tool_result (30s timeout)
- deep_agent: single-agent runner with home/floating/stream variants
- memory_middleware: core/associative/episodic/proactive memory with Fernet
- Domain agents: task (8 tools), note (5), project (6), timeline (4)
- LLM factory via LiteLLM (100+ providers)
- Output formatter (StreamFormatter)
- POST /chat REST fallback with Traefik header auth
- Multi-worker Dockerfile with 120s timeout for LLM calls
2026-03-22 01:20:11 +01:00
Roberto Musso
1e2e395676 fix: PEM newline parsing + shared config extra=ignore
- Add field_validator to expand literal \n in PEM keys (auth config + shared config)
- Set extra='ignore' on shared Settings so service-specific .env vars don't cause ValidationError
- Add *.pem to .gitignore
2026-03-22 01:03:28 +01:00
Roberto Musso
59d3a53980 chore: update .env.example files for RS256 + Redis
- Root .env.example: replace JWT_SECRET/JWT_ALGORITHM with JWT_PUBLIC_KEY, add REDIS_URL
- Auth Service .env.example: JWT_PRIVATE_KEY + JWT_PUBLIC_KEY with generation instructions
2026-03-22 00:51:54 +01:00
Roberto Musso
9feeaa79c8 feat(auth): migrate JWT from HS256 to RS256
- Add services/auth/app/config.py with JWT_PRIVATE_KEY and JWT_PUBLIC_KEY
  (Auth Service local config - private key never leaves this service)
- Update routes.py: sign tokens with RS256 private key
- Update deps.py + verify.py: verify tokens with RS256 public key
- Update shared/config.py: replace JWT_SECRET/JWT_ALGORITHM with
  JWT_PUBLIC_KEY (for optional local verification by other services)
- Add sys.path fix in main.py for local dev without PYTHONPATH
2026-03-22 00:50:36 +01:00
Roberto Musso
aa219a4d08 feat: microservices scaffold + Auth Service (Step 1)
- Add shared/ module: config, db, models, schemas, redis utilities
- Add Auth Service (services/auth/): register, login, refresh, me,
  ForwardAuth /verify endpoint for Traefik
- Add Traefik config: ACME/Cloudflare DNS-01, dynamic routing,
  ForwardAuth middleware, sticky sessions for WS Gateway
- Add service scaffolds: ws-gateway, chat, batch-agent, billing (READMEs)
- Add redis>=5.0.0 to requirements.txt
- Monolith app/ is untouched — strangler fig migration
2026-03-22 00:29:51 +01:00
Roberto Musso
552b8eb305 Fix project creation: code-based in runner, not delegated to Step 2 LLM
Root causes fixed:
1. PROJECT_TOOLS removed from Step 2 tool set — project assignment is now
   exclusively handled by the runner in code, never by the LLM.
2. When Step 1 returns "new", runner calls execute_on_client insert/projects
   directly (before Step 2), gets the created id, and passes it as context.
3. Newly created projects are appended to the local `projects` list so that
   subsequent files in the same run can match to them via Step 1 — prevents
   one project per file when multiple files share the same topic.

Also add tests/test_classify_file.py with pytest cases for _classify_file
and a CLI runner: python -m tests.test_classify_file <file> [project...]

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-21 23:40:38 +01:00
Roberto Musso
0d93b3960d Exclude project/projectId questions from agent setup journey
- Add explicit MUST NOT instruction: never ask about projects, projectId,
  or how to link records; project assignment is handled by the agent runner
- Remove projectId from template field list; remove projects from entity types
- Remove stale isApproved=0 reference (already removed from the data model)

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-21 22:58:05 +01:00
Roberto Musso
f07580574b Replace max_turns cap with 90% confidence stopping criterion in agent setup
- Remove fixed _MAX_TURNS=5 instruction from system prompt; LLM now decides
  when to stop based on self-assessed confidence (>= 90%)
- Add _MIN_TURNS_BEFORE_NUDGE=3 and raise safety cap to _MAX_TURNS=15
- Nudge message and hard cap still act as a safety net for infinite loops

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-21 22:54:34 +01:00
Roberto Musso
1a8bf11f90 update migration plan 2026-03-20 23:48:36 +01:00
Roberto Musso
e7cdce8287 Improve Step 1 project matching and Step 2 update-first enforcement
- Rewrite _STEP1_SYSTEM_PROMPT: lower matching threshold (no longer requires
  "clear" match), strongly prefer existing projects over creating new ones,
  use structured id=|name=|status= format with aiSummary for richer context
- Add code-level UUID validation: reject hallucinated ids not in the fetched
  projects list, fall back to "new" instead of creating a bad link
- Rewrite _PROCESSING_SYSTEM_PROMPT: enforce explicit scan-before-create
  process (read existing → search → update if found → create only if not)
  with hard rule against calling create_* without checking existing records

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 23:45:29 +01:00
Roberto Musso
58bc6efd4b Rewrite run_local_agent: code-based flow, concurrency guard, remove isApproved
- Replace LLM-driven triage with code-based directory scan and project fetch
- Two-step LLM approach: Step 1 classifies file→project+domains, Step 2 processes with tools
- Add domain descriptions to Step 1 prompt for better extraction accuracy
- Add _running_agents set for per-agent concurrency guard (one running instance per agent)
- Return 409 from route before DB write when agent already running
- Remove is_approved from task_agent create/update tools and system prompt
- Remove is_approved from timeline_agent create/update tools and system prompt

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-03-20 22:21:30 +01:00
122 changed files with 8245 additions and 4953 deletions

View File

@@ -4,9 +4,19 @@ ENV=dev
# ── Database ────────────────────────────────────────────────────────────────── # ── Database ──────────────────────────────────────────────────────────────────
DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva DATABASE_URL=postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva
# ── Auth ────────────────────────────────────────────────────────────────────── # ── Redis ─────────────────────────────────────────────────────────────────────
JWT_SECRET=replace-with-a-long-random-secret REDIS_URL=redis://localhost:6379/0
JWT_ALGORITHM=HS256
# ── Auth (JWT RS256) ──────────────────────────────────────────────────────────
# Generate keypair:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# openssl rsa -in private.pem -pubout -out public.pem
# Paste PEM content with literal \n for newlines.
#
# Private key — ONLY used by the Auth Service (JWT signing).
JWT_PRIVATE_KEY=
# Public key — used by all services / Traefik ForwardAuth (JWT verification).
JWT_PUBLIC_KEY=
JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30 JWT_ACCESS_TOKEN_EXPIRE_MINUTES=30
JWT_REFRESH_TOKEN_EXPIRE_DAYS=30 JWT_REFRESH_TOKEN_EXPIRE_DAYS=30
@@ -17,7 +27,6 @@ OPENAI_API_KEY=
ANTHROPIC_API_KEY= ANTHROPIC_API_KEY=
GOOGLE_API_KEY= GOOGLE_API_KEY=
LLM_MODEL=gpt-4o LLM_MODEL=gpt-4o
LLM_ROUTER_MODEL=gpt-4o-mini
# ── Stripe (leave empty to stub billing) ────────────────────────────────────── # ── Stripe (leave empty to stub billing) ──────────────────────────────────────
STRIPE_SECRET_KEY= STRIPE_SECRET_KEY=
@@ -42,3 +51,17 @@ QDRANT_API_KEY=
# ── CORS ────────────────────────────────────────────────────────────────────── # ── CORS ──────────────────────────────────────────────────────────────────────
# Comma-separated list parsed by Settings (override default if needed) # Comma-separated list parsed by Settings (override default if needed)
# CORS_ORIGINS=["app://.","http://localhost:3000"] # CORS_ORIGINS=["app://.","http://localhost:3000"]
# ── Langfuse (observability) ─────────────────────────────────────────────────
LANGFUSE_SECRET_KEY=sk-lf-...
LANGFUSE_PUBLIC_KEY=pk-lf-...
LANGFUSE_HOST=https://cloud.langfuse.com # or self-hosted URL
# ── Cloudflare (Traefik ACME DNS-01 challenge) ───────────────────────────────
CF_DNS_API_TOKEN=
ACME_EMAIL=
# ── PostgreSQL (used by docker-compose) ──────────────────────────────────────
POSTGRES_USER=postgres
POSTGRES_PASSWORD=postgres
POSTGRES_DB=adiuva

6
.gitignore vendored
View File

@@ -13,6 +13,9 @@ env/
# Environment variables # Environment variables
.env .env
# Cryptographic keys
*.pem
# IDE # IDE
.vscode/ .vscode/
.idea/ .idea/
@@ -32,3 +35,6 @@ Thumbs.db
# Claude Code # Claude Code
.claude/ .claude/
logs/ logs/
# Eval private test data
services/batch-agent/eval/fixtures/private_data/

View File

@@ -739,7 +739,7 @@ adiuva-api/
│ │ │ │
│ ├── core/ # Orchestration engine │ ├── core/ # Orchestration engine
│ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry │ │ ├── agent_registry.py # BaseAgent, ChatAgent, AgentRegistry
│ │ ├── llm.py # LiteLLM factory (get_llm, get_router_llm) │ │ ├── llm.py # LiteLLM factory (get_llm)
│ │ ├── orchestrator.py # Intent classification & routing │ │ ├── orchestrator.py # Intent classification & routing
│ │ └── execution_plan.py # Plan builder, templates, cache │ │ └── execution_plan.py # Plan builder, templates, cache
│ │ │ │

View File

@@ -1,5 +0,0 @@
"""Expose tool modules used by deep orchestrator-worker graphs."""
from app.agents import filesystem_agent, timeline_agent, note_agent, project_agent, task_agent
__all__ = ["filesystem_agent", "timeline_agent", "note_agent", "project_agent", "task_agent"]

View File

@@ -1,14 +0,0 @@
"""Shared FastAPI dependencies.
``get_current_user`` and ``oauth2_scheme`` live in ``app.api.middleware.auth``
(the canonical location per Step 9). This module re-exports them so that all
existing route imports (``from app.api.deps import get_current_user``) continue
to work without modification.
Step 12 will update ``get_current_user`` to fetch the live tier from PostgreSQL
instead of reading it from the JWT payload.
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme # noqa: F401
__all__ = ["get_current_user", "oauth2_scheme"]

View File

@@ -1,19 +0,0 @@
"""API middleware package.
Exports the three middleware components introduced in Step 9:
- Auth: ``get_current_user`` FastAPI dependency + ``oauth2_scheme``
- Rate limit: ``TierRateLimitMiddleware`` + ``limiter`` (slowapi Limiter)
- Sanitizer: ``SanitizerMiddleware``
"""
from app.api.middleware.auth import get_current_user, oauth2_scheme
from app.api.middleware.rate_limit import TierRateLimitMiddleware, limiter
from app.api.middleware.sanitizer import SanitizerMiddleware
__all__ = [
"get_current_user",
"oauth2_scheme",
"TierRateLimitMiddleware",
"limiter",
"SanitizerMiddleware",
]

View File

@@ -1,129 +0,0 @@
"""Tier-aware rate limiting middleware.
Uses a per-user sliding-window counter (in-process, no Redis required).
The ``slowapi`` Limiter is also exported for optional route-level decoration.
Limits (requests per minute):
- free: 20
- pro: 60
- power: 120
- team: 200
Exempt paths bypass the limiter entirely:
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
- GET /api/v1/health
"""
from __future__ import annotations
import json
import time
from collections import defaultdict
from fastapi import Request, Response
from jose import JWTError, jwt
from slowapi import Limiter
from slowapi.util import get_remote_address
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
from app.config.settings import settings
_TIER_LIMITS: dict[str, int] = {
"free": 20,
"pro": 60,
"power": 120,
"team": 200,
}
_EXEMPT_PATHS: frozenset[str] = frozenset(
{
"/api/v1/auth/register",
"/api/v1/auth/login",
"/api/v1/billing/webhook",
"/api/v1/health",
}
)
def _get_user_id_from_jwt(request: Request) -> str:
"""Key function for the slowapi Limiter: returns JWT sub or remote IP."""
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return get_remote_address(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
return payload.get("sub") or get_remote_address(request)
except JWTError:
return get_remote_address(request)
# Exported Limiter instance — available for optional route-level decoration.
limiter = Limiter(key_func=_get_user_id_from_jwt)
class TierRateLimitMiddleware(BaseHTTPMiddleware):
"""Sliding-window rate limiter applied globally across all non-exempt routes.
Each authenticated user gets their own 60-second window sized by tier.
Unauthenticated requests pass through (the auth dependency will reject them
with 401 before the route handler runs).
"""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
# user_id → list of request timestamps (float, seconds since epoch)
self._window: dict[str, list[float]] = defaultdict(list)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
if request.url.path in _EXEMPT_PATHS:
return await call_next(request)
# Extract JWT claims — if no valid token, pass through for auth dep to handle.
auth = request.headers.get("Authorization", "")
token = auth.removeprefix("Bearer ").strip()
if not token:
return await call_next(request)
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str = payload.get("sub") or get_remote_address(request)
tier: str = payload.get("tier", "free")
except JWTError:
return await call_next(request)
limit = _TIER_LIMITS.get(tier, _TIER_LIMITS["free"])
now = time.monotonic()
window_start = now - 60.0
# Slide the window: discard timestamps older than 60 seconds.
timestamps = [t for t in self._window[user_id] if t > window_start]
if len(timestamps) >= limit:
retry_after = max(1, int(60 - (now - min(timestamps))))
return Response(
content=json.dumps(
{
"detail": (
f"Rate limit exceeded ({limit} req/min for {tier} tier). "
f"Retry in {retry_after}s."
)
}
),
status_code=429,
headers={
"Retry-After": str(retry_after),
"Content-Type": "application/json",
},
)
timestamps.append(now)
self._window[user_id] = timestamps
return await call_next(request)

View File

@@ -1,139 +0,0 @@
"""Response sanitizer middleware.
Scans JSON responses from the /api/v1/chat endpoint and strips any fragments
that could reveal server-side prompt IP:
- System prompt openers ("You are a/an/the …")
- Agent routing metadata ("Available agents:", "intent classifier", …)
- LangChain tool schema fragments (``"type": "function"``)
- Internal reasoning markers (<thinking>, <reasoning>, [INST], …)
- Exact-match known prompt fingerprints
Binary responses (storage blobs, backup data) are never touched — the
middleware only activates for paths under /api/v1/chat.
Any sanitisation event is logged as a WARNING with the request path and the
names of the fields that were modified.
"""
from __future__ import annotations
import json
import logging
import re
from fastapi import Request, Response
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.types import ASGIApp
logger = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# Detection patterns — order matters: fingerprints checked first (exact),
# then compiled regexes.
# ---------------------------------------------------------------------------
_FINGERPRINTS: tuple[str, ...] = (
"You are an intent classifier",
"Respond with just the agent name",
"Summarize these agent results",
"Available agents:",
"route to:",
)
_PATTERNS: tuple[re.Pattern[str], ...] = (
re.compile(r"You are (a|an|the)\b.{0,200}", re.IGNORECASE | re.DOTALL),
re.compile(r"Available agents\s*:", re.IGNORECASE),
re.compile(r"\bintent classifier\b", re.IGNORECASE),
re.compile(r'"type"\s*:\s*"function"'), # LangChain tool schema
re.compile(r"<(thinking|reasoning|system|prompt)>", re.IGNORECASE),
re.compile(r"\[INST\]|\[/INST\]"), # Llama instruct markers
re.compile(r"route\s+to\s*:", re.IGNORECASE),
re.compile(r"prompt_template\s*:\s*['\"].{10,}", re.IGNORECASE),
)
def _sanitize_text(text: str) -> tuple[str, bool]:
"""Scan *text* for prompt fragments and replace matches with ``[REDACTED]``.
Returns ``(cleaned_text, was_changed)``.
"""
# Fingerprint check — if any exact phrase is present, redact the whole string.
for fp in _FINGERPRINTS:
if fp in text:
return "[REDACTED]", True
changed = False
for pattern in _PATTERNS:
new_text, n = pattern.subn("[REDACTED]", text)
if n:
text = new_text
changed = True
return text, changed
class SanitizerMiddleware(BaseHTTPMiddleware):
"""Strip prompt IP from /api/v1/chat JSON responses."""
def __init__(self, app: ASGIApp) -> None:
super().__init__(app)
async def dispatch(self, request: Request, call_next) -> Response: # type: ignore[override]
response: Response = await call_next(request)
# Only process chat endpoint responses.
if not request.url.path.startswith("/api/v1/chat"):
return response
# Read body — collect streaming chunks.
body_bytes = b""
async for chunk in response.body_iterator:
body_bytes += chunk if isinstance(chunk, bytes) else chunk.encode()
# Skip non-JSON bodies (shouldn't happen on /chat, but be safe).
try:
body = json.loads(body_bytes.decode("utf-8"))
except (json.JSONDecodeError, UnicodeDecodeError):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
if not isinstance(body, dict):
return Response(
content=body_bytes,
status_code=response.status_code,
headers=dict(response.headers),
media_type=response.media_type,
)
# Walk top-level string fields and sanitise.
sanitised_fields: list[str] = []
for key, value in body.items():
if isinstance(value, str):
cleaned, changed = _sanitize_text(value)
if changed:
body[key] = cleaned
sanitised_fields.append(key)
if sanitised_fields:
logger.warning(
"Sanitizer redacted prompt fragments",
extra={
"path": request.url.path,
"fields": sanitised_fields,
},
)
new_body = json.dumps(body).encode("utf-8")
headers = dict(response.headers)
headers["content-length"] = str(len(new_body))
return Response(
content=new_body,
status_code=response.status_code,
headers=headers,
media_type="application/json",
)

View File

@@ -1,216 +0,0 @@
"""Agent routes.
Backend responsibilities are intentionally minimal:
GET /agents/catalog — static catalog for UI display
POST /agents/can-create — billing eligibility check
POST /agents/trigger — trigger a local agent run
Agent configuration is owned by the Electron app and is not persisted
in backend agent-config tables.
"""
from __future__ import annotations
import asyncio
import uuid
from datetime import datetime, timedelta, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES
from app.core.agent_runner import run_local_agent
from app.core.device_manager import device_manager
from app.db import get_session
from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import (
AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse,
AgentTriggerRequest,
UserProfile,
)
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Datetime helpers ──────────────────────────────────────────────────
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
return AgentRunLogResponse(
id=log.id,
agent_id=log.agent_id,
agent_type=log.agent_type, # type: ignore[arg-type]
status=log.status, # type: ignore[arg-type]
items_processed=log.items_processed,
items_created=log.items_created,
errors=log.errors or [],
started_at=_dt_ms(log.started_at),
completed_at=_dt_ms_opt(log.completed_at),
)
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog", response_model=list[AgentCatalogItem])
async def get_agent_catalog(
current_user: UserProfile = Depends(get_current_user),
) -> list[AgentCatalogItem]:
"""Return the static list of available agent types and their descriptions."""
return [
AgentCatalogItem(
type="local_directory",
name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI",
),
AgentCatalogItem(
type="gmail",
name="Gmail Connector",
description="Scans Gmail inbox, extracts tasks/notes from emails",
),
AgentCatalogItem(
type="teams",
name="Microsoft Teams Connector",
description="Monitors Teams messages, extracts action items",
),
AgentCatalogItem(
type="outlook",
name="Outlook Connector",
description="Scans Outlook inbox, extracts tasks/notes",
),
]
@router.post("/can-create", response_model=AgentCreationCheckResponse)
async def can_create_agent(
body: AgentCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
) -> AgentCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return AgentCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: AgentTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse:
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id=body.device_id,
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
)
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=current_user.id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_context = {
"type": "agent_batch",
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
)
return _to_run_log_response(run_log)

View File

@@ -1,171 +0,0 @@
"""Backup routes: upload, download, history, and delete E2E-encrypted backups.
Blobs are stored in S3 via BlobStore. Backup metadata is persisted in the
PostgreSQL ``backup_metadata`` table.
IMPORTANT: GET /history must be declared BEFORE GET / to avoid FastAPI
treating "history" as a ``{backup_id}`` path parameter.
"""
from __future__ import annotations
import uuid
from email.utils import parsedate_to_datetime
from fastapi import APIRouter, Depends, Header, HTTPException, Request, Response, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import BackupMetadata as BackupMetadataModel
from app.schemas import BackupMetadata, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/backup", tags=["backup"])
_blob_store = BlobStore()
async def _current_backup_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total backup bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(BackupMetadataModel.size_bytes), 0)).where(
BackupMetadataModel.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_backup_quota(
user: UserProfile, size_bytes: int, db: AsyncSession
) -> None:
"""Raise HTTP 402 if the upload would exceed the tier's backup limit."""
current = await _current_backup_bytes(user.id, db)
tier_manager.enforce_backup_quota(
user.tier, current_bytes=current, additional_bytes=size_bytes
)
@router.put("")
async def upload_backup(
request: Request,
x_backup_version: int = Header(..., alias="X-Backup-Version"),
x_backup_timestamp: int = Header(..., alias="X-Backup-Timestamp"),
x_backup_checksum: str = Header(..., alias="X-Backup-Checksum"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Upload an E2E-encrypted backup blob.
Metadata is passed via custom headers; the raw body is the encrypted blob.
"""
blob = await request.body()
reject_if_tampered(blob, x_backup_checksum)
await _check_backup_quota(current_user, len(blob), db)
s3_key = await _blob_store.upload(
current_user.id, "backup", str(x_backup_timestamp), blob, x_backup_checksum
)
row = BackupMetadataModel(
id=str(uuid.uuid4()),
user_id=current_user.id,
s3_key=s3_key,
version=x_backup_version,
timestamp=x_backup_timestamp,
checksum=x_backup_checksum,
size_bytes=len(blob),
)
db.add(row)
await db.commit()
return {"ok": True}
@router.get("/history", response_model=list[BackupMetadata])
async def backup_history(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[BackupMetadata]:
"""Return backup metadata records for the authenticated user (no blob bytes)."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
)
rows = result.scalars().all()
return [
BackupMetadata(
version=r.version,
timestamp=r.timestamp,
checksum=r.checksum,
chunk_count=1,
)
for r in rows
]
@router.get("")
async def download_backup(
request: Request,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download the latest backup blob. Supports ``If-Modified-Since``."""
result = await db.execute(
select(BackupMetadataModel)
.where(BackupMetadataModel.user_id == current_user.id)
.order_by(BackupMetadataModel.timestamp.desc())
.limit(1)
)
latest = result.scalar_one_or_none()
if latest is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="No backup found")
ims_header = request.headers.get("If-Modified-Since")
if ims_header:
try:
ims_dt = parsedate_to_datetime(ims_header)
ims_ms = int(ims_dt.timestamp() * 1000)
if latest.timestamp <= ims_ms:
return Response(status_code=status.HTTP_304_NOT_MODIFIED)
except Exception:
pass # malformed header — ignore and serve the blob
blob = await _blob_store.download(current_user.id, latest.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={
"X-Backup-Version": str(latest.version),
"X-Backup-Timestamp": str(latest.timestamp),
"X-Checksum": latest.checksum,
},
)
@router.delete("/{backup_id}", response_model=dict)
async def delete_backup(
backup_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a specific backup by ID."""
result = await db.execute(
select(BackupMetadataModel).where(
BackupMetadataModel.id == backup_id,
BackupMetadataModel.user_id == current_user.id,
)
)
target = result.scalar_one_or_none()
if target is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Backup not found")
await _blob_store.delete(current_user.id, target.s3_key)
await db.delete(target)
await db.commit()
return {"ok": True}

View File

@@ -1,85 +0,0 @@
"""Billing routes: Stripe checkout, webhook, subscription management.
Business logic lives in ``app.billing.stripe_service.StripeService``.
The route layer handles HTTP concerns (request parsing, response shaping)
and delegates everything else to the service singleton.
"""
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Depends, Header, Request, status
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.stripe_service import stripe_service
from app.db import get_session
from app.schemas import BillingTier, UserProfile
router = APIRouter(prefix="/billing", tags=["billing"])
# ── Request bodies ─────────────────────────────────────────────────────
class _CheckoutRequest(BaseModel):
tier: BillingTier
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/checkout", response_model=dict)
async def create_checkout(
body: _CheckoutRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, str]:
"""Create a Stripe checkout session for a tier upgrade.
Returns a stub URL when ``STRIPE_SECRET_KEY`` is not configured.
"""
url = stripe_service.create_checkout_session(current_user.id, body.tier)
return {"checkout_url": url}
@router.post("/webhook", response_model=dict)
async def stripe_webhook(
request: Request,
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Handle Stripe webhook events.
No JWT auth — authenticated via Stripe signature verification instead.
Returns 200 immediately when Stripe is not configured (local dev).
"""
payload = await request.body()
await stripe_service.handle_webhook(payload, stripe_signature, db)
return {"ok": True}
@router.get("/subscription", response_model=dict)
async def get_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Return the current subscription info for the authenticated user."""
sub = await stripe_service.get_subscription(current_user.id, db)
if sub is None:
return {
"tier": current_user.tier,
"status": "free",
"stripe_subscription_id": None,
"current_period_end": None,
}
return sub
@router.delete("/subscription", response_model=dict, status_code=status.HTTP_200_OK)
async def cancel_subscription(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Cancel the active subscription."""
await stripe_service.cancel_subscription(current_user.id, db)
return {"ok": True}

View File

@@ -1,29 +0,0 @@
"""Chat routes: POST /chat (REST fallback).
WebSocket chat is handled by the unified device WS endpoint (/api/v1/ws/device).
"""
from __future__ import annotations
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from app.api.deps import get_current_user
from app.core.deep_agent import run_home
from app.schemas import ChatRequest, UserProfile
router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("")
async def chat(
body: ChatRequest,
current_user: UserProfile = Depends(get_current_user),
) -> JSONResponse:
"""REST fallback for home chat when websocket streaming is unavailable."""
response = await run_home(
user_id=current_user.id,
message=body.message,
context=body.context.model_dump(),
)
return JSONResponse(content={"response": response})

View File

@@ -1,417 +0,0 @@
"""Device WebSocket endpoint.
Persistent connection from Electron devices to the backend.
WS /api/v1/ws/device?token=<jwt>
Auth: JWT passed as ``?token=`` query parameter (Bearer header is not
available during the WebSocket handshake).
Protocol:
1. Client connects → JWT validated → connection accepted.
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
3. Backend registers the connection in ``DeviceConnectionManager``.
4. Session enters message dispatch loop + heartbeat.
Incoming frame dispatch:
- ``tool_result`` → resolves a pending tool-call Future.
- ``journey_start`` → starts a guided setup journey session.
- ``journey_message`` → continues a journey conversation.
- ``pong`` → heartbeat acknowledgement (updates last-seen).
- unknown types → logged, ignored.
Outgoing heartbeat: ``{ "type": "ping" }`` every 30 s.
On disconnect:
- Unregisters from DeviceConnectionManager.
- Marks all in-progress AgentRunLog rows for this user as ``error``
with message "device disconnected".
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.core.deep_agent import run_floating_stream, run_home_stream
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog
from app.schemas import WsFrameType
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/ws", tags=["device-ws"])
_HEARTBEAT_INTERVAL = 30 # seconds
_PONG_TIMEOUT = 10 # seconds — grace window after a ping
@router.websocket("/device")
async def device_ws(websocket: WebSocket) -> None:
"""Persistent WebSocket endpoint for Electron device connections.
Authentication is via ``?token=<jwt>`` query parameter.
"""
# ── 1. Authenticate before accepting ─────────────────────────────
token = websocket.query_params.get("token", "")
try:
payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM]
)
user_id: str | None = payload.get("sub")
if not user_id:
raise JWTError("missing sub")
except JWTError:
await websocket.close(code=1008) # Policy Violation
return
await websocket.accept()
# ── 2. Await device_hello frame ───────────────────────────────────
try:
raw = await asyncio.wait_for(websocket.receive_text(), timeout=15.0)
except (asyncio.TimeoutError, WebSocketDisconnect):
await websocket.close(code=1008)
return
try:
hello = json.loads(raw)
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
await websocket.close(code=1008)
return
# ── 3. Register connection ────────────────────────────────────────
device_manager.register(user_id, device_id, websocket)
logger.info(
"device_ws: connected user=%s device=%s agents=%s",
user_id,
device_id,
agent_ids,
)
# Trigger any overdue agent runs now that the device is connected.
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
# ── 4. Concurrent message loop + heartbeat ────────────────────────
try:
await asyncio.gather(
_message_loop(websocket, user_id),
_heartbeat_loop(websocket),
)
except WebSocketDisconnect:
pass
except Exception as exc:
logger.warning("device_ws: unhandled exception user=%s: %s", user_id, exc)
finally:
device_manager.unregister(user_id)
logger.info("device_ws: disconnected user=%s device=%s", user_id, device_id)
await _mark_runs_disconnected(user_id)
# ── Message dispatch loop ─────────────────────────────────────────────
async def _message_loop(websocket: WebSocket, user_id: str) -> None:
"""Receive frames from Electron and dispatch to the appropriate handler."""
async for raw in websocket.iter_text():
try:
frame: dict = json.loads(raw)
except json.JSONDecodeError:
logger.warning("device_ws: invalid JSON from user=%s", user_id)
continue
frame_type = frame.get("type")
if frame_type == WsFrameType.tool_result:
call_id = frame.get("id")
if call_id:
device_manager.resolve_pending_call(user_id, call_id, frame)
else:
logger.warning(
"device_ws: tool_result missing id from user=%s", user_id
)
elif frame_type == WsFrameType.home_request:
asyncio.create_task(
_handle_home_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.floating_request:
asyncio.create_task(
_handle_floating_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_start:
asyncio.create_task(
_handle_journey_start(websocket, user_id, frame)
)
elif frame_type == WsFrameType.journey_message:
asyncio.create_task(
_handle_journey_message(websocket, user_id, frame)
)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
else:
logger.debug(
"device_ws: unknown frame type %r from user=%s", frame_type, user_id
)
# ── v3 Chat Handlers ──────────────────────────────────────────────────
async def _make_ws_executor(websocket: WebSocket, user_id: str):
"""Return a callback that sends tool_call frames and awaits tool_result."""
async def _executor(payload: dict) -> dict:
payload["type"] = WsFrameType.tool_call
await websocket.send_text(json.dumps(payload))
future = device_manager.create_pending_call(user_id, payload["id"])
return await future
return _executor
async def _handle_home_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a home_request frame — streams HomeFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
logger.info(
"device_ws: home_request_start user=%s req=%s session=%s msg=%s",
user_id,
request_id,
session_id,
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_home_stream(user_id, message, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
# Collect text chunks to build the full response for episode storage
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: home_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: home_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
async def _handle_floating_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
logger.info(
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
user_id,
request_id,
session_id,
json.dumps(scope, ensure_ascii=True)[:200],
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id,
message,
trace_id=request_id,
session_id=session_id,
)
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_floating_stream(user_id, message, context)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
if ws_frame.type == "stream_text": # type: ignore[union-attr]
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: floating_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
len("".join(response_chunks)),
)
# ── v4 Journey Handlers ─────────────────────────────────────────────
async def _handle_journey_start(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_start frame — explores directory and sends first question."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_start(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
logger.error(
"device_ws: journey_start failed user=%s: %s", user_id, exc
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": frame.get("session_id", ""),
"message": f"Failed to start journey: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
async def _handle_journey_message(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a journey_message frame — continues the journey conversation."""
executor = await _make_ws_executor(websocket, user_id)
set_client_executor(executor)
try:
reply = await handle_journey_message(user_id, frame)
await websocket.send_text(json.dumps(reply))
except Exception as exc:
session_id = frame.get("session_id", "")
logger.error(
"device_ws: journey_message failed user=%s session=%s: %s",
user_id, session_id, exc,
)
await websocket.send_text(json.dumps({
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey error: {exc}",
"done": True,
"prompt_template": None,
}))
finally:
clear_client_executor()
# ── Heartbeat ─────────────────────────────────────────────────────────
async def _heartbeat_loop(websocket: WebSocket) -> None:
"""Send a ping frame every 30 s to keep the connection alive."""
while True:
await asyncio.sleep(_HEARTBEAT_INTERVAL)
await websocket.send_text(json.dumps({"type": "ping"}))
# ── Disconnect cleanup ────────────────────────────────────────────────
async def _mark_runs_disconnected(user_id: str) -> None:
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
try:
async with async_session() as db:
await db.execute(
update(AgentRunLog)
.where(
AgentRunLog.user_id == user_id,
AgentRunLog.status == "running",
)
.values(
status="error",
errors=["device disconnected"],
)
)
await db.commit()
except Exception as exc:
logger.error(
"device_ws: failed to mark runs as disconnected for user=%s: %s",
user_id,
exc,
)

View File

@@ -1,148 +0,0 @@
"""Plugins routes: browse and install plugins from the marketplace.
Backed by ``PluginRegistry`` and ``RevenueShare`` service classes that
persist data in the PostgreSQL ``plugins`` and ``revenue_events`` tables.
"""
from __future__ import annotations
from typing import Any, Literal
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.db import get_session
from app.marketplace.plugin_registry import registry
from app.marketplace.revenue_share import revenue_share
from app.models import PluginInstallation, PluginReview as PluginReviewModel
from app.schemas import PluginInstallRequest, PluginListResponse, PluginManifest, UserProfile
router = APIRouter(prefix="/plugins", tags=["plugins"])
# ── Tier gate ─────────────────────────────────────────────────────────
def _require_plugin_tier(user: UserProfile) -> None:
"""Raise HTTP 403 for users below Power tier."""
if user.tier not in ("power", "team"):
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="Plugin marketplace requires Power tier or above",
)
# ── Local detail schema ────────────────────────────────────────────────
class _PluginDetail(BaseModel):
plugin: PluginManifest
install_count: int
ratings: list[Any]
# ── Routes ────────────────────────────────────────────────────────────
@router.get("", response_model=PluginListResponse)
async def list_plugins(
category: str | None = Query(default=None),
q: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
sort: Literal["rating", "installs", "newest"] = Query(default="newest"),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> PluginListResponse:
"""Browse the plugin marketplace. Requires Power tier or above."""
_require_plugin_tier(current_user)
return await registry.list_plugins(db, category=category, query=q, page=page, sort=sort)
@router.get("/{plugin_id}", response_model=_PluginDetail)
async def get_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _PluginDetail:
"""Get full plugin details including install count. Requires Power tier or above."""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Fetch review ratings for this plugin
review_result = await db.execute(
select(PluginReviewModel).where(PluginReviewModel.plugin_id == plugin_id)
)
reviews = review_result.scalars().all()
ratings = [
{
"reviewer_id": r.reviewer_id,
"decision": r.decision,
"notes": r.notes,
"reviewed_at": int(r.reviewed_at.timestamp() * 1000) if r.reviewed_at else None,
}
for r in reviews
]
return _PluginDetail(
plugin=entry["manifest"],
install_count=entry["install_count"],
ratings=ratings,
)
@router.post("/{plugin_id}/install", response_model=dict)
async def install_plugin(
plugin_id: str,
body: PluginInstallRequest, # noqa: ARG001 — reserved for future fields
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, Any]:
"""Install a plugin. Triggers Stripe Connect revenue split for paid plugins.
Requires Power tier or above.
"""
_require_plugin_tier(current_user)
entry = await registry.get_plugin(db, plugin_id)
if entry is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Plugin not found")
# Record the installation in plugin_installations
installation = PluginInstallation(
plugin_id=plugin_id,
user_id=current_user.id,
)
db.add(installation)
await db.flush()
await revenue_share.record_install(
db,
plugin_id=plugin_id,
user_id=current_user.id,
amount_cents=entry["manifest"].price_cents,
)
download_url = f"https://cdn.adiuva.app/plugins/{plugin_id}/package.zip"
return {"ok": True, "download_url": download_url}
@router.delete("/{plugin_id}/install", response_model=dict)
async def uninstall_plugin(
plugin_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Unregister a plugin installation."""
result = await db.execute(
select(PluginInstallation).where(
PluginInstallation.plugin_id == plugin_id,
PluginInstallation.user_id == current_user.id,
)
)
installation = result.scalar_one_or_none()
if installation is not None:
await db.delete(installation)
await db.commit()
await registry.record_uninstall(db, plugin_id)
return {"ok": True}

View File

@@ -1,195 +0,0 @@
"""Storage routes: CRUD for E2E-encrypted cloud records.
Blobs are stored in S3 via BlobStore. Record metadata is persisted in the
PostgreSQL ``storage_records`` table.
"""
from __future__ import annotations
import uuid
from fastapi import APIRouter, Depends, HTTPException, Query, Response, status
from pydantic import BaseModel
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import tier_manager
from app.db import get_session
from app.models import StorageRecord
from app.schemas import StorageRecordCreate, StorageRecordUpdate, UserProfile
from app.storage.blob_store import BlobStore
from app.storage.encryption import reject_if_tampered
router = APIRouter(prefix="/storage", tags=["storage"])
_blob_store = BlobStore()
# ── Local response schemas ─────────────────────────────────────────────
class _CreateResponse(BaseModel):
id: str
created_at: int
class _RecordMeta(BaseModel):
id: str
table: str
checksum: str
created_at: int
updated_at: int
# ── Helpers ────────────────────────────────────────────────────────────
async def _current_usage_bytes(user_id: str, db: AsyncSession) -> int:
"""Return total bytes stored by *user_id*."""
result = await db.execute(
select(func.coalesce(func.sum(StorageRecord.size_bytes), 0)).where(
StorageRecord.user_id == user_id
)
)
return int(result.scalar_one())
async def _check_quota(user: UserProfile, additional_bytes: int, db: AsyncSession) -> None:
"""Raise HTTP 402 if adding *additional_bytes* would exceed the tier limit."""
current = await _current_usage_bytes(user.id, db)
tier_manager.enforce_quota(user.tier, current_bytes=current, additional_bytes=additional_bytes)
async def _get_record_for_user(
record_id: str, user_id: str, db: AsyncSession
) -> StorageRecord:
"""Look up a record and verify ownership. Returns 404 on mismatch
to prevent user enumeration attacks."""
result = await db.execute(
select(StorageRecord).where(
StorageRecord.id == record_id, StorageRecord.user_id == user_id
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Record not found")
return record
# ── Routes ─────────────────────────────────────────────────────────────
@router.post("/records", response_model=_CreateResponse, status_code=status.HTTP_201_CREATED)
async def create_record(
body: StorageRecordCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _CreateResponse:
"""Upload a new E2E-encrypted blob. Verifies checksum before storing."""
reject_if_tampered(body.blob, body.checksum)
await _check_quota(current_user, len(body.blob), db)
record_id = str(uuid.uuid4())
s3_key = await _blob_store.upload(
current_user.id, body.table, record_id, body.blob, body.checksum
)
record = StorageRecord(
id=record_id,
user_id=current_user.id,
table_name=body.table,
s3_key=s3_key,
checksum=body.checksum,
size_bytes=len(body.blob),
)
db.add(record)
await db.commit()
await db.refresh(record)
created_at_ms = int(record.created_at.timestamp() * 1000)
return _CreateResponse(id=record_id, created_at=created_at_ms)
@router.get("/records", response_model=list[_RecordMeta])
async def list_records(
table: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=50, ge=1, le=200),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[_RecordMeta]:
"""List record metadata for the authenticated user. Blob bytes are never returned."""
query = select(StorageRecord).where(StorageRecord.user_id == current_user.id)
if table is not None:
query = query.where(StorageRecord.table_name == table)
query = query.offset((page - 1) * limit).limit(limit)
result = await db.execute(query)
rows = result.scalars().all()
return [
_RecordMeta(
id=r.id,
table=r.table_name,
checksum=r.checksum,
created_at=int(r.created_at.timestamp() * 1000),
updated_at=int(r.updated_at.timestamp() * 1000),
)
for r in rows
]
@router.get("/records/{record_id}")
async def download_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> Response:
"""Download an E2E-encrypted blob. Returns raw bytes with ``X-Checksum`` header."""
record = await _get_record_for_user(record_id, current_user.id, db)
blob = await _blob_store.download(current_user.id, record.s3_key)
return Response(
content=blob,
media_type="application/octet-stream",
headers={"X-Checksum": record.checksum},
)
@router.put("/records/{record_id}", response_model=dict)
async def update_record(
record_id: str,
body: StorageRecordUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Replace the blob for an existing record. Verifies checksum before storing."""
record = await _get_record_for_user(record_id, current_user.id, db)
reject_if_tampered(body.blob, body.checksum)
delta = len(body.blob) - record.size_bytes
if delta > 0:
await _check_quota(current_user, delta, db)
s3_key = await _blob_store.upload(
current_user.id, record.table_name, record_id, body.blob, body.checksum
)
record.s3_key = s3_key
record.checksum = body.checksum
record.size_bytes = len(body.blob)
await db.commit()
return {"ok": True}
@router.delete("/records/{record_id}", response_model=dict)
async def delete_record(
record_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a record and its S3 blob."""
record = await _get_record_for_user(record_id, current_user.id, db)
await _blob_store.delete(current_user.id, record.s3_key)
await db.delete(record)
await db.commit()
return {"ok": True}

View File

@@ -1,79 +0,0 @@
"""Vectors routes: upsert, search, delete cloud vector store entries, and embed text."""
from __future__ import annotations
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.core.llm import embed
from app.schemas import (
UserProfile,
VectorSearchRequest,
VectorSearchResponse,
VectorUpsertRequest,
)
from app.storage.encryption import reject_if_tampered
from app.storage.vector_store import VectorStore
router = APIRouter(prefix="/storage", tags=["vectors"])
_vector_store = VectorStore()
class _VectorDeleteRequest(BaseModel):
ids: list[str]
class _EmbedRequest(BaseModel):
text: str
class _EmbedResponse(BaseModel):
vector: list[float]
@router.post("/vectors/upsert", response_model=dict)
async def upsert_vectors(
body: VectorUpsertRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, int]:
"""Verify checksums and store encrypted vectors in the user-scoped namespace."""
for item in body.vectors:
reject_if_tampered(item.blob, item.checksum)
await _vector_store.upsert(current_user.id, body.vectors)
return {"upserted": len(body.vectors)}
@router.post("/vectors/search", response_model=VectorSearchResponse)
async def search_vectors(
body: VectorSearchRequest,
current_user: UserProfile = Depends(get_current_user),
) -> VectorSearchResponse:
"""Search the user-scoped vector namespace with an encrypted query blob."""
results = await _vector_store.search(current_user.id, body.query_blob, body.top_k)
return VectorSearchResponse(results=results)
@router.delete("/vectors", response_model=dict)
async def delete_vectors(
body: _VectorDeleteRequest,
current_user: UserProfile = Depends(get_current_user),
) -> dict[str, bool]:
"""Delete vectors by ID, scoped to the authenticated user."""
await _vector_store.delete(current_user.id, body.ids)
return {"ok": True}
@router.post("/vectors/embed", response_model=_EmbedResponse)
async def embed_text(
body: _EmbedRequest,
current_user: UserProfile = Depends(get_current_user),
) -> _EmbedResponse:
"""Generate a 1536-dim embedding vector for the given text.
Uses ``text-embedding-3-small`` via OpenAI. Auth required (JWT).
Used by backend tools (note_agent) and Electron (vectordb.ts) alike.
"""
vector = await embed(body.text)
return _EmbedResponse(vector=vector)

View File

@@ -1,4 +0,0 @@
from app.billing.stripe_service import stripe_service
from app.billing.tier_manager import tier_manager
__all__ = ["stripe_service", "tier_manager"]

View File

@@ -1,60 +0,0 @@
from typing import Literal
from pydantic_settings import BaseSettings, SettingsConfigDict
class Settings(BaseSettings):
DATABASE_URL: str = "postgresql+asyncpg://postgres:postgres@localhost:5432/adiuva"
JWT_SECRET: str = "change-me-in-production"
JWT_ALGORITHM: str = "HS256"
JWT_ACCESS_TOKEN_EXPIRE_MINUTES: int = 30
JWT_REFRESH_TOKEN_EXPIRE_DAYS: int = 30
STRIPE_SECRET_KEY: str = ""
STRIPE_WEBHOOK_SECRET: str = ""
S3_BUCKET: str = ""
S3_REGION: str = "us-east-1"
S3_ENDPOINT_URL: str = ""
AWS_ACCESS_KEY_ID: str = ""
AWS_SECRET_ACCESS_KEY: str = ""
PINECONE_API_KEY: str = ""
PINECONE_INDEX: str = "adiuva"
QDRANT_URL: str = ""
QDRANT_API_KEY: str = ""
OPENAI_API_KEY: str = ""
ANTHROPIC_API_KEY: str = ""
GOOGLE_API_KEY: str = ""
CEREBRAS_API_KEY: str = ""
LLM_MODEL: str = "gpt-4o"
LLM_ROUTER_MODEL: str = "gpt-4o-mini"
LLM_EMBED_MODEL: str = "text-embedding-3-small"
# GitHub Copilot OAuth token storage directory.
# Leave empty to use the LiteLLM default (~/.config/litellm/github_copilot).
# In Docker, set this to a path backed by a named volume so tokens survive restarts.
GITHUB_COPILOT_TOKEN_DIR: str = ""
# OAuth client credentials — used for Gmail and Microsoft (Outlook/Teams) flows.
GMAIL_CLIENT_ID: str = ""
GMAIL_CLIENT_SECRET: str = ""
MS_CLIENT_ID: str = ""
MS_CLIENT_SECRET: str = ""
# MS_TENANT_ID: set to 'common' to allow multi-tenant (personal + work accounts).
MS_TENANT_ID: str = "common"
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
OAUTH_ENCRYPTION_KEY: str = ""
CORS_ORIGINS: list[str] = ["app://.", "http://localhost:3000", "http://localhost:5173"]
ENV: Literal["dev", "prod"] = "dev"
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
settings = Settings()

View File

@@ -1,30 +0,0 @@
"""Minimal agent base types retained for compatibility with batch runners."""
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any
class BaseAgent(ABC):
"""Common base for non-chat agents still using the old base contract."""
def __init__(
self,
user_id: str = "",
shared_memory: dict[str, Any] | None = None,
vector_store_context: list[str] | None = None,
) -> None:
self.user_id = user_id
self.shared_memory: dict[str, Any] = shared_memory or {}
self.vector_store_context: list[str] = vector_store_context or []
@abstractmethod
def get_name(self) -> str: ...
@abstractmethod
def get_description(self) -> str: ...
@property
def skills(self) -> list[str]:
return []

View File

@@ -1,816 +0,0 @@
"""Agent run orchestrator.
Drives two agent types:
* **Local directory agent** — two-phase execution that mirrors the
``deep_agent.py`` tool-calling pattern. Phase 1 (Triage) explores the
user's directory via file-system tools and groups files by project.
Phase 2 (Processing) reads full file contents and performs CRUD
operations using the standard entity tools (tasks, notes, etc.).
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
Teams, Outlook) and pushes extracted items to Electron.
Usage
-----
Background tasks are spawned with ``asyncio.create_task()``::
asyncio.create_task(run_local_agent(user_id, config, run_log, device_manager))
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
The ``trigger_pending_runs`` function is called by the device WS endpoint
when Electron sends ``device_hello``, so any overdue runs fire immediately
when the device reconnects.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
from croniter import croniter
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from sqlalchemy import select
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.device_manager import DeviceConnectionManager
from app.core.llm import get_llm
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
logger = logging.getLogger(__name__)
# ── Timeouts ───────────────────────────────────────────────────────────────
# Max seconds to wait for a single tool-call round-trip (FE → BE).
_TOOL_CALL_TIMEOUT: int = 30
# Max LLM reasoning steps per phase.
_MAX_TRIAGE_STEPS: int = 10
_MAX_PROCESSING_STEPS: int = 12
# ── Data-type to tool mapping ─────────────────────────────────────────────
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
"tasks": TASK_TOOLS,
"projects": PROJECT_TOOLS,
"notes": NOTE_TOOLS,
"timelines": TIMELINE_TOOLS,
}
# ── Triage prompt ─────────────────────────────────────────────────────────
_TRIAGE_SYSTEM_PROMPT = """\
You are a file triage assistant for a freelance project management tool.
Your job is to explore a local directory on the user's device, understand its
structure, and group files by project context.
You have access to these tools:
- list_directory: to map folder structure
- get_file_metadata: to check creation/modification dates
- read_file_content: to read brief snippets when needed for categorisation
- list_projects / list_all_projects / get_project: to fetch existing projects
from the user's workspace and match files to them
Instructions:
1. Start by calling list_directory on the configured root path.
2. Explore subdirectories as needed to understand the structure.
3. Use get_file_metadata to check modification dates. Skip files that have
NOT been modified since: {last_run_at}.
4. Call list_all_projects to get the user's existing projects.
5. Match files to existing projects by name, folder structure, or content hints.
6. If files don't match any existing project, group them under "standalone".
{custom_prompt_section}
Target entity types to extract: {data_types}
File extensions to consider: {file_extensions}
When you have finished exploring, output ONLY a JSON object (no markdown
fences, no explanation) mapping project IDs or "standalone" to file path
arrays:
{{"<project_id>": ["<file_path>", ...], "standalone": ["<file_path>", ...]}}
Return ONLY the JSON object as your final message.
"""
# ── Processing prompt ─────────────────────────────────────────────────────
_PROCESSING_BASE_PROMPT = """\
You are a data extraction and management assistant for a freelance project
management tool.
Available tools:
Filesystem : read_file_content, list_directory, get_file_metadata
Tasks : list_tasks, create_task, update_task, add_task_comment
Notes : list_notes, get_note, create_note, update_note
Timelines : list_timelines, create_timeline, update_timeline
Projects : list_all_projects, get_project, create_project, update_project
Your task:
1. Read the full content of each file below using read_file_content.
2. For each piece of information found, ALWAYS try to match and update an
existing record before creating a new one.
3. ONLY act on these entity types: {data_types}.
4. Do NOT invent data. Only extract what is clearly present in the files.
5. If a file contains no relevant data for the target entity types, skip it.
Update-first rules (apply in this order):
Tasks:
- Call list_tasks to find a match by title or context.
- If found: call add_task_comment (author "Adiuva"), update_task to set
assignees, state (ToDo / In Progress / Completed), or other fields.
- If NOT found: call create_task with isAiSuggested=1, isApproved=0.
Timelines:
- Call list_timelines to find a match by title or date.
- If found: call update_timeline to edit fields or mark it complete.
- If NOT found: call create_timeline with isAiSuggested=1, isApproved=0.
Notes:
- Call list_notes to find a match by title or topic, then get_note to
read its current content.
- If found: call update_note with the merged content.
- If NOT found: call create_note with isAiSuggested=1, isApproved=0.
Projects:
- Call list_all_projects to check for a match first.
- Only call create_project if the information is clearly significant and
no existing project matches. Set isAiSuggested=1, isApproved=0.
{project_context}
Files to process:
{file_list}
{custom_prompt_section}
After processing all files, respond with a brief summary of what you updated
and what you created.
"""
# ── Cron helper ────────────────────────────────────────────────────────────
def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
"""Return ``True`` if the next scheduled run time has already passed.
Always validates the cron expression first — an invalid expression returns
``False`` (fail-safe: never trigger an unparseable schedule).
"""
try:
now = datetime.now(timezone.utc)
if last_run_at is None:
# Validate the expression before deciding this is overdue.
croniter(schedule_cron, now)
return True
ts = last_run_at
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
cron = croniter(schedule_cron, ts)
next_run: datetime = cron.get_next(datetime)
return now >= next_run
except Exception as exc:
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc)
return False # Fail-safe: don't trigger if expression is invalid.
# ── WS executor for agent context ─────────────────────────────────────────
def _make_agent_executor(
user_id: str,
device_mgr: DeviceConnectionManager,
run_context: dict | None = None,
) -> Any:
"""Create a WS callback for ``set_client_executor()`` so that all tools
can use ``execute_on_client()`` during an agent run.
If *run_context* is provided it is attached to every ``tool_call`` frame
so the Electron client can attribute actions to the correct agent run.
"""
async def _executor(payload: dict) -> dict:
payload["type"] = "tool_call"
if run_context:
payload["run_context"] = run_context
call_id = payload["id"]
fut = device_mgr.create_pending_call(user_id, call_id)
await device_mgr.send_frame(user_id, payload)
return await asyncio.wait_for(fut, timeout=_TOOL_CALL_TIMEOUT)
return _executor
# ── LLM tool-calling loop (mirrors deep_agent._run_single_agent) ──────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _run_agent_with_tools(
*,
system_prompt: str,
user_message: str,
tools: list[Any],
max_steps: int,
) -> str:
"""Run an LLM agent with tool-calling, returning the final text response.
Follows the same pattern as ``deep_agent._run_single_agent``:
bind tools → invoke → handle tool calls → repeat until final text.
"""
llm = get_llm()
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_message),
]
tool_calls_count = 0
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
tool_calls_count += 1
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_runner: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_runner: tool_result name=%s output=%s",
call_name,
str(tool_output)[:1200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max steps, get final response without tools.
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Triage map parser ─────────────────────────────────────────────────────
def _parse_triage_map(raw: str) -> dict[str, list[str]] | None:
"""Extract the JSON triage map from the LLM's final response."""
text = raw.strip()
# Try direct parse first.
try:
parsed = json.loads(text)
if isinstance(parsed, dict):
return {k: v for k, v in parsed.items() if isinstance(v, list)}
except json.JSONDecodeError:
pass
# Try extracting JSON from markdown fences or surrounding text.
import re
match = re.search(r"\{[\s\S]*\}", text)
if match:
try:
parsed = json.loads(match.group(0))
if isinstance(parsed, dict):
return {k: v for k, v in parsed.items() if isinstance(v, list)}
except json.JSONDecodeError:
pass
return None
# ── Tool list builder ─────────────────────────────────────────────────────
def _build_processing_tools(data_types: list[str]) -> list[Any]:
"""Build the tool list for Phase 2 based on user's data_types selection."""
tools: list[Any] = list(FILESYSTEM_TOOLS)
for dt in data_types:
dt_tools = _DATA_TYPE_TOOLS.get(dt)
if dt_tools:
tools.extend(dt_tools)
return tools
# ── Local agent runner (two-phase) ─────────────────────────────────────────
async def run_local_agent(
user_id: str,
config: LocalAgentConfig,
run_log: AgentRunLog,
device_mgr: DeviceConnectionManager,
run_context: dict | None = None,
) -> None:
"""Execute a local directory agent run using two-phase LLM-with-tools.
Phase 1 — Triage:
Explore the directory structure, check metadata, match files to
existing projects. Output: a JSON map of project → file paths.
Phase 2 — Processing:
For each project group, read full file contents and perform CRUD
operations using the standard entity tools.
"""
run_id = run_log.id
# ── Device online check ─────────────────────────────────────────
target_device_id = config.device_id.strip() if isinstance(config.device_id, str) else ""
if target_device_id:
is_online = device_mgr.is_online(user_id, target_device_id)
else:
is_online = device_mgr.is_online(user_id)
if not is_online:
logger.info(
"agent_runner: skip run=%s — device %r offline for user=%s",
run_id,
target_device_id or "<any>",
user_id,
)
await _finalize_run(
run_log,
status="error",
errors=[f"Device {target_device_id or '<any>'!r} is not connected"],
)
return
# ── Set up WS executor for tools ────────────────────────────────
executor = _make_agent_executor(user_id, device_mgr, run_context)
set_client_executor(executor)
errors: list[str] = []
items_processed = 0
items_created = 0
try:
# ── Phase 1: Triage ─────────────────────────────────────────
logger.info("agent_runner: run=%s phase=triage start user=%s", run_id, user_id)
last_run_str = "never (process all files)"
if config.last_run_at:
last_run_str = config.last_run_at.isoformat()
custom_section = ""
if config.prompt_template:
custom_section = f"User instructions:\n{config.prompt_template}"
file_ext_str = ", ".join(config.file_extensions) if config.file_extensions else "all"
triage_prompt = _TRIAGE_SYSTEM_PROMPT.format(
last_run_at=last_run_str,
custom_prompt_section=custom_section,
data_types=", ".join(config.data_types),
file_extensions=file_ext_str,
)
directory_paths = config.directory_paths
triage_user_msg = (
f"Explore these directories and produce the triage map:\n"
f"{json.dumps(directory_paths, ensure_ascii=False)}"
)
triage_tools: list[Any] = list(FILESYSTEM_TOOLS) + list(PROJECT_TOOLS)
triage_response = await _run_agent_with_tools(
system_prompt=triage_prompt,
user_message=triage_user_msg,
tools=triage_tools,
max_steps=_MAX_TRIAGE_STEPS,
)
triage_map = _parse_triage_map(triage_response)
if not triage_map:
errors.append(f"Triage phase failed to produce a valid file map: {triage_response[:500]}")
await _finalize_run(run_log, status="error", errors=errors)
return
logger.info(
"agent_runner: run=%s triage complete groups=%d total_files=%d",
run_id,
len(triage_map),
sum(len(files) for files in triage_map.values()),
)
# ── Phase 2: Processing (per group) ─────────────────────────
processing_tools = _build_processing_tools(config.data_types)
for group_key, file_paths in triage_map.items():
if not file_paths:
continue
logger.info(
"agent_runner: run=%s phase=processing group=%s files=%d",
run_id,
group_key,
len(file_paths),
)
# Build project context for the LLM.
if group_key == "standalone":
project_context = "These files are not associated with any existing project."
else:
project_context = f"These files belong to project ID: {group_key}. Use this project_id when creating records."
file_list_str = "\n".join(f"- {fp}" for fp in file_paths)
processing_prompt = _PROCESSING_BASE_PROMPT.format(
data_types=", ".join(config.data_types),
project_context=project_context,
file_list=file_list_str,
custom_prompt_section=custom_section,
)
items_processed += len(file_paths)
try:
result_text = await _run_agent_with_tools(
system_prompt=processing_prompt,
user_message="Process the listed files now.",
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
)
logger.info(
"agent_runner: run=%s group=%s processing_result=%s",
run_id,
group_key,
result_text[:500],
)
# Count created items by scanning tool call results.
# The tools themselves handle creation; we estimate from the
# summary. A more precise count would require intercepting
# tool results, but the summary is sufficient for the run log.
except Exception as exc:
errors.append(f"Processing error for group '{group_key}': {exc}")
logger.error(
"agent_runner: run=%s group=%s processing failed: %s",
run_id,
group_key,
exc,
)
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
logger.error("agent_runner: run=%s failed: %s", run_id, exc)
finally:
clear_client_executor()
# ── Finalise ────────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log,
status=final_status,
items_processed=items_processed,
items_created=items_created,
errors=errors,
update_config_last_run=False,
config_id=config.id,
config_type="local",
)
logger.info(
"agent_runner: run=%s done status=%s processed=%d errors=%d",
run_id,
final_status,
items_processed,
len(errors),
)
# Notify the Electron client that the run is complete so it can close
# the run record in its local SQLite.
if run_context and device_mgr.is_online(user_id):
try:
await device_mgr.send_frame(user_id, {
"type": "run_complete",
"run_context": run_context,
"status": final_status,
})
except Exception as exc:
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_id, exc)
# ── Cloud agent runner ─────────────────────────────────────────────────────
# Default lookback window when an agent has never run before.
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
async def run_cloud_agent(
user_id: str,
config: CloudAgentConfig,
run_log: AgentRunLog,
device_mgr: DeviceConnectionManager,
) -> None:
"""Execute a cloud connector agent run end-to-end.
Steps:
1. Verify the user's device is online — results are pushed to Electron
via WS tool-call frames. If no device is connected, abort.
2. Decrypt the stored OAuth token from ``config.oauth_token_encrypted``.
3. Instantiate the provider client (Gmail or MS Graph).
4. Fetch messages/emails since ``config.last_run_at`` (or 7 days ago for
the first run) applying ``config.filter_config`` filters.
5. For each message/email call the LLM to extract structured items.
6. Push each item to Electron as an ``insert`` tool-call.
7. If the provider refreshed its access token, re-encrypt and write it
back to ``config.oauth_token_encrypted``.
8. Persist the run outcome via ``_finalize_run``.
"""
run_id = run_log.id
# ── 1. Device online check ─────────────────────────────────────────
if not device_mgr.is_online(user_id):
logger.info(
"agent_runner: skip cloud run=%s — no device online for user=%s",
run_id,
user_id,
)
await _finalize_run(
run_log,
status="error",
errors=["No connected device — cloud agent results cannot be delivered"],
)
return
# ── 2. Decrypt OAuth token ─────────────────────────────────────────
from app.integrations import decrypt_token, encrypt_token, get_provider
if not config.oauth_token_encrypted:
await _finalize_run(
run_log,
status="error",
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
)
return
try:
credentials_info = decrypt_token(config.oauth_token_encrypted)
except ValueError as exc:
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
await _finalize_run(
run_log,
status="error",
errors=[f"Failed to decrypt OAuth token: {exc}"],
)
return
# ── 3. Instantiate provider client ────────────────────────────────
try:
provider = get_provider(config.provider, credentials_info)
except ValueError as exc:
await _finalize_run(
run_log,
status="error",
errors=[str(exc)],
)
return
# ── 4. Fetch messages ─────────────────────────────────────────────
since: datetime | None = config.last_run_at
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
if since.tzinfo is None:
since = since.replace(tzinfo=timezone.utc)
errors: list[str] = []
items_processed = 0
items_created = 0
try:
if config.provider == "gmail":
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
filter_config=config.filter_config,
since=since,
)
elif config.provider == "outlook":
raw_messages = await provider.fetch_emails( # type: ignore[union-attr]
filter_config=config.filter_config,
since=since,
)
elif config.provider == "teams":
raw_messages = await provider.fetch_messages( # type: ignore[union-attr]
filter_config=config.filter_config,
since=since,
)
else:
raw_messages = []
except RuntimeError as exc:
logger.error(
"agent_runner: provider fetch failed for cloud agent %s: %s",
config.id,
exc,
)
await _finalize_run(
run_log,
status="error",
errors=[f"Provider fetch failed: {exc}"],
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
return
logger.info(
"agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
config.id,
len(raw_messages),
config.provider,
user_id,
)
# ── 56. Extract + insert via LLM with tools ─────────────────────
executor = _make_agent_executor(user_id, device_mgr)
set_client_executor(executor)
try:
processing_tools = _build_processing_tools(config.data_types)
custom_section = ""
if config.prompt_template:
custom_section = f"User instructions:\n{config.prompt_template}"
for msg in raw_messages:
content_text = msg.as_text
if not content_text:
continue
items_processed += 1
processing_prompt = _PROCESSING_BASE_PROMPT.format(
data_types=", ".join(config.data_types),
project_context="Determine the appropriate project from the message context.",
file_list=f"Message from {config.provider} (id: {msg.id})",
custom_prompt_section=custom_section,
)
try:
await _run_agent_with_tools(
system_prompt=processing_prompt,
user_message=f"Process this message content:\n\n{content_text[:8000]}",
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
)
except Exception as exc:
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
finally:
clear_client_executor()
# ── 7. Persist refreshed token (if any) ───────────────────────────
refreshed = getattr(provider, "refreshed_credentials", None)
if refreshed:
try:
new_encrypted = encrypt_token(refreshed)
async with async_session() as db:
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
)
cfg_row = cfg_result.scalar_one_or_none()
if cfg_row:
cfg_row.oauth_token_encrypted = new_encrypted
await db.commit()
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id)
except Exception as exc:
logger.warning("agent_runner: failed to persist refreshed token for agent %s: %s", config.id, exc)
# ── 8. Finalise ────────────────────────────────────────────────────
if errors and items_created == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log,
status=final_status,
items_processed=items_processed,
items_created=items_created,
errors=errors,
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
logger.info(
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
run_id,
final_status,
items_processed,
items_created,
len(errors),
)
# ── Pending-run trigger ─────────────────────────────────────────────────────
async def trigger_pending_runs(
user_id: str,
device_id: str,
device_mgr: DeviceConnectionManager,
) -> None:
"""Dispatch any overdue agent runs after an Electron device connects.
Called as a background task from the device WS endpoint on ``device_hello``.
Scheduling rules:
* **Local agents**: only triggered when ``config.device_id == device_id``.
* **Cloud agents**: triggered on any connected device (no device binding).
* Runs execute **sequentially** to avoid flooding the WS connection.
"""
logger.info(
"agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
user_id,
device_id,
)
return
# ── Internal helper ─────────────────────────────────────────────────────────
async def _finalize_run(
run_log: AgentRunLog,
*,
status: str,
items_processed: int = 0,
items_created: int = 0,
errors: list[str] | None = None,
update_config_last_run: bool = False,
config_id: str | None = None,
config_type: str | None = None,
) -> None:
"""Persist the run outcome and optionally update ``LocalAgentConfig.last_run_at``.
Uses a fresh DB session so this is safe to call from background tasks
after the original request session has closed.
"""
now = datetime.now(timezone.utc)
try:
async with async_session() as db:
managed = await db.merge(run_log)
managed.status = status
managed.items_processed = items_processed
managed.items_created = items_created
managed.errors = errors or []
managed.completed_at = now
if update_config_last_run and config_id:
if config_type == "local":
cfg_result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
elif config_type == "cloud":
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
await db.commit()
except Exception as exc:
logger.error(
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc
)

View File

@@ -1,151 +0,0 @@
"""Device connection manager.
Maintains in-memory state for all active Electron → backend WebSocket
connections. One connection per user (latest replaces previous).
The manager handles the **tool-call round-trip** pattern:
- Backend sends ``tool_call`` frame → Electron executes the action →
returns ``tool_result`` frame.
- ``create_pending_call`` registers a Future keyed by ``call_id``.
- ``resolve_pending_call`` fulfils the Future; callers awaiting it
receive the result dict from Electron.
This pattern is used by all tools (CRUD, file-system, etc.) via
``execute_on_client()`` in ``ws_context.py``.
The ``device_manager`` module-level singleton is imported by both the
device WS route and the agent runner.
"""
from __future__ import annotations
import asyncio
import json
import logging
from dataclasses import dataclass, field
from fastapi import WebSocket
logger = logging.getLogger(__name__)
@dataclass
class DeviceConnection:
"""State for a single connected Electron device."""
ws: WebSocket
device_id: str
# Futures indexed by tool_call id — resolved when tool_result arrives.
pending_calls: dict[str, asyncio.Future[dict]] = field(default_factory=dict)
class DeviceConnectionManager:
"""Singleton registry of active Electron WebSocket connections.
Thread/task safety note: asyncio is single-threaded by design. All
mutations happen inside await-points on the main event loop, so no
locking is required for the in-memory dicts.
"""
def __init__(self) -> None:
self._connections: dict[str, DeviceConnection] = {}
# ── Registration ──────────────────────────────────────────────────
def register(self, user_id: str, device_id: str, ws: WebSocket) -> None:
"""Store the active connection for *user_id*, replacing any previous one."""
if user_id in self._connections:
old = self._connections[user_id]
logger.info(
"device_manager: replacing existing connection for user=%s device=%s",
user_id,
old.device_id,
)
# Cancel any futures that were waiting on the old connection.
for fut in old.pending_calls.values():
if not fut.done():
fut.cancel()
self._connections[user_id] = DeviceConnection(ws=ws, device_id=device_id)
logger.info(
"device_manager: registered user=%s device=%s", user_id, device_id
)
def unregister(self, user_id: str) -> None:
"""Remove the connection for *user_id* and cancel any pending futures."""
conn = self._connections.pop(user_id, None)
if conn is None:
return
for fut in conn.pending_calls.values():
if not fut.done():
fut.cancel()
logger.info("device_manager: unregistered user=%s", user_id)
# ── Presence queries ──────────────────────────────────────────────
def get_ws(self, user_id: str) -> WebSocket | None:
"""Return the active WebSocket for *user_id*, or ``None`` if offline."""
conn = self._connections.get(user_id)
return conn.ws if conn else None
def is_online(self, user_id: str, device_id: str | None = None) -> bool:
"""Return ``True`` if the user has an active connection.
If *device_id* is provided also checks that it matches the connected device.
"""
conn = self._connections.get(user_id)
if conn is None:
return False
if device_id is not None:
return conn.device_id == device_id
return True
# ── Frame sending ─────────────────────────────────────────────────
async def send_frame(self, user_id: str, frame: dict) -> None:
"""Send *frame* as a JSON text message to the device.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"send_frame: user {user_id!r} is not connected"
)
await conn.ws.send_text(json.dumps(frame))
# ── Tool-call round-trip ──────────────────────────────────────────
def create_pending_call(
self, user_id: str, call_id: str
) -> asyncio.Future[dict]:
"""Register a Future that will be resolved when the tool_result arrives.
Raises ``RuntimeError`` if the user is not connected.
"""
conn = self._connections.get(user_id)
if conn is None:
raise RuntimeError(
f"create_pending_call: user {user_id!r} is not connected"
)
loop = asyncio.get_event_loop()
fut: asyncio.Future[dict] = loop.create_future()
conn.pending_calls[call_id] = fut
return fut
def resolve_pending_call(
self, user_id: str, call_id: str, result: dict
) -> None:
"""Fulfil the Future registered under *call_id* with the Electron result.
No-ops if the call_id is unknown (already timed out or cancelled).
"""
conn = self._connections.get(user_id)
if conn is None:
return
fut = conn.pending_calls.pop(call_id, None)
if fut is not None and not fut.done():
fut.set_result(result)
# Module-level singleton — import this everywhere.
device_manager = DeviceConnectionManager()

View File

@@ -1,125 +0,0 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Every agent and the orchestrator call ``get_llm()`` or ``get_router_llm()``
instead of directly constructing a provider-specific class. The model string
follows the `LiteLLM model naming convention
<https://docs.litellm.ai/docs/providers>`_:
* OpenAI: ``gpt-4o``, ``gpt-4o-mini``
* Anthropic: ``anthropic/claude-3.5-sonnet``
* Google: ``gemini/gemini-pro``
* Ollama: ``ollama/llama3``
* Bedrock: ``bedrock/anthropic.claude-v2``
Switch providers by changing **LLM_MODEL** / **LLM_ROUTER_MODEL** in ``.env``
— no code changes required.
"""
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from litellm import get_supported_openai_params # noqa: F401 validates install
from app.config.settings import settings
# Some models (e.g. gpt-5, o-series) reject unsupported params like temperature.
# Drop them silently instead of raising UnsupportedParamsError.
litellm.drop_params = True
# Some provider responses include a plain dict in the `usage` field where a
# richer Pydantic model is expected. This warning is noisy but non-fatal.
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
"""Return the most appropriate API key for the given LiteLLM model string."""
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github_copilot/"):
# GitHub Copilot uses OAuth device-flow tokens managed by LiteLLM.
# No API key is required; returning None lets LiteLLM handle auth.
return None
# Default: OpenAI-compatible (covers plain model names like "gpt-4o")
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
"""Return a LangChain chat model backed by LiteLLM.
LiteLLM exposes an OpenAI-compatible API, so we use ``ChatOpenAI`` pointed
at the LiteLLM proxy endpoint. In practice, ``litellm`` patches the
``openai`` client transparently when the model string contains a provider
prefix (``anthropic/…``, ``gemini/…``, etc.).
Parameters
----------
model:
LiteLLM model identifier. Defaults to ``settings.LLM_MODEL``.
temperature:
Sampling temperature. ``0`` = deterministic.
"""
model = model or settings.LLM_MODEL
# Point LiteLLM to the custom token directory when configured.
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
# Use ChatLiteLLM for provider-prefixed models (github_copilot/, anthropic/, etc.)
# so LiteLLM handles routing and auth. ChatOpenAI for plain OpenAI model names.
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
)
def get_router_llm(
*,
temperature: float = 0,
) -> ChatOpenAI | ChatLiteLLM:
"""Return the lighter model used for intent classification / routing."""
return get_llm(model=settings.LLM_ROUTER_MODEL, temperature=temperature)
async def embed(text: str) -> list[float]:
"""Return an embedding vector for *text*.
Uses ``settings.LLM_EMBED_MODEL`` so the same provider switch in ``.env``
(e.g. ``github_copilot/text-embedding-3-small``) applies here without any
code changes. Falls back to the raw AsyncOpenAI client for plain OpenAI
model names to preserve existing behaviour.
"""
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
# Use LiteLLM for all provider-prefixed models (Copilot, Bedrock, etc.)
# so the provider's auth mechanism is applied correctly.
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
# Plain OpenAI model name — use the raw AsyncOpenAI client (existing path).
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

View File

@@ -1,92 +0,0 @@
"""WebSocket client executor context.
Holds a per-request async callback that tools call to execute CRUD
operations on the Electron client's local SQLite / LanceDB databases.
The callback sends a `tool_call` WS frame and awaits the `tool_result`.
"""
from __future__ import annotations
from contextvars import ContextVar
from typing import Any, Callable, Coroutine
from uuid import uuid4
# Holds the execute callback for the current WS session.
# Set by the chat WS handler before the orchestrator runs; cleared after.
_client_executor: ContextVar[Callable[[dict], Coroutine[Any, Any, dict]]] = ContextVar(
"_client_executor"
)
# Optional collector that captures raw execute_on_client results.
# Set by _tool_loop / _tool_loop_stream to populate ChatAgent.tool_results.
_tool_result_collector: ContextVar[list[dict] | None] = ContextVar(
"_tool_result_collector", default=None
)
def set_tool_result_collector(lst: list[dict]) -> None:
"""Register *lst* as the collector for this async context."""
_tool_result_collector.set(lst)
def clear_tool_result_collector() -> None:
"""Clear the collector (best-effort)."""
_tool_result_collector.set(None)
def set_client_executor(fn: Callable[[dict], Coroutine[Any, Any, dict]]) -> None:
"""Bind *fn* as the executor for the current async context (task/coroutine)."""
_client_executor.set(fn)
def clear_client_executor() -> None:
"""Remove the executor binding (best-effort; ContextVar resets on task exit)."""
try:
_client_executor.set(None) # type: ignore[arg-type]
except Exception:
pass
async def execute_on_client(
action: str,
table: str | None = None,
data: dict[str, Any] | None = None,
filters: dict[str, Any] | None = None,
vector: list[float] | None = None,
limit: int | None = None,
) -> dict[str, Any]:
"""Send a CRUD/vector operation to the Electron client and return the result.
Builds a ``tool_call`` payload, invokes the per-session WS callback,
and returns the ``tool_result`` dict from Electron.
Raises ``RuntimeError`` if no executor is set (i.e. called outside a WS session).
"""
callback = _client_executor.get(None)
if callback is None:
raise RuntimeError(
"execute_on_client() called outside a WebSocket session — "
"no client executor is set."
)
payload: dict[str, Any] = {"id": str(uuid4()), "action": action}
if table is not None:
payload["table"] = table
if data is not None:
payload["data"] = data
if filters is not None:
payload["filters"] = {k: v for k, v in filters.items() if v is not None}
if vector is not None:
payload["vector"] = vector
if limit is not None:
payload["limit"] = limit
result = await callback(payload)
collector = _tool_result_collector.get(None)
if collector is not None:
collector.append({
"action": action,
"table": table,
"data": result,
})
return result

View File

@@ -1,72 +0,0 @@
from contextlib import asynccontextmanager
import logging
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
from app.api.middleware.rate_limit import TierRateLimitMiddleware
from app.api.middleware.sanitizer import SanitizerMiddleware
from app.config.settings import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: ensure agent tool modules are loaded.
import app.agents # noqa: F401
yield
# Shutdown: dispose SQLAlchemy connection pool
from app.db import engine
await engine.dispose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Cloud API",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Middleware stack (Starlette inserts at position 0, so last-added = outermost).
# Request flow: TierRateLimit → Sanitizer → CORS → Router
# Response flow: Router → CORS → Sanitizer → TierRateLimit
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, backup, billing, chat, device_ws, plugins, storage, vectors
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(storage.router, prefix="/api/v1")
app.include_router(vectors.router, prefix="/api/v1")
app.include_router(backup.router, prefix="/api/v1")
app.include_router(plugins.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "version": app.version}
return app
app = create_app()

View File

@@ -1,7 +0,0 @@
"""Plugin marketplace package.
Three service classes introduced in Step 10:
- ``PluginRegistry`` — catalog, submit/approve/reject, install counts
- ``ReviewQueue`` — approval workflow + security checklist
- ``RevenueShare`` — 70/30 split tracking and Stripe Connect payouts
"""

View File

@@ -1,212 +0,0 @@
"""Plugin catalog registry backed by PostgreSQL.
Maintains the authoritative list of plugins, their review status, and
aggregate install counts. All data is persisted in the ``plugins`` table.
Module-level singleton::
from app.marketplace.plugin_registry import registry
"""
from __future__ import annotations
import json
from typing import Any, Literal
from sqlalchemy import select, func
from sqlalchemy.ext.asyncio import AsyncSession
from app.models import Plugin
from app.schemas import PluginListResponse, PluginManifest
_PAGE_SIZE = 20
def _plugin_to_manifest(p: Plugin) -> PluginManifest:
"""Convert an ORM ``Plugin`` row to a Pydantic ``PluginManifest``."""
try:
permissions = json.loads(p.permissions) if p.permissions else []
except (json.JSONDecodeError, TypeError):
permissions = []
return PluginManifest(
id=p.id,
name=p.name,
description=p.description,
version=p.version,
author=p.author_name,
permissions=permissions,
category=p.category,
price_cents=p.price_cents,
)
class PluginRegistry:
"""PostgreSQL-backed plugin catalog.
All methods accept an ``AsyncSession`` parameter so the calling route
controls the session lifecycle.
"""
# ── Queries ──────────────────────────────────────────────────────
async def list_plugins(
self,
db: AsyncSession,
category: str | None = None,
query: str | None = None,
page: int = 1,
sort: Literal["rating", "installs", "newest"] = "newest",
) -> PluginListResponse:
"""Return a page of approved plugins, optionally filtered and sorted."""
base = select(Plugin).where(Plugin.status == "approved")
if category:
base = base.where(Plugin.category == category)
if query:
pattern = f"%{query}%"
base = base.where(
Plugin.name.ilike(pattern) | Plugin.description.ilike(pattern)
)
# Count
count_q = select(func.count()).select_from(base.subquery())
total = (await db.execute(count_q)).scalar_one()
# Sort
if sort == "installs":
base = base.order_by(Plugin.install_count.desc())
elif sort == "rating":
base = base.order_by(Plugin.avg_rating.desc())
else: # newest
base = base.order_by(Plugin.created_at.desc())
base = base.offset((page - 1) * _PAGE_SIZE).limit(_PAGE_SIZE)
rows = (await db.execute(base)).scalars().all()
return PluginListResponse(
plugins=[_plugin_to_manifest(r) for r in rows],
total=total,
page=page,
)
async def get_plugin(self, db: AsyncSession, plugin_id: str) -> dict[str, Any] | None:
"""Return ``{manifest, status, install_count, avg_rating}`` or ``None``."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
p = result.scalar_one_or_none()
if p is None:
return None
return {
"manifest": _plugin_to_manifest(p),
"status": p.status,
"install_count": p.install_count,
"avg_rating": p.avg_rating,
}
# ── Mutations ────────────────────────────────────────────────────
async def submit_plugin(
self,
db: AsyncSession,
manifest: PluginManifest,
package_s3_key: str,
) -> str:
"""Add *manifest* to the catalog with ``status='pending_review'``.
Returns the plugin_id. If a plugin with the same id already exists
it is overwritten (re-submission after rejection).
"""
plugin_id = manifest.id
existing = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = existing.scalar_one_or_none()
if row is not None:
row.name = manifest.name
row.description = manifest.description
row.version = manifest.version
row.author_name = manifest.author
row.category = manifest.category
row.price_cents = manifest.price_cents
row.permissions = json.dumps(manifest.permissions)
row.status = "pending_review"
row.s3_package_key = package_s3_key
row.rejection_reason = None
else:
row = Plugin(
id=plugin_id,
name=manifest.name,
description=manifest.description,
version=manifest.version,
author_name=manifest.author,
category=manifest.category,
price_cents=manifest.price_cents,
permissions=json.dumps(manifest.permissions),
status="pending_review",
s3_package_key=package_s3_key,
install_count=0,
avg_rating=0.0,
)
db.add(row)
await db.commit()
return plugin_id
async def approve_plugin(self, db: AsyncSession, plugin_id: str) -> None:
"""Set *plugin_id* status to ``'approved'``.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "approved"
row.rejection_reason = None
await db.commit()
async def reject_plugin(self, db: AsyncSession, plugin_id: str, reason: str) -> None:
"""Set *plugin_id* status to ``'rejected'`` and record the reason.
Raises ``KeyError`` if the plugin is not found.
"""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is None:
raise KeyError(f"Plugin not found: {plugin_id}")
row.status = "rejected"
row.rejection_reason = reason
await db.commit()
async def record_install(self, db: AsyncSession, plugin_id: str) -> None:
"""Increment the install count for *plugin_id* (no-op if not found)."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = row.install_count + 1
await db.commit()
async def record_uninstall(self, db: AsyncSession, plugin_id: str) -> None:
"""Decrement the install count for *plugin_id*, floored at 0."""
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
row = result.scalar_one_or_none()
if row is not None:
row.install_count = max(0, row.install_count - 1)
await db.commit()
# ── Internal helpers used by ReviewQueue ─────────────────────────
async def get_pending_entries(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all entries with status='pending_review'."""
result = await db.execute(
select(Plugin).where(Plugin.status == "pending_review")
)
rows = result.scalars().all()
return [
{
"manifest": _plugin_to_manifest(r),
"submitted_at": int(r.submitted_at.timestamp()) if r.submitted_at else 0,
}
for r in rows
]
# Module-level singleton
registry = PluginRegistry()

View File

@@ -1,125 +0,0 @@
"""Plugin review workflow backed by PostgreSQL.
Manages the approval queue for newly submitted plugins and enforces a
security checklist before any plugin is made visible in the marketplace.
Module-level singleton::
from app.marketplace.plugin_review import review_queue
"""
from __future__ import annotations
import re
from typing import Any, Literal
from sqlalchemy.ext.asyncio import AsyncSession
from app.marketplace.plugin_registry import registry
from app.models import PluginReview as PluginReviewModel
from app.schemas import PluginManifest
# ── Security policy ───────────────────────────────────────────────────
ALLOWED_PERMISSIONS: frozenset[str] = frozenset(
{
"read:tasks",
"write:tasks",
"read:projects",
"write:projects",
"read:notes",
"write:notes",
"read:timelines",
"write:timelines",
"read:calendar",
"write:calendar",
}
)
_PLUGIN_ID_RE = re.compile(r"^[a-z0-9-]+$")
def validate_manifest(manifest: PluginManifest) -> None:
"""Enforce the plugin security checklist.
Raises:
``ValueError`` on the first violation found. Callers should catch
this and return HTTP 422 / reject the submission.
Checks:
1. Plugin id matches ``^[a-z0-9-]+$``
2. All declared permissions are in ``ALLOWED_PERMISSIONS``
3. No manifest field contains raw binary data
"""
if not _PLUGIN_ID_RE.match(manifest.id):
raise ValueError(
f"Invalid plugin id format: '{manifest.id}'. "
"Only lowercase letters, digits, and hyphens are allowed."
)
for perm in manifest.permissions:
if perm not in ALLOWED_PERMISSIONS:
raise ValueError(
f"Unknown permission: '{perm}'. "
f"Allowed permissions: {sorted(ALLOWED_PERMISSIONS)}"
)
for field_name, value in manifest.model_dump().items():
if isinstance(value, (bytes, bytearray)):
raise ValueError(
f"Binary content is not allowed in manifest field '{field_name}'."
)
class ReviewQueue:
"""Approval queue for pending plugin submissions.
Delegates status changes to the shared ``PluginRegistry`` singleton.
Review records are persisted in the ``plugin_reviews`` table.
"""
async def get_pending(self, db: AsyncSession) -> list[dict[str, Any]]:
"""Return all plugins currently awaiting review.
Each item is ``{plugin_id, manifest, submitted_at}``.
"""
entries = await registry.get_pending_entries(db)
return [
{
"plugin_id": e["manifest"].id,
"manifest": e["manifest"],
"submitted_at": e["submitted_at"],
}
for e in entries
]
async def submit_review(
self,
db: AsyncSession,
plugin_id: str,
reviewer_id: str,
decision: Literal["approved", "rejected"],
notes: str = "",
) -> None:
"""Record a review decision and update the plugin's status.
Raises:
``KeyError`` if *plugin_id* is not found in the registry.
"""
if decision == "approved":
await registry.approve_plugin(db, plugin_id)
else:
await registry.reject_plugin(db, plugin_id, reason=notes)
review = PluginReviewModel(
plugin_id=plugin_id,
reviewer_id=reviewer_id,
decision=decision,
notes=notes,
)
db.add(review)
await db.commit()
# Module-level singleton
review_queue = ReviewQueue()

View File

@@ -1,233 +0,0 @@
"""Revenue share tracking and Stripe Connect payouts backed by PostgreSQL.
Records every plugin installation as a revenue event and facilitates
70 % / 30 % payouts to developers via Stripe Connect. Data is persisted
in the ``revenue_events`` table.
Module-level singleton::
from app.marketplace.revenue_share import revenue_share
"""
from __future__ import annotations
import logging
from datetime import datetime, timezone
from typing import Any
import stripe as stripe_lib
from sqlalchemy import extract, func, select
from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings
from app.marketplace.plugin_registry import registry
from app.models import Plugin, RevenueEvent
logger = logging.getLogger(__name__)
# ── Revenue split constants ───────────────────────────────────────────
DEVELOPER_SHARE: float = 0.70
PLATFORM_SHARE: float = 0.30
class RevenueShare:
"""Records installation revenue events and coordinates developer payouts.
Stripe Connect calls are gracefully stubbed when ``STRIPE_SECRET_KEY``
is not configured, consistent with the rest of the billing layer.
"""
# ── Helpers ──────────────────────────────────────────────────────
@staticmethod
def _stripe_configured() -> bool:
return bool(settings.STRIPE_SECRET_KEY)
@staticmethod
def _stripe() -> Any:
stripe_lib.api_key = settings.STRIPE_SECRET_KEY
return stripe_lib
# ── Core operations ──────────────────────────────────────────────
async def record_install(
self,
db: AsyncSession,
plugin_id: str,
user_id: str,
amount_cents: int,
) -> None:
"""Record a plugin installation and trigger a Stripe Connect charge if paid.
For free plugins (``amount_cents == 0``) no payment is initiated but
the event is still recorded for analytics.
For paid plugins the developer receives 70 % via a Stripe Connect
destination charge. If Stripe is not configured or the charge fails
the installation still succeeds (the event is recorded and the install
count is incremented) — a warning is logged for monitoring.
"""
developer_share_cents = int(amount_cents * DEVELOPER_SHARE)
stripe_transfer_id: str | None = None
if amount_cents > 0 and self._stripe_configured():
# Look up the plugin's author Stripe account from the DB
result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = result.scalar_one_or_none()
developer_stripe_account: str | None = None
if plugin_row and plugin_row.author_id:
# Future: look up user.stripe_connect_account_id
developer_stripe_account = None # no real account yet
if developer_stripe_account:
try:
s = self._stripe()
transfer = s.Transfer.create(
amount=developer_share_cents,
currency="eur",
destination=developer_stripe_account,
description=f"Revenue share for plugin {plugin_id}",
metadata={"plugin_id": plugin_id, "user_id": user_id},
)
stripe_transfer_id = transfer["id"]
except Exception as exc:
logger.warning(
"Stripe Connect transfer failed for plugin %s: %s",
plugin_id,
exc,
)
else:
logger.debug(
"No Stripe account on file for plugin %s developer; "
"skipping transfer.",
plugin_id,
)
event = RevenueEvent(
plugin_id=plugin_id,
user_id=user_id,
amount_cents=amount_cents,
developer_share_cents=developer_share_cents,
stripe_transfer_id=stripe_transfer_id,
)
db.add(event)
await db.commit()
await registry.record_install(db, plugin_id)
async def get_earnings(
self,
db: AsyncSession,
developer_id: str,
period: str | None = None,
) -> dict[str, Any]:
"""Return aggregated earnings for *developer_id*.
``period`` is an optional ``YYYY-MM`` string to restrict the window.
Returns::
{
"developer_id": str,
"period": str | None,
"total_installs": int,
"total_revenue_cents": int,
"developer_share_cents": int,
}
"""
# Find plugin ids belonging to this developer (by author_name match)
plugin_q = select(Plugin.id).where(Plugin.author_name == developer_id)
plugin_result = await db.execute(plugin_q)
developer_plugin_ids = [row[0] for row in plugin_result.all()]
if not developer_plugin_ids:
return {
"developer_id": developer_id,
"period": period,
"total_installs": 0,
"total_revenue_cents": 0,
"developer_share_cents": 0,
}
query = select(
func.count().label("total_installs"),
func.coalesce(func.sum(RevenueEvent.amount_cents), 0).label("total_revenue"),
func.coalesce(func.sum(RevenueEvent.developer_share_cents), 0).label("dev_share"),
).where(RevenueEvent.plugin_id.in_(developer_plugin_ids))
if period:
# Filter by YYYY-MM: extract year and month from created_at
try:
year, month = period.split("-")
query = query.where(
extract("year", RevenueEvent.created_at) == int(year),
extract("month", RevenueEvent.created_at) == int(month),
)
except ValueError:
pass # invalid period format — return all
result = await db.execute(query)
row = result.one()
return {
"developer_id": developer_id,
"period": period,
"total_installs": row.total_installs,
"total_revenue_cents": row.total_revenue,
"developer_share_cents": row.dev_share,
}
async def payout_developer(self, db: AsyncSession, plugin_id: str, period: str) -> None:
"""Aggregate unpaid revenue for *period* and issue a Stripe Transfer.
Marks processed events with ``paid_at`` timestamp.
Stubs gracefully when Stripe is not configured.
"""
try:
year, month = period.split("-")
year_int, month_int = int(year), int(month)
except ValueError:
logger.warning("Invalid period format: %s", period)
return
result = await db.execute(
select(RevenueEvent).where(
RevenueEvent.plugin_id == plugin_id,
RevenueEvent.paid_at.is_(None),
extract("year", RevenueEvent.created_at) == year_int,
extract("month", RevenueEvent.created_at) == month_int,
)
)
unpaid = list(result.scalars().all())
total_dev_share = sum(e.developer_share_cents for e in unpaid)
if total_dev_share <= 0 or not unpaid:
logger.debug("Nothing to pay out for plugin %s in period %s", plugin_id, period)
return
if self._stripe_configured():
plugin_result = await db.execute(select(Plugin).where(Plugin.id == plugin_id))
plugin_row = plugin_result.scalar_one_or_none()
developer_stripe_account: str | None = None # Future: fetch from DB
if plugin_row and developer_stripe_account:
try:
s = self._stripe()
s.Transfer.create(
amount=total_dev_share,
currency="eur",
destination=developer_stripe_account,
description=f"Payout for plugin {plugin_id} period {period}",
)
except Exception as exc:
logger.warning("Payout transfer failed for plugin %s: %s", plugin_id, exc)
return
paid_ts = datetime.now(timezone.utc)
for event in unpaid:
event.paid_at = paid_ts
await db.commit()
# Module-level singleton
revenue_share = RevenueShare()

View File

@@ -1 +0,0 @@
"""Cloud storage layer — E2E encrypted blobs and vectors."""

View File

@@ -1,106 +0,0 @@
"""S3-backed store for E2E-encrypted blobs.
Keys are structured as ``{user_id}/{table}/{record_id}``.
The backend never inspects blob content — it stores and retrieves opaque bytes.
"""
from __future__ import annotations
from typing import Any
import boto3
from app.config.settings import settings
class BlobStore:
"""Thin wrapper around boto3 S3.
All blobs must be E2E encrypted by the client before upload.
The backend adds SSE-S3 as an extra layer of at-rest encryption
but cannot decrypt the inner client-side payload.
"""
def _client(self) -> Any:
kwargs: dict[str, Any] = {
"region_name": settings.S3_REGION,
"aws_access_key_id": settings.AWS_ACCESS_KEY_ID,
"aws_secret_access_key": settings.AWS_SECRET_ACCESS_KEY,
}
if settings.S3_ENDPOINT_URL and isinstance(settings.S3_ENDPOINT_URL, str):
kwargs["endpoint_url"] = settings.S3_ENDPOINT_URL
return boto3.client("s3", **kwargs)
@staticmethod
def _key(user_id: str, table: str, record_id: str) -> str:
return f"{user_id}/{table}/{record_id}"
async def upload(
self,
user_id: str,
table: str,
record_id: str,
blob: bytes,
checksum: str,
) -> str:
"""Store *blob* in S3 and return the S3 key.
Args:
user_id: Owner of the blob (used as key prefix).
table: Logical table name (e.g. ``"tasks"``).
record_id: Record UUID.
blob: Raw bytes (pre-encrypted by client).
checksum: SHA-256 hex digest supplied by the client; stored as
object metadata for download-time verification.
Returns:
The S3 key under which the blob was stored.
"""
key = self._key(user_id, table, record_id)
self._client().put_object(
Bucket=settings.S3_BUCKET,
Key=key,
Body=blob,
ServerSideEncryption="AES256", # SSE-S3 at rest
Metadata={"checksum": checksum},
)
return key
async def download(self, user_id: str, s3_key: str) -> bytes:
"""Retrieve the blob stored at *s3_key*.
*user_id* is retained in the signature so higher-level code can
enforce ownership without re-parsing the key.
Raises:
``botocore.exceptions.ClientError`` with code ``NoSuchKey`` if the
object does not exist.
"""
response = self._client().get_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
return response["Body"].read()
async def delete(self, user_id: str, s3_key: str) -> None:
"""Delete the object at *s3_key*.
S3 ``delete_object`` is idempotent — it succeeds even if the key does
not exist.
"""
self._client().delete_object(
Bucket=settings.S3_BUCKET,
Key=s3_key,
)
async def list_keys(self, user_id: str, table: str) -> list[str]:
"""Return all S3 keys for a given user + table combination.
Uses the prefix ``{user_id}/{table}/`` to scope the listing.
"""
prefix = f"{user_id}/{table}/"
response = self._client().list_objects_v2(
Bucket=settings.S3_BUCKET,
Prefix=prefix,
)
return [obj["Key"] for obj in response.get("Contents", [])]

View File

@@ -1,32 +0,0 @@
"""Integrity verification only — the backend NEVER decrypts user data."""
from __future__ import annotations
import hashlib
import hmac
from fastapi import HTTPException
def verify_checksum(blob: bytes, checksum: str) -> bool:
"""Return ``True`` if SHA-256(blob) matches *checksum*.
Uses ``hmac.compare_digest`` for constant-time comparison to prevent
timing-based side-channel attacks.
"""
computed = hashlib.sha256(blob).hexdigest()
return hmac.compare_digest(computed, checksum)
def reject_if_tampered(blob: bytes, checksum: str) -> None:
"""Raise ``HTTP 400`` if the blob does not match its checksum.
Call this before storing or forwarding any client-provided blob.
The backend never holds decryption keys — this check only verifies
that the opaque bytes arrived intact.
"""
if not verify_checksum(blob, checksum):
raise HTTPException(
status_code=400,
detail="Checksum mismatch: blob integrity check failed",
)

View File

@@ -1,205 +0,0 @@
"""Cloud vector store — wraps Pinecone (default) or Qdrant.
Vectors are pre-encrypted blobs from the client. The backend stores them
alongside a deterministic 32-dim float representation derived from the blob's
SHA-256 hash. Semantic ANN search is not meaningful on encrypted data — this
is a known trade-off documented in the backend plan.
Isolation: Pinecone uses ``namespace=user_id``; Qdrant filters by
``user_id`` payload field on a shared collection.
"""
from __future__ import annotations
import base64
import hashlib
from typing import Any
from pinecone import Pinecone
from qdrant_client import QdrantClient
from qdrant_client.models import FieldCondition, Filter, MatchValue, PointIdsList, PointStruct
from app.config.settings import settings
from app.schemas import VectorItem, VectorSearchResult
_QDRANT_COLLECTION = "adiuva_vectors"
def _blob_to_vector(blob: bytes) -> list[float]:
"""Derive a 32-dim float vector from *blob* for storage purposes only.
Uses SHA-256 to produce a deterministic 32-byte fingerprint, then
normalises each byte to the range [-1.0, 1.0]. This vector carries no
semantic meaning on encrypted data.
"""
return [(b - 128) / 128.0 for b in hashlib.sha256(blob).digest()]
class VectorStore:
"""Thin wrapper around Pinecone or Qdrant.
The backend to use is selected at runtime:
- Pinecone: when ``settings.PINECONE_API_KEY`` is non-empty.
- Qdrant: otherwise (requires ``settings.QDRANT_URL``).
"""
def _use_pinecone(self) -> bool:
return bool(settings.PINECONE_API_KEY)
# ── Pinecone helpers ──────────────────────────────────────────────
def _pinecone_index(self) -> Any:
pc = Pinecone(api_key=settings.PINECONE_API_KEY)
return pc.Index(settings.PINECONE_INDEX)
# ── Qdrant helpers ────────────────────────────────────────────────
def _qdrant_client(self) -> Any:
return QdrantClient(
url=settings.QDRANT_URL,
api_key=settings.QDRANT_API_KEY or None,
)
# ── Public API ────────────────────────────────────────────────────
async def upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
"""Store encrypted vectors in the backend.
Each ``VectorItem.blob`` is base64-encoded and kept in metadata/payload
so it can be returned verbatim during search.
Args:
user_id: Used as Pinecone namespace or Qdrant payload field.
vectors: List of encrypted vector items from the client.
"""
if self._use_pinecone():
await self._pinecone_upsert(user_id, vectors)
else:
await self._qdrant_upsert(user_id, vectors)
async def search(
self,
user_id: str,
query_blob: bytes,
top_k: int,
) -> list[VectorSearchResult]:
"""Query the vector store and return encrypted result blobs.
The query vector is derived from *query_blob* using the same
deterministic mapping as upsert.
Args:
user_id: Scopes the search to this user's namespace.
query_blob: Encrypted query from the client.
top_k: Maximum number of results to return.
Returns:
List of ``VectorSearchResult`` with ``id``, ``score``, and ``blob``.
"""
if self._use_pinecone():
return await self._pinecone_search(user_id, query_blob, top_k)
return await self._qdrant_search(user_id, query_blob, top_k)
async def delete(self, user_id: str, vector_ids: list[str]) -> None:
"""Remove vectors by ID, scoped to *user_id*.
Args:
user_id: Namespace / payload filter to prevent cross-user deletion.
vector_ids: List of vector IDs to remove.
"""
if self._use_pinecone():
await self._pinecone_delete(user_id, vector_ids)
else:
await self._qdrant_delete(user_id, vector_ids)
# ── Pinecone implementation ───────────────────────────────────────
async def _pinecone_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
index = self._pinecone_index()
records = [
{
"id": v.id,
"values": _blob_to_vector(v.blob),
"metadata": {
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
}
for v in vectors
]
index.upsert(vectors=records, namespace=user_id)
async def _pinecone_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
index = self._pinecone_index()
query_vector = _blob_to_vector(query_blob)
response = index.query(
vector=query_vector,
top_k=top_k,
namespace=user_id,
include_metadata=True,
)
results: list[VectorSearchResult] = []
for match in response.get("matches", []):
blob_bytes = base64.b64decode(match["metadata"]["blob"])
results.append(
VectorSearchResult(
id=match["id"],
score=match["score"],
blob=blob_bytes,
)
)
return results
async def _pinecone_delete(self, user_id: str, vector_ids: list[str]) -> None:
index = self._pinecone_index()
index.delete(ids=vector_ids, namespace=user_id)
# ── Qdrant implementation ─────────────────────────────────────────
async def _qdrant_upsert(self, user_id: str, vectors: list[VectorItem]) -> None:
client = self._qdrant_client()
points = [
PointStruct(
id=v.id,
vector=_blob_to_vector(v.blob),
payload={
"blob": base64.b64encode(v.blob).decode(),
"checksum": v.checksum,
"user_id": user_id,
},
)
for v in vectors
]
client.upsert(collection_name=_QDRANT_COLLECTION, points=points)
async def _qdrant_search(
self, user_id: str, query_blob: bytes, top_k: int
) -> list[VectorSearchResult]:
client = self._qdrant_client()
query_vector = _blob_to_vector(query_blob)
hits = client.search(
collection_name=_QDRANT_COLLECTION,
query_vector=query_vector,
query_filter=Filter(
must=[FieldCondition(key="user_id", match=MatchValue(value=user_id))]
),
limit=top_k,
)
return [
VectorSearchResult(
id=str(hit.id),
score=hit.score,
blob=base64.b64decode(hit.payload["blob"]),
)
for hit in hits
]
async def _qdrant_delete(self, user_id: str, vector_ids: list[str]) -> None:
client = self._qdrant_client()
client.delete(
collection_name=_QDRANT_COLLECTION,
points_selector=PointIdsList(points=vector_ids),
)

View File

@@ -1,27 +1,34 @@
# ── Adiuva Microservices ─────────────────────────────────────────────
# docker compose up --build
# docker compose up --build auth ws-gateway chat # subset
services: services:
app:
build: . # ═══════════════════════════════════════════════════════════════════
# Infrastructure
# ═══════════════════════════════════════════════════════════════════
traefik:
image: traefik:v3.1
ports: ports:
- "8080:8000" - "80:80"
env_file: - "443:443"
- path: .env - "8080:8080" # dashboard (dev only)
required: false
environment: environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva CF_DNS_API_TOKEN: ${CF_DNS_API_TOKEN:-}
GITHUB_COPILOT_TOKEN_DIR: /root/.config/litellm/github_copilot
volumes: volumes:
- copilot_tokens:/root/.config/litellm/github_copilot - /var/run/docker.sock:/var/run/docker.sock:ro
depends_on: - ./traefik/traefik.yml:/etc/traefik/traefik.yml:ro
db: - ./traefik/dynamic:/etc/traefik/dynamic:ro
condition: service_healthy - traefik_acme:/etc/traefik/acme
restart: unless-stopped restart: unless-stopped
db: db:
image: pgvector/pgvector:pg16 image: pgvector/pgvector:pg16
environment: environment:
POSTGRES_USER: postgres POSTGRES_USER: ${POSTGRES_USER:-postgres}
POSTGRES_PASSWORD: postgres POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:-postgres}
POSTGRES_DB: adiuva POSTGRES_DB: ${POSTGRES_DB:-adiuva}
volumes: volumes:
- postgres_data:/var/lib/postgresql/data - postgres_data:/var/lib/postgresql/data
healthcheck: healthcheck:
@@ -31,42 +38,161 @@ services:
retries: 5 retries: 5
restart: unless-stopped restart: unless-stopped
# Optional Redis for future rate-limit or caching needs redis:
# redis: image: redis:7-alpine
# image: redis:7-alpine command: redis-server --maxmemory 256mb --maxmemory-policy allkeys-lru
# restart: unless-stopped
# ── Local S3-compatible storage (MinIO) ──
minio:
image: minio/minio:latest
command: server /data --console-address ":9001"
ports:
- "9000:9000"
- "9001:9001"
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin
volumes: volumes:
- minio_data:/data - redis_data:/data
healthcheck: healthcheck:
test: ["CMD", "mc", "ready", "local"] test: ["CMD", "redis-cli", "ping"]
interval: 5s interval: 5s
timeout: 5s timeout: 3s
retries: 5 retries: 5
restart: unless-stopped restart: unless-stopped
# ── Local vector store (Qdrant) ── # ── Optional infrastructure (uncomment as needed) ────────────────
qdrant:
image: qdrant/qdrant:latest # minio:
ports: # image: minio/minio:latest
- "6333:6333" # command: server /data --console-address ":9001"
- "6334:6334" # ports:
volumes: # - "9000:9000"
- qdrant_data:/qdrant/storage # - "9001:9001"
# environment:
# MINIO_ROOT_USER: minioadmin
# MINIO_ROOT_PASSWORD: minioadmin
# volumes:
# - minio_data:/data
# healthcheck:
# test: ["CMD", "mc", "ready", "local"]
# interval: 5s
# timeout: 5s
# retries: 5
# restart: unless-stopped
# qdrant:
# image: qdrant/qdrant:latest
# ports:
# - "6333:6333"
# - "6334:6334"
# volumes:
# - qdrant_data:/qdrant/storage
# restart: unless-stopped
# ═══════════════════════════════════════════════════════════════════
# Migrations (run once, then exit)
# ═══════════════════════════════════════════════════════════════════
migrate:
build:
context: .
dockerfile: Dockerfile
command: ["python", "-m", "alembic", "upgrade", "head"]
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
depends_on:
db:
condition: service_healthy
restart: "no"
# ═══════════════════════════════════════════════════════════════════
# Application Services
# ═══════════════════════════════════════════════════════════════════
auth:
build:
context: .
dockerfile: services/auth/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
REDIS_URL: redis://redis:6379/0
depends_on:
db:
condition: service_healthy
migrate:
condition: service_completed_successfully
restart: unless-stopped
ws-gateway:
build:
context: .
dockerfile: services/ws-gateway/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
REDIS_URL: redis://redis:6379/0
depends_on:
redis:
condition: service_healthy
auth:
condition: service_started
restart: unless-stopped
chat:
build:
context: .
dockerfile: services/chat/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
REDIS_URL: redis://redis:6379/0
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
migrate:
condition: service_completed_successfully
restart: unless-stopped
batch-agent:
build:
context: .
dockerfile: services/batch-agent/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
REDIS_URL: redis://redis:6379/0
depends_on:
db:
condition: service_healthy
redis:
condition: service_healthy
migrate:
condition: service_completed_successfully
restart: unless-stopped
billing:
build:
context: .
dockerfile: services/billing/Dockerfile
env_file:
- path: .env
required: false
environment:
DATABASE_URL: postgresql+asyncpg://${POSTGRES_USER:-postgres}:${POSTGRES_PASSWORD:-postgres}@db:5432/${POSTGRES_DB:-adiuva}
depends_on:
db:
condition: service_healthy
migrate:
condition: service_completed_successfully
restart: unless-stopped restart: unless-stopped
volumes: volumes:
postgres_data: postgres_data:
minio_data: redis_data:
qdrant_data: traefik_acme:
copilot_tokens: # minio_data:
# qdrant_data:

View File

@@ -1,8 +1,10 @@
# Adiuva — Architettura Microservizi # Adiuva — Architettura Microservizi (MVP)
## Panoramica ## Panoramica
Il monolite attuale viene suddiviso in **5 servizi** + un **API Gateway**, orchestrati con Docker Compose e raggiungibili tramite dominio su Cloudflare. Il monolite viene suddiviso in **4 servizi MVP** + un **API Gateway (Traefik)**, orchestrati con Docker Compose su un singolo VPS raggiungibile via Cloudflare.
> **Fuori dall'MVP**: Storage Service (S3/backup CRUD) e Plugin Service (marketplace). Verranno aggiunti come servizi indipendenti in una fase successiva.
``` ```
┌──────────────┐ ┌──────────────┐
@@ -14,20 +16,21 @@ Il monolite attuale viene suddiviso in **5 servizi** + un **API Gateway**, orche
│ Traefik │ │ Traefik │
│ API Gateway │ │ API Gateway │
│ (routing, │ │ (routing, │
│ TLS term.) │ TLS, rate
│ limiting) │
└──────┬───────┘ └──────┬───────┘
┌──────────┬───────────┼───────────┬────────── ┌──────────┬───────────┼───────────┐
│ │ │ │ │ │ │ │
┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐ ┌───▼─────┐ ┌─────▼────┐ ┌───▼───┐ ┌────▼────┐ ┌────▼───┐
│ Auth │ │ Chat │ │ Storage │ │Billing │ │ Plugins │ │ Auth │ │ Chat │ │ Agent │ │Billing │
│ Service │ │Service│ │ Service │ │Service │ │ Service │ │ Service │ │Service│ │ Service │ │Service │
└─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘ └───┬─────┘ └─────┬────┘ └───┬───┘ └────┬────┘ └────┬───┘
│ │ │ │ │ │ │ │
┌─────▼──────────▼──────────▼───────────▼──────────▼───── ┌─────▼──────────▼──────────▼───────────▼────┐
│ Infrastruttura │ │ Infrastruttura │
│ PostgreSQL Redis │ MinIO (S3) │ Qdrant │ (Pinecone) │ PostgreSQL Redis │ Qdrant
└──────────────────────────────────────────────────────── └─────────────────────────────────────────────┘
``` ```
--- ---
@@ -83,46 +86,68 @@ def verify_token(token: str) -> dict:
--- ---
### 1.2 Chat Service (`chat-service`) ⭐ Core ### 1.2 Chat Service (`chat-service`) ⭐ Real-time
**Responsabilità**: WebSocket device, home chat, floating chat, agent runner, memory middleware, agent setup journeys. **Responsabilità**: WebSocket device connection, home chat, floating chat, memory middleware, streaming LLM responses verso il client.
| Endpoint originale | Tipo | Questo servizio gestisce la **connessione persistente** con l'app Electron e le interazioni **real-time** dell'utente (chat home, floating chat). È il proprietario della WebSocket.
| Endpoint | Tipo |
|---|---| |---|---|
| `/api/v1/ws/device` | WebSocket | | `/api/v1/ws/device` | WebSocket (connessione persistente) |
| `/api/v1/chat` | POST (REST fallback) | | `/api/v1/chat` | POST (REST fallback) |
| `/api/v1/agents/catalog` | GET |
| `/api/v1/agents/can-create` | POST |
| `/api/v1/agents/trigger` | POST |
**Moduli inclusi**: `deep_agent`, `agent_runner`, `agent_registry`, `memory_middleware`, `ws_context`, `device_manager`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`, `filesystem_agent`). **Moduli inclusi**: `deep_agent`, `memory_middleware`, `ws_context`, `device_manager` (Redis-backed), `output_formatter`, `llm`, tutti gli agent tools (`task_agent`, `project_agent`, `note_agent`, `timeline_agent`).
**Questa è la bestia che deve scalare orizzontalmente** — è il servizio più CPU/memory intensive (LLM calls, tool loops, WebSocket persistenti). **Perché separato dall'Agent Service**: Il Chat Service tiene la WebSocket aperta e risponde in tempo reale (streaming). Scalare aggiungendo repliche è semplice con sticky sessions + Redis pub/sub per il cross-instance routing dei tool_call.
**Scaling**: 2N repliche. Sticky cookies per le WS + Redis per cross-instance.
--- ---
### 1.3 Storage Service (`storage-service`) ### 1.3 Agent Service (`agent-service`) ⭐ Batch
**Responsabilità**: CRUD record crittografati su S3, vector operations, backup. **Responsabilità**: Batch agent processing (directory scanning, file classification, entity extraction), agent setup journeys, agent configuration CRUD.
| Endpoint originale | Metodo | Questo servizio gestisce i processi **long-running** e **CPU-intensive**: scansione filesystem, classificazione file con LLM, estrazione entità in batch. Non possiede la WebSocket — comunica con il device dell'utente tramite **Redis pub/sub** passando per il Chat Service.
| Endpoint | Tipo |
|---|---| |---|---|
| `/api/v1/storage/records` | POST / GET | | `/api/v1/agents/catalog` | GET |
| `/api/v1/storage/records/{id}` | GET / PUT / DELETE | | `/api/v1/agents/can-create` | POST |
| `/api/v1/vectors/upsert` | POST | | `/api/v1/agents/trigger` | POST |
| `/api/v1/vectors/search` | POST | | `/api/v1/agents/journey/start` | POST (o WS relay) |
| `/api/v1/vectors/embed` | POST | | `/api/v1/agents/journey/message` | POST (o WS relay) |
| `/api/v1/vectors` | DELETE |
| `/api/v1/backup` | PUT / GET / DELETE |
| `/api/v1/backup/history` | GET |
**Scaling**: 23 repliche. I/O bound (S3, Qdrant). Stateless. **Moduli inclusi**: `agent_runner`, `agent_registry`, `filesystem_agent`, `llm`.
**Flusso tool-call cross-service** (l'Agent Service non ha la WS):
```
┌──────────────┐ ┌──────────────┐ ┌──────────┐
│ Agent Service│ │ Redis │ │ Chat │
│ (batch run) │ │ │ │ Service │
│ │ │ │ │ (ha WS) │
│ 1. Needs to │ PUBLISH │ │ SUBSCRIBE │ │
│ read file ├───────────►│tool_call:u123├───────────►│ 2. Invia │
│ from │ │ │ │ al │
│ device │ │ │ │ device│
│ │ │ │ │ via WS│
│ │ SUBSCRIBE │ │ PUBLISH │ │
│ 4. Riceve ◄────────────┤tool_result:id│◄───────────┤ 3. Device│
│ risultato │ │ │ │ reply │
└──────────────┘ └──────────────┘ └──────────┘
```
**Scaling**: 1N repliche. Completamente stateless, scala indipendentemente dalla chat. Ogni replica processa batch job diversi. Può essere scalato a 0 se non ci sono agent attivi (risparmio risorse).
**Vantaggio dello split**: Se 50 utenti triggerano agenti batch contemporaneamente, il Chat Service non ne risente — le risposte real-time rimangono veloci.
--- ---
### 1.4 Billing Service (`billing-service`) ### 1.4 Billing Service (`billing-service`)
**Responsabilità**: Stripe checkout, webhook, subscription management, tier enforcement. **Responsabilità**: Stripe checkout, webhook, subscription management.
| Endpoint originale | Metodo | | Endpoint originale | Metodo |
|---|---| |---|---|
@@ -132,31 +157,125 @@ def verify_token(token: str) -> dict:
**Database**: Tabelle `subscriptions` (schema `billing`). **Database**: Tabelle `subscriptions` (schema `billing`).
**Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users (oppure i servizi leggono il tier direttamente dal JWT, aggiornato al prossimo refresh). **Comunicazione inter-servizio**: Quando Stripe invia un webhook e il tier cambia, il Billing Service pubblica un evento su **Redis pub/sub** channel `tier_changed:{user_id}`. L'Auth Service aggiorna il campo `tier` nella tabella users. Al prossimo token refresh il JWT conterrà il tier aggiornato.
**Scaling**: 1 replica sufficiente. Basso traffico. **Scaling**: 1 replica sufficiente. Basso traffico.
--- ---
### 1.5 Plugin Service (`plugin-service`) ### 1.5 Servizi esclusi dall'MVP
**Responsabilità**: Marketplace, installazione plugin, revenue split. I seguenti servizi verranno aggiunti post-MVP come servizi indipendenti:
| Endpoint originale | Metodo | | Servizio | Responsabilità | Note |
|---|---| |---|---|---|
| `/api/v1/plugins` | GET | | **Storage Service** | S3 blobs CRUD, vector ops, backup | Le funzionalità vector/embed possono restare nel Chat Service per il MVP |
| `/api/v1/plugins/{id}` | GET | | **Plugin Service** | Marketplace, install, revenue split | Feature non critica per il lancio |
| `/api/v1/plugins/{id}/install` | POST / DELETE |
**Database**: Tabelle `plugins`, `plugin_installations`, `revenue_events`.
**Scaling**: 1 replica. Basso traffico.
--- ---
## 2. WebSocket con Scaling Orizzontale — Il Problema Chiave ## 2. Tier Check — Dove e Come
### Il problema attuale Il tier dell'utente (free/pro/power/team) determina rate-limiting, quote e accesso a funzionalità. Con i microservizi, **ogni servizio controlla il tier autonomamente** senza chiamare l'Auth Service.
### Strategia: Tier nel JWT
L'Auth Service include il `tier` come claim nel JWT al momento del login/refresh:
```json
{
"sub": "user_123",
"tier": "pro",
"exp": 1742515200,
"iat": 1742511600
}
```
Ogni servizio:
1. Decodifica il JWT con la chiave pubblica (già lo fa per l'auth)
2. Legge `payload["tier"]`**zero chiamate extra**
3. Applica le sue regole di enforcement localmente
```python
# shared/auth.py — dependency FastAPI condivisa
from fastapi import Depends, HTTPException, Request
from jose import jwt
PUBLIC_KEY = ...
class CurrentUser:
def __init__(self, user_id: str, tier: str):
self.user_id = user_id
self.tier = tier
async def get_current_user(request: Request) -> CurrentUser:
token = request.headers.get("Authorization", "").removeprefix("Bearer ")
payload = jwt.decode(token, PUBLIC_KEY, algorithms=["RS256"])
return CurrentUser(user_id=payload["sub"], tier=payload["tier"])
def require_tier(*allowed_tiers: str):
"""Dependency che blocca se il tier non è tra quelli ammessi."""
async def check(user: CurrentUser = Depends(get_current_user)):
if user.tier not in allowed_tiers:
raise HTTPException(403, "Tier insufficient")
return user
return check
```
### Cosa succede quando il tier cambia (upgrade/downgrade)?
```
┌──────────┐ Stripe webhook ┌──────────┐ tier_changed ┌──────────┐
│ Stripe │ ─────────────────►│ Billing │ ───────────────►│ Auth │
│ │ │ Service │ (Redis pub/sub) │ Service │
└──────────┘ └──────────┘ └────┬─────┘
UPDATE users
SET tier = 'power'
Al prossimo /refresh
il JWT conterrà tier='power'
```
**Latenza del cambio**: Il tier si propaga al prossimo token refresh (tipicamente 1530 min, o il client può forzare un refresh immediato dopo il checkout). Per il billing webhook, il downgrade può essere forzato invalidando il refresh token su Redis → il client è obbligato a ri-autenticarsi.
### Dove si applica in ciascun servizio
| Servizio | Enforcement |
|---|---|
| **Auth Service** | Nessuno (è lui che scrive il tier) |
| **Chat Service** | Rate-limit per tier (req/min), quota messaggi |
| **Agent Service** | Max agent configs, max runs/day, max concurrent batches |
| **Billing Service** | Nessuno (gestisce i tier, non li consuma) |
### Rate-limit distribuito via Redis
Poiché ogni servizio ha le sue repliche, il rate-limiting deve essere **condiviso** via Redis:
```python
# shared/middleware/rate_limit.py
import redis.asyncio as aioredis
class DistributedRateLimiter:
def __init__(self, redis: aioredis.Redis):
self._redis = redis
async def check(self, user_id: str, tier: str, service: str) -> bool:
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
max_req = limits.get(tier, 20)
key = f"rate:{service}:{user_id}"
pipe = self._redis.pipeline()
pipe.incr(key)
pipe.expire(key, 60)
count, _ = await pipe.execute()
return count <= max_req
```
---
## 3. WebSocket con Scaling Orizzontale — Il Problema Chiave
`DeviceConnectionManager` è un **singleton in-memory**: `DeviceConnectionManager` è un **singleton in-memory**:
@@ -354,7 +473,7 @@ class RedisDeviceManager:
--- ---
## 3. Struttura Directory Proposta ## 4. Struttura Directory Proposta (MVP)
``` ```
adiuva-api/ adiuva-api/
@@ -364,7 +483,7 @@ adiuva-api/
│ ├── auth.py # JWT verification (chiave pubblica) │ ├── auth.py # JWT verification (chiave pubblica)
│ ├── schemas.py # Pydantic schemas condivisi │ ├── schemas.py # Pydantic schemas condivisi
│ ├── middleware/ │ ├── middleware/
│ │ ├── rate_limit.py │ │ ├── rate_limit.py # DistributedRateLimiter (Redis)
│ │ └── sanitizer.py │ │ └── sanitizer.py
│ └── models/ │ └── models/
│ └── base.py # SQLAlchemy base condivisa │ └── base.py # SQLAlchemy base condivisa
@@ -390,42 +509,45 @@ adiuva-api/
│ ├── main.py │ ├── main.py
│ ├── config.py │ ├── config.py
│ ├── db.py │ ├── db.py
│ ├── models.py # agent_run_logs, memory_* │ ├── models.py # memory_*
│ ├── routes/ │ ├── routes/
│ │ ├── device_ws.py │ │ ├── device_ws.py # WS connection owner
│ │ ── chat.py │ │ ── chat.py # REST fallback
│ │ └── agents.py
│ ├── core/ │ ├── core/
│ │ ├── device_manager.py # RedisDeviceManager │ │ ├── device_manager.py # RedisDeviceManager
│ │ ├── deep_agent.py │ │ ├── deep_agent.py # Home + floating chat
│ │ ├── agent_runner.py
│ │ ├── agent_registry.py
│ │ ├── memory_middleware.py │ │ ├── memory_middleware.py
│ │ ├── ws_context.py │ │ ├── ws_context.py
│ │ ├── output_formatter.py │ │ ├── output_formatter.py
│ │ └── llm.py │ │ └── llm.py
│ └── agents/ │ └── agents/ # Tool definitions (used by deep_agent)
│ ├── task_agent.py │ ├── task_agent.py
│ ├── project_agent.py │ ├── project_agent.py
│ ├── note_agent.py │ ├── note_agent.py
── timeline_agent.py ── timeline_agent.py
│ └── filesystem_agent.py
├── storage-service/ ├── agent-service/
│ ├── Dockerfile │ ├── Dockerfile
│ ├── requirements.txt │ ├── requirements.txt
│ └── app/ │ └── app/
│ ├── main.py │ ├── main.py
│ ├── config.py │ ├── config.py
│ ├── db.py │ ├── db.py
│ ├── models.py # storage_records, backup_metadata │ ├── models.py # agent_run_logs, local/cloud_agent_configs
│ ├── routes/ │ ├── routes/
│ │ ├── storage.py │ │ ├── agents.py # catalog, can-create, trigger
│ │ ── vectors.py │ │ ── agent_setup.py # journey start/message
│ └── backup.py ├── core/
└── services/ │ ├── agent_runner.py # Batch classify → process
├── blob_store.py ├── agent_registry.py
── vector_store.py ── redis_executor.py # execute_on_client via Redis pub/sub
│ │ └── llm.py
│ └── agents/
│ ├── task_agent.py # Tool definitions (batch context)
│ ├── project_agent.py
│ ├── note_agent.py
│ ├── timeline_agent.py
│ └── filesystem_agent.py
├── billing-service/ ├── billing-service/
│ ├── Dockerfile │ ├── Dockerfile
@@ -441,26 +563,18 @@ adiuva-api/
│ ├── stripe_service.py │ ├── stripe_service.py
│ └── tier_manager.py │ └── tier_manager.py
├── plugin-service/
│ ├── Dockerfile
│ ├── requirements.txt
│ └── app/
│ ├── main.py
│ ├── config.py
│ ├── db.py
│ ├── models.py # plugins, installations, revenue
│ └── routes/
│ └── plugins.py
└── infra/ └── infra/
├── traefik/ ├── traefik/
│ └── traefik.yml │ └── traefik.yml
├── keys/
│ ├── jwt_private.pem # Solo auth-service
│ └── jwt_public.pem # Tutti i servizi
└── alembic/ # Migrazioni condivise o per-servizio └── alembic/ # Migrazioni condivise o per-servizio
``` ```
--- ---
## 4. Docker Compose — Configurazione Completa ## 5. Docker Compose — Configurazione MVP
```yaml ```yaml
# docker-compose.yml # docker-compose.yml
@@ -478,14 +592,14 @@ services:
- "--providers.docker.exposedbydefault=false" - "--providers.docker.exposedbydefault=false"
- "--entrypoints.web.address=:80" - "--entrypoints.web.address=:80"
- "--entrypoints.websecure.address=:443" - "--entrypoints.websecure.address=:443"
# Cloudflare gestisce TLS, Traefik riceve HTTP dal proxy
- "--entrypoints.web.http.redirections.entrypoint.to=websecure" - "--entrypoints.web.http.redirections.entrypoint.to=websecure"
ports: ports:
- "80:80" - "80:80"
- "443:443" - "443:443"
- "8080:8080" # Dashboard Traefik - "8080:8080" # Dashboard Traefik (disabilitare in prod)
volumes: volumes:
- /var/run/docker.sock:/var/run/docker.sock:ro - /var/run/docker.sock:/var/run/docker.sock:ro
- ./infra/certs:/certs:ro
restart: unless-stopped restart: unless-stopped
# ══════════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════════
@@ -498,10 +612,12 @@ services:
env_file: .env env_file: .env
environment: environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key JWT_PRIVATE_KEY_FILE: /run/secrets/jwt_private_key
SERVICE_NAME: auth SERVICE_NAME: auth
secrets: secrets:
- jwt_private_key - jwt_private_key
- jwt_public_key
labels: labels:
- "traefik.enable=true" - "traefik.enable=true"
- "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)" - "traefik.http.routers.auth.rule=PathPrefix(`/api/v1/auth`)"
@@ -509,14 +625,16 @@ services:
depends_on: depends_on:
db: db:
condition: service_healthy condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════════
# Chat Service (scalabile, N repliche) # Chat Service — Real-time WS + Chat (scalabile)
# ══════════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════════
chat-service: chat-service:
build: ./chat-service build: ./chat-service
deploy: deploy:
replicas: 3 replicas: 2
env_file: .env env_file: .env
environment: environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
@@ -527,8 +645,8 @@ services:
- jwt_public_key - jwt_public_key
labels: labels:
- "traefik.enable=true" - "traefik.enable=true"
# REST routes # REST chat endpoint
- "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`) || PathPrefix(`/api/v1/agents`)" - "traefik.http.routers.chat.rule=PathPrefix(`/api/v1/chat`)"
- "traefik.http.services.chat.loadbalancer.server.port=8000" - "traefik.http.services.chat.loadbalancer.server.port=8000"
# WebSocket route con sticky session # WebSocket route con sticky session
- "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)" - "traefik.http.routers.ws.rule=PathPrefix(`/api/v1/ws`)"
@@ -543,26 +661,29 @@ services:
condition: service_healthy condition: service_healthy
# ══════════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════════
# Storage Service (2 repliche) # Agent Service — Batch processing (scalabile indipendentemente)
# ══════════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════════
storage-service: agent-service:
build: ./storage-service build: ./agent-service
deploy: deploy:
replicas: 2 replicas: 2
env_file: .env env_file: .env
environment: environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
REDIS_URL: redis://redis:6379
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: storage SERVICE_NAME: agent
secrets: secrets:
- jwt_public_key - jwt_public_key
labels: labels:
- "traefik.enable=true" - "traefik.enable=true"
- "traefik.http.routers.storage.rule=PathPrefix(`/api/v1/storage`) || PathPrefix(`/api/v1/vectors`) || PathPrefix(`/api/v1/backup`)" - "traefik.http.routers.agents.rule=PathPrefix(`/api/v1/agents`)"
- "traefik.http.services.storage.loadbalancer.server.port=8000" - "traefik.http.services.agents.loadbalancer.server.port=8000"
depends_on: depends_on:
db: db:
condition: service_healthy condition: service_healthy
redis:
condition: service_healthy
# ══════════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════════
# Billing Service (1 replica) # Billing Service (1 replica)
@@ -589,28 +710,6 @@ services:
redis: redis:
condition: service_healthy condition: service_healthy
# ══════════════════════════════════════════════════════════
# Plugin Service (1 replica)
# ══════════════════════════════════════════════════════════
plugin-service:
build: ./plugin-service
deploy:
replicas: 1
env_file: .env
environment:
DATABASE_URL: postgresql+asyncpg://postgres:postgres@db:5432/adiuva
JWT_PUBLIC_KEY_FILE: /run/secrets/jwt_public_key
SERVICE_NAME: plugins
secrets:
- jwt_public_key
labels:
- "traefik.enable=true"
- "traefik.http.routers.plugins.rule=PathPrefix(`/api/v1/plugins`)"
- "traefik.http.services.plugins.loadbalancer.server.port=8000"
depends_on:
db:
condition: service_healthy
# ══════════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════════
# Infrastruttura # Infrastruttura
# ══════════════════════════════════════════════════════════ # ══════════════════════════════════════════════════════════
@@ -641,19 +740,6 @@ services:
retries: 5 retries: 5
restart: unless-stopped restart: unless-stopped
minio:
image: minio/minio:latest
command: server /data --console-address ":9001"
ports:
- "9000:9000"
- "9001:9001"
environment:
MINIO_ROOT_USER: minioadmin
MINIO_ROOT_PASSWORD: minioadmin
volumes:
- minio_data:/data
restart: unless-stopped
qdrant: qdrant:
image: qdrant/qdrant:latest image: qdrant/qdrant:latest
volumes: volumes:
@@ -669,22 +755,21 @@ secrets:
volumes: volumes:
postgres_data: postgres_data:
redis_data: redis_data:
minio_data:
qdrant_data: qdrant_data:
``` ```
--- ---
## 5. Configurazione Cloudflare + VPS ## 6. Configurazione Cloudflare + VPS
### 5.1 DNS ### 6.1 DNS
``` ```
api.tuodominio.com → A record → IP del VPS api.tuodominio.com → A record → IP del VPS
→ Proxy: ON (orange cloud) → Proxy: ON (orange cloud)
``` ```
### 5.2 Cloudflare Settings ### 6.2 Cloudflare Settings
| Setting | Valore | Motivo | | Setting | Valore | Motivo |
|---------|--------|--------| |---------|--------|--------|
@@ -693,7 +778,7 @@ api.tuodominio.com → A record → IP del VPS
| Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ | | Proxy timeout | **100s** (Enterprise) o default | Le LLM calls possono durare 30s+ |
| Under Attack Mode | Off (attivare se necessario) | | | Under Attack Mode | Off (attivare se necessario) | |
### 5.3 TLS sul VPS ### 6.3 TLS sul VPS
Due opzioni: Due opzioni:
- **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik - **Opzione A (consigliata)**: Cloudflare Origin Certificate → montato in Traefik
@@ -711,7 +796,7 @@ tls:
keyFile: /certs/origin-key.pem keyFile: /certs/origin-key.pem
``` ```
### 5.4 Rete VPS ### 6.4 Rete VPS
```bash ```bash
# UFW firewall — solo Cloudflare può raggiungere le porte 80/443 # UFW firewall — solo Cloudflare può raggiungere le porte 80/443
@@ -726,9 +811,9 @@ ufw enable
--- ---
## 6. Comunicazione Inter-Servizio ## 7. Comunicazione Inter-Servizio
### 6.1 Pattern: Event Bus via Redis Pub/Sub ### 7.1 Redis Pub/Sub — Event Bus
``` ```
┌──────────┐ tier_changed:user_123 ┌──────────┐ ┌──────────┐ tier_changed:user_123 ┌──────────┐
@@ -736,87 +821,55 @@ ufw enable
│ Service │ │ Service │ │ Service │ │ Service │
└──────────┘ └──────────┘ └──────────┘ └──────────┘
┌──────────┐ agent_triggered:user_123 ┌──────────┐ ┌──────────┐ tool_call:user_123 ┌──────────┐
Chat ──────────────────────── │ Any Agent │ ──────────────────────── Chat
│ Service │ │ Service │ │ Service │ │ Service │
└──────────┘ └──────────┘ │ (batch) │ ◄────────────────────────│ (ha WS) │
└──────────┘ tool_result:{call_id} └──────────┘
``` ```
### 6.2 Pattern: HTTP Sincrono (per query semplici) ### 7.2 Health Checks e Service Discovery
Il Chat Service può avere bisogno del tier dell'utente per il rate-limiting degli agent. Due strategie:
- **Strategia A (preferita)**: Il tier è nel JWT. All'aggiornamento, il Billing Service forza token refresh invalidando i vecchi token su Redis.
- **Strategia B**: Il Chat Service chiama `http://auth-service:8000/internal/user/{id}/tier` (rete Docker interna, non esposta).
### 6.3 Health Checks e Service Discovery
Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via: Traefik gestisce automaticamente il service discovery via Docker labels. I servizi non devono conoscersi tra loro — comunicano solo via:
- **Redis pub/sub** (eventi asincroni) - **Redis pub/sub** (tool-call cross-instance, tier events)
- **Redis hash** (stato condiviso, es. `ws:connections`) - **Redis hash** (stato condiviso: `ws:connections`, rate-limit counters)
- **PostgreSQL** (dati persistenti condivisi) - **PostgreSQL** (dati persistenti condivisi)
--- ---
## 7. Piano di Migrazione Incrementale ## 8. Piano di Migrazione Incrementale (MVP)
### Fase 1 — Preparazione (senza rompere nulla) ### Fase 1 — Preparazione (nel monolite attuale)
1. Aggiungere Redis al `docker-compose.yml` attuale 1. Aggiungere Redis al `docker-compose.yml` attuale
2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi) 2. Migrare JWT da HS256 → RS256 (backward-compatible: accetta entrambi per un periodo)
3. Implementare `RedisDeviceManager` come drop-in replacement 3. Implementare `RedisDeviceManager` come drop-in replacement del singleton in-memory
4. Estrarre `shared/` con auth verification, schemas, middleware 4. Estrarre `shared/` con auth verification, schemas, middleware
### Fase 2 — Primo split: Auth Service ### Fase 2 — Auth Service (primo split)
1. Estrarre `auth.py` routes + models in `auth-service/` 1. Estrarre `auth.py` routes + models in `auth-service/`
2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite 2. Verificare che i JWT firmati da `auth-service` vengano validati dal monolite
3. Aggiornare Traefik per routare `/api/v1/auth/*` al nuovo servizio 3. Aggiungere Traefik e routare `/api/v1/auth/*` al nuovo servizio
4. Il monolite continua a servire tutto il resto 4. Il monolite continua a servire tutto il resto
### Fase 3 — Storage + Billing + Plugins ### Fase 3 — Billing Service
1. Servizi stateless e senza WebSocket → facili da estrarre 1. Estrarre billing routes, Stripe service, tier manager
2. Estrarre uno alla volta, testare, routare via Traefik 2. Configurare Redis pub/sub per `tier_changed` events
3. Il monolite diventa sempre più magro 3. Routare via Traefik
### Fase 4 — Chat Service (il più delicato) ### Fase 4 — Split Chat + Agent (il più delicato)
1. Il monolite residuo **diventa** il Chat Service 1. Il monolite residuo contiene WS + chat + agents
2. Rimuovere i route migrati, tenere solo WS + chat + agents 2. Separare Agent Service: estrarre `agent_runner`, `agent_registry`, `agent_setup`, route `/agents/*`
3. Testare lo scaling a 2+ istanze con `RedisDeviceManager` 3. Implementare `redis_executor.py` nell'Agent Service per tool-call via Redis
4. Verificare tool-call cross-instance 4. Il Chat Service resta proprietario della WS e sottoscrive i canali `tool_call:{user_id}`
5. Testare: trigger agent dall'Agent Service → tool_call via Redis → Chat Service → WS → device → risposta
### Fase 5 — Cleanup ### Fase 5 — Scaling test
1. Rimuovere il monolite originale 1. Scalare Chat Service a 2 repliche, verificare sticky sessions
2. CI/CD pipeline per build/push separati 2. Scalare Agent Service a 2 repliche, verificare batch processing distribuito
3. Monitoring (Prometheus + Grafana) per ogni servizio 3. Monitoring (Prometheus + Grafana) per ogni servizio
--- ---
## 8. Rate Limiting Distribuito
Il middleware attuale usa un contatore in-memory per il rate-limiting. Con i microservizi:
```python
# shared/middleware/rate_limit.py
import redis.asyncio as aioredis
class DistributedRateLimiter:
def __init__(self, redis: aioredis.Redis):
self._redis = redis
async def check(self, user_id: str, tier: str) -> bool:
limits = {"free": 20, "pro": 60, "power": 120, "team": 200}
max_req = limits.get(tier, 20)
key = f"rate:{user_id}"
pipe = self._redis.pipeline()
pipe.incr(key)
pipe.expire(key, 60) # Finestra di 60 secondi
count, _ = await pipe.execute()
return count <= max_req
```
---
## 9. Monitoraggio e Logging ## 9. Monitoraggio e Logging
```yaml ```yaml
@@ -845,35 +898,44 @@ Ogni servizio espone `/metrics` (Prometheus) e scrive log strutturati (JSON) rac
--- ---
## 10. Sizing VPS Minimo Consigliato ## 10. Sizing VPS Minimo Consigliato (MVP)
| Componente | CPU | RAM | Note | | Componente | CPU | RAM | Note |
|---|---|---|---| |---|---|---|---|
| Traefik | 0.25 | 128MB | | | Traefik | 0.25 | 128MB | |
| Auth Service ×2 | 0.25 ×2 | 128MB ×2 | | | Auth Service ×2 | 0.25 ×2 | 128MB ×2 | Stateless, leggero |
| Chat Service ×2 | 1.0 ×2 | 1GB ×2 | Il più pesante (LLM calls) | | Chat Service ×2 | 1.0 ×2 | 1GB ×2 | WS + streaming LLM |
| Storage Service ×2 | 0.5 ×2 | 256MB ×2 | I/O bound | | Agent Service ×2 | 0.75 ×2 | 512MB ×2 | Batch LLM, CPU-bound |
| Billing Service | 0.25 | 128MB | | | Billing Service | 0.25 | 128MB | |
| Plugin Service | 0.25 | 128MB | |
| PostgreSQL | 1.0 | 1GB | | | PostgreSQL | 1.0 | 1GB | |
| Redis | 0.25 | 256MB | | | Redis | 0.25 | 256MB | |
| Qdrant | 0.5 | 512MB | | | Qdrant | 0.5 | 512MB | |
| MinIO | 0.25 | 256MB | | | **Totale MVP** | **~5.5 vCPU** | **~5 GB** | |
| **Totale** | **~6 vCPU** | **~5.5 GB** | |
**Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. **Raccomandazione**: VPS con **8 vCPU / 16 GB RAM** per avere margine. Hetzner CPX41 (~€30/mese) o equivalente. Senza Storage/Plugin si risparmia ~1 vCPU e 512MB rispetto alla versione completa.
--- ---
## Riepilogo Decisioni Architetturali ## Riepilogo Architettura MVP
| Servizio | Repliche | Proprietario di |
|---|---|---|
| **Traefik** | 1 | Routing, TLS, sticky sessions |
| **Auth Service** | 2 | JWT RS256, registrazione, login, profilo |
| **Chat Service** | 2N | WebSocket, home/floating chat, streaming |
| **Agent Service** | 2N | Batch processing, directory scan, agent setup |
| **Billing Service** | 1 | Stripe, subscriptions, tier management |
| Decisione | Scelta | Motivazione | | Decisione | Scelta | Motivazione |
|---|---|---| |---|---|---|
| API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico | | API Gateway | Traefik | Nativo Docker, WebSocket support, service discovery automatico |
| JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service | | JWT | RS256 (asimmetrico) | Verifica distribuita senza contattare Auth Service |
| WebSocket scaling | Redis pub/sub + registry | Cross-instance tool-call routing | | Tier check | Claim nel JWT | Ogni servizio verifica localmente, zero roundtrip |
| Rate limiting | Redis contatori | Distribuito, sliding window | | WebSocket scaling | Redis pub/sub + sticky cookies | Cross-instance tool-call routing |
| Service communication | Redis pub/sub + HTTP interno | Asincrono per eventi, sincrono per query | | Chat ↔ Agent split | Servizi separati | Batch CPU-bound non impatta real-time chat |
| Database | PostgreSQL condiviso (un DB, schema separation opzionale) | Semplicità; split DB futuro facile | | Agent → Device comms | Redis pub/sub via Chat Service | Agent non possiede la WS, usa un relay |
| TLS | Cloudflare Origin Certificate | Zero maintenance, trust Cloudflare | | Rate limiting | Redis contatori distribuiti | Sliding window condivisa tra repliche |
| Database | PostgreSQL condiviso | Semplicità MVP; split DB futuro facile |
| TLS | Cloudflare Origin Certificate | Zero maintenance |
| Orchestrazione | Docker Compose | Sufficiente per un singolo VPS | | Orchestrazione | Docker Compose | Sufficiente per un singolo VPS |
| Storage / Plugin | Post-MVP | Non critici per il lancio |

View File

@@ -1,35 +0,0 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
langchain>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.1.0
litellm>=1.50.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
stripe>=11.0.0
boto3>=1.35.0
slowapi>=0.1.9
sqlalchemy>=2.0.0
asyncpg>=0.30.0
alembic>=1.14.0
bcrypt>=4.2.0
python-dotenv>=1.0.0
httpx>=0.28.0
websockets>=14.0
psycopg2-binary>=2.9.0
pytest>=8.0.0
pytest-asyncio>=0.24.0
aiosqlite>=0.20.0
moto[s3]>=5.0.0
pinecone>=5.0.0
qdrant-client>=1.7.0
croniter>=3.0.0
google-api-python-client>=2.130.0
google-auth>=2.29.0
google-auth-oauthlib>=1.2.0
google-auth-httplib2>=0.2.0
msal>=1.28.0
cryptography>=42.0.0
ruff>=0.8.0

View File

@@ -0,0 +1,19 @@
# ── Auth Service ──────────────────────────────────────────────────────────────
# This file contains env vars specific to the Auth Service.
# Shared vars (DATABASE_URL, REDIS_URL, etc.) come from the root .env
# or from docker-compose environment.
# ── JWT RS256 Keys ────────────────────────────────────────────────────────────
# Generate keypair:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# openssl rsa -in private.pem -pubout -out public.pem
#
# Paste PEM content with literal \n for newlines:
# JWT_PRIVATE_KEY=-----BEGIN PRIVATE KEY-----\nMIIEvQ...
# JWT_PUBLIC_KEY=-----BEGIN PUBLIC KEY-----\nMIIBIj...
# PRIVATE KEY — used to SIGN JWTs. NEVER share outside this service.
JWT_PRIVATE_KEY=
# PUBLIC KEY — used to VERIFY JWTs.
JWT_PUBLIC_KEY=

36
services/auth/Dockerfile Normal file
View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
# Install shared + service deps in one layer
COPY services/auth/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Copy shared module (available to all services)
COPY shared/ shared/
# Copy service source
COPY services/auth/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "30"]

16
services/auth/README.md Normal file
View File

@@ -0,0 +1,16 @@
# Auth Service
Owns: user registration, login, JWT RS256 issuance, token refresh, `/me` endpoint.
## Tables owned
- `users`
- `refresh_tokens`
- `subscriptions` (read; Billing Service writes)
## Endpoints
- `POST /auth/register`
- `POST /auth/login`
- `POST /auth/refresh`
- `GET /auth/me`
- `PUT /auth/me`
- `GET /auth/verify` (ForwardAuth for Traefik)

View File

@@ -0,0 +1,34 @@
"""Auth Service — local configuration.
Contains secrets that ONLY the Auth Service needs (e.g., JWT private key).
These are NOT in shared/config.py to prevent other services from accessing them.
"""
from pydantic import field_validator
from pydantic_settings import BaseSettings, SettingsConfigDict
class AuthSettings(BaseSettings):
# RS256 private key (PEM format). Used to SIGN JWTs.
# Only the Auth Service has this. Generate with:
# openssl genpkey -algorithm RSA -out private.pem -pkeyopt rsa_keygen_bits:2048
# Then set the env var (newlines as \n):
# JWT_PRIVATE_KEY="-----BEGIN PRIVATE KEY-----\nMIIEv..."
JWT_PRIVATE_KEY: str = ""
# RS256 public key (PEM format). Used to VERIFY JWTs.
# Derived from the private key:
# openssl rsa -in private.pem -pubout -out public.pem
JWT_PUBLIC_KEY: str = ""
@field_validator("JWT_PRIVATE_KEY", "JWT_PUBLIC_KEY", mode="before")
@classmethod
def _expand_pem_newlines(cls, v: str) -> str:
if isinstance(v, str) and r"\n" in v:
return v.replace(r"\n", "\n")
return v
model_config = SettingsConfigDict(env_file=".env", env_file_encoding="utf-8")
auth_settings = AuthSettings()

View File

@@ -1,14 +1,7 @@
"""Auth middleware — JWT validation dependency. """Auth dependencies — JWT validation for the Auth Service.
``get_current_user`` is the FastAPI dependency used by all protected routes. This is the canonical get_current_user used by protected endpoints
It decodes the Bearer JWT (identity + expiry), then fetches the current tier within the Auth Service itself (/me, /me PUT).
from the ``subscriptions`` table so that tier changes take effect immediately
without requiring token re-issue.
Exempt routes (no JWT required):
- POST /api/v1/auth/register
- POST /api/v1/auth/login
- POST /api/v1/billing/webhook
""" """
from __future__ import annotations from __future__ import annotations
@@ -19,9 +12,12 @@ from jose import JWTError, jwt
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings from shared.config import settings
from app.db import get_session from shared.db import get_session
from app.schemas import UserProfile from shared.models import Subscription, User
from shared.schemas import UserProfile
from app.config import auth_settings
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login") oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/auth/login")
@@ -32,11 +28,8 @@ async def get_current_user(
) -> UserProfile: ) -> UserProfile:
"""Validate a Bearer JWT and return the authenticated user. """Validate a Bearer JWT and return the authenticated user.
The JWT is used for identity and expiry only. The tier is fetched live The JWT is used for identity and expiry. Tier is fetched live from the
from the ``subscriptions`` table so that upgrades/downgrades take effect subscriptions table so upgrades/downgrades take effect immediately.
immediately. Falls back to ``'free'`` when no subscription row exists.
Raises HTTP 401 on any invalid or expired token.
""" """
credentials_exc = HTTPException( credentials_exc = HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED, status_code=status.HTTP_401_UNAUTHORIZED,
@@ -45,7 +38,7 @@ async def get_current_user(
) )
try: try:
payload = jwt.decode( payload = jwt.decode(
token, settings.JWT_SECRET, algorithms=[settings.JWT_ALGORITHM] token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
) )
user_id: str | None = payload.get("sub") user_id: str | None = payload.get("sub")
email: str | None = payload.get("email") email: str | None = payload.get("email")
@@ -54,18 +47,14 @@ async def get_current_user(
except JWTError: except JWTError:
raise credentials_exc raise credentials_exc
# Live tier lookup — subscription row is the authoritative source. # Live tier lookup
# In dev, fall back to 'power' (unlimited) so quota limits don't
# block local development when no Stripe subscription exists.
from app.models import Subscription, User # noqa: PLC0415
result = await db.execute( result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id) select(Subscription.tier).where(Subscription.user_id == user_id)
) )
default_tier = "power" if settings.ENV == "dev" else "free" default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier tier: str = result.scalar_one_or_none() or default_tier
# Fetch name/surname from user row. # Fetch name/surname
user_result = await db.execute( user_result = await db.execute(
select(User.name, User.surname).where(User.id == user_id) select(User.name, User.surname).where(User.id == user_id)
) )

62
services/auth/app/main.py Normal file
View File

@@ -0,0 +1,62 @@
"""Auth Service — JWT issuance, user management, ForwardAuth verification.
Standalone FastAPI service extracted from the adiuva-api monolith.
Owns: users, refresh_tokens, subscriptions (read).
"""
import sys
from contextlib import asynccontextmanager
from pathlib import Path
# Ensure the repo root is on sys.path so "shared" is importable.
# In Docker, COPY shared/ puts it at /app/shared/ (already importable).
# In local dev, we need to add the repo root (two levels up from this file).
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from shared.config import settings
@asynccontextmanager
async def lifespan(app: FastAPI):
yield
from shared.db import engine
await engine.dispose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Auth Service",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from app.routes import router
from app.verify import router as verify_router
app.include_router(router, prefix="/api/v1")
app.include_router(verify_router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "auth", "version": app.version}
return app
app = create_app()

View File

@@ -1,8 +1,6 @@
"""Auth routes: register, login, refresh, me. """Auth routes: register, login, refresh, me.
Users and refresh tokens are persisted in PostgreSQL (users + refresh_tokens Extracted from app/api/routes/auth.py uses shared.* imports instead of app.*.
tables). Passwords are hashed with bcrypt; refresh tokens are stored as
SHA-256 hashes so plaintext never reaches the DB.
""" """
from __future__ import annotations from __future__ import annotations
@@ -20,11 +18,13 @@ from pydantic import BaseModel
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user from shared.config import settings
from app.config.settings import settings from shared.db import get_session
from app.db import get_session from shared.models import RefreshToken, Subscription, User
from app.models import RefreshToken, User from shared.schemas import AuthTokens, UserProfile
from app.schemas import AuthTokens, UserProfile
from app.config import auth_settings
from app.deps import get_current_user
router = APIRouter(prefix="/auth", tags=["auth"]) router = APIRouter(prefix="/auth", tags=["auth"])
@@ -46,7 +46,7 @@ def _hash_token(plain_token: str) -> str:
def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]: def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"""Return (signed JWT, expires_at_ms).""" """Return (RS256-signed JWT, expires_at_ms)."""
now = int(time.time()) now = int(time.time())
exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60 exp = now + settings.JWT_ACCESS_TOKEN_EXPIRE_MINUTES * 60
payload = { payload = {
@@ -56,10 +56,19 @@ def _make_access_token(user_id: str, email: str, tier: str) -> tuple[str, int]:
"exp": exp, "exp": exp,
"iat": now, "iat": now,
} }
token = jwt.encode(payload, settings.JWT_SECRET, algorithm=settings.JWT_ALGORITHM) token = jwt.encode(payload, auth_settings.JWT_PRIVATE_KEY, algorithm="RS256")
return token, exp * 1000 # ms for client return token, exp * 1000 # ms for client
async def _get_live_tier(db: AsyncSession, user_id: str) -> str:
"""Fetch authoritative tier from subscriptions table."""
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
return result.scalar_one_or_none() or default_tier
# ── Request bodies ──────────────────────────────────────────────────── # ── Request bodies ────────────────────────────────────────────────────
@@ -79,6 +88,11 @@ class _RefreshRequest(BaseModel):
refresh_token: str refresh_token: str
class _UpdateProfileRequest(BaseModel):
name: str | None = None
surname: str | None = None
# ── Routes ──────────────────────────────────────────────────────────── # ── Routes ────────────────────────────────────────────────────────────
@@ -102,7 +116,7 @@ async def register(
encryption_key=Fernet.generate_key().decode(), encryption_key=Fernet.generate_key().decode(),
) )
db.add(user) db.add(user)
await db.flush() # get user.id without committing await db.flush()
plain_token = str(uuid.uuid4()) plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta( expires_at = datetime.now(timezone.utc) + timedelta(
@@ -135,6 +149,9 @@ async def login(
if user is None or not _verify_password(body.password, user.password_hash): if user is None or not _verify_password(body.password, user.password_hash):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid credentials")
# Fetch live tier for the JWT claim
tier = await _get_live_tier(db, user.id)
plain_token = str(uuid.uuid4()) plain_token = str(uuid.uuid4())
expires_at = datetime.now(timezone.utc) + timedelta( expires_at = datetime.now(timezone.utc) + timedelta(
days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS
@@ -147,7 +164,7 @@ async def login(
db.add(rt) db.add(rt)
await db.commit() await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
return AuthTokens( return AuthTokens(
access_token=access_token, access_token=access_token,
refresh_token=plain_token, refresh_token=plain_token,
@@ -171,7 +188,6 @@ async def refresh(
if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now: if rt is None or rt.expires_at.replace(tzinfo=timezone.utc) < now:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired refresh token")
# Rotate: delete old token, issue new one.
await db.delete(rt) await db.delete(rt)
user_result = await db.execute(select(User).where(User.id == rt.user_id)) user_result = await db.execute(select(User).where(User.id == rt.user_id))
@@ -179,6 +195,9 @@ async def refresh(
if user is None: if user is None:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found") raise HTTPException(status.HTTP_401_UNAUTHORIZED, "User not found")
# Fetch live tier for the new JWT
tier = await _get_live_tier(db, user.id)
plain_token = str(uuid.uuid4()) plain_token = str(uuid.uuid4())
new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS) new_expires = now + timedelta(days=settings.JWT_REFRESH_TOKEN_EXPIRE_DAYS)
new_rt = RefreshToken( new_rt = RefreshToken(
@@ -189,7 +208,7 @@ async def refresh(
db.add(new_rt) db.add(new_rt)
await db.commit() await db.commit()
access_token, expires_at_ms = _make_access_token(user.id, user.email, user.tier) access_token, expires_at_ms = _make_access_token(user.id, user.email, tier)
return AuthTokens( return AuthTokens(
access_token=access_token, access_token=access_token,
refresh_token=plain_token, refresh_token=plain_token,
@@ -197,11 +216,6 @@ async def refresh(
) )
class _UpdateProfileRequest(BaseModel):
name: str | None = None
surname: str | None = None
@router.get("/me", response_model=UserProfile) @router.get("/me", response_model=UserProfile)
async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile: async def me(current_user: UserProfile = Depends(get_current_user)) -> UserProfile:
"""Return the profile for the authenticated user.""" """Return the profile for the authenticated user."""

View File

@@ -0,0 +1,66 @@
"""ForwardAuth verification endpoint for Traefik.
Traefik calls GET /api/v1/auth/verify on every request to a protected
service. This endpoint validates the JWT from the Authorization header
and returns identity headers that Traefik injects into downstream requests.
Downstream services NEVER validate JWTs themselves — they trust the
X-User-Id, X-User-Email, X-User-Tier headers injected by Traefik.
"""
from __future__ import annotations
from fastapi import APIRouter, Request, Response
from fastapi import status as http_status
from jose import JWTError, jwt
from sqlalchemy import select
from shared.config import settings
from shared.db import async_session
from shared.models import Subscription
from app.config import auth_settings
router = APIRouter(tags=["auth"])
@router.get("/auth/verify")
async def verify(request: Request) -> Response:
"""Validate JWT and return identity headers for Traefik ForwardAuth.
Returns 200 with X-User-* headers on success, 401 on failure.
Traefik copies response headers to the downstream request.
"""
auth_header = request.headers.get("Authorization", "")
if not auth_header.startswith("Bearer "):
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
token = auth_header[7:] # strip "Bearer "
try:
payload = jwt.decode(
token, auth_settings.JWT_PUBLIC_KEY, algorithms=["RS256"]
)
user_id: str | None = payload.get("sub")
email: str | None = payload.get("email")
if not user_id or not email:
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
except JWTError:
return Response(status_code=http_status.HTTP_401_UNAUTHORIZED)
# Live tier lookup from subscriptions table
async with async_session() as db:
result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id)
)
default_tier = "power" if settings.ENV == "dev" else "free"
tier: str = result.scalar_one_or_none() or default_tier
return Response(
status_code=http_status.HTTP_200_OK,
headers={
"X-User-Id": user_id,
"X-User-Email": email,
"X-User-Tier": tier,
},
)

View File

@@ -0,0 +1,11 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
python-jose[cryptography]>=3.3.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
bcrypt>=4.2.0
cryptography>=42.0.0
python-dotenv>=1.0.0

View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/batch-agent/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/batch-agent/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Batch runs are long-lived — use a longer timeout than chat (300s vs 120s)
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "2", \
"--timeout", "300"]

View File

@@ -0,0 +1,23 @@
# Batch Agent Service
Owns: agent_runner, journey builder, filesystem_agent, integrations (Gmail, MS Graph).
## Tables owned
- `local_agent_configs`
- `cloud_agent_configs`
- `agent_run_logs`
## Endpoints
- `GET /agents/catalog`
- `POST /agents/can-create`
- `POST /agents/trigger`
- `GET /agents/{id}/history`
## Redis channels
- Subscribe: `batch:request:{user_id}`
- Publish: `ws:out:{user_id}` (journey replies + tool calls)
- BRPOP: `tool:result:{call_id}` (30s timeout)
- SET+EX: `journey:{user_id}` (session state, TTL 1800s)
## TODO
- [ ] Integrate Langfuse tracing (reuse `services/chat/app/tracing.py` pattern — `trace_span()`, `get_langfuse_callback()`, prompt management). Each batch agent run should create a trace with input/output, link prompts, and pass the LangChain `CallbackHandler` to LLM calls.

View File

@@ -0,0 +1,910 @@
"""Agent run orchestrator — adapted for Batch Agent Service.
Key changes from monolith app/core/agent_runner.py:
- No DeviceConnectionManager — tool calls go through Redis ws_context.
- set_current_user / clear_current_user replace set_client_executor.
- run_local_agent accepts a serialized dict (from Redis / REST) instead
of SQLAlchemy model objects.
- _finalize_run writes to PostgreSQL via shared.db.async_session.
- Cloud agent import path changed to app.integrations.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timedelta, timezone
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from sqlalchemy import select
from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from shared.agents.note_agent import NOTE_TOOLS
from shared.agents.project_agent import PROJECT_TOOLS
from shared.agents.task_agent import TASK_TOOLS
from shared.agents.timeline_agent import TIMELINE_TOOLS
from shared.llm import get_llm
from shared.ws_context import execute_on_client, set_current_user, clear_current_user
import app.tracing as tracing
from shared.db import async_session
from shared.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from shared.redis import redis_client, ws_out_channel
logger = logging.getLogger(__name__)
# ── Concurrency guard ─────────────────────────────────────────────────────
_running_agents: set[str] = set()
def is_agent_running(agent_id: str) -> bool:
return agent_id in _running_agents
# ── Timeouts ───────────────────────────────────────────────────────────────
_TOOL_CALL_TIMEOUT: int = 30
_MAX_PROCESSING_STEPS: int = 12
_MAX_SCAN_DEPTH: int = 5
# ── Data-type to tool mapping ─────────────────────────────────────────────
_DATA_TYPE_TOOLS: dict[str, list[Any]] = {
"tasks": TASK_TOOLS,
"notes": NOTE_TOOLS,
"timelines": TIMELINE_TOOLS,
}
# ── Step 1: Classification prompt ─────────────────────────────────────────
_DOMAIN_DESCRIPTIONS: dict[str, str] = {
"tasks": (
"Action items, to-dos, deliverables — anything that describes work to be done, "
"assigned to someone, or tracked with a due date or status."
),
"notes": (
"Documentation, meeting notes, summaries, reference material — "
"written content meant to be read and referenced rather than acted on."
),
"timelines": (
"Project milestones, deadlines, scheduled events — "
"specific dates that mark a point in the progress of a project."
),
"projects": (
"High-level project entities — only relevant if the file clearly introduces "
"a new project or updates the scope of an existing one."
),
}
_STEP1_SYSTEM_PROMPT = """\
You are a file classifier for a freelance project management tool.
Your job is to match a file to an existing project and identify which data domains to extract.
## Project matching rules (STRICT — follow in order)
1. Search the file content for any mention of a project name, client name, acronym, or topic
that overlaps with the existing projects listed below.
2. The match does NOT need to be exact — partial name, abbreviation, or topic similarity is enough.
3. STRONGLY PREFER matching an existing project. Only return "new" as an absolute last resort
when the file has zero meaningful connection to any listed project.
4. When in doubt, pick the closest match from the list.
## Response format
Respond ONLY with a JSON object — no markdown, no explanation:
{{"project_id": "<exact id from the list below, or new>", "new_project_name": "<concise 2-5 word name, only when project_id is new>", "domains": ["tasks", "notes"]}}
## Domain definitions (only consider domains in the allowed list)
{domain_definitions}
## Existing projects
{projects_list}
"""
# ── Step 2: Processing prompt ─────────────────────────────────────────────
_PROCESSING_SYSTEM_PROMPT = """\
You are a data extraction assistant for a freelance project management tool.
Your task: extract structured data from the file content and persist it using the available tools.
## Mandatory process — follow this order for EVERY item you extract
1. READ the existing records listed below for the relevant domain.
2. SEARCH for a match by title, topic, or semantic similarity.
3. If a match exists → call the update_* tool with the existing record's id.
4. If no match exists → call the create_* tool and set isAiSuggested=1.
NEVER call create_* without first checking the existing records.
NEVER duplicate a record that already exists under a different wording.
## Existing records (source of truth)
{existing_context}
## Context
Project: {project_context}
Domains to extract: {data_types}
{custom_prompt_section}
"""
# ── Cloud processing prompt ───────────────────────────────────────────────
_CLOUD_PROCESSING_PROMPT = """\
You are a data extraction and management assistant for a freelance project
management tool.
Available tools:
Filesystem : read_file_content, list_directory, get_file_metadata
Tasks : list_tasks, create_task, update_task, add_task_comment
Notes : list_notes, get_note, create_note, update_note
Timelines : list_timelines, create_timeline, update_timeline
Projects : list_all_projects, get_project, create_project, update_project
Your task:
1. Read the full content of each file below using read_file_content.
2. For each piece of information found, ALWAYS try to match and update an
existing record before creating a new one.
3. ONLY act on these entity types: {data_types}.
4. Do NOT invent data. Only extract what is clearly present in the files.
5. If a file contains no relevant data for the target entity types, skip it.
{project_context}
Files to process:
{file_list}
{custom_prompt_section}
After processing all files, respond with a brief summary of what you updated
and what you created.
"""
# ── LLM tool-calling loop ─────────────────────────────────────────────────
def _as_text(content: Any) -> str:
if content is None:
return ""
if isinstance(content, str):
return content
if isinstance(content, list):
parts: list[str] = []
for item in content:
if isinstance(item, str):
parts.append(item)
elif isinstance(item, dict):
text = item.get("text")
if isinstance(text, str):
parts.append(text)
return "".join(parts)
return str(content)
async def _run_agent_with_tools(
*,
system_prompt: str,
user_message: str,
tools: list[Any],
max_steps: int,
langfuse_handler: Any | None = None,
) -> str:
"""Run an LLM agent with tool-calling, returning the final text response."""
callbacks = [langfuse_handler] if langfuse_handler else None
llm = get_llm(callbacks=callbacks)
llm_with_tools = llm.bind_tools(tools)
messages: list[Any] = [
SystemMessage(content=system_prompt),
HumanMessage(content=user_message),
]
tool_map = {tool_def.name: tool_def for tool_def in tools}
for _ in range(max_steps):
response: AIMessage = await llm_with_tools.ainvoke(messages)
messages.append(response)
if not response.tool_calls:
return _as_text(response.content)
for call in response.tool_calls:
call_id = str(call.get("id", ""))
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_runner: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
tool_fn = tool_map.get(call_name)
if tool_fn is None:
tool_output = f"Unknown tool: {call_name}"
else:
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_runner: tool_result name=%s output=%s",
call_name,
str(tool_output)[:200],
)
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
final = await llm.ainvoke(messages)
return _as_text(final.content)
# ── Tool list builder ─────────────────────────────────────────────────────
def _build_processing_tools(data_types: list[str]) -> list[Any]:
tools: list[Any] = list(FILESYSTEM_TOOLS)
for dt in data_types:
dt_tools = _DATA_TYPE_TOOLS.get(dt)
if dt_tools:
tools.extend(dt_tools)
return tools
# ── Code-based directory scanner ─────────────────────────────────────────
async def _scan_directories(
paths: list[str],
extensions: list[str],
last_run_at: datetime | None,
) -> list[str]:
all_files: list[str] = []
ext_set = {e.lstrip(".").lower() for e in extensions} if extensions else set()
async def _walk(path: str, depth: int) -> None:
if depth > _MAX_SCAN_DEPTH:
return
try:
result = await execute_on_client(action="list_directory", data={"path": path})
except Exception as exc:
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
return
for entry in result.get("entries", []):
entry_path = entry.get("path", "")
if not entry_path:
continue
if entry.get("type") == "directory":
await _walk(entry_path, depth + 1)
elif entry.get("type") == "file":
if ext_set:
dot_pos = entry_path.rfind(".")
file_ext = entry_path[dot_pos + 1:].lower() if dot_pos != -1 else ""
if file_ext not in ext_set:
continue
all_files.append(entry_path)
for root in paths:
await _walk(root, depth=0)
if last_run_at is None:
return all_files
last_run_ms = int(last_run_at.timestamp() * 1000)
filtered: list[str] = []
for file_path in all_files:
try:
meta = await execute_on_client(action="get_file_metadata", data={"path": file_path})
modified_at = meta.get("modifiedAt")
if modified_at is None:
filtered.append(file_path)
continue
if isinstance(modified_at, (int, float)):
mod_ms = int(modified_at)
else:
mod_ms = int(datetime.fromisoformat(str(modified_at)).timestamp() * 1000)
if mod_ms > last_run_ms:
filtered.append(file_path)
except Exception:
filtered.append(file_path)
return filtered
# ── Code-based entity fetchers ────────────────────────────────────────────
async def _fetch_projects() -> list[dict]:
try:
result = await execute_on_client(action="select", table="projects")
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch projects: %s", exc)
return []
_DOMAIN_TABLE: dict[str, str] = {
"tasks": "tasks",
"notes": "notes",
"timelines": "timelines",
"projects": "projects",
}
async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
table = _DOMAIN_TABLE.get(domain)
if not table:
return []
filters: dict[str, Any] = {}
if project_id != "standalone" and domain != "projects":
filters["projectId"] = project_id
try:
result = await execute_on_client(
action="select",
table=table,
filters=filters if filters else None,
)
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
return []
def _format_entities_for_context(domain: str, rows: list[dict]) -> str:
if not rows:
return f"No existing {domain}."
lines: list[str] = []
for r in rows:
if domain == "tasks":
desc = r.get("description") or ""
desc_part = f"{desc[:120]}" if desc else ""
assignee = r.get("assignee") or r.get("assignees") or ""
due = r.get("dueDate") or r.get("due_date") or ""
meta = ", ".join(filter(None, [
f"priority: {r.get('priority', '')}" if r.get("priority") else "",
f"assignee: {assignee}" if assignee else "",
f"due: {due}" if due else "",
]))
lines.append(
f" - [{r.get('status', '?')}] {r.get('title', '')}{desc_part}"
f" ({meta}, id: {r['id']})"
)
elif domain == "notes":
snippet = (r.get("content") or "")[:200].replace("\n", " ")
snippet_part = f"\n Preview: {snippet}" if snippet else ""
lines.append(
f" - {r.get('title', '')} (id: {r['id']}){snippet_part}"
)
elif domain == "timelines":
lines.append(
f" - {r.get('title', '')} date={r.get('date', '')} (id: {r['id']})"
)
elif domain == "projects":
summary = (r.get("aiSummary") or r.get("ai_summary") or "")[:120]
summary_part = f"{summary}" if summary else ""
lines.append(
f" - {r.get('name', '')} [{r.get('status', '')}]{summary_part}"
f" (id: {r['id']})"
)
return f"Existing {domain}:\n" + "\n".join(lines)
# ── Step 1: LLM file classifier ───────────────────────────────────────────
async def _classify_file(
file_path: str,
file_content: str,
projects: list[dict],
config_data_types: list[str],
langfuse_handler: Any | None = None,
custom_system_prompt: str | None = None,
) -> tuple[str, list[str], str | None]:
fallback: tuple[str, list[str], str | None] = ("new", list(config_data_types), None)
if not file_content.strip():
return fallback
valid_project_ids = {p["id"] for p in projects}
def _fmt_project(p: dict) -> str:
summary = (p.get("aiSummary") or p.get("ai_summary") or "").strip()
summary_part = f"{summary[:100]}" if summary else ""
return f" - id={p['id']} | name={p.get('name', '')} | status={p.get('status', '')}{summary_part}"
projects_list = "\n".join(_fmt_project(p) for p in projects) or " (none yet)"
domain_definitions = "\n".join(
f" - {d}: {_DOMAIN_DESCRIPTIONS[d]}"
for d in config_data_types
if d in _DOMAIN_DESCRIPTIONS
)
if custom_system_prompt:
# Fixture-provided prompt takes absolute priority
system = custom_system_prompt.format_map(
{"domain_definitions": domain_definitions, "projects_list": projects_list}
)
else:
system = tracing.compile_prompt(
"batch_file_classifier",
fallback=_STEP1_SYSTEM_PROMPT,
variables={
"domain_definitions": domain_definitions,
"projects_list": projects_list,
},
)
llm = get_llm(callbacks=[langfuse_handler] if langfuse_handler else None)
try:
response = await llm.ainvoke([
SystemMessage(content=system),
HumanMessage(content=f"File: {file_path}\n\nContent:\n{file_content[:4000]}"),
])
raw = _as_text(response.content).strip()
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
parsed = json.loads(raw.strip())
raw_project_id: str = str(parsed.get("project_id") or "new")
project_id = raw_project_id if raw_project_id in valid_project_ids else "new"
new_project_name: str | None = (
str(parsed["new_project_name"]).strip() or None
if project_id == "new" and parsed.get("new_project_name")
else None
)
domains: list[str] = [
d for d in parsed.get("domains", [])
if d in config_data_types
]
if not domains:
domains = list(config_data_types)
return project_id, domains, new_project_name
except Exception as exc:
logger.warning(
"agent_runner: step1 classification failed for %r: %s", file_path, exc
)
return fallback
# ── Local agent runner (two-step per file) ────────────────────────────────
async def run_local_agent(user_id: str, trigger_data: dict[str, Any], *, langfuse_handler: Any | None = None) -> None:
"""Execute a local directory agent run.
In the microservice world, trigger_data is a serialized dict from
the REST route (forwarded via Redis), containing the agent config
fields and run_context.
set_current_user() must be called BEFORE this function.
"""
run_context: dict = trigger_data.get("run_context", {})
agent_id = run_context.get("agent_id", str(uuid.uuid4()))
run_id = run_context.get("run_id")
_running_agents.add(agent_id)
# Extract config from trigger payload
directory_paths: list[str] = trigger_data.get("directory_paths", [])
if not directory_paths:
directory = trigger_data.get("directory", "")
if directory:
directory_paths = [directory]
data_types: list[str] = trigger_data.get("data_types", [])
file_extensions: list[str] = trigger_data.get("file_extensions", [])
prompt_template: str = trigger_data.get("prompt_template", "")
last_run_at_raw = trigger_data.get("last_run_at")
last_run_at: datetime | None = None
if last_run_at_raw:
if isinstance(last_run_at_raw, str):
last_run_at = datetime.fromisoformat(last_run_at_raw)
elif isinstance(last_run_at_raw, (int, float)):
last_run_at = datetime.fromtimestamp(last_run_at_raw / 1000, tz=timezone.utc)
errors: list[str] = []
items_processed = 0
items_created = 0
custom_section = (
f"User instructions:\n{prompt_template}"
if prompt_template
else ""
)
# Create or load run log
run_log_id = run_id
if not run_log_id:
async with async_session() as db:
run_log = AgentRunLog(
agent_id=agent_id,
agent_type="local",
user_id=user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
try:
# ── Scan directories ─────────────────────────────────────────
logger.info("agent_runner: run=%s scanning directories user=%s", run_log_id, user_id)
file_paths = await _scan_directories(
paths=directory_paths,
extensions=file_extensions,
last_run_at=last_run_at,
)
logger.info(
"agent_runner: run=%s found %d file(s) after filtering", run_log_id, len(file_paths)
)
if not file_paths:
await _finalize_run(run_log_id, status="success", items_processed=0, items_created=0)
return
# ── Fetch all projects once ──────────────────────────────────
projects = await _fetch_projects()
for file_path in file_paths:
try:
file_result = await execute_on_client(
action="read_file_content", data={"path": file_path}
)
file_content: str = file_result.get("content", "")
if not file_content:
continue
items_processed += 1
# Step 1 — classify file
project_id, domains, new_project_name = await _classify_file(
file_path=file_path,
file_content=file_content,
projects=projects,
config_data_types=data_types,
langfuse_handler=langfuse_handler,
)
# Step 2 — resolve project_id, fetch entities, process
if project_id == "new":
proj_name = new_project_name or "Untitled Project"
try:
proj_result = await execute_on_client(
action="insert",
table="projects",
data={"name": proj_name, "clientId": None},
)
created = proj_result.get("row", {})
effective_project_id = created.get("id", "standalone")
if "id" in created:
projects.append(created)
except Exception as exc:
logger.warning("agent_runner: run=%s create project failed: %s", run_log_id, exc)
effective_project_id = "standalone"
proj_name = "unknown"
project_context = (
f"Project: {proj_name} (id: {effective_project_id}). "
"Always set projectId to this id on every record you create."
)
else:
effective_project_id = project_id
proj = next((p for p in projects if p["id"] == project_id), None)
proj_name = proj.get("name", project_id) if proj else project_id
project_context = (
f"Project: {proj_name} (id: {project_id}). "
"Always set projectId to this id on every record you create."
)
domains = [d for d in domains if d != "projects"]
existing_blocks: list[str] = []
for domain in domains:
rows = await _fetch_domain_entities(domain, effective_project_id)
existing_blocks.append(_format_entities_for_context(domain, rows))
existing_context = "\n\n".join(existing_blocks)
system_prompt = tracing.compile_prompt(
"batch_processing",
fallback=_PROCESSING_SYSTEM_PROMPT,
variables={
"existing_context": existing_context,
"project_context": project_context,
"data_types": ", ".join(domains),
"custom_prompt_section": custom_section,
},
)
processing_tools = _build_processing_tools(domains)
result_text = await _run_agent_with_tools(
system_prompt=system_prompt,
user_message=(
f"Process this file and extract relevant information.\n\n"
f"File: {file_path}\n\nContent:\n{file_content}"
),
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
langfuse_handler=langfuse_handler,
)
logger.info(
"agent_runner: run=%s file=%r result=%s",
run_log_id, file_path, result_text[:200],
)
except Exception as exc:
errors.append(f"Error processing '{file_path}': {exc}")
logger.error("agent_runner: run=%s file=%r failed: %s", run_log_id, file_path, exc)
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
logger.error("agent_runner: run=%s failed: %s", run_log_id, exc)
finally:
_running_agents.discard(agent_id)
# ── Finalise ────────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log_id,
status=final_status,
items_processed=items_processed,
items_created=items_created,
errors=errors,
)
# Notify Electron that the run is complete via Redis
if run_context:
try:
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps({
"type": "run_complete",
"run_context": run_context,
"status": final_status,
}))
except Exception as exc:
logger.warning("agent_runner: run=%s failed to send run_complete: %s", run_log_id, exc)
# ── Cloud agent runner ─────────────────────────────────────────────────────
_CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
async def run_cloud_agent(user_id: str, config_id: str, *, langfuse_handler: Any | None = None) -> None:
"""Execute a cloud connector agent run.
Loads the CloudAgentConfig from DB, decrypts OAuth tokens, fetches
messages from the provider, and runs LLM extraction.
set_current_user() must be called BEFORE this function.
"""
from app.integrations import decrypt_token, encrypt_token, get_provider
async with async_session() as db:
result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
config = result.scalar_one_or_none()
if config is None:
logger.error("agent_runner: cloud config %s not found", config_id)
return
# Create run log
run_log = AgentRunLog(
agent_id=config.id,
agent_type="cloud",
user_id=user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
# ── Decrypt OAuth token ────────────────────────────────────────
if not config.oauth_token_encrypted:
await _finalize_run(
run_log_id,
status="error",
errors=[f"No OAuth token stored for cloud agent '{config.name}'"],
)
return
try:
credentials_info = decrypt_token(config.oauth_token_encrypted)
except ValueError as exc:
await _finalize_run(
run_log_id,
status="error",
errors=[f"Failed to decrypt OAuth token: {exc}"],
)
return
# ── Instantiate provider ──────────────────────────────────────
try:
provider = get_provider(config.provider, credentials_info)
except ValueError as exc:
await _finalize_run(run_log_id, status="error", errors=[str(exc)])
return
# ── Fetch messages ────────────────────────────────────────────
since: datetime | None = config.last_run_at
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=_CLOUD_DEFAULT_LOOKBACK_DAYS)
if since.tzinfo is None:
since = since.replace(tzinfo=timezone.utc)
errors: list[str] = []
items_processed = 0
try:
if config.provider == "gmail":
raw_messages = await provider.fetch_messages(
filter_config=config.filter_config,
since=since,
)
elif config.provider == "outlook":
raw_messages = await provider.fetch_emails(
filter_config=config.filter_config,
since=since,
)
elif config.provider == "teams":
raw_messages = await provider.fetch_messages(
filter_config=config.filter_config,
since=since,
)
else:
raw_messages = []
except RuntimeError as exc:
await _finalize_run(
run_log_id,
status="error",
errors=[f"Provider fetch failed: {exc}"],
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
return
logger.info(
"agent_runner: cloud agent %s fetched %d item(s) from %s",
config.id, len(raw_messages), config.provider,
)
# ── Extract + insert via LLM ─────────────────────────────────
try:
processing_tools = _build_processing_tools(config.data_types)
custom_section = (
f"User instructions:\n{config.prompt_template}"
if config.prompt_template
else ""
)
for msg in raw_messages:
content_text = msg.as_text
if not content_text:
continue
items_processed += 1
processing_prompt = tracing.compile_prompt(
"batch_cloud_processing",
fallback=_CLOUD_PROCESSING_PROMPT,
variables={
"data_types": ", ".join(config.data_types),
"project_context": "Determine the appropriate project from the message context.",
"file_list": f"Message from {config.provider} (id: {msg.id})",
"custom_prompt_section": custom_section,
},
)
try:
await _run_agent_with_tools(
system_prompt=processing_prompt,
user_message=f"Process this message content:\n\n{content_text[:8000]}",
tools=processing_tools,
max_steps=_MAX_PROCESSING_STEPS,
langfuse_handler=langfuse_handler,
)
except Exception as exc:
errors.append(f"LLM processing error for message {msg.id!r}: {exc}")
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
# ── Persist refreshed token ───────────────────────────────────
refreshed = getattr(provider, "refreshed_credentials", None)
if refreshed:
try:
new_encrypted = encrypt_token(refreshed)
async with async_session() as db:
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
)
cfg_row = cfg_result.scalar_one_or_none()
if cfg_row:
cfg_row.oauth_token_encrypted = new_encrypted
await db.commit()
except Exception as exc:
logger.warning("agent_runner: failed to persist refreshed token: %s", exc)
# ── Finalise ──────────────────────────────────────────────────
if errors and items_processed == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log_id,
status=final_status,
items_processed=items_processed,
items_created=0,
errors=errors,
update_config_last_run=True,
config_id=config.id,
config_type="cloud",
)
# ── Internal helper ─────────────────────────────────────────────────────────
async def _finalize_run(
run_log_id: int | str,
*,
status: str,
items_processed: int = 0,
items_created: int = 0,
errors: list[str] | None = None,
update_config_last_run: bool = False,
config_id: str | None = None,
config_type: str | None = None,
) -> None:
"""Persist the run outcome and optionally update last_run_at on the config."""
now = datetime.now(timezone.utc)
try:
async with async_session() as db:
result = await db.execute(
select(AgentRunLog).where(AgentRunLog.id == run_log_id)
)
managed = result.scalar_one_or_none()
if managed is None:
logger.warning("agent_runner: run_log %s not found for finalization", run_log_id)
return
managed.status = status
managed.items_processed = items_processed
managed.items_created = items_created
managed.errors = errors or []
managed.completed_at = now
if update_config_last_run and config_id:
if config_type == "local":
cfg_result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
elif config_type == "cloud":
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
await db.commit()
except Exception as exc:
logger.error("agent_runner: failed to finalize run_log=%s: %s", run_log_id, exc)

View File

@@ -0,0 +1 @@
"""Batch Agent Service domain agents and filesystem tools."""

View File

@@ -1,8 +1,6 @@
"""Filesystem agent — tools for reading local directories and files on Electron. """Filesystem agent — tools for reading local directories and files on Electron.
These tools delegate to the Electron client via ``execute_on_client()`` using Adapted for Batch Agent Service: import from app.ws_context.
the same WS tool-call round-trip pattern as CRUD tools. The Electron app
handles actual disk I/O and responds with ``tool_result`` frames.
""" """
from __future__ import annotations from __future__ import annotations
@@ -11,7 +9,7 @@ from typing import Any
from langchain_core.tools import tool from langchain_core.tools import tool
from app.core.ws_context import execute_on_client from shared.ws_context import execute_on_client
@tool @tool

View File

@@ -1,20 +1,11 @@
"""Cloud provider integration utilities. """Cloud provider integration utilities.
Provides: Adapted for Batch Agent Service: import from shared.config instead of app.config.
* Shared message dataclasses (``EmailMessage``, ``ChatMessage``) used by
both the Gmail and MS Graph clients and consumed by ``agent_runner``.
* ``get_provider()`` factory that returns the correct client given a
provider name and decrypted OAuth credentials dict.
* ``encrypt_token()`` / ``decrypt_token()`` Fernet-based at-rest
encryption for OAuth tokens stored in ``cloud_agent_configs``.
Encryption rationale Provides:
-------------------- * Shared message dataclasses (EmailMessage, ChatMessage)
Unlike user content (which is E2E-encrypted client-side and **never** * get_provider() factory for Gmail/MS Graph clients
decrypted server-side), OAuth tokens *must* be decrypted server-side * encrypt_token() / decrypt_token() Fernet-based OAuth token encryption
because the backend makes provider API calls on behalf of the user.
The Fernet key lives solely in ``OAUTH_ENCRYPTION_KEY`` env var it
is never returned to clients.
""" """
from __future__ import annotations from __future__ import annotations
@@ -27,7 +18,7 @@ from typing import TYPE_CHECKING
from cryptography.fernet import Fernet, InvalidToken from cryptography.fernet import Fernet, InvalidToken
from app.config.settings import settings from shared.config import settings
if TYPE_CHECKING: if TYPE_CHECKING:
from app.integrations.gmail import GmailClient from app.integrations.gmail import GmailClient
@@ -35,13 +26,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# ── Shared message types ──────────────────────────────────────────────────
@dataclass @dataclass
class EmailMessage: class EmailMessage:
"""A single email message fetched from Gmail or Outlook."""
id: str id: str
subject: str subject: str
sender: str sender: str
@@ -51,7 +38,6 @@ class EmailMessage:
@property @property
def as_text(self) -> str: def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M") date_str = self.date.strftime("%Y-%m-%d %H:%M")
labels_str = f" [{', '.join(self.labels)}]" if self.labels else "" labels_str = f" [{', '.join(self.labels)}]" if self.labels else ""
return ( return (
@@ -64,8 +50,6 @@ class EmailMessage:
@dataclass @dataclass
class ChatMessage: class ChatMessage:
"""A single Teams chat or channel message fetched from MS Graph."""
id: str id: str
content: str content: str
sender: str sender: str
@@ -74,7 +58,6 @@ class ChatMessage:
@property @property
def as_text(self) -> str: def as_text(self) -> str:
"""Return a human-readable text representation for LLM extraction."""
date_str = self.date.strftime("%Y-%m-%d %H:%M") date_str = self.date.strftime("%Y-%m-%d %H:%M")
channel_str = f" [channel: {self.channel}]" if self.channel else "" channel_str = f" [channel: {self.channel}]" if self.channel else ""
return ( return (
@@ -84,15 +67,7 @@ class ChatMessage:
) )
# ── Fernet helpers ────────────────────────────────────────────────────────
def _get_fernet() -> Fernet: def _get_fernet() -> Fernet:
"""Return a ``Fernet`` instance using ``settings.OAUTH_ENCRYPTION_KEY``.
Raises ``RuntimeError`` if ``OAUTH_ENCRYPTION_KEY`` is not set callers
must ensure this is configured before persisting OAuth tokens.
"""
key = settings.OAUTH_ENCRYPTION_KEY key = settings.OAUTH_ENCRYPTION_KEY
if not key: if not key:
raise RuntimeError( raise RuntimeError(
@@ -103,15 +78,6 @@ def _get_fernet() -> Fernet:
def encrypt_token(token_info: dict) -> str: def encrypt_token(token_info: dict) -> str:
"""Fernet-encrypt an OAuth credential dict and return a base64 string.
Stores the full ``{access_token, refresh_token, token_uri, client_id,
client_secret, scopes, expiry}`` dict (or equivalent MSAL shape).
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: ``token_info`` is not a non-empty dict.
"""
if not isinstance(token_info, dict) or not token_info: if not isinstance(token_info, dict) or not token_info:
raise ValueError("token_info must be a non-empty dict") raise ValueError("token_info must be a non-empty dict")
plaintext = json.dumps(token_info).encode("utf-8") plaintext = json.dumps(token_info).encode("utf-8")
@@ -119,13 +85,6 @@ def encrypt_token(token_info: dict) -> str:
def decrypt_token(encrypted: str) -> dict: def decrypt_token(encrypted: str) -> dict:
"""Decrypt a Fernet-encrypted token string and return the credential dict.
Raises:
RuntimeError: OAUTH_ENCRYPTION_KEY is not configured.
ValueError: The encrypted string is invalid or was encrypted with a
different key.
"""
try: try:
plaintext = _get_fernet().decrypt(encrypted.encode("utf-8")) plaintext = _get_fernet().decrypt(encrypted.encode("utf-8"))
return json.loads(plaintext) return json.loads(plaintext)
@@ -133,25 +92,10 @@ def decrypt_token(encrypted: str) -> dict:
raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc raise ValueError(f"Failed to decrypt OAuth token: {exc}") from exc
# ── Provider factory ──────────────────────────────────────────────────────
def get_provider( def get_provider(
provider: str, provider: str,
credentials_info: dict, credentials_info: dict,
) -> "GmailClient | MSGraphClient": ) -> "GmailClient | MSGraphClient":
"""Return the correct provider client for *provider*.
Parameters
----------
provider:
One of ``"gmail"``, ``"outlook"``, ``"teams"``.
credentials_info:
Decrypted OAuth credential dict (Google or Microsoft shape).
Raises:
ValueError: Unknown provider name.
"""
if provider == "gmail": if provider == "gmail":
from app.integrations.gmail import GmailClient from app.integrations.gmail import GmailClient
return GmailClient(credentials_info) return GmailClient(credentials_info)

View File

@@ -1,26 +1,7 @@
"""Gmail API client for cloud agent integration. """Gmail API client for cloud agent integration.
Wraps the Google Gmail REST API to fetch email messages matching a Adapted for Batch Agent Service: import from app.integrations instead of
``filter_config`` dict. Uses the official ``google-api-python-client`` app.integrations (same relative path within the service).
library (synchronous) wrapped in ``asyncio.to_thread()`` to avoid
blocking the event loop.
Token refresh is handled transparently: when the stored access token has
expired, ``google.auth.transport.requests.Request`` will use the refresh
token to obtain a fresh one. The caller is responsible for persisting
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
(see ``agent_runner.run_cloud_agent``).
Credential dict shape (Google OAuth2):
{
"token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_uri": "https://oauth2.googleapis.com/token",
"client_id": "<client_id>",
"client_secret": "<client_secret>",
"scopes": ["https://www.googleapis.com/auth/gmail.readonly"],
"expiry": "2025-01-01T00:00:00Z" # optional ISO-8601
}
""" """
from __future__ import annotations from __future__ import annotations
@@ -38,13 +19,8 @@ from app.integrations import EmailMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Gmail search date format — e.g. "after:2025/01/01"
_GMAIL_DATE_FMT = "%Y/%m/%d" _GMAIL_DATE_FMT = "%Y/%m/%d"
# Maximum characters of body text forwarded to the LLM.
_BODY_TRUNCATE = 8_000 _BODY_TRUNCATE = 8_000
# Maximum messages retrieved per run (prevents runaway quota usage).
_MAX_MESSAGES = 200 _MAX_MESSAGES = 200
@@ -52,20 +28,9 @@ def _build_gmail_query(
filter_config: dict[str, Any] | None, filter_config: dict[str, Any] | None,
since: datetime | None, since: datetime | None,
) -> str: ) -> str:
"""Build a Gmail search query string from *filter_config* and *since*.
Supported ``filter_config`` keys:
labels (list[str]): Gmail label names, e.g. ``["INBOX", "work"]``
senders (list[str]): Sender addresses or domains to include
date_range (dict): ``{from: "<YYYY-MM-DD>", to: "<YYYY-MM-DD>"}``
A hard ``since`` date (from last run) always overrides ``date_range.from``
when it is earlier.
"""
parts: list[str] = [] parts: list[str] = []
cfg = filter_config or {} cfg = filter_config or {}
# Labels — joined with OR when multiple given.
labels: list[str] = cfg.get("labels", []) labels: list[str] = cfg.get("labels", [])
if labels: if labels:
if len(labels) == 1: if len(labels) == 1:
@@ -74,17 +39,14 @@ def _build_gmail_query(
label_expr = " OR ".join(f"label:{lbl}" for lbl in labels) label_expr = " OR ".join(f"label:{lbl}" for lbl in labels)
parts.append(f"({label_expr})") parts.append(f"({label_expr})")
# Senders — each prefixed with "from:".
senders: list[str] = cfg.get("senders", []) senders: list[str] = cfg.get("senders", [])
for sender in senders: for sender in senders:
parts.append(f"from:{sender}") parts.append(f"from:{sender}")
# Date range.
date_range: dict = cfg.get("date_range", {}) date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from") from_str: str | None = date_range.get("from")
to_str: str | None = date_range.get("to") to_str: str | None = date_range.get("to")
# Determine effective "from" date: most recent of filter_config.date_range.from and since.
effective_since: datetime | None = since effective_since: datetime | None = since
if from_str: if from_str:
try: try:
@@ -110,18 +72,12 @@ def _build_gmail_query(
def _strip_html(raw_html: str) -> str: def _strip_html(raw_html: str) -> str:
"""Remove HTML tags and decode entities to get plain text."""
no_tags = re.sub(r"<[^>]+>", " ", raw_html) no_tags = re.sub(r"<[^>]+>", " ", raw_html)
decoded = html.unescape(no_tags) decoded = html.unescape(no_tags)
return re.sub(r"\s+", " ", decoded).strip() return re.sub(r"\s+", " ", decoded).strip()
def _parse_body(payload: dict[str, Any]) -> str: def _parse_body(payload: dict[str, Any]) -> str:
"""Recursively extract the plain-text body from a Gmail message payload.
Prefers ``text/plain``; falls back to ``text/html`` (stripped of tags).
Returns an empty string if no body can be extracted.
"""
mime_type: str = payload.get("mimeType", "") mime_type: str = payload.get("mimeType", "")
body: dict = payload.get("body", {}) body: dict = payload.get("body", {})
parts: list[dict] = payload.get("parts", []) parts: list[dict] = payload.get("parts", [])
@@ -139,7 +95,6 @@ def _parse_body(payload: dict[str, Any]) -> str:
return _strip_html(raw) return _strip_html(raw)
return "" return ""
# Multipart — prefer text/plain part, fall back to text/html.
plain_fallback = "" plain_fallback = ""
for part in parts: for part in parts:
part_mime = part.get("mimeType", "") part_mime = part.get("mimeType", "")
@@ -155,7 +110,6 @@ def _parse_body(payload: dict[str, Any]) -> str:
def _parse_date(raw: str) -> datetime: def _parse_date(raw: str) -> datetime:
"""Parse an RFC 2822 email date header into a UTC ``datetime``."""
try: try:
parsed = email.utils.parsedate_to_datetime(raw) parsed = email.utils.parsedate_to_datetime(raw)
if parsed.tzinfo is None: if parsed.tzinfo is None:
@@ -166,16 +120,6 @@ def _parse_date(raw: str) -> datetime:
class GmailClient: class GmailClient:
"""Fetch email messages from a Gmail account via the Gmail REST API.
Parameters
----------
credentials_info:
Decrypted OAuth2 credential dict. Must contain at minimum
``token`` (access token) or ``refresh_token`` + ``token_uri`` +
``client_id`` + ``client_secret``.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None: def __init__(self, credentials_info: dict[str, Any]) -> None:
from google.oauth2.credentials import Credentials from google.oauth2.credentials import Credentials
@@ -200,38 +144,20 @@ class GmailClient:
expiry=expiry, expiry=expiry,
) )
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_messages( async def fetch_messages(
self, self,
filter_config: dict[str, Any] | None = None, filter_config: dict[str, Any] | None = None,
since: datetime | None = None, since: datetime | None = None,
) -> list[EmailMessage]: ) -> list[EmailMessage]:
"""Return up to ``_MAX_MESSAGES`` emails matching *filter_config*.
Runs the synchronous Google API calls inside ``asyncio.to_thread()``
to avoid blocking the async event loop.
Token refresh is performed automatically when the access token has
expired. After the call, ``self.refreshed_credentials`` may be
consulted to detect whether new credentials should be persisted.
"""
query = _build_gmail_query(filter_config, since) query = _build_gmail_query(filter_config, since)
logger.debug("gmail: executing search query %r", query) logger.debug("gmail: executing search query %r", query)
return await asyncio.to_thread(self._fetch_sync, query) return await asyncio.to_thread(self._fetch_sync, query)
@property @property
def refreshed_credentials(self) -> dict[str, Any] | None: def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
If the credentials were refreshed during ``fetch_messages()``, returns
a new dict that should be re-encrypted and written back to the DB.
Returns ``None`` if no refresh occurred.
"""
creds = self._credentials creds = self._credentials
if not creds.valid and creds.expired: if not creds.valid and creds.expired:
return None return None
# Check whether the token changed from what was stored.
if creds.token != self._credentials_info.get("token"): if creds.token != self._credentials_info.get("token"):
result = { result = {
"token": creds.token, "token": creds.token,
@@ -246,15 +172,11 @@ class GmailClient:
return result return result
return None return None
# ── Internal sync worker ───────────────────────────────────────────────
def _fetch_sync(self, query: str) -> list[EmailMessage]: def _fetch_sync(self, query: str) -> list[EmailMessage]:
"""Synchronous worker — called inside ``asyncio.to_thread()``."""
import googleapiclient.discovery import googleapiclient.discovery
import googleapiclient.errors import googleapiclient.errors
from google.auth.transport.requests import Request from google.auth.transport.requests import Request
# Refresh token if needed before building the service.
if self._credentials.expired and self._credentials.refresh_token: if self._credentials.expired and self._credentials.refresh_token:
try: try:
self._credentials.refresh(Request()) self._credentials.refresh(Request())
@@ -264,9 +186,8 @@ class GmailClient:
service = googleapiclient.discovery.build( service = googleapiclient.discovery.build(
"gmail", "v1", credentials=self._credentials, cache_discovery=False "gmail", "v1", credentials=self._credentials, cache_discovery=False
) )
user_api = service.users() # type: ignore[attr-defined] user_api = service.users()
# ── List matching message IDs ──────────────────────────────────────
ids: list[str] = [] ids: list[str] = []
page_token: str | None = None page_token: str | None = None
while len(ids) < _MAX_MESSAGES: while len(ids) < _MAX_MESSAGES:
@@ -293,12 +214,10 @@ class GmailClient:
break break
if not ids: if not ids:
logger.debug("gmail: no messages matched query %r", query)
return [] return []
logger.info("gmail: fetching %d message(s)", len(ids)) logger.info("gmail: fetching %d message(s)", len(ids))
# ── Fetch individual message details ──────────────────────────────
messages: list[EmailMessage] = [] messages: list[EmailMessage] = []
for msg_id in ids: for msg_id in ids:
try: try:
@@ -326,10 +245,8 @@ class GmailClient:
date=date, date=date,
labels=labels, labels=labels,
)) ))
except googleapiclient.errors.HttpError as exc:
logger.warning("gmail: skipping message %s — HTTP error: %s", msg_id, exc)
except Exception as exc: except Exception as exc:
logger.warning("gmail: skipping message %s — unexpected error: %s", msg_id, exc) logger.warning("gmail: skipping message %s: %s", msg_id, exc)
logger.info("gmail: returned %d message(s)", len(messages)) logger.info("gmail: returned %d message(s)", len(messages))
return messages return messages

View File

@@ -1,52 +1,30 @@
"""Microsoft Graph API client for Outlook and Teams cloud agent integration. """Microsoft Graph API client for Outlook and Teams.
Handles two data sources: Adapted for Batch Agent Service: import settings from shared.config.
* **Outlook email** (``provider="outlook"``) ``fetch_emails()`` calls
``/me/messages`` with an OData ``$filter`` built from ``filter_config``.
* **Teams messages** (``provider="teams"``) ``fetch_messages()`` calls
``/me/chats/getAllMessages`` filtered by date.
Authentication uses MSAL ``PublicClientApplication`` to acquire a token
from a stored refresh token. The ``httpx.AsyncClient`` (already a project
dependency) is used for all API calls.
Credential dict shape (Microsoft OAuth2 / MSAL):
{
"access_token": "<access_token>",
"refresh_token": "<refresh_token>",
"token_type": "Bearer",
"scope": "Mail.Read ChannelMessage.Read.All offline_access",
"expires_in": 3600
}
""" """
from __future__ import annotations from __future__ import annotations
import logging import logging
import re import re
from datetime import datetime, timedelta, timezone from datetime import datetime, timezone
from typing import Any from typing import Any
import httpx import httpx
from app.config.settings import settings from shared.config import settings
from app.integrations import ChatMessage, EmailMessage from app.integrations import ChatMessage, EmailMessage
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_GRAPH_BASE = "https://graph.microsoft.com/v1.0" _GRAPH_BASE = "https://graph.microsoft.com/v1.0"
# Max items fetched per run.
_MAX_EMAILS = 200 _MAX_EMAILS = 200
_MAX_MESSAGES = 200 _MAX_MESSAGES = 200
# Max characters of body forwarded to the LLM.
_BODY_TRUNCATE = 8_000 _BODY_TRUNCATE = 8_000
def _strip_html(raw: str) -> str: def _strip_html(raw: str) -> str:
"""Strip HTML tags and collapse whitespace."""
no_tags = re.sub(r"<[^>]+>", " ", raw) no_tags = re.sub(r"<[^>]+>", " ", raw)
import html as _html import html as _html
decoded = _html.unescape(no_tags) decoded = _html.unescape(no_tags)
@@ -54,7 +32,6 @@ def _strip_html(raw: str) -> str:
def _odata_datetime(dt: datetime) -> str: def _odata_datetime(dt: datetime) -> str:
"""Format a datetime as an OData datetime literal (UTC, ISO 8601)."""
utc = dt.astimezone(timezone.utc) utc = dt.astimezone(timezone.utc)
return utc.strftime("%Y-%m-%dT%H:%M:%SZ") return utc.strftime("%Y-%m-%dT%H:%M:%SZ")
@@ -63,29 +40,14 @@ def _build_email_filter(
filter_config: dict[str, Any] | None, filter_config: dict[str, Any] | None,
since: datetime | None, since: datetime | None,
) -> str: ) -> str:
"""Build an OData ``$filter`` expression for the ``/me/messages`` endpoint.
Supported ``filter_config`` keys:
senders (list[str]): Sender email addresses.
date_range (dict): ``{from: "<ISO-8601>", to: "<ISO-8601>"}``
folders (list[str]): Folder display names (not directly filterable
via OData, so ignored here callers iterate
folder IDs separately if needed; listed for
completeness).
A hard ``since`` date always overrides ``date_range.from`` when it is
earlier.
"""
clauses: list[str] = [] clauses: list[str] = []
cfg = filter_config or {} cfg = filter_config or {}
# Senders.
senders: list[str] = cfg.get("senders", []) senders: list[str] = cfg.get("senders", [])
if senders: if senders:
sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders] sender_clauses = [f"from/emailAddress/address eq '{s}'" for s in senders]
clauses.append("(" + " or ".join(sender_clauses) + ")") clauses.append("(" + " or ".join(sender_clauses) + ")")
# Date range.
date_range: dict = cfg.get("date_range", {}) date_range: dict = cfg.get("date_range", {})
from_str: str | None = date_range.get("from") from_str: str | None = date_range.get("from")
@@ -117,33 +79,16 @@ def _build_email_filter(
class MSGraphClient: class MSGraphClient:
"""Fetch emails and Teams messages via the Microsoft Graph REST API.
Parameters
----------
credentials_info:
Decrypted MSAL credential dict.
"""
def __init__(self, credentials_info: dict[str, Any]) -> None: def __init__(self, credentials_info: dict[str, Any]) -> None:
self._credentials_info = credentials_info self._credentials_info = credentials_info
self._access_token: str = credentials_info.get("access_token", "") self._access_token: str = credentials_info.get("access_token", "")
self._original_access_token: str = self._access_token self._original_access_token: str = self._access_token
self._refresh_token: str | None = credentials_info.get("refresh_token") self._refresh_token: str | None = credentials_info.get("refresh_token")
# ── Token management ───────────────────────────────────────────────────
def _auth_headers(self) -> dict[str, str]: def _auth_headers(self) -> dict[str, str]:
return {"Authorization": f"Bearer {self._access_token}"} return {"Authorization": f"Bearer {self._access_token}"}
async def _refresh_access_token(self) -> None: async def _refresh_access_token(self) -> None:
"""Use MSAL to exchange the refresh token for a fresh access token.
Updates ``self._access_token`` and ``self._credentials_info`` in-place.
Raises:
RuntimeError: MSAL reports an auth error.
"""
import msal import msal
app = msal.ConfidentialClientApplication( app = msal.ConfidentialClientApplication(
@@ -164,7 +109,6 @@ class MSGraphClient:
raise RuntimeError(f"MS Graph token refresh failed: {error}") raise RuntimeError(f"MS Graph token refresh failed: {error}")
self._access_token = result["access_token"] self._access_token = result["access_token"]
# MSAL may issue a new refresh token.
if "refresh_token" in result: if "refresh_token" in result:
self._refresh_token = result["refresh_token"] self._refresh_token = result["refresh_token"]
self._credentials_info["refresh_token"] = result["refresh_token"] self._credentials_info["refresh_token"] = result["refresh_token"]
@@ -172,16 +116,10 @@ class MSGraphClient:
@property @property
def refreshed_credentials(self) -> dict[str, Any] | None: def refreshed_credentials(self) -> dict[str, Any] | None:
"""Return updated credential dict if the access token was refreshed.
Returns ``None`` if no change was made.
"""
if self._access_token != self._original_access_token: if self._access_token != self._original_access_token:
return {**self._credentials_info, "access_token": self._access_token} return {**self._credentials_info, "access_token": self._access_token}
return None return None
# ── HTTP helpers ───────────────────────────────────────────────────────
async def _get( async def _get(
self, self,
client: httpx.AsyncClient, client: httpx.AsyncClient,
@@ -190,10 +128,8 @@ class MSGraphClient:
*, *,
retry_on_401: bool = True, retry_on_401: bool = True,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""GET *url* with auth; refresh token on 401 and retry once."""
resp = await client.get(url, params=params, headers=self._auth_headers()) resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 401 and retry_on_401 and self._refresh_token: if resp.status_code == 401 and retry_on_401 and self._refresh_token:
logger.debug("ms_graph: 401 on %s — refreshing token", url)
await self._refresh_access_token() await self._refresh_access_token()
resp = await client.get(url, params=params, headers=self._auth_headers()) resp = await client.get(url, params=params, headers=self._auth_headers())
if resp.status_code == 429: if resp.status_code == 429:
@@ -201,22 +137,11 @@ class MSGraphClient:
resp.raise_for_status() resp.raise_for_status()
return resp.json() return resp.json()
# ── Public API ─────────────────────────────────────────────────────────
async def fetch_emails( async def fetch_emails(
self, self,
filter_config: dict[str, Any] | None = None, filter_config: dict[str, Any] | None = None,
since: datetime | None = None, since: datetime | None = None,
) -> list[EmailMessage]: ) -> list[EmailMessage]:
"""Return up to ``_MAX_EMAILS`` Outlook messages matching *filter_config*.
Parameters
----------
filter_config:
Optional dict with ``senders``, ``date_range``, ``folders`` keys.
since:
Hard lower-bound on email date (from last agent run).
"""
odata_filter = _build_email_filter(filter_config, since) odata_filter = _build_email_filter(filter_config, since)
params: dict[str, Any] = { params: dict[str, Any] = {
"$top": 50, "$top": 50,
@@ -237,7 +162,7 @@ class MSGraphClient:
if len(emails) >= _MAX_EMAILS: if len(emails) >= _MAX_EMAILS:
break break
url = data.get("@odata.nextLink", "") url = data.get("@odata.nextLink", "")
params = {} # nextLink already contains encoded params. params = {}
logger.info("ms_graph: fetched %d Outlook email(s)", len(emails)) logger.info("ms_graph: fetched %d Outlook email(s)", len(emails))
return emails return emails
@@ -247,13 +172,6 @@ class MSGraphClient:
filter_config: dict[str, Any] | None = None, filter_config: dict[str, Any] | None = None,
since: datetime | None = None, since: datetime | None = None,
) -> list[ChatMessage]: ) -> list[ChatMessage]:
"""Return up to ``_MAX_MESSAGES`` Teams messages matching *filter_config*.
Fetches from ``/me/chats/getAllMessages`` (personal + group chats).
The ``filter_config.channels`` key is checked as a text-filter on
the channel name post-fetch (the API doesn't support channel OData
filter directly on ``getAllMessages``).
"""
cfg = filter_config or {} cfg = filter_config or {}
channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])] channel_filter: list[str] = [c.lower() for c in cfg.get("channels", [])]
params: dict[str, Any] = {"$top": 50} params: dict[str, Any] = {"$top": 50}
@@ -268,11 +186,9 @@ class MSGraphClient:
try: try:
data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None) data = await self._get(client, url, params if url.startswith(_GRAPH_BASE) else None)
except httpx.HTTPStatusError as exc: except httpx.HTTPStatusError as exc:
# getAllMessages requires specific licensing; degrade gracefully.
if exc.response.status_code in (403, 404): if exc.response.status_code in (403, 404):
logger.warning( logger.warning(
"ms_graph: /me/chats/getAllMessages not available (%d)" "ms_graph: /me/chats/getAllMessages not available (%d)",
"check Teams license or permissions",
exc.response.status_code, exc.response.status_code,
) )
break break
@@ -292,8 +208,6 @@ class MSGraphClient:
logger.info("ms_graph: fetched %d Teams message(s)", len(messages)) logger.info("ms_graph: fetched %d Teams message(s)", len(messages))
return messages return messages
# ── Parsers ────────────────────────────────────────────────────────────
@staticmethod @staticmethod
def _parse_email(item: dict[str, Any]) -> EmailMessage: def _parse_email(item: dict[str, Any]) -> EmailMessage:
subject: str = item.get("subject", "(no subject)") or "(no subject)" subject: str = item.get("subject", "(no subject)") or "(no subject)"

View File

@@ -1,22 +1,16 @@
"""Chatbot Journey — WS-based guided conversation to build an agent prompt_template. """Chatbot Journey — guided conversation to build an agent prompt_template.
The journey is driven entirely through WebSocket frames (no REST endpoints). Adapted for Batch Agent Service: imports from app.agents.filesystem_agent
The device WS handler dispatches ``journey_start`` and ``journey_message`` and app.llm instead of monolith paths. Session state is in-memory (could
frames to the functions exported here. be moved to Redis for horizontal scaling in the future).
Journey flow: Journey flow:
1. FE sends ``journey_start`` frame with basic agent config (directory, 1. Redis consumer dispatches ``journey_start`` with basic agent config.
data_types, schedule). 2. Server creates an in-memory session, runs the setup LLM with
2. Server creates an in-memory session, sets up a WS executor so the file-system tools to explore the directory, returns first question.
setup LLM can use file-system tools, does a first directory scrape, 3. ``journey_message`` frames drive the conversation.
and sends back a ``journey_reply`` with the first question. 4. After 3-5 turns the LLM emits PROMPT_TEMPLATE_START / _END block.
3. FE sends ``journey_message`` frames for each user reply. 5. Server parses the block and returns ``journey_reply`` with ``done=True``.
4. Server appends the user message, calls the LLM (which may read files
via tools), and sends back a ``journey_reply``.
5. After 3-5 turns the LLM wraps up by emitting a ``prompt_template``
block delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
6. Server parses the block, sends ``journey_reply`` with ``done=True``
and the template. FE stores it locally.
""" """
from __future__ import annotations from __future__ import annotations
@@ -31,7 +25,8 @@ from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from app.agents.filesystem_agent import FILESYSTEM_TOOLS from app.agents.filesystem_agent import FILESYSTEM_TOOLS
from app.core.llm import get_llm from shared.llm import get_llm
import app.tracing as tracing
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -43,9 +38,8 @@ _SESSION_TTL_SECONDS: int = 1800 # 30 minutes
_TEMPLATE_START = "PROMPT_TEMPLATE_START" _TEMPLATE_START = "PROMPT_TEMPLATE_START"
_TEMPLATE_END = "PROMPT_TEMPLATE_END" _TEMPLATE_END = "PROMPT_TEMPLATE_END"
# Maximum number of conversation turns before the LLM is nudged to wrap up. _MIN_TURNS_BEFORE_NUDGE: int = 3
_MAX_TURNS: int = 5 _MAX_TURNS: int = 15
# Max tool-calling steps per LLM invocation.
_MAX_TOOL_STEPS: int = 6 _MAX_TOOL_STEPS: int = 6
# ── In-memory session store ─────────────────────────────────────────────── # ── In-memory session store ───────────────────────────────────────────────
@@ -86,17 +80,9 @@ def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
_SYSTEM_PROMPT_TEMPLATE = """\ _SYSTEM_PROMPT_TEMPLATE = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent. You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand exactly what data the user wants to extract from their Your job is to understand exactly what data the user wants to extract from their
local directory and produce a detailed prompt_template that a separate AI will use local directory and produce a concise prompt_template that a separate AI will use
as its instruction set. as its instruction set.
The extraction agent already has this base behaviour built in:
- Reads each file using file-system tools.
- Creates records (tasks, notes, timelines, projects) via CRUD tools.
- Sets isAiSuggested=1 and isApproved=0 on every record.
- Only extracts data explicitly present in the files it never invents information.
The user's custom prompt is appended AFTER this base behaviour, so focus on
what to look for and how to map it not on the general extraction mechanics.
You have access to file-system tools to explore the user's directory: You have access to file-system tools to explore the user's directory:
- list_directory: to see folder structure - list_directory: to see folder structure
- read_file_content: to peek at file contents - read_file_content: to peek at file contents
@@ -105,31 +91,43 @@ You have access to file-system tools to explore the user's directory:
The user's configured directory is: {directory} The user's configured directory is: {directory}
Target data types: {data_types} Target data types: {data_types}
Start by exploring the directory to understand its structure. Then ask concise, IMPORTANT project assignment is handled automatically. You MUST NOT ask the user
focused questions one at a time. Cover these topics (not necessarily in this order): about projects, projectId, or how to link records to projects. Never include
1. The type and format of the source content (confirmed by your exploration). projectId logic or project creation instructions in the generated prompt_template.
2. How fields should be mapped (e.g. filename task title).
3. Priority or status rules (e.g. "urgent" keyword high priority).
4. Any special handling, date extraction, or exclusions.
After 3-5 questions (when you have enough information), output the final prompt_template Start by exploring the directory to understand its structure. Then ask concise,
between these exact markers on their own lines: focused questions one at a time. Cover only the topics relevant to the target
data types listed above:
1. Content type and format confirmed by your exploration.
2. For TASKS (if in scope): field mapping for title, status, priority, content,
dueDate (where is the date found? what's the fallback when absent?),
and assignee (is there a person name to assign?).
3. For NOTES when TASKS are also in scope: note vs task distinction
what makes something a note rather than a task?
4. For TIMELINES (if in scope): the date source what marks a milestone or event?
5. Exclusions and special handling applicable to the target data types.
Keep asking focused questions until you are at least 90% confident. Then stop and
output the final prompt_template immediately, wrapped between these exact markers
on their own lines:
{template_start} {template_start}
<the complete extraction prompt here> <the complete extraction prompt here>
{template_end} {template_end}
The prompt_template must be a self-contained instruction for an AI that reads files The prompt_template must be concise (bullet points, ~1525 lines maximum).
and must perform CRUD operations using tools to create records. It should specify: Specify only:
- What entity types to create (tasks, notes, timelines, projects). - Scope: what files/content qualify and what entity types to create.
- How to map file content to record fields (camelCase: title, status, priority, - Field mapping rules per entity type (camelCase fields: title, status, priority,
dueDate, projectId, content, etc.). dueDate, content, assignee, etc.).
- That isAiSuggested must be set to 1 and isApproved to 0 on every record. - dueDate rule (if tasks in scope): source and fallback behaviour.
- Concrete examples of mappings based on what you discovered in the directory. - Note vs task rule (if both in scope): the criterion that separates them.
- Timeline date rule (if timelines in scope): what constitutes a timeline event.
- Exclusion/filtering rules.
- 23 concrete mapping examples based on what you discovered.
{existing_section}\ {existing_section}Begin by exploring the directory, then ask your first question.\
Do not ask more than {max_turns} questions total. Begin by exploring the directory,
then ask your first question.\
""" """
@@ -144,13 +142,15 @@ def _build_system_prompt(
if existing_template if existing_template
else "" else ""
) )
return _SYSTEM_PROMPT_TEMPLATE.format( # Use Langfuse compile_prompt ({{variable}} syntax) with Python .format() fallback
directory=directory, return tracing.compile_prompt(
data_types=", ".join(data_types), "journey_system",
template_start=_TEMPLATE_START, fallback=_SYSTEM_PROMPT_TEMPLATE,
template_end=_TEMPLATE_END, variables={
existing_section=existing_section, "directory": directory,
max_turns=_MAX_TURNS, "data_types": ", ".join(data_types),
"existing_section": existing_section,
},
) )
@@ -191,6 +191,7 @@ async def _call_llm_with_tools(
system_prompt: str, system_prompt: str,
history: list[dict[str, Any]], history: list[dict[str, Any]],
tools: list[Any], tools: list[Any],
langfuse_handler: Any | None = None,
) -> str: ) -> str:
"""Build LangChain messages from history and invoke the LLM with tools. """Build LangChain messages from history and invoke the LLM with tools.
@@ -204,7 +205,8 @@ async def _call_llm_with_tools(
else: else:
messages.append(AIMessage(content=turn["content"])) messages.append(AIMessage(content=turn["content"]))
llm = get_llm(model=None, temperature=0.4) callbacks = [langfuse_handler] if langfuse_handler else None
llm = get_llm(model=None, temperature=0.4, callbacks=callbacks)
llm_with_tools = llm.bind_tools(tools) llm_with_tools = llm.bind_tools(tools)
tool_map = {tool_def.name: tool_def for tool_def in tools} tool_map = {tool_def.name: tool_def for tool_def in tools}
@@ -219,7 +221,7 @@ async def _call_llm_with_tools(
call_name = str(call.get("name", "")) call_name = str(call.get("name", ""))
call_args = call.get("args", {}) call_args = call.get("args", {})
logger.info( logger.info(
"agent_setup: journey tool_call name=%s args=%s", "journey: tool_call name=%s args=%s",
call_name, call_name,
json.dumps(call_args, ensure_ascii=True)[:500], json.dumps(call_args, ensure_ascii=True)[:500],
) )
@@ -231,25 +233,27 @@ async def _call_llm_with_tools(
tool_output = await tool_fn.ainvoke(call_args) tool_output = await tool_fn.ainvoke(call_args)
logger.info( logger.info(
"agent_setup: journey tool_result name=%s output=%s", "journey: tool_result name=%s output=%s",
call_name, call_name,
str(tool_output)[:800], str(tool_output)[:800],
) )
messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"])) messages.append(ToolMessage(content=str(tool_output), tool_call_id=call["id"]))
# Fallback: exceeded max steps. # Fallback: exceeded max tool steps.
final = await llm.ainvoke(messages) final = await llm.ainvoke(messages)
return _as_text(final.content) return _as_text(final.content)
# ── Journey handlers (called from device_ws.py) ────────────────────────── # ── Journey handlers (called from redis_consumer) ────────────────────────
async def handle_journey_start( async def handle_journey_start(
user_id: str, user_id: str,
frame: dict[str, Any], frame: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Handle a ``journey_start`` WS frame. """Handle a ``journey_start`` request.
Creates a session, runs the setup LLM with directory exploration, Creates a session, runs the setup LLM with directory exploration,
and returns the ``journey_reply`` payload. and returns the ``journey_reply`` payload.
@@ -259,8 +263,6 @@ async def handle_journey_start(
data_types = frame.get("data_types", []) data_types = frame.get("data_types", [])
existing_template = frame.get("existing_template") existing_template = frame.get("existing_template")
# Use the session_id provided by the FE so the reply matches the
# listener key; fall back to a generated one if absent.
session_id = frame.get("session_id") or str(uuid.uuid4()) session_id = frame.get("session_id") or str(uuid.uuid4())
system_prompt = _build_system_prompt(directory, data_types, existing_template) system_prompt = _build_system_prompt(directory, data_types, existing_template)
@@ -273,10 +275,6 @@ async def handle_journey_start(
system_prompt=system_prompt, system_prompt=system_prompt,
) )
# The LLM will explore the directory using FILESYSTEM_TOOLS via the
# ws_context executor (already set by the WS handler before calling us).
# Seed with an initial user message — some providers (e.g. GitHub Copilot)
# require at least one user/input message to be present.
seed_history: list[dict[str, Any]] = [ seed_history: list[dict[str, Any]] = [
{"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."}, {"role": "user", "content": "Hi, I'm ready to set up my agent. Please explore my directory and ask me your first question."},
] ]
@@ -284,6 +282,7 @@ async def handle_journey_start(
system_prompt=system_prompt, system_prompt=system_prompt,
history=seed_history, history=seed_history,
tools=list(FILESYSTEM_TOOLS), tools=list(FILESYSTEM_TOOLS),
langfuse_handler=langfuse_handler,
) )
session.history.extend(seed_history) session.history.extend(seed_history)
@@ -291,13 +290,12 @@ async def handle_journey_start(
_sessions[session_id] = session _sessions[session_id] = session
logger.info( logger.info(
"agent_setup: journey session %s started for user %s (directory=%s)", "journey: session %s started for user %s (directory=%s)",
session_id, session_id,
user_id, user_id,
directory, directory,
) )
# Check if the LLM produced the template on the first turn (unlikely but possible).
prompt_template = _extract_template(ai_reply) prompt_template = _extract_template(ai_reply)
done = prompt_template is not None done = prompt_template is not None
@@ -321,8 +319,10 @@ async def handle_journey_start(
async def handle_journey_message( async def handle_journey_message(
user_id: str, user_id: str,
frame: dict[str, Any], frame: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Handle a ``journey_message`` WS frame. """Handle a ``journey_message`` request.
Appends the user message, calls the LLM, and returns the Appends the user message, calls the LLM, and returns the
``journey_reply`` payload. ``journey_reply`` payload.
@@ -340,24 +340,20 @@ async def handle_journey_message(
"prompt_template": None, "prompt_template": None,
} }
# Append user turn.
session.history.append({"role": "user", "content": message}) session.history.append({"role": "user", "content": message})
# Call the LLM with tools.
ai_reply = await _call_llm_with_tools( ai_reply = await _call_llm_with_tools(
system_prompt=session.system_prompt, system_prompt=session.system_prompt,
history=session.history, history=session.history,
tools=list(FILESYSTEM_TOOLS), tools=list(FILESYSTEM_TOOLS),
langfuse_handler=langfuse_handler,
) )
session.history.append({"role": "assistant", "content": ai_reply}) session.history.append({"role": "assistant", "content": ai_reply})
# Check if the LLM produced the final template.
prompt_template = _extract_template(ai_reply) prompt_template = _extract_template(ai_reply)
done = prompt_template is not None done = prompt_template is not None
# If the LLM didn't produce a template but we've hit max turns, nudge it
# and call the LLM one more time to force template generation.
if not done: if not done:
turns = sum(1 for t in session.history if t["role"] == "user") turns = sum(1 for t in session.history if t["role"] == "user")
if turns >= _MAX_TURNS: if turns >= _MAX_TURNS:
@@ -371,6 +367,7 @@ async def handle_journey_message(
system_prompt=session.system_prompt, system_prompt=session.system_prompt,
history=session.history, history=session.history,
tools=list(FILESYSTEM_TOOLS), tools=list(FILESYSTEM_TOOLS),
langfuse_handler=langfuse_handler,
) )
session.history.append({"role": "assistant", "content": nudge_reply}) session.history.append({"role": "assistant", "content": nudge_reply})
@@ -387,7 +384,7 @@ async def handle_journey_message(
else "Here is your agent configuration. You can save it or continue refining." else "Here is your agent configuration. You can save it or continue refining."
) )
_sessions.pop(session_id, None) _sessions.pop(session_id, None)
logger.info("agent_setup: journey session %s completed for user %s", session_id, user_id) logger.info("journey: session %s completed for user %s", session_id, user_id)
return { return {
"type": "journey_reply", "type": "journey_reply",

View File

@@ -0,0 +1,76 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Identical to services/chat/app/llm.py. Uses shared.config.settings.
"""
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from shared.config import settings
litellm.drop_params = True
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github/"):
return settings.GITHUB_TOKEN or None
if model.startswith("github_copilot/"):
return None
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
callbacks: list | None = None,
) -> ChatOpenAI | ChatLiteLLM:
model = model or settings.LLM_MODEL
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
if settings.GITHUB_TOKEN:
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
callbacks=callbacks,
)
async def embed(text: str) -> list[float]:
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

View File

@@ -0,0 +1,79 @@
"""Batch Agent Service — FastAPI application.
Owns: agent_runner (local directory + cloud connectors), journey builder,
filesystem_agent, integrations (Gmail, MS Graph).
Communicates with WS Gateway via Redis:
- Subscribes to batch:request:{user_id} (journey_start, journey_message)
- Publishes to ws:out:{user_id} (journey replies + tool calls)
- BRPOP on tool:result:{call_id} (tool-call round-trip, 30s timeout)
- SET+EX on journey:{user_id} (journey session state, TTL 1800s)
"""
from __future__ import annotations
import asyncio
import logging
import sys
from pathlib import Path
# Ensure the repo root is on sys.path so ``shared`` is importable when
# running locally (in Docker the COPY already places it at /app/shared/).
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from contextlib import asynccontextmanager
from typing import AsyncGenerator
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.redis_consumer import start_consumer
from app.routes import router
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
# Initialise Langfuse tracing (no-op if keys are missing)
from app.tracing import init_langfuse
init_langfuse()
logger.info("batch-agent: starting Redis consumer")
task = asyncio.create_task(start_consumer())
yield
task.cancel()
try:
await task
except asyncio.CancelledError:
pass
from app.tracing import shutdown as shutdown_langfuse
shutdown_langfuse()
from shared.db import engine
await engine.dispose()
from shared.redis import redis_client
await redis_client.aclose()
logger.info("batch-agent: Redis consumer stopped")
app = FastAPI(title="Adiuva Batch Agent Service", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
app.include_router(router)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "ok", "service": "batch-agent"}

View File

@@ -0,0 +1,183 @@
"""Redis consumer for the Batch Agent Service.
Subscribes to batch:request:* (pattern) and dispatches:
- journey_start → handle_journey_start
- journey_message → handle_journey_message
- agent_trigger → run_local_agent / run_cloud_agent
Results are published back to ws:out:{user_id} via Redis.
"""
from __future__ import annotations
import asyncio
import json
import logging
from typing import Any
from shared.redis import redis_client, batch_request_channel, ws_out_channel
import app.tracing as tracing
from shared.ws_context import set_current_user, clear_current_user
logger = logging.getLogger(__name__)
async def _publish_to_user(user_id: str, payload: dict[str, Any]) -> None:
"""Publish a frame to the user's WS outbound channel."""
channel = ws_out_channel(user_id)
await redis_client.publish(channel, json.dumps(payload))
async def _handle_journey_start(user_id: str, data: dict[str, Any]) -> None:
"""Handle a journey_start request from WS Gateway."""
from app.journey import handle_journey_start
session_id = data.get("session_id", "")
set_current_user(user_id)
try:
with tracing.trace_span(
name="journey_start",
user_id=user_id,
session_id=session_id,
input=data.get("directory", ""),
metadata={"data_types": data.get("data_types", [])},
tags=["journey"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
reply = await handle_journey_start(user_id, data, langfuse_handler=langfuse_handler)
tracing.link_prompt_to_trace(span, "journey_system")
span.update(output=reply.get("message", "")[:500])
await _publish_to_user(user_id, reply)
tracing.flush()
except Exception as exc:
logger.error("batch-agent: journey_start failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey setup failed: {exc}",
"done": True,
"prompt_template": None,
})
finally:
clear_current_user()
async def _handle_journey_message(user_id: str, data: dict[str, Any]) -> None:
"""Handle a journey_message from WS Gateway."""
from app.journey import handle_journey_message
session_id = data.get("session_id", "")
set_current_user(user_id)
try:
with tracing.trace_span(
name="journey_message",
user_id=user_id,
session_id=session_id,
input=data.get("message", "")[:200],
tags=["journey"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
reply = await handle_journey_message(user_id, data, langfuse_handler=langfuse_handler)
tracing.link_prompt_to_trace(span, "journey_system")
span.update(output=reply.get("message", "")[:500])
await _publish_to_user(user_id, reply)
tracing.flush()
except Exception as exc:
logger.error("batch-agent: journey_message failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "journey_reply",
"session_id": session_id,
"message": f"Journey processing failed: {exc}",
"done": True,
"prompt_template": None,
})
finally:
clear_current_user()
async def _handle_agent_trigger(user_id: str, data: dict[str, Any]) -> None:
"""Handle an agent_trigger request from the REST route (forwarded via Redis)."""
from app.agent_runner import run_local_agent
run_context = data.get("run_context", {})
agent_id = run_context.get("agent_id", "")
set_current_user(user_id)
try:
with tracing.trace_span(
name="agent_trigger",
user_id=user_id,
trace_id=run_context.get("run_id"),
input={"agent_id": agent_id, "directory": data.get("directory", "")},
metadata={"data_types": data.get("data_types", [])},
tags=["batch", "agent_run"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
await run_local_agent(user_id, data, langfuse_handler=langfuse_handler)
tracing.link_prompt_to_trace(span, "batch_processing")
span.update(output={"status": "completed"})
tracing.flush()
except Exception as exc:
logger.error("batch-agent: agent_trigger failed user=%s: %s", user_id, exc)
await _publish_to_user(user_id, {
"type": "run_complete",
"status": "error",
"run_context": run_context,
})
finally:
clear_current_user()
async def _dispatch(user_id: str, message_data: dict[str, Any]) -> None:
"""Route a batch request to the correct handler."""
msg_type = message_data.get("type", "")
if msg_type == "journey_start":
await _handle_journey_start(user_id, message_data)
elif msg_type == "journey_message":
await _handle_journey_message(user_id, message_data)
elif msg_type == "agent_trigger":
await _handle_agent_trigger(user_id, message_data)
elif msg_type == "device_online":
logger.info("batch-agent: device_online user=%s device=%s", user_id, message_data.get("device_id", "?"))
else:
logger.warning("batch-agent: unknown message type %r from user=%s", msg_type, user_id)
async def start_consumer() -> None:
"""Subscribe to batch:request:* and dispatch incoming frames."""
pubsub = redis_client.pubsub()
await pubsub.psubscribe("batch:request:*")
logger.info("batch-agent: subscribed to batch:request:*")
try:
async for message in pubsub.listen():
if message["type"] != "pmessage":
continue
channel: str = message["channel"]
if isinstance(channel, bytes):
channel = channel.decode()
# Extract user_id from channel: batch:request:{user_id}
parts = channel.split(":", 2)
if len(parts) < 3:
continue
user_id = parts[2]
raw = message["data"]
if isinstance(raw, bytes):
raw = raw.decode()
try:
data = json.loads(raw)
except json.JSONDecodeError:
logger.warning("batch-agent: invalid JSON on channel %s", channel)
continue
# Dispatch in a separate task to avoid blocking the consumer
asyncio.create_task(_dispatch(user_id, data))
except asyncio.CancelledError:
logger.info("batch-agent: consumer shutting down")
finally:
await pubsub.punsubscribe("batch:request:*")

View File

@@ -0,0 +1,208 @@
"""Agent REST routes — catalog, billing checks, trigger.
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
Agent trigger dispatches via Redis to the consumer instead of spawning
an in-process background task.
"""
from __future__ import annotations
import json
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Header, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from shared.db import async_session
from shared.models import AgentRunLog
from shared.redis import redis_client, batch_request_channel
from app.agent_runner import is_agent_running
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Tier feature limits ───────────────────────────────────────────────
# Mirrors app/billing/tier_manager.py FEATURES dict.
FEATURES: dict[str, dict] = {
"free": {"batch_active": 1, "batch_runs_per_day": 3},
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
"power": {"batch_active": 20, "batch_runs_per_day": 100},
"team": {"batch_active": -1, "batch_runs_per_day": -1},
}
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
async with async_session() as db:
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog")
async def get_agent_catalog(
x_user_id: str = Header(..., alias="X-User-Id"),
) -> list[dict]:
return [
{
"type": "local_directory",
"name": "Local Directory Monitor",
"description": "Watches local directories, extracts data from files using AI",
},
{
"type": "gmail",
"name": "Gmail Connector",
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
},
{
"type": "teams",
"name": "Microsoft Teams Connector",
"description": "Monitors Teams messages, extracts action items",
},
{
"type": "outlook",
"name": "Outlook Connector",
"description": "Scans Outlook inbox, extracts tasks/notes",
},
]
# ── Can-create check ─────────────────────────────────────────────────
@router.post("/can-create")
async def can_create_agent(
body: dict,
x_user_id: str = Header(..., alias="X-User-Id"),
x_user_tier: str = Header("free", alias="X-User-Tier"),
) -> dict:
active_agents = body.get("active_agents", 0)
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or active_agents < limit
return {
"allowed": allowed,
"tier": x_user_tier,
"active_agents": active_agents,
"limit": limit,
}
# ── Trigger ──────────────────────────────────────────────────────────
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: dict,
x_user_id: str = Header(..., alias="X-User-Id"),
x_user_tier: str = Header("free", alias="X-User-Tier"),
) -> dict:
"""Trigger a local agent run — creates run log and dispatches via Redis."""
active_agents = body.get("active_agents", 0)
_enforce_agent_limit(x_user_tier, active_agents)
await _enforce_run_frequency(x_user_tier, x_user_id)
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running.",
)
# Create run log in DB
async with async_session() as db:
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=x_user_id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_log_id = run_log.id
run_context = {
"type": "agent_batch",
"run_id": run_log_id,
"agent_id": stable_agent_id,
}
# Dispatch to the Redis consumer for processing
trigger_data = {
"type": "agent_trigger",
"directory": body.get("directory", ""),
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
"data_types": _to_data_types(body.get("what_to_extract", [])),
"file_extensions": body.get("file_extensions", []),
"prompt_template": body.get("custom_agent_prompt", ""),
"device_id": body.get("device_id", ""),
"run_context": run_context,
}
channel = batch_request_channel(x_user_id)
await redis_client.publish(channel, json.dumps(trigger_data))
return {
"id": run_log_id,
"agent_id": stable_agent_id,
"agent_type": "local",
"status": "running",
"items_processed": 0,
"items_created": 0,
"errors": [],
"started_at": _dt_ms(run_log.started_at),
"completed_at": None,
}

View File

@@ -0,0 +1,336 @@
"""Langfuse tracing & prompt management for the Batch Agent Service (v4 SDK).
Provides:
- ``init_langfuse()`` — initialise the singleton client at startup
- ``trace_span()`` — context manager that creates a trace + span
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
- ``flush()`` / ``shutdown()`` — lifecycle management
All functions gracefully degrade to no-ops when Langfuse is not configured,
so the service works identically with or without observability keys.
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import Any
from shared.config import settings
logger = logging.getLogger(__name__)
# ── State ────────────────────────────────────────────────────────────────
_initialised: bool = False
_disabled: bool = False
def _is_configured() -> bool:
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
def init_langfuse() -> None:
"""Initialise the Langfuse singleton. Call once at startup."""
global _initialised, _disabled
if _initialised or _disabled:
return
if not _is_configured():
_disabled = True
logger.info("tracing: Langfuse keys not set — tracing disabled")
return
try:
from langfuse import Langfuse
Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_HOST,
)
_initialised = True
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
except Exception as exc:
_disabled = True
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
def _get_client() -> Any | None:
"""Return the singleton Langfuse client, or *None* if disabled."""
if _disabled:
return None
if not _initialised:
init_langfuse()
if _disabled:
return None
try:
from langfuse import get_client
return get_client()
except Exception:
return None
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
class _NullSpan:
"""Drop-in replacement when Langfuse is disabled."""
def update(self, **_: Any) -> None: ...
def set_trace_io(self, **_: Any) -> None: ...
def score_trace(self, **_: Any) -> None: ...
# ── Trace context manager ───────────────────────────────────────────────
@contextmanager
def trace_span(
*,
name: str,
user_id: str,
session_id: str | None = None,
trace_id: str | None = None,
input: Any = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
):
"""Context manager that creates a Langfuse trace/span.
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
A ``CallbackHandler`` created inside this block auto-inherits the trace
context, so there is no need to pass trace IDs manually.
"""
lf = _get_client()
if lf is None:
yield _NullSpan()
return
try:
from langfuse import Langfuse, propagate_attributes
trace_ctx: dict[str, str] = {}
if trace_id is not None:
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
with lf.start_as_current_observation(
as_type="span",
name=name,
input=input,
metadata=metadata or {},
**({"trace_context": trace_ctx} if trace_ctx else {}),
) as span:
with propagate_attributes(
user_id=user_id,
session_id=session_id,
tags=tags or [],
):
yield span
except Exception as exc:
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
yield _NullSpan()
# ── LangChain callback handler ──────────────────────────────────────────
def get_langfuse_callback() -> Any | None:
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
Must be called inside a ``trace_span()`` block for proper linking.
Returns *None* when Langfuse is disabled.
"""
if _disabled and not _initialised:
return None
try:
from langfuse.langchain import CallbackHandler
return CallbackHandler()
except Exception as exc:
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
return None
# ── Prompt management ────────────────────────────────────────────────────
def get_prompt(
name: str,
*,
version: int | None = None,
label: str | None = None,
fallback: str | None = None,
cache_ttl_seconds: int = 300,
) -> str | None:
"""Fetch a managed prompt from Langfuse by name (without variable compilation).
Returns the raw prompt string, or *fallback* if the prompt is not
found or Langfuse is disabled.
"""
lf = _get_client()
if lf is None:
return fallback
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
return prompt.prompt
except Exception as exc:
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
return fallback
def compile_prompt(
name: str,
*,
fallback: str,
variables: dict[str, str],
version: int | None = None,
label: str | None = None,
cache_ttl_seconds: int = 300,
) -> str:
"""Fetch a managed prompt from Langfuse and compile it with ``{{variables}}``.
If the prompt exists in Langfuse, uses the SDK's ``.compile(**variables)``
which replaces ``{{key}}`` placeholders. If Langfuse is disabled or the
prompt is not found, falls back to ``fallback.format(**variables)`` (Python
``{key}`` placeholders).
This means:
- Langfuse prompts use ``{{variable}}`` syntax.
- Hardcoded fallback strings use Python ``{variable}`` syntax.
"""
lf = _get_client()
if lf is None:
return fallback.format(**variables)
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
return prompt.compile(**variables)
except Exception as exc:
logger.warning("tracing: compile_prompt(%s) failed, using fallback: %s", name, exc)
return fallback.format(**variables)
def get_prompt_object(
name: str,
*,
version: int | None = None,
label: str | None = None,
cache_ttl_seconds: int = 300,
) -> Any | None:
"""Fetch the raw Langfuse prompt *object* (not the compiled string).
Returns ``None`` when Langfuse is disabled or the prompt is not found.
Use this when you need to pass the prompt to ``start_observation(prompt=...)``
for linking the prompt to a trace in the Langfuse UI.
"""
lf = _get_client()
if lf is None:
return None
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
return lf.get_prompt(**kwargs)
except Exception as exc:
logger.warning("tracing: get_prompt_object(%s) failed: %s", name, exc)
return None
def link_prompt_to_trace(
span: Any,
prompt_name: str,
*,
version: int | None = None,
label: str | None = None,
) -> None:
"""Link a Langfuse managed prompt to a span/observation.
Uses the SDK v4 ``prompt=`` parameter so that the prompt version
appears linked in the Langfuse UI with metrics tracking.
"""
lf = _get_client()
if lf is None or isinstance(span, _NullSpan):
return
try:
prompt = get_prompt_object(prompt_name, version=version, label=label)
if prompt is not None:
span.update(prompt=prompt)
except Exception as exc:
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
# ── Scoring helper ───────────────────────────────────────────────────────
def score_trace(
trace_id: str,
name: str,
value: float,
*,
comment: str | None = None,
) -> None:
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
lf = _get_client()
if lf is None:
return
try:
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
except Exception as exc:
logger.warning("tracing: score_trace failed: %s", exc)
# ── Shutdown ─────────────────────────────────────────────────────────────
def flush() -> None:
"""Flush pending Langfuse events."""
lf = _get_client()
if lf is not None:
try:
lf.flush()
except Exception as exc:
logger.warning("tracing: flush failed: %s", exc)
def shutdown() -> None:
"""Flush and close the Langfuse client."""
global _initialised, _disabled
lf = _get_client()
if lf is not None:
try:
lf.flush()
lf.shutdown()
except Exception as exc:
logger.warning("tracing: shutdown failed: %s", exc)
_initialised = False
_disabled = False

View File

@@ -0,0 +1 @@
"""Batch Agent E2E evaluation harness."""

View File

@@ -0,0 +1,5 @@
"""Allow running the eval package as ``python -m eval``."""
from eval.cli import main
main()

View File

@@ -0,0 +1,285 @@
"""CLI entry point for the batch agent evaluation harness.
Usage::
# From services/batch-agent/:
python -m eval run # all agent fixtures, default model
python -m eval run --fixture=classify-invoices # single fixture
python -m eval run --models=gpt-4o,gpt-5.3-codex # multiple models
python -m eval run --mode=step1 # only step1 fixtures
python -m eval run --no-judge # skip LLM judge scoring
python -m eval interactive # interactive journey session
python -m eval interactive --fixture=journey-invoice-setup
python -m eval interactive --model=gpt-4o
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
python -m eval list # list all fixtures
python -m eval sync # sync fixtures to Langfuse datasets
"""
from __future__ import annotations
import argparse
import asyncio
import logging
import sys
from pathlib import Path
# Ensure the service root and repo root are in sys.path.
# Service root must come BEFORE repo root so its ``app/`` package
# shadows the monolith ``app/`` in the repo root.
_SERVICE_ROOT = Path(__file__).resolve().parent.parent
_REPO_ROOT = _SERVICE_ROOT.parent.parent
_sr = str(_SERVICE_ROOT)
_rr = str(_REPO_ROOT)
if _rr not in sys.path:
sys.path.insert(0, _rr)
# Always force service root to position 0 (python -m may have already
# added CWD further down the list, which loses to repo root).
if _sr in sys.path:
sys.path.remove(_sr)
sys.path.insert(0, _sr)
from eval.config import discover_fixtures, discover_journey_fixtures
from eval.runner import run_fixture_eval, print_results
from eval.interactive import run_interactive
from eval import langfuse_eval
def _setup_logging(verbose: bool) -> None:
level = logging.DEBUG if verbose else logging.INFO
logging.basicConfig(
level=level,
format="%(asctime)s %(name)-20s %(levelname)-5s %(message)s",
datefmt="%H:%M:%S",
)
# Quiet noisy libraries
for name in ("httpx", "httpcore", "openai", "litellm", "urllib3"):
logging.getLogger(name).setLevel(logging.WARNING)
def _parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Batch Agent E2E evaluation harness",
prog="python -m eval",
)
sub = parser.add_subparsers(dest="command", required=True)
# ── run ───────────────────────────────────────────────────────
run_cmd = sub.add_parser("run", help="Run evaluations")
run_cmd.add_argument(
"--fixture", "-f",
help="Run only the named fixture (default: all)",
)
run_cmd.add_argument(
"--models", "-m",
default="github_copilot/gpt-5.3-codex",
help="Comma-separated list of models to test (default: github_copilot/gpt-5.3-codex)",
)
run_cmd.add_argument(
"--mode",
default=None,
choices=["step1", "step2", "full"],
help="Only run fixtures with this mode (default: all)",
)
run_cmd.add_argument(
"--no-judge",
action="store_true",
help="Skip LLM-as-judge scoring",
)
run_cmd.add_argument(
"--judge-model",
default="gpt-4o",
help="Model for LLM judge (default: gpt-4o)",
)
run_cmd.add_argument(
"--fixtures-dir",
default=None,
help="Path to fixtures directory (default: eval/fixtures/)",
)
run_cmd.add_argument("-v", "--verbose", action="store_true")
# ── list ──────────────────────────────────────────────────────
list_cmd = sub.add_parser("list", help="List available fixtures")
list_cmd.add_argument("--fixtures-dir", default=None)
list_cmd.add_argument("-v", "--verbose", action="store_true")
# ── sync ──────────────────────────────────────────────────────
sync_cmd = sub.add_parser("sync", help="Sync fixtures to Langfuse datasets")
sync_cmd.add_argument("--fixture", "-f", default=None, help="Sync only the named fixture")
sync_cmd.add_argument("--fixtures-dir", default=None)
sync_cmd.add_argument("-v", "--verbose", action="store_true")
# ── interactive ───────────────────────────────────────────────
inter_cmd = sub.add_parser("interactive", help="Interactive journey session (human-in-the-loop)")
inter_cmd.add_argument(
"--fixture", "-f",
help="Journey fixture to use (default: pick interactively)",
)
inter_cmd.add_argument(
"--model", "-m",
default="github_copilot/gpt-5.3-codex",
help="Model for the journey AI (default: github_copilot/gpt-5.3-codex)",
)
inter_cmd.add_argument(
"--judge-model",
default="gpt-4o",
help="Model for LLM judge (default: gpt-4o)",
)
inter_cmd.add_argument(
"--fixtures-dir",
default=None,
help="Path to fixtures directory (default: eval/fixtures/)",
)
inter_cmd.add_argument(
"--data-dir",
default=None,
help="Override sample data directory (e.g. path to private test files not in git)",
)
inter_cmd.add_argument("-v", "--verbose", action="store_true")
return parser.parse_args()
def _fixtures_dir(arg: str | None) -> Path | None:
if arg:
return Path(arg)
return None
async def _cmd_run(args: argparse.Namespace) -> None:
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
if not fixtures:
print("No fixtures found. Create YAML files in eval/fixtures/.")
return
if args.fixture:
fixtures = [f for f in fixtures if f.name == args.fixture]
if not fixtures:
print(f"Fixture '{args.fixture}' not found.")
return
models = [m.strip() for m in args.models.split(",")]
all_results = []
for fixture in fixtures:
if args.mode and fixture.mode != args.mode:
continue
results = await run_fixture_eval(
fixture,
models=models,
use_llm_judge=not args.no_judge,
judge_model=args.judge_model,
)
all_results.extend(results)
print_results(all_results)
def _cmd_list(args: argparse.Namespace) -> None:
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
if not fixtures and not journey_fixtures:
print("No fixtures found.")
return
if fixtures:
print(f"\n{'[Agent Fixtures]'}")
print(f"{'Name':<30} {'Mode':<6} {'Types':<25} {'Expected'}")
print("-" * 90)
for f in fixtures:
types = ", ".join(f.data_types)
n_expected = len(f.expected) + len(f.expected_classification)
print(f"{f.name:<30} {f.mode:<6} {types:<25} {n_expected}")
if journey_fixtures:
print(f"\n{'[Journey Fixtures]'}")
print(f"{'Name':<30} {'Types':<25} {'Messages':<10} {'Criteria'}")
print("-" * 90)
for f in journey_fixtures:
types = ", ".join(f.data_types)
print(f"{f.name:<30} {types:<25} {len(f.user_messages):<10} {len(f.expected_template_criteria)}")
print()
def _cmd_sync(args: argparse.Namespace) -> None:
fixtures = discover_fixtures(_fixtures_dir(args.fixtures_dir))
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
if args.fixture:
fixtures = [f for f in fixtures if f.name == args.fixture]
journey_fixtures = [f for f in journey_fixtures if f.name == args.fixture]
if not fixtures and not journey_fixtures:
print("No fixtures to sync.")
return
for fixture in fixtures:
name = langfuse_eval.sync_fixture_to_dataset(fixture)
if name:
print(f"Synced: {fixture.name}{name}")
else:
print(f"Skipped: {fixture.name} (Langfuse not configured)")
for fixture in journey_fixtures:
name = langfuse_eval.sync_journey_fixture_to_dataset(fixture)
if name:
print(f"Synced: {fixture.name}{name}")
else:
print(f"Skipped: {fixture.name} (Langfuse not configured)")
async def _cmd_interactive(args: argparse.Namespace) -> None:
journey_fixtures = discover_journey_fixtures(_fixtures_dir(args.fixtures_dir))
if not journey_fixtures:
print("No journey fixtures found. Create YAML files with type: journey in eval/fixtures/.")
return
if args.fixture:
fixtures = [f for f in journey_fixtures if f.name == args.fixture]
if not fixtures:
print(f"Journey fixture '{args.fixture}' not found.")
return
fixture = fixtures[0]
elif len(journey_fixtures) == 1:
fixture = journey_fixtures[0]
else:
# Let user pick
print("\nAvailable journey fixtures:")
for i, f in enumerate(journey_fixtures, 1):
print(f" {i}. {f.name}{f.description[:60]}")
print()
try:
choice = int(input("Pick a fixture number: ").strip()) - 1
fixture = journey_fixtures[choice]
except (ValueError, IndexError, EOFError, KeyboardInterrupt):
print("Invalid choice.")
return
await run_interactive(
fixture,
model=args.model,
judge_model=args.judge_model,
data_dir=Path(args.data_dir).resolve() if args.data_dir else None,
)
def main() -> None:
args = _parse_args()
_setup_logging(args.verbose)
if args.command == "run":
asyncio.run(_cmd_run(args))
elif args.command == "interactive":
asyncio.run(_cmd_interactive(args))
elif args.command == "list":
_cmd_list(args)
elif args.command == "sync":
_cmd_sync(args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,220 @@
"""Eval configuration — YAML fixture loader and dataclasses.
Fixtures come in two families:
1. **Agent fixtures** — test the batch agent pipeline.
Three modes controlled by ``mode``:
``step1`` — classification prompt only.
``step2`` — processing prompt only.
``full`` — both steps in sequence.
2. **Journey fixtures** — test the prompt-template builder conversation
(unchanged).
"""
from __future__ import annotations
import logging
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any, Literal
import yaml
logger = logging.getLogger(__name__)
EvalMode = Literal["step1", "step2", "full"]
@dataclass
class ExpectedRecord:
"""A single expected extraction result.
Only the fields specified are checked — unspecified fields are ignored.
"""
table: str # tasks | notes | timelines | projects
fields: dict[str, Any] # field_name → expected_value
@dataclass
class ExpectedClassification:
"""Expected output of step-1 classification for one file."""
file: str # relative path to the sample file
project_id: str # expected matched project id, or "new"
domains: list[str] # expected domain list
new_project_name: str | None = None
@dataclass
class EvalFixture:
"""A complete test scenario loaded from YAML.
``mode`` determines which pipeline steps are exercised:
- **step1**: only ``_classify_file``
- **step2**: only the processing LLM + tool loop
- **full**: both steps in sequence (``run_local_agent``)
"""
name: str
description: str
mode: EvalMode
directory: str # relative path to sample files
data_types: list[str]
file_extensions: list[str]
models: list[str] # if empty, use CLI default
fixture_path: Path = field(default_factory=lambda: Path("."))
# ── Step-1 inputs (classification) ───────────────────────────
domain_definitions: str = ""
projects_list: list[dict[str, Any]] = field(default_factory=list)
custom_step1_prompt: str = ""
# ── Step-2 inputs (processing) ───────────────────────────────
existing_context: str = ""
project_context: str = ""
custom_prompt_section: str = ""
# ── Seed records for mock executor ───────────────────────────
seed_records: dict[str, list[dict]] = field(default_factory=dict)
# ── Expected outputs ─────────────────────────────────────────
expected_classification: list[ExpectedClassification] = field(default_factory=list)
expected: list[ExpectedRecord] = field(default_factory=list)
@property
def fixture_dir(self) -> Path:
"""Absolute path to the sample files directory."""
return self.fixture_path.parent / self.directory
@classmethod
def from_yaml(cls, path: Path) -> "EvalFixture":
"""Load a fixture from a YAML file."""
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
mode: EvalMode = raw.get("mode", "full")
# Parse expected records (step2/full)
expected: list[ExpectedRecord] = []
for table, records in (raw.get("expected") or {}).items():
for rec in records:
expected.append(ExpectedRecord(table=table, fields=rec))
# Parse expected classification (step1/full)
expected_classification: list[ExpectedClassification] = []
for item in raw.get("expected_classification") or []:
expected_classification.append(ExpectedClassification(
file=item["file"],
project_id=item["project_id"],
domains=item.get("domains", []),
new_project_name=item.get("new_project_name"),
))
return cls(
name=raw["name"],
description=raw.get("description", ""),
mode=mode,
directory=raw.get("directory", "sample_files"),
data_types=raw.get("data_types", ["tasks"]),
file_extensions=raw.get("file_extensions", []),
models=raw.get("models", []),
fixture_path=path,
# Step-1 inputs
domain_definitions=raw.get("domain_definitions", ""),
projects_list=raw.get("projects_list", []),
# Step-2 inputs
existing_context=raw.get("existing_context", ""),
project_context=raw.get("project_context", ""),
custom_prompt_section=raw.get("custom_prompt_section", ""),
# Shared
seed_records=raw.get("seed_records", {}),
expected_classification=expected_classification,
expected=expected,
)
def discover_fixtures(fixtures_dir: Path | None = None) -> list[EvalFixture]:
"""Find and load all YAML fixtures in the fixtures directory."""
if fixtures_dir is None:
fixtures_dir = Path(__file__).parent / "fixtures"
fixtures: list[EvalFixture] = []
if not fixtures_dir.is_dir():
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
return fixtures
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
try:
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
if raw.get("type") == "journey":
continue # Skip journey fixtures
fixtures.append(EvalFixture.from_yaml(yaml_path))
logger.info("eval: loaded fixture %s from %s", fixtures[-1].name, yaml_path.name)
except Exception as exc:
logger.error("eval: failed to load fixture %s: %s", yaml_path.name, exc)
return fixtures
# ── Journey fixtures ─────────────────────────────────────────────────────
@dataclass
class JourneyFixture:
"""A journey test scenario — tests the prompt_template builder conversation."""
name: str
description: str
directory: str # relative path to sample files
data_types: list[str]
expected_template_criteria: list[str] # what the template should contain/satisfy
user_messages: list[str] = field(default_factory=list) # for automated journey runs (unused in interactive mode)
models: list[str] = field(default_factory=list)
fixture_path: Path = field(default_factory=lambda: Path("."))
@property
def fixture_dir(self) -> Path:
"""Absolute path to the sample files directory."""
return self.fixture_path.parent / self.directory
@classmethod
def from_yaml(cls, path: Path) -> "JourneyFixture":
"""Load a journey fixture from a YAML file."""
raw = yaml.safe_load(path.read_text(encoding="utf-8"))
return cls(
name=raw["name"],
description=raw.get("description", ""),
directory=raw.get("directory", "sample_files"),
data_types=raw.get("data_types", ["tasks"]),
user_messages=raw.get("user_messages", []),
expected_template_criteria=raw.get("expected_template_criteria", []),
models=raw.get("models", []),
fixture_path=path,
)
def discover_journey_fixtures(fixtures_dir: Path | None = None) -> list[JourneyFixture]:
"""Find and load all journey YAML fixtures in the fixtures directory."""
if fixtures_dir is None:
fixtures_dir = Path(__file__).parent / "fixtures"
fixtures: list[JourneyFixture] = []
if not fixtures_dir.is_dir():
logger.warning("eval: fixtures directory not found: %s", fixtures_dir)
return fixtures
for yaml_path in sorted(fixtures_dir.glob("*.yaml")):
try:
raw = yaml.safe_load(yaml_path.read_text(encoding="utf-8"))
if raw.get("type") != "journey":
continue
fixtures.append(JourneyFixture.from_yaml(yaml_path))
logger.info("eval: loaded journey fixture %s from %s", fixtures[-1].name, yaml_path.name)
except Exception as exc:
logger.error("eval: failed to load journey fixture %s: %s", yaml_path.name, exc)
return fixtures

View File

@@ -0,0 +1,40 @@
# Fixture: classify-invoices (step1)
# Tests _STEP1_SYSTEM_PROMPT — file classification and project matching.
# Verifies that the LLM correctly matches files to existing projects
# and identifies the right data domains.
name: classify-invoices
mode: step1
description: >
Test file classification on Italian freelance invoices and meeting notes.
Verifies project matching and domain identification.
directory: sample_files/invoices
data_types: [tasks, notes, timelines]
file_extensions: [txt, md]
# ── Step-1 prompt variables ──────────────────────────────────────
domain_definitions: |
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
projects_list:
- id: "proj-web-redesign"
name: "Redesign Sito Web Corporate"
status: "active"
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
- id: "proj-ecommerce"
name: "E-Commerce FashionStore"
status: "active"
aiSummary: "Next.js e-commerce platform for FashionStore srl"
# ── Expected classification results ─────────────────────────────
expected_classification:
- file: "sample_files/invoices/fattura_042.txt"
project_id: "proj-web-redesign"
domains: [tasks, notes, timelines]
- file: "sample_files/invoices/meeting_ecommerce.md"
project_id: "proj-ecommerce"
domains: [tasks, notes, timelines]

View File

@@ -0,0 +1,108 @@
# Fixture: full-invoices (full)
# Tests both _STEP1_SYSTEM_PROMPT and _PROCESSING_SYSTEM_PROMPT in sequence
# via run_local_agent(). Verifies end-to-end classification + extraction.
name: full-invoices
mode: full
description: >
End-to-end test: classify Italian invoices/meeting notes into the
correct project, then extract tasks, notes, and timeline events.
directory: sample_files/invoices
data_types: [tasks, notes, timelines]
file_extensions: [txt, md]
# ── Step-1 prompt variables ──────────────────────────────────────
domain_definitions: |
- tasks: Action items, deliverables, things to do — anything that someone needs to complete.
- notes: Meeting summaries, decisions, reference information — permanent knowledge entries.
- timelines: Project milestones, deadlines, scheduled events — specific dates that mark a point in the progress of a project.
projects_list:
- id: "proj-web-redesign"
name: "Redesign Sito Web Corporate"
status: "active"
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
- id: "proj-ecommerce"
name: "E-Commerce FashionStore"
status: "active"
aiSummary: "Next.js e-commerce platform for FashionStore srl"
# ── Step-2 prompt variables ──────────────────────────────────────
existing_context: |
Existing tasks:
(none)
Existing notes:
(none)
Existing timelines:
(none)
project_context: ""
custom_prompt_section: |
User instructions:
Estrai i dati dai file come segue:
- TASK: ogni azione da fare, deliverable, o item con scadenza.
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
Mappa "media priorità" → priority: medium.
Mappa "bassa priorità" → priority: low.
Se un item è marcato come "completato" o [x], impostalo status: done.
Altrimenti status: todo.
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
- TIMELINE: date di scadenza, milestone, meeting futuri.
Imposta sempre isAiSuggested=1.
# ── Seed records (pre-existing DB state) ─────────────────────────
seed_records:
projects:
- id: "proj-web-redesign"
name: "Redesign Sito Web Corporate"
status: "active"
aiSummary: "Corporate website redesign for Studio Architettura Bianchi"
- id: "proj-ecommerce"
name: "E-Commerce FashionStore"
status: "active"
aiSummary: "Next.js e-commerce platform for FashionStore srl"
tasks: []
notes: []
timelines: []
# ── Expected classification (step 1) ─────────────────────────────
expected_classification:
- file: "sample_files/invoices/fattura_042.txt"
project_id: "proj-web-redesign"
domains: [tasks, notes, timelines]
- file: "sample_files/invoices/meeting_ecommerce.md"
project_id: "proj-ecommerce"
domains: [tasks, notes, timelines]
# ── Expected extractions (step 2) ────────────────────────────────
expected:
tasks:
- title: "Sviluppo frontend React"
priority: "high"
status: "todo"
- title: "Integrazione API backend"
priority: "medium"
status: "todo"
- title: "Testing cross-browser e fix bug responsive"
status: "todo"
- title: "Preparare wireframe homepage"
priority: "high"
status: "todo"
- title: "Setup progetto Next.js e configurare CI/CD"
priority: "medium"
status: "todo"
- title: "Ricerca plugin Stripe per gestione abbonamenti"
priority: "low"
status: "todo"
notes:
- title: "Meeting Kickoff Progetto E-Commerce"
timelines:
- title: "MVP E-Commerce pronto"
- title: "Meeting di revisione"

View File

@@ -0,0 +1,28 @@
# Journey Fixture: journey-invoice-setup
# Used by `python -m eval interactive` for human-in-the-loop testing
# of the journey chatbot's prompt-building conversation.
type: journey
name: journey-invoice-setup
description: >
Interactive test for the journey chatbot — explore a directory of
Italian invoices and meeting notes, answer the chatbot's questions,
and verify it produces a well-structured prompt_template for data
extraction.
directory: sample_files/invoices
data_types: [tasks, notes, timelines, projects]
# Criteria the generated prompt_template must satisfy
# Each is scored 0-1 by an LLM judge
expected_template_criteria:
- "Mentions creating tasks from action items and work descriptions"
- "Mentions creating notes from meeting summaries"
- "Mentions extracting timeline events from deadlines and meeting dates"
- "Mentions creating projects from relevant information"
- "Sets isAiSuggested=1 on all created records"
- "Does NOT include projectId assignment logic"
- "Uses camelCase field names (title, status, priority, dueDate, content)"
# Models to test (empty = use CLI --models default)
models: []

View File

@@ -0,0 +1,81 @@
# Fixture: process-invoices (step2)
# Tests _PROCESSING_SYSTEM_PROMPT — data extraction & tool calling.
# The classification step is skipped; prompt variables are injected directly.
name: process-invoices
mode: step2
description: >
Test data extraction from Italian freelance invoices.
Verifies correct record creation via tool calls with the right
fields, priorities, and status values.
directory: sample_files/invoices
data_types: [tasks, notes, timelines]
file_extensions: [txt, md]
# ── Step-2 prompt variables ──────────────────────────────────────
existing_context: |
Existing tasks:
(none)
Existing notes:
(none)
Existing timelines:
(none)
project_context: >
Project: Redesign Sito Web Corporate (id: proj-web-redesign).
Always set projectId to this id on every record you create.
custom_prompt_section: |
User instructions:
Estrai i dati dai file come segue:
- TASK: ogni azione da fare, deliverable, o item con scadenza.
Mappa "URGENTE" o "ALTA PRIORITÀ" → priority: high.
Mappa "media priorità" → priority: medium.
Mappa "bassa priorità" → priority: low.
Se un item è marcato come "completato" o [x], impostalo status: done.
Altrimenti status: todo.
- NOTE: riassunti di meeting, decisioni prese, note tecniche.
Il titolo deve essere descrittivo. Il content deve includere tutti i dettagli.
- TIMELINE: date di scadenza, milestone, meeting futuri.
Imposta sempre isAiSuggested=1.
# ── Seed records (pre-existing DB state) ─────────────────────────
seed_records:
projects:
- id: "proj-web-redesign"
name: "Redesign Sito Web Corporate"
status: "active"
tasks: []
notes: []
timelines: []
# ── Expected extractions ─────────────────────────────────────────
expected:
tasks:
- title: "Sviluppo frontend React"
priority: "high"
status: "todo"
- title: "Integrazione API backend"
priority: "medium"
status: "todo"
- title: "Testing cross-browser e fix bug responsive"
status: "todo"
- title: "Preparare wireframe homepage"
priority: "high"
status: "todo"
- title: "Setup progetto Next.js e configurare CI/CD"
priority: "medium"
status: "todo"
- title: "Ricerca plugin Stripe per gestione abbonamenti"
priority: "low"
status: "todo"
notes:
- title: "Meeting Kickoff Progetto E-Commerce"
timelines:
- title: "MVP E-Commerce pronto"
- title: "Meeting di revisione"

View File

@@ -0,0 +1,18 @@
FATTURA N. 2026-0042
Data: 15 Marzo 2026
Cliente: Studio Architettura Bianchi
Progetto: Redesign Sito Web Corporate
Descrizione lavori:
- Sviluppo frontend React (40 ore) — URGENTE, completare entro 20 marzo
- Integrazione API backend (20 ore) — priorità media
- Design UI/UX mockup homepage (8 ore) — completato
- Testing cross-browser e fix bug responsive (12 ore) — da iniziare
Totale: €4.800,00 + IVA
Note:
Meeting di revisione previsto per il 18 marzo alle 10:00.
Il cliente ha richiesto modifiche al layout mobile della sezione contatti.
Attendere conferma budget aggiuntivo per sezione blog.

View File

@@ -0,0 +1,25 @@
# Meeting Notes - Kickoff Progetto E-Commerce
**Data:** 10 Marzo 2026
**Partecipanti:** Marco R., Giulia T., Cliente (FashionStore srl)
## Decisioni prese
1. **Piattaforma**: Next.js + Stripe per i pagamenti
2. **Timeline**: MVP pronto entro 30 aprile 2026
3. **Budget**: €12.000 totale, €4.000 anticipo già ricevuto
## Action items
- [ ] Marco: preparare wireframe homepage entro 14 marzo — ALTA PRIORITÀ
- [ ] Giulia: setup progetto Next.js e configurare CI/CD — media priorità
- [ ] Marco: ricerca plugin Stripe per gestione abbonamenti — bassa priorità
- [x] Giulia: inviare contratto firmato al cliente — COMPLETATO
## Note aggiuntive
Il cliente vuole un design minimalista, ispirato a Zara.com.
Colori primari: nero, bianco, oro.
Font: Inter per body, Playfair Display per headings.
Prossimo meeting: 24 marzo 2026 ore 15:00.

View File

@@ -0,0 +1,471 @@
"""Interactive journey session — human-in-the-loop CLI conversation.
Flow:
1. Show the system prompt used by the journey AI.
2. Start the journey (AI explores files, asks first question).
3. User types responses in the terminal — AI replies.
4. User types `/done` to end the conversation.
5. User writes a comment about the interaction quality.
6. LLM judge scores the conversation + generated template.
7. Results are reported to Langfuse.
Usage::
python -m eval interactive # pick a fixture interactively
python -m eval interactive --fixture=journey-invoice-setup
python -m eval interactive --model=gpt-4o
python -m eval interactive --judge-model=github_copilot/gpt-4o-mini
"""
from __future__ import annotations
import asyncio
import json
import logging
import sys
import time
import uuid
from dataclasses import dataclass, field
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from eval.config import JourneyFixture, discover_journey_fixtures
from eval.mock_executor import MockExecutor
from eval import langfuse_eval
logger = logging.getLogger(__name__)
# ── Special commands ─────────────────────────────────────────────────────
_CMD_DONE = "/done"
_CMD_QUIT = "/quit"
_CMD_TEMPLATE = "/template"
_CMD_HELP = "/help"
_HELP_TEXT = f"""\
{_CMD_DONE} — End the conversation and proceed to evaluation
{_CMD_QUIT} — Abort without evaluation
{_CMD_TEMPLATE} — Show the generated template (if any)
{_CMD_HELP} — Show this help"""
# ── Terminal colours (ANSI) ──────────────────────────────────────────────
_C_RESET = "\033[0m"
_C_BOLD = "\033[1m"
_C_DIM = "\033[2m"
_C_CYAN = "\033[36m"
_C_GREEN = "\033[32m"
_C_YELLOW = "\033[33m"
_C_MAGENTA = "\033[35m"
_C_RED = "\033[31m"
_C_BLUE = "\033[34m"
def _print_header(text: str) -> None:
print(f"\n{_C_BOLD}{_C_CYAN}{'' * 80}")
print(f" {text}")
print(f"{'' * 80}{_C_RESET}\n")
def _print_ai(text: str) -> None:
print(f"\n{_C_GREEN}{_C_BOLD}AI:{_C_RESET} {text}\n")
def _print_system(text: str) -> None:
print(f"{_C_DIM}{text}{_C_RESET}")
def _print_score(label: str, score: float) -> None:
if score >= 0.7:
color = _C_GREEN
tag = "PASS"
elif score >= 0.4:
color = _C_YELLOW
tag = "PARTIAL"
else:
color = _C_RED
tag = "FAIL"
print(f" {color}{tag:>7}{_C_RESET} ({score:.1f}) {label}")
# ── Result type ──────────────────────────────────────────────────────────
@dataclass
class InteractiveResult:
fixture_name: str
model: str
judge_model: str
prompt_template: str | None
conversation: list[dict[str, str]]
user_comment: str
done: bool
criteria_scores: dict[str, float]
overall_score: float
judge_reasoning: str
elapsed_seconds: float
def summary(self) -> dict[str, Any]:
return {
"fixture": self.fixture_name,
"model": self.model,
"judge_model": self.judge_model,
"done": self.done,
"turns": len([c for c in self.conversation if c["role"] == "user"]),
"overall_score": round(self.overall_score, 3),
"user_comment": self.user_comment,
"criteria_scores": {k: round(v, 3) for k, v in self.criteria_scores.items()},
"elapsed_s": round(self.elapsed_seconds, 1),
}
# ── LLM judge ────────────────────────────────────────────────────────────
_INTERACTIVE_JUDGE_SYSTEM = """\
You are an evaluation judge for AI-generated prompt templates produced during
an interactive conversation between a human and a journey chatbot.
The chatbot explored a directory and through multi-turn conversation with the
user produced a prompt_template — an instruction set for a data-extraction agent.
You have access to:
- The full conversation transcript
- The generated prompt_template (if any)
- The user's own comment about the interaction
- A list of quality criteria
Score each criterion from 0 to 1:
- 1.0: Fully satisfied
- 0.5: Partially satisfied
- 0.0: Not satisfied
Also provide an overall_quality score (0-1) evaluating the conversation flow,
how well the AI understood the user, and the template quality.
Respond with ONLY a JSON object:
{
"criteria_scores": {"criterion_1": 0.8, ...},
"overall_quality": 0.85,
"reasoning": "Brief explanation covering both conversation quality and template accuracy"
}
"""
async def _judge_interactive(
conversation: list[dict[str, str]],
prompt_template: str | None,
user_comment: str,
criteria: list[str],
*,
judge_model: str = "gpt-4o-mini",
) -> tuple[dict[str, float], float, str]:
"""Score an interactive session. Returns (criteria_scores, overall_quality, reasoning)."""
from shared.llm import get_llm
llm = get_llm(model=judge_model, temperature=0)
conv_text = "\n".join(
f"{'USER' if t['role'] == 'user' else 'AI'}: {t['content']}"
for t in conversation
)
criteria_text = "\n".join(f" {i+1}. {c}" for i, c in enumerate(criteria))
user_content = (
f"## Conversation transcript\n```\n{conv_text}\n```\n\n"
f"## Generated prompt_template\n```\n{prompt_template or '(none — conversation did not complete)'}\n```\n\n"
f"## User's comment\n{user_comment}\n\n"
f"## Criteria to evaluate\n{criteria_text}"
)
try:
response = await llm.ainvoke([
SystemMessage(content=_INTERACTIVE_JUDGE_SYSTEM),
HumanMessage(content=user_content),
])
raw = response.content.strip()
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
parsed = json.loads(raw.strip())
scores_raw = parsed.get("criteria_scores", parsed.get("scores", {}))
criteria_scores: dict[str, float] = {}
for i, criterion in enumerate(criteria):
key_candidates = [f"criterion_{i+1}", criterion, criterion[:50], str(i + 1)]
score = 0.0
for key in key_candidates:
if key in scores_raw:
score = float(scores_raw[key])
break
if score == 0.0 and i < len(scores_raw):
score = float(list(scores_raw.values())[i])
criteria_scores[criterion] = score
overall = float(parsed.get("overall_quality", 0.0))
reasoning = str(parsed.get("reasoning", ""))
return criteria_scores, overall, reasoning
except Exception as exc:
logger.warning("interactive judge failed: %s", exc)
return {c: 0.0 for c in criteria}, 0.0, f"Judge error: {exc}"
# ── Interactive session ──────────────────────────────────────────────────
async def run_interactive(
fixture: JourneyFixture,
*,
model: str = "gpt-4o",
judge_model: str = "gpt-4o-mini",
data_dir: Path | None = None,
) -> InteractiveResult:
"""Run an interactive journey session in the terminal.
Parameters
----------
data_dir :
If set, overrides the fixture's sample-file directory. The LLM
will explore this folder instead of the default
``fixtures/sample_files/…``. Useful for private test data that
shouldn't be committed to git.
"""
from shared.config import settings
from shared.ws_context import set_current_user, clear_current_user
from app.journey import (
handle_journey_start,
handle_journey_message,
_build_system_prompt,
)
# When --data-dir is given, the MockExecutor's root becomes
# data_dir's parent and the journey directory is data_dir's name.
# This way the LLM sees a meaningful directory name (not ".") and
# MockExecutor resolves paths correctly.
# Otherwise, use the fixture's YAML parent and its relative path.
if data_dir:
mock_root = data_dir.parent
journey_directory = data_dir.name
else:
mock_root = fixture.fixture_path.parent
journey_directory = fixture.directory
mock = MockExecutor(
fixture_dir=mock_root,
seed_records={},
)
original_model = settings.LLM_MODEL
settings.LLM_MODEL = model
eval_user_id = f"interactive-{uuid.uuid4().hex[:8]}"
# ── Show system prompt ───────────────────────────────────────
system_prompt = _build_system_prompt(journey_directory, fixture.data_types)
_print_header("SYSTEM PROMPT")
print(f"{_C_DIM}{system_prompt}{_C_RESET}")
_print_header(f"INTERACTIVE JOURNEY | fixture: {fixture.name} | model: {model}")
print(f" Data dir: {mock_root}")
print(f" Type your responses. Commands: {_CMD_DONE}, {_CMD_QUIT}, {_CMD_TEMPLATE}, {_CMD_HELP}")
print(f" Judge model: {judge_model}")
print(f" Criteria: {len(fixture.expected_template_criteria)}")
print()
conversation: list[dict[str, str]] = []
prompt_template: str | None = None
done = False
start_time = time.time()
try:
set_current_user(eval_user_id)
with mock.patch():
# ── Start ────────────────────────────────────────────
_print_system("Starting journey... (AI is exploring your files)")
start_frame: dict[str, Any] = {
"agent_type": "local",
"directory": journey_directory,
"data_types": fixture.data_types,
"session_id": f"interactive-{uuid.uuid4().hex[:8]}",
}
reply = await handle_journey_start(eval_user_id, start_frame)
session_id = reply["session_id"]
conversation.append({"role": "assistant", "content": reply["message"]})
_print_ai(reply["message"])
if reply["done"]:
prompt_template = reply.get("prompt_template")
done = True
_print_system("Journey completed on first reply (template generated).")
# ── Conversation loop ────────────────────────────────
while not done:
try:
user_input = input(f"{_C_BOLD}{_C_BLUE}YOU:{_C_RESET} ").strip()
except (EOFError, KeyboardInterrupt):
print()
user_input = _CMD_QUIT
if not user_input:
continue
# Handle commands
if user_input.lower() == _CMD_QUIT:
_print_system("Aborted — no evaluation will be performed.")
settings.LLM_MODEL = original_model
clear_current_user()
return InteractiveResult(
fixture_name=fixture.name, model=model, judge_model=judge_model,
prompt_template=None, conversation=conversation,
user_comment="(aborted)", done=False,
criteria_scores={}, overall_score=0.0,
judge_reasoning="Session aborted by user.",
elapsed_seconds=time.time() - start_time,
)
if user_input.lower() == _CMD_HELP:
print(_HELP_TEXT)
continue
if user_input.lower() == _CMD_TEMPLATE:
if prompt_template:
print(f"\n{_C_MAGENTA}{prompt_template}{_C_RESET}\n")
else:
_print_system("No template generated yet.")
continue
if user_input.lower() == _CMD_DONE:
_print_system("Ending conversation...")
break
# ── Send message to AI ───────────────────────────
conversation.append({"role": "user", "content": user_input})
_print_system("AI is thinking...")
msg_frame: dict[str, Any] = {
"session_id": session_id,
"message": user_input,
}
reply = await handle_journey_message(eval_user_id, msg_frame)
conversation.append({"role": "assistant", "content": reply["message"]})
_print_ai(reply["message"])
if reply["done"]:
prompt_template = reply.get("prompt_template")
done = True
_print_system("Journey completed — template generated!")
except Exception as exc:
logger.error("interactive journey failed: %s", exc)
_print_system(f"Error: {exc}")
finally:
settings.LLM_MODEL = original_model
clear_current_user()
elapsed = time.time() - start_time
turns = len([c for c in conversation if c["role"] == "user"])
# ── Show template if generated ───────────────────────────────
if prompt_template:
_print_header("GENERATED TEMPLATE")
print(f"{_C_MAGENTA}{prompt_template}{_C_RESET}\n")
else:
_print_system("No template was generated during this session.")
# ── User comment ─────────────────────────────────────────────
_print_header("YOUR EVALUATION")
print(" Write your comment about this interaction (press Enter twice to finish):")
print()
comment_lines: list[str] = []
try:
while True:
line = input()
if line == "" and comment_lines and comment_lines[-1] == "":
comment_lines.pop() # remove trailing empty
break
comment_lines.append(line)
except (EOFError, KeyboardInterrupt):
pass
user_comment = "\n".join(comment_lines).strip() or "(no comment)"
# ── Judge ────────────────────────────────────────────────────
_print_header("LLM JUDGE EVALUATION")
_print_system(f"Scoring with {judge_model}...")
criteria_scores, overall_quality, judge_reasoning = await _judge_interactive(
conversation=conversation,
prompt_template=prompt_template,
user_comment=user_comment,
criteria=fixture.expected_template_criteria,
judge_model=judge_model,
)
# ── Display scores ───────────────────────────────────────────
print()
for criterion, score in criteria_scores.items():
_print_score(criterion, score)
overall = (
sum(criteria_scores.values()) / len(criteria_scores)
if criteria_scores
else 0.0
)
print(f"\n {_C_BOLD}Criteria avg: {overall:.2f}{_C_RESET}")
print(f" {_C_BOLD}Overall quality: {overall_quality:.2f}{_C_RESET}")
print(f" {_C_BOLD}Turns: {turns}{_C_RESET}")
print(f" {_C_BOLD}Time: {elapsed:.1f}s{_C_RESET}")
print(f"\n {_C_DIM}Judge: {judge_reasoning}{_C_RESET}")
print(f" {_C_DIM}Your comment: {user_comment}{_C_RESET}\n")
result = InteractiveResult(
fixture_name=fixture.name,
model=model,
judge_model=judge_model,
prompt_template=prompt_template,
conversation=conversation,
user_comment=user_comment,
done=done,
criteria_scores=criteria_scores,
overall_score=overall_quality,
judge_reasoning=judge_reasoning,
elapsed_seconds=elapsed,
)
# ── Report to Langfuse ───────────────────────────────────────
trace_id = langfuse_eval.log_eval_trace(
fixture_name=fixture.name,
model=model,
prompt_variant="interactive",
prompt_template=prompt_template or "(not generated)",
actual_mutations=[{
"conversation": conversation[:30],
"user_comment": user_comment,
}],
scores_summary=result.summary(),
langfuse_prompt_names=["journey_system"],
)
if trace_id:
from eval.scorer import EvalScores
scores_obj = EvalScores(
fixture_name=fixture.name,
model=model,
prompt_variant="interactive",
precision=overall,
recall=float(done),
f1=overall,
llm_judge_score=overall_quality,
llm_judge_reasoning=judge_reasoning,
)
langfuse_eval.post_eval_scores(scores_obj, trace_id=trace_id)
_print_system(f"Results reported to Langfuse (trace: {trace_id})")
else:
_print_system("Langfuse not configured — results not reported.")
return result

View File

@@ -0,0 +1,385 @@
"""Journey eval runner — tests the prompt_template builder conversation.
For each (journey_fixture × model) combination:
1. Build a MockExecutor (for filesystem tools used during journey)
2. Patch execute_on_client
3. Override LLM_MODEL
4. Call handle_journey_start to kick off the conversation
5. Feed simulated user_messages via handle_journey_message
6. Collect the generated prompt_template
7. Score it against expected_template_criteria (via LLM judge)
8. Report to Langfuse
"""
from __future__ import annotations
import asyncio
import copy
import json
import logging
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
from eval.config import JourneyFixture
from eval.mock_executor import MockExecutor
from eval import langfuse_eval
logger = logging.getLogger(__name__)
# ── Result type ──────────────────────────────────────────────────────────
@dataclass
class JourneyEvalResult:
"""Result of one journey eval run."""
fixture_name: str
model: str
prompt_template: str | None # the generated template (None if journey failed)
conversation_turns: int
done: bool # whether journey reached completion
criteria_scores: dict[str, float] # criterion → 0-1 score
overall_score: float # average of criteria scores
judge_reasoning: str
elapsed_seconds: float
def summary(self) -> dict[str, Any]:
return {
"fixture": self.fixture_name,
"model": self.model,
"done": self.done,
"turns": self.conversation_turns,
"overall_score": round(self.overall_score, 3),
"criteria_scores": {k: round(v, 3) for k, v in self.criteria_scores.items()},
"elapsed_s": round(self.elapsed_seconds, 1),
}
# ── LLM judge for template quality ──────────────────────────────────────
_JOURNEY_JUDGE_SYSTEM = """\
You are an evaluation judge for AI-generated prompt templates.
A journey chatbot explored a user's directory structure and through
conversation produced a prompt_template — an instruction set for a
data-extraction agent.
Your task: evaluate the generated template against a list of criteria.
Score each criterion from 0 to 1:
- 1.0: Fully satisfied, clearly present in the template
- 0.5: Partially satisfied or ambiguously addressed
- 0.0: Not satisfied, missing from the template
Respond with ONLY a JSON object:
{
"scores": {"criterion_1": 0.8, "criterion_2": 1.0, ...},
"reasoning": "Brief explanation"
}
"""
async def _judge_template(
prompt_template: str,
criteria: list[str],
*,
judge_model: str = "gpt-4o-mini",
) -> tuple[dict[str, float], str]:
"""Use an LLM to evaluate a generated prompt_template against criteria.
Returns (criteria_scores, reasoning).
"""
from shared.llm import get_llm
llm = get_llm(model=judge_model, temperature=0)
criteria_text = "\n".join(f" {i+1}. {c}" for i, c in enumerate(criteria))
user_content = (
f"## Generated prompt_template\n```\n{prompt_template}\n```\n\n"
f"## Criteria to evaluate\n{criteria_text}"
)
try:
response = await llm.ainvoke([
SystemMessage(content=_JOURNEY_JUDGE_SYSTEM),
HumanMessage(content=user_content),
])
raw = response.content.strip()
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
parsed = json.loads(raw.strip())
scores_raw = parsed.get("scores", {})
# Map criterion keys back to the original criteria text
criteria_scores: dict[str, float] = {}
for i, criterion in enumerate(criteria):
# Try matching by index key or exact criterion text
key_candidates = [
f"criterion_{i+1}",
criterion,
criterion[:50],
str(i + 1),
]
score = 0.0
for key in key_candidates:
if key in scores_raw:
score = float(scores_raw[key])
break
# If no match found, try values in order
if score == 0.0 and i < len(scores_raw):
score = float(list(scores_raw.values())[i])
criteria_scores[criterion] = score
reasoning = str(parsed.get("reasoning", ""))
return criteria_scores, reasoning
except Exception as exc:
logger.warning("journey_eval: LLM judge failed: %s", exc)
return {c: 0.0 for c in criteria}, f"Judge error: {exc}"
# ── Journey runner ───────────────────────────────────────────────────────
async def run_single_journey_eval(
fixture: JourneyFixture,
model: str,
*,
judge_model: str = "gpt-4o-mini",
data_dir: Path | None = None,
) -> JourneyEvalResult:
"""Execute one journey eval: start \u2192 messages \u2192 score template."""
from shared.config import settings
# When data_dir is given, use its parent as MockExecutor root
# and its name as the journey directory so the LLM sees a
# meaningful path (not ".").
if data_dir:
mock_root = data_dir.parent
journey_directory = data_dir.name
else:
mock_root = fixture.fixture_path.parent
journey_directory = fixture.directory
mock = MockExecutor(
fixture_dir=mock_root,
seed_records={},
)
original_model = settings.LLM_MODEL
settings.LLM_MODEL = model
eval_user_id = f"eval-journey-{uuid.uuid4().hex[:8]}"
logger.info(
"journey_eval: starting %s | model=%s",
fixture.name, model,
)
start_time = time.time()
prompt_template: str | None = None
conversation: list[dict[str, str]] = []
done = False
try:
from shared.ws_context import set_current_user, clear_current_user
from app.journey import handle_journey_start, handle_journey_message, _sessions
set_current_user(eval_user_id)
with mock.patch():
# ── Start the journey ────────────────────────────────
start_frame: dict[str, Any] = {
"agent_type": "local",
"directory": journey_directory,
"data_types": fixture.data_types,
"session_id": f"eval-{uuid.uuid4().hex[:8]}",
}
reply = await handle_journey_start(eval_user_id, start_frame)
session_id = reply["session_id"]
conversation.append({"role": "assistant", "content": reply["message"]})
logger.info(
"journey_eval: start reply (%d chars), done=%s",
len(reply["message"]), reply["done"],
)
if reply["done"]:
prompt_template = reply.get("prompt_template")
done = True
else:
# ── Send user messages ───────────────────────────
for i, user_msg in enumerate(fixture.user_messages):
if done:
break
conversation.append({"role": "user", "content": user_msg})
msg_frame: dict[str, Any] = {
"session_id": session_id,
"message": user_msg,
}
reply = await handle_journey_message(eval_user_id, msg_frame)
conversation.append({"role": "assistant", "content": reply["message"]})
logger.info(
"journey_eval: turn %d reply (%d chars), done=%s",
i + 1, len(reply["message"]), reply["done"],
)
if reply["done"]:
prompt_template = reply.get("prompt_template")
done = True
# If not done after all user messages, send a final nudge
if not done:
nudge = "Please generate the final prompt_template now. I'm satisfied with the configuration."
conversation.append({"role": "user", "content": nudge})
nudge_frame: dict[str, Any] = {
"session_id": session_id,
"message": nudge,
}
reply = await handle_journey_message(eval_user_id, nudge_frame)
conversation.append({"role": "assistant", "content": reply["message"]})
if reply["done"]:
prompt_template = reply.get("prompt_template")
done = True
except Exception as exc:
logger.error("journey_eval: pipeline failed for %s/%s: %s", fixture.name, model, exc)
finally:
settings.LLM_MODEL = original_model
from shared.ws_context import clear_current_user
clear_current_user()
elapsed = time.time() - start_time
turns = len([c for c in conversation if c["role"] == "user"])
logger.info(
"journey_eval: completed in %.1fs — %d turns, done=%s, template=%s",
elapsed, turns, done, "yes" if prompt_template else "no",
)
# ── Score the template ───────────────────────────────────────
criteria_scores: dict[str, float] = {}
judge_reasoning = ""
if prompt_template and fixture.expected_template_criteria:
criteria_scores, judge_reasoning = await _judge_template(
prompt_template,
fixture.expected_template_criteria,
judge_model=judge_model,
)
elif not prompt_template:
criteria_scores = {c: 0.0 for c in fixture.expected_template_criteria}
judge_reasoning = "No prompt_template was generated — journey did not complete."
overall = (
sum(criteria_scores.values()) / len(criteria_scores)
if criteria_scores
else 0.0
)
result = JourneyEvalResult(
fixture_name=fixture.name,
model=model,
prompt_template=prompt_template,
conversation_turns=turns,
done=done,
criteria_scores=criteria_scores,
overall_score=overall,
judge_reasoning=judge_reasoning,
elapsed_seconds=elapsed,
)
# ── Report to Langfuse ───────────────────────────────────────
trace_id = langfuse_eval.log_eval_trace(
fixture_name=fixture.name,
model=model,
prompt_variant="journey",
prompt_template=prompt_template or "(not generated)",
actual_mutations=[{"conversation": conversation[:20]}],
scores_summary=result.summary(),
langfuse_prompt_names=["journey_system"],
)
if trace_id:
from eval.scorer import EvalScores
scores_obj = EvalScores(
fixture_name=fixture.name,
model=model,
prompt_variant="journey",
precision=overall,
recall=float(done),
f1=overall,
llm_judge_score=overall,
llm_judge_reasoning=judge_reasoning,
)
langfuse_eval.post_eval_scores(scores_obj, trace_id=trace_id)
return result
async def run_journey_fixture_eval(
fixture: JourneyFixture,
models: list[str],
*,
judge_model: str = "gpt-4o-mini",
data_dir: Path | None = None,
) -> list[JourneyEvalResult]:
"""Run all models for a journey fixture."""
langfuse_eval.sync_journey_fixture_to_dataset(fixture)
results: list[JourneyEvalResult] = []
for model in models:
result = await run_single_journey_eval(
fixture, model, judge_model=judge_model,
data_dir=data_dir,
)
results.append(result)
return results
def print_journey_results(results: list[JourneyEvalResult]) -> None:
"""Print a formatted summary of journey eval results."""
if not results:
print("\nNo journey eval results.")
return
print("\n" + "=" * 95)
print(f"{'Fixture':<25} {'Model':<25} {'Done':>5} {'Turns':>6} {'Score':>7} {'Time':>7}")
print("-" * 95)
for r in results:
done_str = "yes" if r.done else "NO"
print(
f"{r.fixture_name:<25} {r.model:<25} {done_str:>5} "
f"{r.conversation_turns:>6} {r.overall_score:>7.2f} {r.elapsed_seconds:>6.1f}s"
)
print("=" * 95)
# Criteria breakdown
for r in results:
if r.criteria_scores:
print(f"\n[{r.model}] Criteria scores:")
for criterion, score in r.criteria_scores.items():
indicator = "PASS" if score >= 0.7 else "PARTIAL" if score >= 0.4 else "FAIL"
print(f" {indicator:>7} ({score:.1f}) {criterion}")
if r.judge_reasoning:
print(f" Judge: {r.judge_reasoning}")
if r.prompt_template:
preview = r.prompt_template[:200].replace("\n", " ")
print(f" Template preview: {preview}...")
print()

View File

@@ -0,0 +1,327 @@
"""Langfuse evaluation integration — datasets, runs, and scoring.
Uses the Langfuse Python SDK v4 (OpenTelemetry-based) to:
1. **Sync fixtures → Langfuse datasets**: Each YAML fixture becomes a dataset,
each prompt variant + expected pair becomes a dataset item.
2. **Track eval runs**: Each (fixture × model × prompt_variant) execution
is recorded as a trace with linked scores.
3. **Post scores**: precision, recall, F1, field_accuracy, llm_judge are
posted as numeric scores on the trace.
"""
from __future__ import annotations
import logging
import os
from typing import Any
from shared.config import settings
from eval.config import EvalFixture
from eval.scorer import EvalScores
logger = logging.getLogger(__name__)
def _get_langfuse():
"""Get or create a Langfuse client instance (SDK v4)."""
if not settings.LANGFUSE_SECRET_KEY or not settings.LANGFUSE_PUBLIC_KEY:
return None
try:
os.environ.setdefault("LANGFUSE_SECRET_KEY", settings.LANGFUSE_SECRET_KEY)
os.environ.setdefault("LANGFUSE_PUBLIC_KEY", settings.LANGFUSE_PUBLIC_KEY)
if settings.LANGFUSE_HOST:
os.environ.setdefault("LANGFUSE_HOST", settings.LANGFUSE_HOST)
from langfuse import get_client
return get_client()
except Exception as exc:
logger.warning("langfuse_eval: failed to create client: %s", exc)
return None
def sync_fixture_to_dataset(fixture: EvalFixture) -> str | None:
"""Create or update a Langfuse dataset from a fixture.
Each prompt variant becomes a separate dataset item with:
- input: {directory, data_types, prompt_template, seed_records}
- expected_output: {expected records}
Returns the dataset name, or None if Langfuse is unavailable.
"""
lf = _get_langfuse()
if lf is None:
logger.info("langfuse_eval: Langfuse not configured — skipping dataset sync")
return None
dataset_name = f"batch-eval-{fixture.name}"
try:
lf.create_dataset(
name=dataset_name,
description=fixture.description,
metadata={
"data_types": ",".join(fixture.data_types),
"file_extensions": ",".join(fixture.file_extensions) if fixture.file_extensions else "",
},
)
except Exception:
# Dataset may already exist — that's fine
pass
# Build expected_output appropriate to the fixture's mode
expected_output: dict[str, Any] = {}
if fixture.mode in ("step1", "full") and fixture.expected_classification:
expected_output["classifications"] = [
{"file": ec.file, "project_id": ec.project_id, "domains": ec.domains}
for ec in fixture.expected_classification
]
if fixture.mode in ("step2", "full") and fixture.expected:
for rec in fixture.expected:
expected_output.setdefault(rec.table, []).append(rec.fields)
item_id = f"{fixture.name}--{fixture.mode}"
try:
lf.create_dataset_item(
dataset_name=dataset_name,
id=item_id,
input={
"directory": fixture.directory,
"data_types": fixture.data_types,
"mode": fixture.mode,
"seed_records": fixture.seed_records,
},
expected_output=expected_output,
metadata={"mode": fixture.mode},
)
except Exception as exc:
logger.warning(
"langfuse_eval: failed to upsert dataset item %s: %s", item_id, exc
)
lf.flush()
logger.info("langfuse_eval: synced fixture '%s' → dataset '%s'", fixture.name, dataset_name)
return dataset_name
def sync_journey_fixture_to_dataset(fixture) -> str | None:
"""Create or update a Langfuse dataset from a journey fixture.
Each journey fixture becomes a single dataset item with:
- input: {directory, data_types, user_messages}
- expected_output: {criteria}
"""
lf = _get_langfuse()
if lf is None:
logger.info("langfuse_eval: Langfuse not configured — skipping journey dataset sync")
return None
dataset_name = f"journey-eval-{fixture.name}"
try:
lf.create_dataset(
name=dataset_name,
description=fixture.description,
metadata={"type": "journey", "data_types": ",".join(fixture.data_types)},
)
except Exception:
pass # Dataset may already exist
item_id = f"{fixture.name}--journey"
try:
lf.create_dataset_item(
dataset_name=dataset_name,
id=item_id,
input={
"directory": fixture.directory,
"data_types": fixture.data_types,
"user_messages": fixture.user_messages,
},
expected_output={
"criteria": fixture.expected_template_criteria,
},
metadata={"type": "journey"},
)
except Exception as exc:
logger.warning("langfuse_eval: failed to upsert journey dataset item %s: %s", item_id, exc)
lf.flush()
logger.info("langfuse_eval: synced journey fixture '%s' → dataset '%s'", fixture.name, dataset_name)
return dataset_name
def create_eval_run(
dataset_name: str,
run_name: str,
*,
metadata: dict[str, Any] | None = None,
) -> str:
"""Create a dataset run in Langfuse. Returns the run name.
Note: In SDK v4, dataset runs are created implicitly via
dataset.run_experiment(). This function is kept for backwards
compatibility but may not create a run.
"""
lf = _get_langfuse()
if lf is None:
return run_name
try:
if hasattr(lf, "create_dataset_run"):
lf.create_dataset_run(
dataset_name=dataset_name,
run_name=run_name,
metadata=metadata or {},
)
lf.flush()
else:
logger.debug("langfuse_eval: create_dataset_run not available in SDK v4")
except Exception as exc:
logger.warning("langfuse_eval: failed to create run %s: %s", run_name, exc)
return run_name
def post_eval_scores(
scores: EvalScores,
*,
trace_id: str | None = None,
dataset_name: str | None = None,
run_name: str | None = None,
) -> None:
"""Post evaluation scores to Langfuse.
If trace_id is provided, scores are attached to that trace.
"""
lf = _get_langfuse()
if lf is None:
return
score_data = [
("precision", scores.precision),
("recall", scores.recall),
("f1", scores.f1),
]
# Only post field_accuracy when there are field-level scores (step2/full)
if scores.field_scores:
score_data.append(("field_accuracy", scores.field_accuracy))
if scores.llm_judge_score is not None:
score_data.append(("llm_judge", scores.llm_judge_score))
for name, value in score_data:
try:
lf.create_score(
name=name,
value=value,
trace_id=trace_id,
data_type="NUMERIC",
comment=f"{scores.fixture_name} | {scores.model} | {scores.prompt_variant}",
)
except Exception as exc:
logger.warning("langfuse_eval: failed to post score %s: %s", name, exc)
lf.flush()
logger.info(
"langfuse_eval: posted %d scores for %s/%s/%s",
len(score_data), scores.fixture_name, scores.model, scores.prompt_variant,
)
def log_eval_trace(
*,
fixture_name: str,
model: str,
prompt_variant: str,
prompt_template: str,
actual_mutations: list[dict],
scores_summary: dict[str, Any],
step1_results: list[dict] | None = None,
dataset_name: str | None = None,
run_name: str | None = None,
dataset_item_id: str | None = None,
langfuse_prompt_names: list[str] | None = None,
) -> str | None:
"""Create a Langfuse trace for one eval execution and link it to a dataset run.
Uses SDK v4 observation API (traces are created implicitly by root spans).
``langfuse_prompt_names`` can contain one or two prompt names to link
(e.g. ``["batch_file_classifier", "batch_processing"]`` for full mode).
Each prompt gets its own generation-type observation for per-version
metrics tracking.
Returns the trace_id, or None if Langfuse is unavailable.
"""
lf = _get_langfuse()
if lf is None:
return None
try:
from langfuse import propagate_attributes
# Fetch prompt objects for linking
prompt_objs: list[tuple[str, Any]] = []
for pname in (langfuse_prompt_names or []):
try:
obj = lf.get_prompt(name=pname, cache_ttl_seconds=300)
prompt_objs.append((pname, obj))
logger.info("langfuse_eval: linked prompt '%s' (type=%s)", pname, type(obj).__name__)
except Exception as exc:
logger.warning("langfuse_eval: prompt '%s' not found — %s", pname, exc)
# Build trace output dict
trace_output: dict[str, Any] = {"scores": scores_summary}
if step1_results:
trace_output["classifications"] = step1_results
if actual_mutations:
trace_output["mutations"] = actual_mutations[:50]
with propagate_attributes(
trace_name=f"eval-{fixture_name}",
metadata={
"eval": "true",
"fixture": fixture_name,
"model": model,
"prompt_variant": prompt_variant,
},
tags=["eval", f"model:{model}", f"variant:{prompt_variant}"],
):
# Root span for the eval run
span = lf.start_observation(name=f"eval-{fixture_name}")
span.update(
input={
"prompt_template": prompt_template,
"model": model,
"prompt_variant": prompt_variant,
},
output=trace_output,
)
trace_id = span.trace_id
# Create a generation-type observation per linked prompt
for pname, pobj in prompt_objs:
gen = lf.start_observation(
name=f"prompt-{pname}",
prompt=pobj,
as_type="generation",
)
gen.end()
# Link to dataset run if available
if dataset_name and run_name and dataset_item_id:
try:
dataset = lf.get_dataset(dataset_name)
for item in dataset.items:
if item.id == dataset_item_id:
item.link(span, run_name)
break
except Exception as exc:
logger.warning("langfuse_eval: failed to link trace to dataset run: %s", exc)
span.end()
lf.flush()
return trace_id
except Exception as exc:
logger.warning("langfuse_eval: failed to create eval trace: %s", exc)
return None

View File

@@ -0,0 +1,258 @@
"""Mock executor — intercepts execute_on_client for offline E2E testing.
Patches ``execute_on_client`` at all usage sites so agent pipeline runs don't
require a live Electron client or Redis. Instead:
- **Filesystem actions** (list_directory, read_file_content, get_file_metadata)
are served from local fixture files on disk.
- **Read actions** (select, get) return preseeded records from an in-memory
store provided by the test fixture.
- **Write actions** (insert, update, delete) are captured as *mutations* and
stored for later comparison against expected results.
"""
from __future__ import annotations
import json
import os
import time
import uuid
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from contextlib import contextmanager, asynccontextmanager
from unittest.mock import AsyncMock, patch
@dataclass
class Mutation:
"""A single recorded write operation."""
action: str # insert | update | delete
table: str
data: dict[str, Any]
timestamp: float = field(default_factory=time.time)
# ── Fake DB helpers (used to bypass async_session in full mode) ───────
class _FakeRow:
"""Mimics an AgentRunLog row returned by SQLAlchemy."""
id = 0
status = "running"
items_processed = 0
items_created = 0
errors: list[str] = []
completed_at = None
def __setattr__(self, name: str, value: Any) -> None:
object.__setattr__(self, name, value)
class _FakeResult:
"""Mimics a SQLAlchemy ``Result`` with ``scalar_one_or_none``."""
def __init__(self, row: _FakeRow) -> None:
self._row = row
def scalar_one_or_none(self) -> _FakeRow:
return self._row
@dataclass
class MockExecutor:
"""In-memory executor that replaces Redis-based tool round-trip.
Parameters
----------
fixture_dir : Path
Directory containing sample files for filesystem tool calls.
seed_records : dict[str, list[dict]]
Pre-existing records per table, e.g. ``{"tasks": [...], "projects": [...]}``.
The executor returns these for ``select`` / ``get`` actions and auto-updates
them on ``insert`` / ``update`` / ``delete`` so subsequent selects reflect changes.
"""
fixture_dir: Path
seed_records: dict[str, list[dict]] = field(default_factory=dict)
mutations: list[Mutation] = field(default_factory=list)
_id_counter: int = field(default=1000, repr=False)
# ── Public API ───────────────────────────────────────────────────
def reset(self) -> None:
"""Clear recorded mutations (keep seed_records intact)."""
self.mutations.clear()
def get_mutations(self, *, table: str | None = None, action: str | None = None) -> list[Mutation]:
"""Filter mutations by table and/or action."""
result = self.mutations
if table:
result = [m for m in result if m.table == table]
if action:
result = [m for m in result if m.action == action]
return result
def created_records(self, table: str) -> list[dict]:
"""Return data dicts of all inserts into *table*."""
return [m.data for m in self.mutations if m.table == table and m.action == "insert"]
def updated_records(self, table: str) -> list[dict]:
"""Return data dicts of all updates to *table*."""
return [m.data for m in self.mutations if m.table == table and m.action == "update"]
# ── Context manager for patching ──────────────────────────────
@contextmanager
def patch(self):
"""Patch execute_on_client and DB session at all usage sites."""
mock_fn = AsyncMock(side_effect=self._handle)
targets = [
"shared.ws_context.execute_on_client",
"app.agent_runner.execute_on_client",
"app.agents.filesystem_agent.execute_on_client",
]
# Mock async_session so run_local_agent / _finalize_run skip real DB
fake_row = _FakeRow()
fake_db = AsyncMock()
fake_db.commit = AsyncMock()
fake_db.refresh = AsyncMock()
fake_db.execute = AsyncMock(return_value=_FakeResult(fake_row))
fake_db.add = lambda obj: None # noqa: ARG005
@asynccontextmanager
async def _fake_session():
yield fake_db
patches = [patch(t, new=mock_fn) for t in targets]
patches.append(patch("app.agent_runner.async_session", _fake_session))
for p in patches:
p.start()
try:
yield mock_fn
finally:
for p in patches:
p.stop()
# ── Internal dispatch ─────────────────────────────────────────
async def _handle(
self,
action: str,
table: str | None = None,
data: dict[str, Any] | None = None,
filters: dict[str, Any] | None = None,
vector: list[float] | None = None,
limit: int | None = None,
) -> dict[str, Any]:
# Filesystem
if action == "list_directory":
return self._list_directory(data or {})
if action == "read_file_content":
return self._read_file(data or {})
if action == "get_file_metadata":
return self._get_file_metadata(data or {})
# CRUD
if action == "select":
return self._select(table or "", filters)
if action == "get":
return self._get(table or "", data or {})
if action == "insert":
return self._insert(table or "", data or {})
if action == "update":
return self._update(table or "", data or {})
if action == "delete":
return self._delete(table or "", data or {})
# Vector (no-op for eval)
if action in ("vector_upsert", "vector_search"):
return {"rows": []}
return {"error": f"Unknown action: {action}"}
# ── Filesystem handlers ───────────────────────────────────────
def _list_directory(self, data: dict) -> dict:
rel_path = data.get("path", "")
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
if not abs_path.is_dir():
return {"entries": []}
entries: list[dict] = []
for child in sorted(abs_path.iterdir()):
entry_type = "directory" if child.is_dir() else "file"
# Return paths relative to fixture_dir but with the original prefix
entry_path = rel_path.rstrip("/\\") + "/" + child.name
entries.append({
"name": child.name,
"path": entry_path,
"type": entry_type,
})
return {"entries": entries}
def _read_file(self, data: dict) -> dict:
rel_path = data.get("path", "")
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
if not abs_path.is_file():
return {"content": "", "error": f"File not found: {rel_path}"}
return {"content": abs_path.read_text(encoding="utf-8", errors="replace")}
def _get_file_metadata(self, data: dict) -> dict:
rel_path = data.get("path", "")
abs_path = self.fixture_dir / rel_path.lstrip("/\\")
if not abs_path.exists():
return {"error": f"Not found: {rel_path}"}
stat = abs_path.stat()
return {
"path": rel_path,
"size": stat.st_size,
"modifiedAt": int(stat.st_mtime * 1000),
"createdAt": int(stat.st_ctime * 1000),
"isDirectory": abs_path.is_dir(),
}
# ── CRUD handlers ─────────────────────────────────────────────
def _select(self, table: str, filters: dict | None) -> dict:
rows = list(self.seed_records.get(table, []))
if filters:
rows = [
r for r in rows
if all(r.get(k) == v for k, v in filters.items() if v is not None)
]
return {"rows": rows}
def _get(self, table: str, data: dict) -> dict:
record_id = data.get("id", "")
rows = self.seed_records.get(table, [])
for r in rows:
if r.get("id") == record_id:
return {"row": r}
return {"row": None}
def _insert(self, table: str, data: dict) -> dict:
self._id_counter += 1
record = {**data, "id": str(self._id_counter)}
# Add to seed so subsequent selects can find it
self.seed_records.setdefault(table, []).append(record)
self.mutations.append(Mutation(action="insert", table=table, data=record))
return {"row": record}
def _update(self, table: str, data: dict) -> dict:
record_id = data.get("id", "")
rows = self.seed_records.get(table, [])
for r in rows:
if r.get("id") == record_id:
r.update({k: v for k, v in data.items() if v is not None and v != ""})
self.mutations.append(Mutation(action="update", table=table, data=dict(r)))
return {"row": r}
# Record not found — still log the mutation
self.mutations.append(Mutation(action="update", table=table, data=data))
return {"row": data}
def _delete(self, table: str, data: dict) -> dict:
record_id = data.get("id", "")
rows = self.seed_records.get(table, [])
self.seed_records[table] = [r for r in rows if r.get("id") != record_id]
self.mutations.append(Mutation(action="delete", table=table, data={"id": record_id}))
return {"deleted": True}

View File

@@ -0,0 +1,2 @@
# Extra dependencies for the eval harness (on top of the service requirements.txt)
pyyaml>=6.0.0

View File

@@ -0,0 +1,545 @@
"""Eval runner — orchestrates fixture → mock → agent pipeline → scoring.
Supports three eval modes:
- **step1**: Test classification prompt only (``_STEP1_SYSTEM_PROMPT``).
Calls the LLM with fixture-provided ``domain_definitions`` and
``projects_list`` and compares output against ``expected_classification``.
- **step2**: Test processing prompt only (``_PROCESSING_SYSTEM_PROMPT``).
Compiles the prompt with fixture-provided ``existing_context``,
``project_context``, ``data_types``, and ``custom_prompt_section``,
then runs the tool-calling loop. Mutations are scored against
``expected`` records.
- **full**: Run ``run_local_agent()`` end-to-end (both steps).
Scored on both classification and extraction.
"""
from __future__ import annotations
import copy
import json
import logging
import time
import uuid
from typing import Any
from eval.config import EvalFixture, ExpectedClassification
from eval.mock_executor import MockExecutor
from eval.scorer import (
EvalScores,
FieldScore,
compute_precision_recall,
llm_judge_score,
score_field_match,
)
from eval import langfuse_eval
logger = logging.getLogger(__name__)
# ── Step 1 runner ─────────────────────────────────────────────────────────
async def _run_step1(
fixture: EvalFixture,
model: str,
mock: MockExecutor,
) -> list[dict[str, Any]]:
"""Run step-1 classification for every file in the fixture directory.
Scans the directory recursively, classifies each file, and returns
a list of result dicts:
``[{file, project_id, domains, new_project_name}, ...]``
"""
from app.agent_runner import _classify_file
# Build project name lookup for display
proj_names: dict[str, str] = {
p.get("id", ""): p.get("name", "") for p in fixture.projects_list
}
# Discover all files in the fixture directory
all_files = await _scan_fixture_files(mock, fixture.directory)
print(f"\n Scanning {len(all_files)} files in {fixture.directory}\n")
results: list[dict[str, Any]] = []
for i, file_path in enumerate(all_files, 1):
file_result = await mock._handle(
action="read_file_content",
data={"path": file_path},
)
file_content: str = file_result.get("content", "")
if not file_content.strip():
continue
project_id, domains, new_name = await _classify_file(
file_path=file_path,
file_content=file_content,
projects=fixture.projects_list,
config_data_types=fixture.data_types,
custom_system_prompt=fixture.custom_step1_prompt or None,
)
short_name = file_path.rsplit("/", 1)[-1] if "/" in file_path else file_path
proj_label = proj_names.get(project_id, new_name or "?")
print(f" [{i}/{len(all_files)}] {short_name}{project_id} ({proj_label}) {domains}")
results.append({
"file": file_path,
"project_id": project_id,
"domains": domains,
"new_project_name": new_name,
})
return results
async def _scan_fixture_files(mock: MockExecutor, directory: str) -> list[str]:
"""Recursively list all files under *directory* via the mock executor."""
files: list[str] = []
async def _walk(path: str) -> None:
result = await mock._handle(action="list_directory", data={"path": path})
for entry in result.get("entries", []):
if entry.get("type") == "directory":
await _walk(entry["path"])
elif entry.get("type") == "file":
files.append(entry["path"])
await _walk(directory)
return sorted(files)
def _score_step1(
fixture: EvalFixture,
results: list[dict[str, Any]],
) -> tuple[float, float, float, str]:
"""Score step-1 results. Returns (precision, recall, f1, reasoning).
Files with expected classifications are scored (OK/FAIL).
Files without expectations are shown as informational (INFO).
"""
if not fixture.expected_classification:
return 0.0, 0.0, 0.0, "No expected classifications"
# Build project name lookup
proj_names: dict[str, str] = {
p.get("id", ""): p.get("name", "") for p in fixture.projects_list
}
proj_names["new"] = "(new project)"
def _proj_label(pid: str, new_name: str | None = None) -> str:
name = proj_names.get(pid, "?")
if pid == "new" and new_name:
return f"new → \"{new_name}\""
return f"{pid} ({name})" if name and name != "?" else pid
def _short_file(path: str) -> str:
"""Use just the filename for cleaner display."""
return path.rsplit("/", 1)[-1] if "/" in path else path
expected_files = {ec.file for ec in fixture.expected_classification}
total = len(fixture.expected_classification)
matched = 0
scored_lines: list[str] = []
info_lines: list[str] = []
# Score expected files
for ec in fixture.expected_classification:
actual = next((r for r in results if r["file"] == ec.file), None)
fname = _short_file(ec.file)
if actual is None:
scored_lines.append(f" MISS {fname}")
scored_lines.append(f" expected: {_proj_label(ec.project_id)}")
continue
pid_ok = actual["project_id"] == ec.project_id
domains_ok = set(actual["domains"]) == set(ec.domains) if ec.domains else True
if pid_ok and domains_ok:
matched += 1
scored_lines.append(f" OK {fname}")
scored_lines.append(f" project: {_proj_label(actual['project_id'])}")
scored_lines.append(f" domains: {actual['domains']}")
else:
scored_lines.append(f" FAIL {fname}")
if not pid_ok:
scored_lines.append(f" project: {_proj_label(actual['project_id'])} (expected: {_proj_label(ec.project_id)})")
else:
scored_lines.append(f" project: {_proj_label(actual['project_id'])}")
if not domains_ok:
scored_lines.append(f" domains: {actual['domains']} (expected: {ec.domains})")
else:
scored_lines.append(f" domains: {actual['domains']}")
# Show unscored files
for r in results:
if r["file"] not in expected_files:
fname = _short_file(r["file"])
proj = _proj_label(r["project_id"], r.get("new_project_name"))
info_lines.append(f" · {fname}")
info_lines.append(f" project: {proj} | domains: {r['domains']}")
precision = matched / total if total > 0 else 0.0
recall = precision
f1 = precision
parts: list[str] = []
if scored_lines:
parts.append(f"Scored ({matched}/{total}):")
parts.extend(scored_lines)
if info_lines:
parts.append(f"\nOther files ({len(info_lines) // 2}):")
parts.extend(info_lines)
return precision, recall, f1, "\n".join(parts)
# ── Step 2 runner ─────────────────────────────────────────────────────────
async def _run_step2(
fixture: EvalFixture,
model: str,
mock: MockExecutor,
) -> None:
"""Run step-2 processing for each file in the fixture directory.
Compiles ``_PROCESSING_SYSTEM_PROMPT`` with fixture-provided variables
and runs the tool-calling loop. Mutations are captured by the mock.
"""
from app.agent_runner import (
_PROCESSING_SYSTEM_PROMPT,
_build_processing_tools,
_run_agent_with_tools,
_MAX_PROCESSING_STEPS,
)
from app import tracing
# Compile the processing prompt with fixture variables
system_prompt = tracing.compile_prompt(
"batch_processing",
fallback=_PROCESSING_SYSTEM_PROMPT,
variables={
"existing_context": fixture.existing_context,
"project_context": fixture.project_context,
"data_types": ", ".join(fixture.data_types),
"custom_prompt_section": fixture.custom_prompt_section,
},
)
tools = _build_processing_tools(fixture.data_types)
# Scan files in the fixture directory
file_entries = await mock._handle(
action="list_directory",
data={"path": fixture.directory},
)
for entry in file_entries.get("entries", []):
if entry.get("type") != "file":
continue
# Filter by extension if specified
if fixture.file_extensions:
ext = entry["name"].rsplit(".", 1)[-1] if "." in entry["name"] else ""
if ext not in fixture.file_extensions:
continue
file_result = await mock._handle(
action="read_file_content",
data={"path": entry["path"]},
)
file_content: str = file_result.get("content", "")
if not file_content.strip():
continue
await _run_agent_with_tools(
system_prompt=system_prompt,
user_message=(
f"Process this file and extract relevant information.\n\n"
f"File: {entry['path']}\n\nContent:\n{file_content}"
),
tools=tools,
max_steps=_MAX_PROCESSING_STEPS,
)
# ── Full runner ───────────────────────────────────────────────────────────
async def _run_full(
fixture: EvalFixture,
model: str,
mock: MockExecutor,
user_id: str,
) -> None:
"""Run the full two-step pipeline via ``run_local_agent``."""
from app.agent_runner import run_local_agent
trigger_data: dict[str, Any] = {
"type": "agent_trigger",
"directory": fixture.directory,
"directory_paths": [fixture.directory],
"data_types": fixture.data_types,
"file_extensions": fixture.file_extensions,
"prompt_template": fixture.custom_prompt_section,
"device_id": "eval-harness",
"run_context": {
"agent_id": f"eval-{fixture.name}",
"run_id": None,
},
}
with mock.patch():
await run_local_agent(user_id, trigger_data)
# ── Scoring helpers ───────────────────────────────────────────────────────
def _score_mutations(
fixture: EvalFixture,
mock: MockExecutor,
) -> tuple[list[FieldScore], float, float, float, int, int]:
"""Score mutations against expected records.
Returns (field_scores, precision, recall, f1, extra, missing).
"""
all_field_scores: list[FieldScore] = []
total_expected = 0
total_actual = 0
total_matched = 0
total_extra = 0
total_missing = 0
expected_by_table: dict[str, list[dict]] = {}
for rec in fixture.expected:
expected_by_table.setdefault(rec.table, []).append(rec.fields)
tables = set(expected_by_table.keys()) | {m.table for m in mock.mutations}
for table in tables:
expected_records = expected_by_table.get(table, [])
actual_records = mock.created_records(table) + mock.updated_records(table)
field_scores, extra, missing = score_field_match(expected_records, actual_records, table)
all_field_scores.extend(field_scores)
matched = sum(1 for s in field_scores if s.best_match is not None)
total_expected += len(expected_records)
total_actual += len(actual_records)
total_matched += matched
total_extra += extra
total_missing += missing
precision, recall, f1 = compute_precision_recall(total_expected, total_actual, total_matched)
return all_field_scores, precision, recall, f1, total_extra, total_missing
# ── Main entry point ──────────────────────────────────────────────────────
async def run_single_eval(
fixture: EvalFixture,
model: str,
*,
use_llm_judge: bool = True,
judge_model: str = "gpt-4o-mini",
) -> EvalScores:
"""Execute one eval run for a fixture + model. Mode is read from the fixture."""
from shared.config import settings
from shared.ws_context import set_current_user, clear_current_user
seed = copy.deepcopy(fixture.seed_records)
mock = MockExecutor(
fixture_dir=fixture.fixture_path.parent,
seed_records=seed,
)
original_model = settings.LLM_MODEL
settings.LLM_MODEL = model
eval_user_id = str(uuid.uuid4())
logger.info(
"eval: starting %s | mode=%s | model=%s",
fixture.name, fixture.mode, model,
)
start_time = time.time()
step1_results: list[dict[str, Any]] = []
step1_reasoning = ""
try:
set_current_user(eval_user_id)
if fixture.mode == "step1":
with mock.patch():
step1_results = await _run_step1(fixture, model, mock)
elif fixture.mode == "step2":
with mock.patch():
await _run_step2(fixture, model, mock)
elif fixture.mode == "full":
with mock.patch():
# Step 1 — classification (independent from run_local_agent)
if fixture.expected_classification:
step1_results = await _run_step1(fixture, model, mock)
# Step 2 — full pipeline (run_local_agent handles both steps)
await _run_full(fixture, model, mock, eval_user_id)
except Exception as exc:
logger.error("eval: pipeline failed for %s/%s: %s", fixture.name, model, exc)
finally:
settings.LLM_MODEL = original_model
clear_current_user()
elapsed = time.time() - start_time
logger.info("eval: completed in %.1fs — %d mutations", elapsed, len(mock.mutations))
# ── Score ─────────────────────────────────────────────────────
if fixture.mode == "step1":
s1_precision, s1_recall, s1_f1, step1_reasoning = _score_step1(fixture, step1_results)
scores = EvalScores(
fixture_name=fixture.name,
model=model,
prompt_variant=fixture.mode,
precision=s1_precision,
recall=s1_recall,
f1=s1_f1,
llm_judge_reasoning=step1_reasoning,
)
else:
# step2 or full — score mutations
field_scores, precision, recall, f1, extra, missing = _score_mutations(fixture, mock)
scores = EvalScores(
fixture_name=fixture.name,
model=model,
prompt_variant=fixture.mode,
field_scores=field_scores,
precision=precision,
recall=recall,
f1=f1,
extra_records=extra,
missing_records=missing,
)
# Add step1 classification scores for full mode
if fixture.mode == "full" and fixture.expected_classification:
s1_p, s1_r, s1_f1, step1_reasoning = _score_step1(fixture, step1_results)
scores.llm_judge_reasoning = f"Step1 classification:\n{step1_reasoning}"
# Optional LLM judge for extraction quality
if use_llm_judge and fixture.expected:
all_expected = [r.fields for r in fixture.expected]
all_actual = [m.data for m in mock.mutations if m.action in ("insert", "update")]
judge_score, reasoning = await llm_judge_score(
all_expected, all_actual, judge_model=judge_model,
)
scores.llm_judge_score = judge_score
if step1_reasoning:
scores.llm_judge_reasoning += f"\n\nLLM judge:\n{reasoning}"
else:
scores.llm_judge_reasoning = reasoning
# ── Report to Langfuse ────────────────────────────────────────
prompt_names = {
"step1": ["batch_file_classifier"],
"step2": ["batch_processing"],
"full": ["batch_file_classifier", "batch_processing"],
}.get(fixture.mode, ["batch_processing"])
trace_id = langfuse_eval.log_eval_trace(
fixture_name=fixture.name,
model=model,
prompt_variant=fixture.mode,
prompt_template=fixture.custom_prompt_section or "(default)",
actual_mutations=[{"action": m.action, "table": m.table, "data": m.data} for m in mock.mutations],
scores_summary=scores.summary(),
step1_results=step1_results or None,
langfuse_prompt_names=prompt_names,
)
if trace_id:
langfuse_eval.post_eval_scores(scores, trace_id=trace_id)
# For full mode, post classification scores separately
if fixture.mode == "full" and fixture.expected_classification:
s1_p, s1_r, s1_f1, _ = _score_step1(fixture, step1_results)
for name, value in [
("classification_precision", s1_p),
("classification_recall", s1_r),
("classification_f1", s1_f1),
]:
try:
from langfuse import get_client
lf = get_client()
if lf:
lf.create_score(
name=name,
value=value,
trace_id=trace_id,
data_type="NUMERIC",
comment=f"{fixture.name} | {model} | full",
)
except Exception:
pass
return scores
async def run_fixture_eval(
fixture: EvalFixture,
models: list[str],
*,
use_llm_judge: bool = True,
judge_model: str = "gpt-4o-mini",
) -> list[EvalScores]:
"""Run all models for a fixture."""
langfuse_eval.sync_fixture_to_dataset(fixture)
results: list[EvalScores] = []
for model in models:
scores = await run_single_eval(
fixture, model,
use_llm_judge=use_llm_judge,
judge_model=judge_model,
)
results.append(scores)
return results
def print_results(results: list[EvalScores]) -> None:
"""Print a formatted summary table of eval results."""
if not results:
print("\nNo eval results.")
return
W = 90
print("\n" + "=" * W)
print(f"{'Fixture':<25} {'Mode':<6} {'Model':<25} {'P':>6} {'R':>6} {'F1':>6} {'FA':>6} {'LLM':>6}")
print("-" * W)
for s in results:
llm_str = f"{s.llm_judge_score:.2f}" if s.llm_judge_score is not None else " --"
fa_str = f"{s.field_accuracy:.2f}" if s.field_scores else " --"
print(
f"{s.fixture_name:<25} {s.prompt_variant:<6} {s.model:<25} "
f"{s.precision:>6.2f} {s.recall:>6.2f} {s.f1:>6.2f} "
f"{fa_str:>6} {llm_str:>6}"
)
print("=" * W)
for s in results:
if s.llm_judge_reasoning:
print(f"\n{'' * W}")
print(f" {s.fixture_name} | {s.model} | {s.prompt_variant}")
print(f"{'' * W}")
print(s.llm_judge_reasoning)
print()

View File

@@ -0,0 +1,268 @@
"""Scoring functions for batch agent evaluation.
Two scoring strategies:
1. **FieldMatchScorer** — deterministic check: for each expected record,
find the best-matching actual record and compare specified fields.
Returns precision, recall, and per-field accuracy.
2. **LLMJudgeScorer** — uses a secondary LLM to semantically evaluate
whether the actual extractions satisfy the expected intent, even if
wording differs. Returns a 0-1 score + reasoning.
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass, field
from difflib import SequenceMatcher
from typing import Any
from langchain_core.messages import HumanMessage, SystemMessage
logger = logging.getLogger(__name__)
# ── Result types ─────────────────────────────────────────────────────────
@dataclass
class FieldScore:
"""Score for a single expected record against its best match."""
expected: dict[str, Any]
best_match: dict[str, Any] | None
matched_fields: dict[str, bool]
similarity: float # 0-1 overall similarity
@property
def field_accuracy(self) -> float:
if not self.matched_fields:
return 0.0
return sum(self.matched_fields.values()) / len(self.matched_fields)
@dataclass
class EvalScores:
"""Aggregated scores for one eval run."""
fixture_name: str
model: str
prompt_variant: str
field_scores: list[FieldScore] = field(default_factory=list)
precision: float = 0.0
recall: float = 0.0
f1: float = 0.0
llm_judge_score: float | None = None
llm_judge_reasoning: str = ""
extra_records: int = 0 # records created but not expected
missing_records: int = 0 # expected but not found
@property
def field_accuracy(self) -> float:
if not self.field_scores:
return 0.0
return sum(s.field_accuracy for s in self.field_scores) / len(self.field_scores)
def summary(self) -> dict[str, Any]:
return {
"fixture": self.fixture_name,
"model": self.model,
"prompt_variant": self.prompt_variant,
"precision": round(self.precision, 3),
"recall": round(self.recall, 3),
"f1": round(self.f1, 3),
"field_accuracy": round(self.field_accuracy, 3),
"llm_judge_score": round(self.llm_judge_score, 3) if self.llm_judge_score is not None else None,
"extra_records": self.extra_records,
"missing_records": self.missing_records,
}
# ── Field Match Scorer ───────────────────────────────────────────────────
def _normalize(value: Any) -> str:
"""Normalize a value for comparison."""
if value is None:
return ""
return str(value).strip().lower()
def _text_similarity(a: str, b: str) -> float:
"""Fuzzy text similarity using SequenceMatcher."""
if not a and not b:
return 1.0
if not a or not b:
return 0.0
return SequenceMatcher(None, a.lower(), b.lower()).ratio()
def _find_best_match(
expected: dict[str, Any],
actuals: list[dict[str, Any]],
) -> tuple[dict[str, Any] | None, float]:
"""Find the actual record most similar to expected, return (match, similarity)."""
if not actuals:
return None, 0.0
best_match = None
best_score = 0.0
# Primary matching key: title or name
expected_title = _normalize(expected.get("title", expected.get("name", "")))
for actual in actuals:
actual_title = _normalize(actual.get("title", actual.get("name", "")))
sim = _text_similarity(expected_title, actual_title)
if sim > best_score:
best_score = sim
best_match = actual
return best_match, best_score
def _compare_fields(
expected: dict[str, Any],
actual: dict[str, Any],
) -> dict[str, bool]:
"""Compare each expected field against the actual record."""
results: dict[str, bool] = {}
for key, expected_val in expected.items():
actual_val = actual.get(key)
# Exact match for non-string types
if not isinstance(expected_val, str):
results[key] = actual_val == expected_val
else:
# Fuzzy match for strings (threshold: 0.7)
results[key] = _text_similarity(
_normalize(expected_val), _normalize(actual_val)
) >= 0.7
return results
def score_field_match(
expected_records: list[dict[str, Any]],
actual_records: list[dict[str, Any]],
table: str,
) -> tuple[list[FieldScore], int, int]:
"""Score actual extractions against expected records for one table.
Returns (field_scores, extra_count, missing_count).
"""
field_scores: list[FieldScore] = []
matched_actuals: set[int] = set()
for exp in expected_records:
# Find best match among unmatched actuals
candidates = [
(i, a) for i, a in enumerate(actual_records) if i not in matched_actuals
]
if not candidates:
field_scores.append(FieldScore(
expected=exp, best_match=None, matched_fields={}, similarity=0.0,
))
continue
best_idx, best_match = None, None
best_sim = 0.0
for idx, actual in candidates:
_, sim = _find_best_match(exp, [actual])
if sim > best_sim:
best_sim = sim
best_idx = idx
best_match = actual
if best_sim >= 0.5 and best_match is not None:
matched_actuals.add(best_idx)
matched_fields = _compare_fields(exp, best_match)
field_scores.append(FieldScore(
expected=exp, best_match=best_match,
matched_fields=matched_fields, similarity=best_sim,
))
else:
field_scores.append(FieldScore(
expected=exp, best_match=None, matched_fields={}, similarity=0.0,
))
extra_count = len(actual_records) - len(matched_actuals)
missing_count = sum(1 for s in field_scores if s.best_match is None)
return field_scores, extra_count, missing_count
def compute_precision_recall(
expected_count: int,
actual_count: int,
matched_count: int,
) -> tuple[float, float, float]:
"""Compute precision, recall, F1."""
precision = matched_count / actual_count if actual_count > 0 else 0.0
recall = matched_count / expected_count if expected_count > 0 else 0.0
f1 = (
2 * precision * recall / (precision + recall)
if (precision + recall) > 0
else 0.0
)
return precision, recall, f1
# ── LLM Judge Scorer ─────────────────────────────────────────────────────
_JUDGE_SYSTEM_PROMPT = """\
You are an evaluation judge for a data extraction system.
Your task is to compare the EXPECTED extractions against the ACTUAL extractions
produced by an AI agent, and assess quality on a 0-1 scale.
Scoring criteria:
- 1.0: All expected records found with correct fields, no significant extras
- 0.8: Most expected records found, minor field differences or extras
- 0.6: Core extractions present but some missing or incorrect
- 0.4: Partial match — several expected records missing or wrong
- 0.2: Poor quality — most expected records missing or incorrect
- 0.0: Complete failure — no meaningful overlap
Consider semantic equivalence: "Send invoice" and "Email the invoice" are matches.
Ignore field ordering and formatting differences.
Respond with ONLY a JSON object:
{"score": 0.85, "reasoning": "Brief explanation of the score"}
"""
async def llm_judge_score(
expected: list[dict[str, Any]],
actual: list[dict[str, Any]],
*,
judge_model: str = "gpt-4o-mini",
) -> tuple[float, str]:
"""Use an LLM to semantically evaluate extraction quality.
Returns (score, reasoning).
"""
from shared.llm import get_llm
llm = get_llm(model=judge_model, temperature=0)
user_content = (
f"## Expected extractions\n```json\n{json.dumps(expected, indent=2, default=str)}\n```\n\n"
f"## Actual extractions\n```json\n{json.dumps(actual, indent=2, default=str)}\n```"
)
try:
response = await llm.ainvoke([
SystemMessage(content=_JUDGE_SYSTEM_PROMPT),
HumanMessage(content=user_content),
])
raw = response.content.strip()
if raw.startswith("```"):
raw = raw.split("```")[1]
if raw.startswith("json"):
raw = raw[4:]
parsed = json.loads(raw.strip())
return float(parsed.get("score", 0.0)), str(parsed.get("reasoning", ""))
except Exception as exc:
logger.warning("eval: LLM judge failed: %s", exc)
return 0.0, f"Judge error: {exc}"

View File

@@ -0,0 +1,21 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
redis>=5.0.0
cryptography>=42.0.0
python-dotenv>=1.0.0
langchain-core>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.3.0
litellm>=1.50.0
openai>=1.50.0
httpx>=0.27.0
langfuse>=3.0.0
croniter>=2.0.0
google-api-python-client>=2.130.0
google-auth>=2.30.0
msal>=1.28.0

View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/billing/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/billing/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Billing is lightweight — single worker is fine
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "1", \
"--timeout", "30"]

View File

@@ -0,0 +1,15 @@
# Billing Service
Owns: Stripe integration, tier management, subscription CRUD.
## Tables owned (write)
- `subscriptions`
## Endpoints
- `POST /billing/checkout`
- `POST /billing/webhook` (Stripe, no JWT auth)
- `GET /billing/subscription`
- `DELETE /billing/subscription`
## Redis channels
- Publish: `tier:changed:{user_id}` on tier change

View File

@@ -0,0 +1,53 @@
"""Billing Service — FastAPI application.
Owns: Stripe checkout/webhook, subscription management, tier feature matrix,
quota enforcement.
Downstream services query this service (or read the user's tier from
the X-User-Tier header injected by Traefik) for billing decisions.
The webhook endpoint is exposed WITHOUT ForwardAuth so Stripe can reach it.
"""
from __future__ import annotations
import logging
import sys
from contextlib import asynccontextmanager
from pathlib import Path
from typing import AsyncGenerator
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.routes import router
logger = logging.getLogger(__name__)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
logger.info("billing: service started")
yield
logger.info("billing: service stopped")
app = FastAPI(title="Adiuva Billing Service", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST", "DELETE"],
allow_headers=["*"],
)
app.include_router(router)
@app.get("/health")
async def health() -> dict[str, str]:
return {"status": "ok", "service": "billing"}

View File

@@ -0,0 +1,134 @@
"""Billing routes: Stripe checkout, webhook, subscription, tier query.
Adapted for the Billing microservice:
- Authenticated routes use Traefik-injected headers (X-User-Id, X-User-Tier)
- Webhook route has NO auth (Stripe signature verification only)
- Added /tier/{user_id} for internal service-to-service tier lookups
- Added /features/{tier} for feature matrix queries
"""
from __future__ import annotations
from typing import Any
from fastapi import APIRouter, Header, HTTPException, Request, status
from pydantic import BaseModel
from shared.db import async_session
from shared.schemas import BillingTier
from app.stripe_service import stripe_service
from app.tier_manager import tier_manager, FEATURES, RATE_LIMITS
router = APIRouter(prefix="/billing", tags=["billing"])
# ── Request bodies ─────────────────────────────────────────────────────
class _CheckoutRequest(BaseModel):
tier: BillingTier
# ── Checkout ───────────────────────────────────────────────────────────
@router.post("/checkout")
async def create_checkout(
body: _CheckoutRequest,
x_user_id: str = Header(..., alias="X-User-Id"),
) -> dict[str, str]:
"""Create a Stripe checkout session for a tier upgrade."""
url = stripe_service.create_checkout_session(x_user_id, body.tier)
return {"checkout_url": url}
# ── Webhook (NO auth — Stripe signature only) ─────────────────────────
@router.post("/webhook")
async def stripe_webhook(
request: Request,
stripe_signature: str = Header(default="", alias="Stripe-Signature"),
) -> dict[str, bool]:
"""Handle Stripe webhook events.
This endpoint is exposed without ForwardAuth in Traefik config
so Stripe can reach it directly.
"""
payload = await request.body()
async with async_session() as db:
await stripe_service.handle_webhook(payload, stripe_signature, db)
return {"ok": True}
# ── Subscription CRUD ─────────────────────────────────────────────────
@router.get("/subscription")
async def get_subscription(
x_user_id: str = Header(..., alias="X-User-Id"),
x_user_tier: str = Header("free", alias="X-User-Tier"),
) -> dict[str, Any]:
"""Return the current subscription info for the authenticated user."""
async with async_session() as db:
sub = await stripe_service.get_subscription(x_user_id, db)
if sub is None:
return {
"tier": x_user_tier,
"status": "free",
"stripe_subscription_id": None,
"current_period_end": None,
}
return sub
@router.delete("/subscription")
async def cancel_subscription(
x_user_id: str = Header(..., alias="X-User-Id"),
) -> dict[str, bool]:
"""Cancel the active subscription."""
async with async_session() as db:
await stripe_service.cancel_subscription(x_user_id, db)
return {"ok": True}
# ── Tier query (internal, service-to-service) ─────────────────────────
@router.get("/tier/{user_id}")
async def get_user_tier(user_id: str) -> dict[str, str]:
"""Return the billing tier for a given user_id.
Used by other services for tier lookups. Protected by Traefik
ForwardAuth — only internal services should call this.
"""
async with async_session() as db:
tier = await tier_manager.get_tier(user_id, db)
return {"user_id": user_id, "tier": tier}
# ── Feature matrix (public, cacheable) ────────────────────────────────
@router.get("/features/{tier}")
async def get_tier_features(tier: str) -> dict[str, Any]:
"""Return the feature matrix for a tier."""
if tier not in FEATURES:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Unknown tier: {tier}",
)
return {
"tier": tier,
"features": FEATURES[tier],
"rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]),
}
@router.get("/features")
async def get_all_features() -> dict[str, Any]:
"""Return the full feature matrix for all tiers."""
return {
"tiers": {
tier: {
"features": features,
"rate_limit_rpm": RATE_LIMITS.get(tier, RATE_LIMITS["free"]),
}
for tier, features in FEATURES.items()
},
}

View File

@@ -1,7 +1,7 @@
"""Stripe service: checkout sessions, webhook handling, subscription management. """Stripe service: checkout sessions, webhook handling, subscription management.
Subscription records are persisted in the PostgreSQL ``subscriptions`` table. Adapted for the Billing microservice uses shared.models and shared.db.
All Stripe calls are gracefully stubbed when ``STRIPE_SECRET_KEY`` is not All Stripe calls are gracefully stubbed when STRIPE_SECRET_KEY is not
configured, enabling local development without live credentials. configured, enabling local development without live credentials.
""" """
@@ -15,7 +15,8 @@ from fastapi import HTTPException, status
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.config.settings import settings from shared.config import settings
from shared.models import Subscription
# Stripe price IDs per tier — replace with real IDs in production .env # Stripe price IDs per tier — replace with real IDs in production .env
TIER_PRICE_IDS: dict[str, str] = { TIER_PRICE_IDS: dict[str, str] = {
@@ -46,11 +47,7 @@ class StripeService:
success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}", success_url: str = "https://app.adiuva.app/billing/success?session_id={CHECKOUT_SESSION_ID}",
cancel_url: str = "https://app.adiuva.app/billing/cancel", cancel_url: str = "https://app.adiuva.app/billing/cancel",
) -> str: ) -> str:
"""Create a Stripe checkout session and return the URL. """Create a Stripe checkout session and return the URL."""
Returns a stub URL when Stripe is not configured.
Raises ``HTTP 400`` for the free tier or an unknown tier.
"""
if tier == "free": if tier == "free":
raise HTTPException( raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, status_code=status.HTTP_400_BAD_REQUEST,
@@ -87,8 +84,6 @@ class StripeService:
"""Process a Stripe webhook event. """Process a Stripe webhook event.
Verifies the signature, then dispatches on event type. Verifies the signature, then dispatches on event type.
Raises ``HTTP 400`` on signature mismatch.
No-ops when Stripe is not configured.
""" """
if not self._configured(): if not self._configured():
return return
@@ -155,9 +150,7 @@ class StripeService:
async def get_subscription( async def get_subscription(
self, user_id: str, db: AsyncSession self, user_id: str, db: AsyncSession
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Return the subscription record for ``user_id``, or ``None`` if absent.""" """Return the subscription record for user_id, or None."""
from app.models import Subscription # noqa: PLC0415
result = await db.execute( result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id) select(Subscription).where(Subscription.user_id == user_id)
) )
@@ -176,12 +169,7 @@ class StripeService:
} }
async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None: async def cancel_subscription(self, user_id: str, db: AsyncSession) -> None:
"""Cancel the user's Stripe subscription and downgrade them to free. """Cancel the user's Stripe subscription and downgrade to free."""
Raises ``HTTP 404`` when no active subscription exists.
"""
from app.models import Subscription # noqa: PLC0415
result = await db.execute( result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id) select(Subscription).where(Subscription.user_id == user_id)
) )
@@ -211,8 +199,6 @@ class StripeService:
sub_status: str, sub_status: str,
current_period_end: datetime | None, current_period_end: datetime | None,
) -> None: ) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute( result = await db.execute(
select(Subscription).where(Subscription.user_id == user_id) select(Subscription).where(Subscription.user_id == user_id)
) )
@@ -234,8 +220,6 @@ class StripeService:
status: str | None = None, status: str | None = None,
current_period_end: datetime | None = None, current_period_end: datetime | None = None,
) -> None: ) -> None:
from app.models import Subscription # noqa: PLC0415
result = await db.execute( result = await db.execute(
select(Subscription).where( select(Subscription).where(
Subscription.stripe_subscription_id == stripe_subscription_id Subscription.stripe_subscription_id == stripe_subscription_id
@@ -252,5 +236,5 @@ class StripeService:
sub.current_period_end = current_period_end sub.current_period_end = current_period_end
# Module-level singleton shared across the app. # Module-level singleton
stripe_service = StripeService() stripe_service = StripeService()

View File

@@ -1,9 +1,8 @@
"""Tier manager: feature matrix and quota enforcement. """Tier manager: feature matrix and quota enforcement.
``TierManager`` is the single source of truth for what each billing tier Single source of truth for what each billing tier allows.
allows. ``get_tier`` queries the ``subscriptions`` table for the live tier. Other services can query the /tier/{user_id} endpoint or rely on the
Quota-enforcement helpers take ``tier`` directly the caller already has it X-User-Tier header injected by Traefik.
from ``current_user.tier`` (provided by ``get_current_user``).
""" """
from __future__ import annotations from __future__ import annotations
@@ -14,7 +13,9 @@ from fastapi import HTTPException, status
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.schemas import BillingTier from shared.config import settings
from shared.models import Subscription
from shared.schemas import BillingTier
# Feature matrix per tier. -1 means unlimited; 0 means disabled. # Feature matrix per tier. -1 means unlimited; 0 means disabled.
FEATURES: dict[str, dict[str, Any]] = { FEATURES: dict[str, dict[str, Any]] = {
@@ -30,7 +31,7 @@ FEATURES: dict[str, dict[str, Any]] = {
"sso": False, "sso": False,
}, },
"pro": { "pro": {
"agents": -1, # unlimited "agents": -1,
"batch_active": 10, "batch_active": 10,
"batch_runs_per_day": 50, "batch_runs_per_day": 50,
"cloud_storage_gb": 5, "cloud_storage_gb": 5,
@@ -42,8 +43,8 @@ FEATURES: dict[str, dict[str, Any]] = {
}, },
"power": { "power": {
"agents": -1, "agents": -1,
"batch_active": -1, # unlimited "batch_active": -1,
"batch_runs_per_day": -1, # unlimited "batch_runs_per_day": -1,
"cloud_storage_gb": 25, "cloud_storage_gb": 25,
"backup_gb": 25, "backup_gb": 25,
"providers": -1, "providers": -1,
@@ -54,9 +55,9 @@ FEATURES: dict[str, dict[str, Any]] = {
"team": { "team": {
"agents": -1, "agents": -1,
"batch_active": -1, "batch_active": -1,
"batch_runs_per_day": -1, # unlimited "batch_runs_per_day": -1,
"cloud_storage_gb": -1, # unlimited "cloud_storage_gb": -1,
"backup_gb": -1, # unlimited "backup_gb": -1,
"providers": -1, "providers": -1,
"batch_builder": True, "batch_builder": True,
"plugin_marketplace": True, "plugin_marketplace": True,
@@ -76,17 +77,8 @@ RATE_LIMITS: dict[str, int] = {
class TierManager: class TierManager:
"""Centralises tier feature-gating, rate-limit lookups, and quota checks.""" """Centralises tier feature-gating, rate-limit lookups, and quota checks."""
# ── Tier lookup ─────────────────────────────────────────────────────
async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier: async def get_tier(self, user_id: str, db: AsyncSession) -> BillingTier:
"""Return the current billing tier for ``user_id`` from the DB. """Return the current billing tier for user_id from the DB."""
Falls back to ``'power'`` in dev (unlimited) or ``'free'`` in prod
when no subscription row exists.
"""
from app.models import Subscription # noqa: PLC0415
from app.config.settings import settings # noqa: PLC0415
result = await db.execute( result = await db.execute(
select(Subscription.tier).where(Subscription.user_id == user_id) select(Subscription.tier).where(Subscription.user_id == user_id)
) )
@@ -95,13 +87,12 @@ class TierManager:
return "power" if settings.ENV == "dev" else "free" return "power" if settings.ENV == "dev" else "free"
return tier # type: ignore[return-value] return tier # type: ignore[return-value]
# ── Feature access ─────────────────────────────────────────────────── def get_features(self, tier: BillingTier) -> dict[str, Any]:
"""Return the full feature dict for a tier."""
return FEATURES.get(tier, FEATURES["free"])
def check_feature(self, tier: BillingTier, feature: str) -> bool: def check_feature(self, tier: BillingTier, feature: str) -> bool:
"""Return ``True`` if ``tier`` has ``feature`` enabled. """Return True if tier has feature enabled."""
For numeric features, any value > 0 or -1 (unlimited) counts as enabled.
"""
value = FEATURES.get(tier, FEATURES["free"]).get(feature) value = FEATURES.get(tier, FEATURES["free"]).get(feature)
if value is None: if value is None:
return False return False
@@ -110,7 +101,7 @@ class TierManager:
return value != 0 return value != 0
def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None: def require_feature(self, tier: BillingTier, feature: str, tier_name: str = "") -> None:
"""Raise ``HTTP 403`` if ``tier`` does not have ``feature``.""" """Raise HTTP 403 if tier does not have feature."""
if not self.check_feature(tier, feature): if not self.check_feature(tier, feature):
detail = ( detail = (
f"Feature '{feature}' requires {tier_name} tier or above." f"Feature '{feature}' requires {tier_name} tier or above."
@@ -119,25 +110,17 @@ class TierManager:
) )
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail) raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
# ── Rate limiting ────────────────────────────────────────────────────
def get_rate_limit(self, tier: BillingTier) -> int: def get_rate_limit(self, tier: BillingTier) -> int:
"""Return the requests-per-minute limit for ``tier``.""" """Return the requests-per-minute limit for tier."""
return RATE_LIMITS.get(tier, RATE_LIMITS["free"]) return RATE_LIMITS.get(tier, RATE_LIMITS["free"])
# ── Storage quota ────────────────────────────────────────────────────
def enforce_quota( def enforce_quota(
self, self,
tier: BillingTier, tier: BillingTier,
current_bytes: int = 0, current_bytes: int = 0,
additional_bytes: int = 0, additional_bytes: int = 0,
) -> None: ) -> None:
"""Raise ``HTTP 402`` if the user would exceed their cloud storage quota. """Raise HTTP 402 if the user would exceed their cloud storage quota."""
``tier`` is the caller's current tier (from ``current_user.tier``).
``current_bytes`` is the total bytes already stored (queried by caller).
"""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"] limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0: if limit_gb == 0:
raise HTTPException( raise HTTPException(
@@ -145,7 +128,7 @@ class TierManager:
detail=f"Cloud storage is not available on the '{tier}' tier", detail=f"Cloud storage is not available on the '{tier}' tier",
) )
if limit_gb == -1: if limit_gb == -1:
return # unlimited return
limit_bytes = limit_gb * 1024 ** 3 limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes: if current_bytes + additional_bytes > limit_bytes:
raise HTTPException( raise HTTPException(
@@ -159,7 +142,7 @@ class TierManager:
current_bytes: int = 0, current_bytes: int = 0,
additional_bytes: int = 0, additional_bytes: int = 0,
) -> None: ) -> None:
"""Raise ``HTTP 402`` if the user would exceed their backup quota.""" """Raise HTTP 402 if the user would exceed their backup quota."""
limit_gb: int = FEATURES[tier]["backup_gb"] limit_gb: int = FEATURES[tier]["backup_gb"]
if limit_gb == 0: if limit_gb == 0:
raise HTTPException( raise HTTPException(
@@ -167,7 +150,7 @@ class TierManager:
detail=f"Backup is not available on the '{tier}' tier", detail=f"Backup is not available on the '{tier}' tier",
) )
if limit_gb == -1: if limit_gb == -1:
return # unlimited return
limit_bytes = limit_gb * 1024 ** 3 limit_bytes = limit_gb * 1024 ** 3
if current_bytes + additional_bytes > limit_bytes: if current_bytes + additional_bytes > limit_bytes:
raise HTTPException( raise HTTPException(
@@ -181,7 +164,7 @@ class TierManager:
current_bytes: int = 0, current_bytes: int = 0,
additional_bytes: int = 0, additional_bytes: int = 0,
) -> bool: ) -> bool:
"""Return ``True`` if the user can store ``additional_bytes`` more data.""" """Return True if the user can store additional_bytes more data."""
limit_gb: int = FEATURES[tier]["cloud_storage_gb"] limit_gb: int = FEATURES[tier]["cloud_storage_gb"]
if limit_gb == 0: if limit_gb == 0:
return False return False
@@ -191,5 +174,5 @@ class TierManager:
return current_bytes + additional_bytes <= limit_bytes return current_bytes + additional_bytes <= limit_bytes
# Module-level singleton shared across the app. # Module-level singleton
tier_manager = TierManager() tier_manager = TierManager()

View File

@@ -0,0 +1,9 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
python-dotenv>=1.0.0
stripe>=8.0.0

View File

@@ -3,37 +3,34 @@ FROM python:3.12-slim AS builder
WORKDIR /build WORKDIR /build
COPY requirements.txt . COPY services/chat/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \ RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ────────────────────────────────────────────────────────────────── # ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime FROM python:3.12-slim AS runtime
# Non-root user
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app WORKDIR /app
# Copy installed packages from builder
COPY --from=builder /install /usr/local COPY --from=builder /install /usr/local
# Copy application source # Shared module
COPY app/ app/ COPY shared/ shared/
# Copy Alembic migration files # Service source
COPY alembic/ alembic/ COPY services/chat/app/ app/
COPY alembic.ini .
# Ensure appuser owns the working directory
RUN chown -R appuser:appgroup /app RUN chown -R appuser:appgroup /app
USER appuser USER appuser
EXPOSE 8000 EXPOSE 8000
# Chat service is CPU-bound (LLM calls) — use multiple workers
CMD ["gunicorn", "app.main:app", \ CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \ "-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \ "--bind", "0.0.0.0:8000", \
"--workers", "4", \ "--workers", "2", \
"--timeout", "120"] "--timeout", "120"]

21
services/chat/README.md Normal file
View File

@@ -0,0 +1,21 @@
# Chat Service
Owns: deep_agent (home + floating chat), memory middleware, domain agents
(task, note, project, timeline), LLM orchestration.
## Tables owned
- `memory_core`
- `memory_associative`
- `memory_episodic`
- `memory_proactive`
## Tables read (cross-service)
- `users` (for encryption_key — memory decryption)
## Endpoints
- `POST /chat` (REST fallback)
## Redis channels
- Subscribe: `chat:request:{user_id}`
- Publish: `ws:out:{user_id}` (stream frames + tool calls)
- BRPOP: `tool:result:{call_id}` (30s timeout)

View File

@@ -1,4 +1,8 @@
"""Single-agent runners for home and floating chat contexts.""" """Single-agent runners for home and floating chat contexts.
Adapted from app/core/deep_agent.py for the Chat Service.
Import paths changed to use local app modules and shared/.
"""
from __future__ import annotations from __future__ import annotations
@@ -12,14 +16,15 @@ from typing import Any, Literal
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool from langchain_core.tools import tool
from app.agents.note_agent import NOTE_TOOLS from shared.agents.note_agent import NOTE_TOOLS
from app.agents.project_agent import PROJECT_TOOLS from shared.agents.project_agent import PROJECT_TOOLS
from app.agents.task_agent import TASK_TOOLS from shared.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS from shared.agents.timeline_agent import TIMELINE_TOOLS
from app.core.llm import get_llm from shared.llm import get_llm
from app.core.memory_middleware import MemoryMiddleware from app.memory_middleware import MemoryMiddleware
from app.core.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector from shared.ws_context import clear_tool_result_collector, execute_on_client, set_tool_result_collector
from app.db import async_session from app import tracing
from shared.db import async_session
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@@ -240,7 +245,6 @@ def _strip_floating_markup(text: str) -> str:
return text return text
cleaned = _strip_floating_markup_fragment(text) cleaned = _strip_floating_markup_fragment(text)
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()] lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
return "\n".join(line for line in lines if line) return "\n".join(line for line in lines if line)
@@ -279,7 +283,6 @@ class _FloatingStreamSanitizer:
return _strip_floating_markup_fragment(safe_text) return _strip_floating_markup_fragment(safe_text)
def finalize(self) -> str: def finalize(self) -> str:
# Drop dangling unfinished wrappers at the very end.
tail = re.sub(r"<[^>\n]*$", "", self._pending) tail = re.sub(r"<[^>\n]*$", "", self._pending)
tail = re.sub(r"\[[^\]\n]*$", "", tail) tail = re.sub(r"\[[^\]\n]*$", "", tail)
self._pending = "" self._pending = ""
@@ -525,7 +528,9 @@ def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) ->
return {"type": "task", "id": None, "section": None} return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]: async def _infer_floating_domain(
message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None,
) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
@@ -535,10 +540,14 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
} }
try: try:
llm = get_llm() classifier_prompt = _get_system_prompt(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_SYSTEM,
)
callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
response = await llm.ainvoke( response = await llm.ainvoke(
[ [
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_SYSTEM), SystemMessage(content=classifier_prompt),
HumanMessage( HumanMessage(
content=( content=(
f"Message:\n{message}\n\n" f"Message:\n{message}\n\n"
@@ -564,6 +573,19 @@ async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[
return _infer_floating_domain_rule_based(message, context) return _infer_floating_domain_rule_based(message, context)
def _get_system_prompt(langfuse_name: str, fallback: str) -> str:
"""Fetch a managed prompt from Langfuse, falling back to the hardcoded string."""
managed = tracing.get_prompt(langfuse_name, fallback=None)
return managed if managed is not None else fallback
def _build_callbacks(langfuse_handler: Any | None) -> list[Any] | None:
"""Return a callbacks list if a Langfuse handler is available."""
if langfuse_handler is None:
return None
return [langfuse_handler]
async def _run_single_agent( async def _run_single_agent(
*, *,
user_id: str, user_id: str,
@@ -571,9 +593,11 @@ async def _run_single_agent(
message: str, message: str,
context: dict[str, Any], context: dict[str, Any],
max_steps: int = 6, max_steps: int = 6,
langfuse_handler: Any | None = None,
) -> str: ) -> str:
trace_id = _trace_id_from_context(context) trace_id = _trace_id_from_context(context)
llm = get_llm() callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
tools = _all_tools_for_user(user_id, trace_id) tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context) model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id) logger.info("deep_agent: run_single_agent_start trace=%s user=%s", trace_id or "-", user_id)
@@ -656,9 +680,11 @@ async def _run_single_agent_stream(
message: str, message: str,
context: dict[str, Any], context: dict[str, Any],
max_steps: int = 6, max_steps: int = 6,
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
trace_id = _trace_id_from_context(context) trace_id = _trace_id_from_context(context)
llm = get_llm() callbacks = _build_callbacks(langfuse_handler)
llm = get_llm(callbacks=callbacks)
tools = _all_tools_for_user(user_id, trace_id) tools = _all_tools_for_user(user_id, trace_id)
model_context = _context_for_model(context) model_context = _context_for_model(context)
logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id) logger.info("deep_agent: run_single_agent_stream_start trace=%s user=%s", trace_id or "-", user_id)
@@ -691,7 +717,6 @@ async def _run_single_agent_stream(
emitted_any = True emitted_any = True
yield "token", token yield "token", token
# Some providers return final text in `response.content` but stream no chunks.
if not emitted_any: if not emitted_any:
fallback_text = _as_text(response.content) fallback_text = _as_text(response.content)
if fallback_text: if fallback_text:
@@ -750,25 +775,29 @@ async def _run_single_agent_stream(
clear_tool_result_collector() clear_tool_result_collector()
async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str: async def run_home(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> str:
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
response = await _run_single_agent( response = await _run_single_agent(
user_id=user_id, user_id=user_id,
system_prompt=_HOME_SINGLE_AGENT_SYSTEM, system_prompt=system_prompt,
message=message, message=message,
context=prepared_context, context=prepared_context,
langfuse_handler=langfuse_handler,
) )
return _normalize_tagged_list_lines(response, message) return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]: async def run_floating(user_id: str, message: str, context: dict[str, Any], *, langfuse_handler: Any | None = None) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context) domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
response = await _run_single_agent( response = await _run_single_agent(
user_id=user_id, user_id=user_id,
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, system_prompt=system_prompt,
message=message, message=message,
context=prepared_context, context=prepared_context,
langfuse_handler=langfuse_handler,
) )
sanitized = _strip_floating_markup(response) sanitized = _strip_floating_markup(response)
if not sanitized and response: if not sanitized and response:
@@ -780,14 +809,18 @@ async def run_home_stream(
user_id: str, user_id: str,
message: str, message: str,
context: dict[str, Any], context: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
system_prompt = _get_system_prompt("home_system", _HOME_SINGLE_AGENT_SYSTEM)
text_chunks: list[str] = [] text_chunks: list[str] = []
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
system_prompt=_HOME_SINGLE_AGENT_SYSTEM, system_prompt=system_prompt,
message=message, message=message,
context=prepared_context, context=prepared_context,
langfuse_handler=langfuse_handler,
): ):
event_type, data = event event_type, data = event
if event_type != "token": if event_type != "token":
@@ -804,19 +837,23 @@ async def run_floating_stream(
user_id: str, user_id: str,
message: str, message: str,
context: dict[str, Any], context: dict[str, Any],
*,
langfuse_handler: Any | None = None,
) -> AsyncGenerator[tuple[str, Any], None]: ) -> AsyncGenerator[tuple[str, Any], None]:
prepared_context = await _prepare_context(message, context) prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context) domain = await _infer_floating_domain(message, prepared_context, langfuse_handler=langfuse_handler)
yield "floating_domain", domain yield "floating_domain", domain
system_prompt = _get_system_prompt("floating_system", _FLOATING_SINGLE_AGENT_SYSTEM)
sanitizer = _FloatingStreamSanitizer() sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False emitted_sanitized = False
raw_chunks: list[str] = [] raw_chunks: list[str] = []
async for event in _run_single_agent_stream( async for event in _run_single_agent_stream(
user_id=user_id, user_id=user_id,
system_prompt=_FLOATING_SINGLE_AGENT_SYSTEM, system_prompt=system_prompt,
message=message, message=message,
context=prepared_context, context=prepared_context,
langfuse_handler=langfuse_handler,
): ):
event_type, data = event event_type, data = event
if event_type != "token": if event_type != "token":

77
services/chat/app/llm.py Normal file
View File

@@ -0,0 +1,77 @@
"""LLM factory — centralised model instantiation via LiteLLM.
Adapted from app/core/llm.py for the Chat Service.
Uses shared.config.settings instead of app.config.settings.
"""
from __future__ import annotations
import os
import warnings
from openai import AsyncOpenAI
import litellm
from langchain_openai import ChatOpenAI
from langchain_litellm import ChatLiteLLM
from shared.config import settings
litellm.drop_params = True
warnings.filterwarnings(
"ignore",
message=r"PydanticSerializationUnexpectedValue\(Expected `ResponseAPIUsage`",
category=UserWarning,
)
def _api_key_for_model(model: str) -> str | None:
if model.startswith("anthropic/"):
return settings.ANTHROPIC_API_KEY or None
if model.startswith("gemini/") or model.startswith("google/"):
return settings.GOOGLE_API_KEY or None
if model.startswith("cerebras/"):
return settings.CEREBRAS_API_KEY or None
if model.startswith("github/"):
return settings.GITHUB_TOKEN or None
if model.startswith("github_copilot/"):
return None
return settings.OPENAI_API_KEY or None
def get_llm(
*,
model: str | None = None,
temperature: float = 0,
callbacks: list | None = None,
) -> ChatOpenAI | ChatLiteLLM:
model = model or settings.LLM_MODEL
if settings.GITHUB_COPILOT_TOKEN_DIR:
os.environ.setdefault("GITHUB_COPILOT_TOKEN_DIR", settings.GITHUB_COPILOT_TOKEN_DIR)
if settings.GITHUB_TOKEN:
os.environ.setdefault("GITHUB_TOKEN", settings.GITHUB_TOKEN)
if "/" in model:
return ChatLiteLLM(model=model, temperature=temperature, callbacks=callbacks)
return ChatOpenAI(
model=model,
temperature=temperature,
api_key=_api_key_for_model(model),
callbacks=callbacks,
)
async def embed(text: str) -> list[float]:
model = settings.LLM_EMBED_MODEL
if model.startswith("github_copilot/") or "/" in model:
response = await litellm.aembedding(model=model, input=[text])
return response.data[0]["embedding"]
client = AsyncOpenAI(api_key=settings.OPENAI_API_KEY)
response = await client.embeddings.create(model=model, input=text)
return response.data[0].embedding

87
services/chat/app/main.py Normal file
View File

@@ -0,0 +1,87 @@
"""Chat Service — LLM orchestration, domain agents, memory.
Consumes chat requests from Redis, executes deep_agent (home/floating),
streams responses back via Redis pub/sub to WS Gateway.
Owns: memory_core, memory_associative, memory_episodic, memory_proactive tables.
"""
import sys
from contextlib import asynccontextmanager
import logging
from pathlib import Path
# Ensure the repo root is on sys.path so "shared" is importable in local dev.
_repo_root = str(Path(__file__).resolve().parents[3])
if _repo_root not in sys.path:
sys.path.insert(0, _repo_root)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from shared.config import settings
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s %(name)s: %(message)s",
)
logging.getLogger("sqlalchemy.engine").setLevel(logging.WARNING)
logging.getLogger("sqlalchemy.pool").setLevel(logging.WARNING)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Initialise Langfuse tracing (no-op if keys are missing)
from app.tracing import init_langfuse
init_langfuse()
# Start Redis consumer in background
from app.redis_consumer import start_consumer
consumer_task = start_consumer()
yield
consumer_task.cancel()
from app.tracing import shutdown as shutdown_langfuse
shutdown_langfuse()
from shared.db import engine
await engine.dispose()
from shared.redis import redis_client
await redis_client.aclose()
def create_app() -> FastAPI:
app = FastAPI(
title="Adiuva Chat Service",
version="0.1.0",
docs_url="/docs" if settings.ENV == "dev" else None,
redoc_url=None,
lifespan=lifespan,
)
app.add_middleware(
CORSMiddleware,
allow_origins=settings.CORS_ORIGINS,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
from app.routes import router
app.include_router(router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:
return {"status": "ok", "service": "chat", "version": app.version}
return app
app = create_app()

View File

@@ -1,19 +1,7 @@
"""Memory Middleware — enrich requests with memory context and store interactions. """Memory Middleware — adapted for Chat Service.
Four-tier memory model (MemGPT-style): Uses shared.models instead of app.models. Otherwise identical to the
core persistent key/value user preferences, always injected monolith's app/core/memory_middleware.py.
associative semantic similarity search via pgvector (top-k)
episodic recent session summaries (last N)
proactive behavioral patterns above confidence threshold
All memory content is encrypted at rest using the per-user Fernet key
stored in User.encryption_key. Decryption happens in-memory only.
Usage:
memory = MemoryMiddleware(db_session)
context = await memory.enrich_context(user_id, message)
# ... run agent ...
await memory.store_episode(user_id, session_id, message, response)
""" """
from __future__ import annotations from __future__ import annotations
@@ -26,7 +14,7 @@ from cryptography.fernet import Fernet, InvalidToken
from sqlalchemy import select from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.models import ( from shared.models import (
MemoryAssociative, MemoryAssociative,
MemoryCore, MemoryCore,
MemoryEpisodic, MemoryEpisodic,
@@ -36,20 +24,16 @@ from app.models import (
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
# Tuning constants
_ASSOCIATIVE_TOP_K = 5 _ASSOCIATIVE_TOP_K = 5
_EPISODIC_RECENT_N = 10 _EPISODIC_RECENT_N = 10
_PROACTIVE_CONFIDENCE_THRESHOLD = 0.6 _PROACTIVE_CONFIDENCE_THRESHOLD = 0.6
class MemoryMiddleware: class MemoryMiddleware:
"""Enrich orchestrator context with memory and persist interactions after."""
def __init__(self, db: AsyncSession) -> None: def __init__(self, db: AsyncSession) -> None:
self._db = db self._db = db
# ── Public API ────────────────────────────────────────────────────────────
async def enrich_context( async def enrich_context(
self, self,
user_id: str, user_id: str,
@@ -57,14 +41,6 @@ class MemoryMiddleware:
trace_id: str | None = None, trace_id: str | None = None,
session_id: str | None = None, session_id: str | None = None,
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Build memory context dict to inject into the orchestrator before LLM call.
Returns a dict with keys:
core_memory {key: plaintext_value, ...}
associative_memory [plaintext_content, ...] (top-k by keyword match)
episodic_memory [plaintext_summary, ...] (most recent N)
proactive_hints [plaintext_pattern, ...] (above threshold)
"""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
return {} return {}
@@ -74,16 +50,9 @@ class MemoryMiddleware:
episodic = await self._load_episodic(user_id, fernet, session_id=session_id) episodic = await self._load_episodic(user_id, fernet, session_id=session_id)
proactive = await self._load_proactive(user_id, fernet) proactive = await self._load_proactive(user_id, fernet)
user_dbg = await self._get_user_debug(user_id)
logger.info( logger.info(
"memory: enrich_context trace=%s user=%s tier=%s core=%d associative=%d episodic=%d proactive=%d", "memory: enrich_context trace=%s user=%s core=%d assoc=%d episodic=%d proactive=%d",
trace_id or "-", trace_id or "-", user_id, len(core), len(associative), len(episodic), len(proactive),
user_id,
user_dbg.get("tier") or "-",
len(core),
len(associative),
len(episodic),
len(proactive),
) )
return { return {
@@ -94,18 +63,9 @@ class MemoryMiddleware:
} }
async def store_episode( async def store_episode(
self, self, user_id: str, session_id: str, message: str, response: str,
user_id: str,
session_id: str,
message: str,
response: str,
trace_id: str | None = None, trace_id: str | None = None,
) -> None: ) -> None:
"""Summarise and store a completed interaction in episodic memory.
The summary is a simple heuristic concatenation (no LLM call) to keep
latency low. Full LLM summarisation can be added in a later step.
"""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
return return
@@ -122,113 +82,68 @@ class MemoryMiddleware:
self._db.add(row) self._db.add(row)
try: try:
await self._db.commit() await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: store_episode trace=%s user=%s tier=%s session=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
session_id,
)
except Exception as exc: except Exception as exc:
logger.error("memory: store_episode failed user=%s: %s", user_id, exc) logger.error("memory: store_episode failed user=%s: %s", user_id, exc)
await self._db.rollback() await self._db.rollback()
async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None: async def update_core(self, user_id: str, key: str, value: str, trace_id: str | None = None) -> None:
"""Upsert a core memory key/value for a user."""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
return return
encrypted = _encrypt(fernet, value) encrypted = _encrypt(fernet, value)
result = await self._db.execute( result = await self._db.execute(
select(MemoryCore).where( select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == key)
MemoryCore.user_id == user_id,
MemoryCore.key == key,
)
) )
existing = result.scalar_one_or_none() existing = result.scalar_one_or_none()
if existing is not None: if existing is not None:
existing.value_encrypted = encrypted existing.value_encrypted = encrypted
else: else:
self._db.add(MemoryCore( self._db.add(MemoryCore(
id=str(uuid.uuid4()), id=str(uuid.uuid4()), user_id=user_id, key=key, value_encrypted=encrypted,
user_id=user_id,
key=key,
value_encrypted=encrypted,
)) ))
try: try:
await self._db.commit() await self._db.commit()
user_dbg = await self._get_user_debug(user_id)
logger.info(
"memory: update_core trace=%s user=%s tier=%s key=%s",
trace_id or "-",
user_id,
user_dbg.get("tier") or "-",
key,
)
except Exception as exc: except Exception as exc:
logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc) logger.error("memory: update_core failed user=%s key=%s: %s", user_id, key, exc)
await self._db.rollback() await self._db.rollback()
async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]: async def list_core_blocks(self, user_id: str) -> list[dict[str, str]]:
"""Return core memory as editable blocks (label/value)."""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
return [] return []
result = await self._db.execute( result = await self._db.execute(
select(MemoryCore) select(MemoryCore).where(MemoryCore.user_id == user_id).order_by(MemoryCore.key.asc())
.where(MemoryCore.user_id == user_id)
.order_by(MemoryCore.key.asc())
) )
rows = result.scalars().all()
out: list[dict[str, str]] = [] out: list[dict[str, str]] = []
for row in rows: for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.value_encrypted) plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None: if plaintext is not None:
out.append({"label": row.key, "value": plaintext}) out.append({"label": row.key, "value": plaintext})
logger.debug("memory: list_core_blocks user=%s count=%d", user_id, len(out))
return out return out
async def get_core_block(self, user_id: str, label: str) -> str | None: async def get_core_block(self, user_id: str, label: str) -> str | None:
"""Return a single core memory block value by label."""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
return None return None
result = await self._db.execute( result = await self._db.execute(
select(MemoryCore).where( select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
) )
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if row is None: if row is None:
logger.debug("memory: get_core_block user=%s label=%s found=0", user_id, label)
return None return None
value = _safe_decrypt(fernet, row.value_encrypted) return _safe_decrypt(fernet, row.value_encrypted)
logger.debug("memory: get_core_block user=%s label=%s found=%d", user_id, label, 1 if value is not None else 0)
return value
async def delete_core(self, user_id: str, label: str) -> bool: async def delete_core(self, user_id: str, label: str) -> bool:
"""Delete a core memory block by label. Returns True if deleted."""
result = await self._db.execute( result = await self._db.execute(
select(MemoryCore).where( select(MemoryCore).where(MemoryCore.user_id == user_id, MemoryCore.key == label)
MemoryCore.user_id == user_id,
MemoryCore.key == label,
)
) )
row = result.scalar_one_or_none() row = result.scalar_one_or_none()
if row is None: if row is None:
logger.debug("memory: delete_core user=%s label=%s found=0", user_id, label)
return False return False
await self._db.delete(row) await self._db.delete(row)
try: try:
await self._db.commit() await self._db.commit()
logger.info("memory: delete_core user=%s label=%s", user_id, label)
return True return True
except Exception as exc: except Exception as exc:
logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc) logger.error("memory: delete_core failed user=%s label=%s: %s", user_id, label, exc)
@@ -236,64 +151,47 @@ class MemoryMiddleware:
return False return False
async def append_core(self, user_id: str, label: str, content: str) -> None: async def append_core(self, user_id: str, label: str, content: str) -> None:
"""Append content to a core block, creating it if missing."""
current = await self.get_core_block(user_id, label) current = await self.get_core_block(user_id, label)
if current is None: if current is None:
await self.update_core(user_id, label, content) await self.update_core(user_id, label, content)
logger.info("memory: append_core user=%s label=%s created=1", user_id, label)
return return
await self.update_core(user_id, label, f"{current}\n{content}") await self.update_core(user_id, label, f"{current}\n{content}")
logger.info("memory: append_core user=%s label=%s created=0", user_id, label)
async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool: async def replace_core(self, user_id: str, label: str, old: str, new: str) -> bool:
"""Replace one exact string inside a core block. Returns False if not found."""
current = await self.get_core_block(user_id, label) current = await self.get_core_block(user_id, label)
if current is None or old not in current: if current is None or old not in current:
logger.debug("memory: replace_core user=%s label=%s changed=0", user_id, label)
return False return False
await self.update_core(user_id, label, current.replace(old, new, 1)) await self.update_core(user_id, label, current.replace(old, new, 1))
logger.info("memory: replace_core user=%s label=%s changed=1", user_id, label)
return True return True
async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None: async def insert_archival(self, user_id: str, content: str, source: str = "manual") -> None:
"""Insert a long-term archival memory entry."""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
return return
encrypted = _encrypt(fernet, content) encrypted = _encrypt(fernet, content)
row = MemoryAssociative( row = MemoryAssociative(
id=str(uuid.uuid4()), id=str(uuid.uuid4()), user_id=user_id,
user_id=user_id, content_encrypted=encrypted, embedding=None,
content_encrypted=encrypted, entity_type=source, entity_id=None,
embedding=None,
entity_type=source,
entity_id=None,
) )
self._db.add(row) self._db.add(row)
try: try:
await self._db.commit() await self._db.commit()
logger.info("memory: insert_archival user=%s source=%s", user_id, source)
except Exception as exc: except Exception as exc:
logger.error("memory: insert_archival failed user=%s: %s", user_id, exc) logger.error("memory: insert_archival failed user=%s: %s", user_id, exc)
await self._db.rollback() await self._db.rollback()
async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]: async def search_archival(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search archival memory (keyword fallback; semantic ranking can replace this)."""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
return [] return []
result = await self._db.execute( result = await self._db.execute(
select(MemoryAssociative) select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
.where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()).limit(100)
.order_by(MemoryAssociative.updated_at.desc())
.limit(100)
) )
rows = result.scalars().all()
needle = query.strip().lower() needle = query.strip().lower()
out: list[str] = [] out: list[str] = []
for row in rows: for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.content_encrypted) plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is None: if plaintext is None:
continue continue
@@ -301,25 +199,19 @@ class MemoryMiddleware:
out.append(plaintext) out.append(plaintext)
if len(out) >= max(top_k, 1): if len(out) >= max(top_k, 1):
break break
logger.info("memory: search_archival user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out return out
async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]: async def search_recall(self, user_id: str, query: str, top_k: int = 5) -> list[str]:
"""Search recall memory (episodic summaries) by keyword."""
fernet = await self._get_fernet(user_id) fernet = await self._get_fernet(user_id)
if fernet is None: if fernet is None:
return [] return []
result = await self._db.execute( result = await self._db.execute(
select(MemoryEpisodic) select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
.where(MemoryEpisodic.user_id == user_id) .order_by(MemoryEpisodic.created_at.desc()).limit(100)
.order_by(MemoryEpisodic.created_at.desc())
.limit(100)
) )
rows = result.scalars().all()
needle = query.strip().lower() needle = query.strip().lower()
out: list[str] = [] out: list[str] = []
for row in rows: for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.summary_encrypted) plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is None: if plaintext is None:
continue continue
@@ -327,13 +219,11 @@ class MemoryMiddleware:
out.append(plaintext) out.append(plaintext)
if len(out) >= max(top_k, 1): if len(out) >= max(top_k, 1):
break break
logger.info("memory: search_recall user=%s query=%s hits=%d", user_id, query[:80], len(out))
return out return out
# ── Private helpers ─────────────────────────────────────────────────────── # ── Private ───────────────────────────────────────────────────────
async def _get_fernet(self, user_id: str) -> Fernet | None: async def _get_fernet(self, user_id: str) -> Fernet | None:
"""Load the user's Fernet key from DB. Returns None if missing."""
result = await self._db.execute(select(User).where(User.id == user_id)) result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none() user = result.scalar_one_or_none()
if user is None or not user.encryption_key: if user is None or not user.encryption_key:
@@ -341,68 +231,38 @@ class MemoryMiddleware:
return None return None
return Fernet(user.encryption_key.encode()) return Fernet(user.encryption_key.encode())
async def _get_user_debug(self, user_id: str) -> dict[str, str | None]:
"""Load lightweight user debug fields for trace logs."""
result = await self._db.execute(select(User).where(User.id == user_id))
user = result.scalar_one_or_none()
if user is None:
return {"tier": None}
return {
"tier": user.tier,
}
async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]: async def _load_core(self, user_id: str, fernet: Fernet) -> dict[str, str]:
result = await self._db.execute( result = await self._db.execute(
select(MemoryCore).where(MemoryCore.user_id == user_id) select(MemoryCore).where(MemoryCore.user_id == user_id)
) )
rows = result.scalars().all()
out: dict[str, str] = {} out: dict[str, str] = {}
for row in rows: for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.value_encrypted) plaintext = _safe_decrypt(fernet, row.value_encrypted)
if plaintext is not None: if plaintext is not None:
out[row.key] = plaintext out[row.key] = plaintext
return out return out
async def _load_associative( async def _load_associative(self, user_id: str, message: str, fernet: Fernet) -> list[str]:
self, user_id: str, message: str, fernet: Fernet
) -> list[str]:
"""Load top-k associative memories.
Production: uses pgvector cosine similarity on the message embedding.
Current implementation: keyword-based fallback (no external embedding call)
so tests pass without a live OpenAI key.
"""
result = await self._db.execute( result = await self._db.execute(
select(MemoryAssociative) select(MemoryAssociative).where(MemoryAssociative.user_id == user_id)
.where(MemoryAssociative.user_id == user_id) .order_by(MemoryAssociative.updated_at.desc()).limit(_ASSOCIATIVE_TOP_K)
.order_by(MemoryAssociative.updated_at.desc())
.limit(_ASSOCIATIVE_TOP_K)
) )
rows = result.scalars().all()
out: list[str] = [] out: list[str] = []
for row in rows: for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.content_encrypted) plaintext = _safe_decrypt(fernet, row.content_encrypted)
if plaintext is not None: if plaintext is not None:
out.append(plaintext) out.append(plaintext)
return out return out
async def _load_episodic( async def _load_episodic(self, user_id: str, fernet: Fernet, session_id: str | None = None) -> list[str]:
self,
user_id: str,
fernet: Fernet,
session_id: str | None = None,
) -> list[str]:
query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id) query = select(MemoryEpisodic).where(MemoryEpisodic.user_id == user_id)
if session_id: if session_id:
query = query.where(MemoryEpisodic.session_id == session_id) query = query.where(MemoryEpisodic.session_id == session_id)
result = await self._db.execute( result = await self._db.execute(
query query.order_by(MemoryEpisodic.created_at.desc()).limit(_EPISODIC_RECENT_N)
.order_by(MemoryEpisodic.created_at.desc())
.limit(_EPISODIC_RECENT_N)
) )
rows = result.scalars().all()
out: list[str] = [] out: list[str] = []
for row in rows: for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.summary_encrypted) plaintext = _safe_decrypt(fernet, row.summary_encrypted)
if plaintext is not None: if plaintext is not None:
out.append(plaintext) out.append(plaintext)
@@ -410,30 +270,24 @@ class MemoryMiddleware:
async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]: async def _load_proactive(self, user_id: str, fernet: Fernet) -> list[str]:
result = await self._db.execute( result = await self._db.execute(
select(MemoryProactive) select(MemoryProactive).where(
.where(
MemoryProactive.user_id == user_id, MemoryProactive.user_id == user_id,
MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD, MemoryProactive.confidence >= _PROACTIVE_CONFIDENCE_THRESHOLD,
).order_by(MemoryProactive.confidence.desc())
) )
.order_by(MemoryProactive.confidence.desc())
)
rows = result.scalars().all()
out: list[str] = [] out: list[str] = []
for row in rows: for row in result.scalars().all():
plaintext = _safe_decrypt(fernet, row.pattern_encrypted) plaintext = _safe_decrypt(fernet, row.pattern_encrypted)
if plaintext is not None: if plaintext is not None:
out.append(plaintext) out.append(plaintext)
return out return out
# ── Encryption helpers ────────────────────────────────────────────────────────
def _encrypt(fernet: Fernet, plaintext: str) -> str: def _encrypt(fernet: Fernet, plaintext: str) -> str:
return fernet.encrypt(plaintext.encode()).decode() return fernet.encrypt(plaintext.encode()).decode()
def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None: def _safe_decrypt(fernet: Fernet, ciphertext: str) -> str | None:
"""Decrypt and return plaintext, or None on error (corrupted/wrong key)."""
try: try:
return fernet.decrypt(ciphertext.encode()).decode() return fernet.decrypt(ciphertext.encode()).decode()
except (InvalidToken, Exception) as exc: except (InvalidToken, Exception) as exc:

View File

@@ -1,11 +1,14 @@
"""Output formatter for deep-agent stream events.""" """Output formatter for deep-agent stream events — Chat Service copy.
Converts (event_type, data) tuples into WebSocket frame Pydantic models.
"""
from __future__ import annotations from __future__ import annotations
from collections.abc import AsyncGenerator from collections.abc import AsyncGenerator
from typing import Any from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText from shared.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain

View File

@@ -0,0 +1,209 @@
"""Redis consumer — listens for chat requests and dispatches to deep_agent.
Subscribes to a Redis pattern channel chat:request:* so it receives
requests for ALL users. Each request is processed in a separate asyncio task.
"""
from __future__ import annotations
import asyncio
import json
import logging
from uuid import uuid4
from shared.db import async_session
from shared.redis import redis_client, ws_out_channel
from app.deep_agent import run_floating_stream, run_home_stream
from app.memory_middleware import MemoryMiddleware
from app.output_formatter import StreamFormatter
from shared.ws_context import clear_current_user, set_current_user
from app import tracing
logger = logging.getLogger(__name__)
def start_consumer() -> asyncio.Task:
"""Start the Redis consumer as a background asyncio task."""
return asyncio.create_task(_consumer_loop())
async def _consumer_loop() -> None:
"""Subscribe to chat:request:* and dispatch incoming frames."""
pubsub = redis_client.pubsub()
await pubsub.psubscribe("chat:request:*")
logger.info("redis_consumer: subscribed to chat:request:*")
try:
while True:
message = await pubsub.get_message(
ignore_subscribe_messages=True, timeout=1.0
)
if message is not None and message["type"] == "pmessage":
frame = json.loads(message["data"])
asyncio.create_task(_dispatch(frame))
else:
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("redis_consumer: shutting down")
finally:
await pubsub.punsubscribe()
await pubsub.aclose()
async def _dispatch(frame: dict) -> None:
"""Route a chat request frame to the appropriate handler."""
frame_type = frame.get("type")
user_id = frame.get("user_id")
if not user_id:
logger.warning("redis_consumer: frame missing user_id: %s", frame.get("type"))
return
if frame_type == "home_request":
await _handle_home_request(user_id, frame)
elif frame_type == "floating_request":
await _handle_floating_request(user_id, frame)
else:
logger.debug("redis_consumer: unknown frame type %r", frame_type)
async def _publish_frame(user_id: str, frame_data: str) -> None:
"""Publish a frame to ws:out:{user_id} for the WS Gateway to forward."""
channel = ws_out_channel(user_id)
await redis_client.publish(channel, frame_data)
async def _handle_home_request(user_id: str, frame: dict) -> None:
"""Process a home_request — enrich with memory, run deep_agent, stream results."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
logger.info(
"redis_consumer: home_request user=%s req=%s msg=%s",
user_id, request_id, message[:200],
)
response_chunks: list[str] = []
with tracing.trace_span(
name="home_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200]},
tags=["home"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
try:
event_stream = run_home_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: home_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "home_system")
response_text = "".join(response_chunks)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks),
trace_id=request_id,
)
async def _handle_floating_request(user_id: str, frame: dict) -> None:
"""Process a floating_request — enrich with memory, run deep_agent, stream results."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
logger.info(
"redis_consumer: floating_request user=%s req=%s scope=%s msg=%s",
user_id, request_id, json.dumps(scope)[:200], message[:200],
)
response_chunks: list[str] = []
with tracing.trace_span(
name="floating_request",
user_id=user_id,
session_id=session_id,
trace_id=request_id,
input=message,
metadata={"message_preview": message[:200], "scope": scope},
tags=["floating"],
) as span:
langfuse_handler = tracing.get_langfuse_callback()
# Enrich with memory context
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
user_id, message,
trace_id=request_id, session_id=session_id,
)
context: dict = {
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
set_current_user(user_id)
try:
event_stream = run_floating_stream(user_id, message, context, langfuse_handler=langfuse_handler)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await _publish_frame(user_id, ws_frame.model_dump_json())
if hasattr(ws_frame, "chunk"):
response_chunks.append(ws_frame.chunk)
except Exception as exc:
logger.error("redis_consumer: floating_request failed user=%s req=%s: %s", user_id, request_id, exc)
finally:
clear_current_user()
# Link prompt and attach output preview
tracing.link_prompt_to_trace(span, "floating_system")
response_text = "".join(response_chunks)
span.update(output=response_text[:500] if response_text else None)
tracing.flush()
# Store episode
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks),
trace_id=request_id,
)

View File

@@ -0,0 +1,37 @@
"""Chat REST route — POST /chat fallback when WS is unavailable."""
from __future__ import annotations
from fastapi import APIRouter, Request
from fastapi.responses import JSONResponse
from shared.schemas import ChatRequest
from app.deep_agent import run_home
from shared.ws_context import clear_current_user, set_current_user
router = APIRouter(prefix="/chat", tags=["chat"])
@router.post("")
async def chat(body: ChatRequest, request: Request) -> JSONResponse:
"""REST fallback for home chat.
In the microservices setup, Traefik ForwardAuth has already validated
the JWT and injected X-User-Id / X-User-Email / X-User-Tier headers.
"""
user_id = request.headers.get("X-User-Id", "")
if not user_id:
return JSONResponse(status_code=401, content={"detail": "Missing X-User-Id header"})
set_current_user(user_id)
try:
response = await run_home(
user_id=user_id,
message=body.message,
context=body.context.model_dump(),
)
finally:
clear_current_user()
return JSONResponse(content={"response": response})

View File

@@ -0,0 +1,304 @@
"""Langfuse tracing & prompt management for the Chat Service (v4 SDK).
Provides:
- ``init_langfuse()`` — initialise the singleton client at startup
- ``trace_span()`` — context manager that creates a trace + span
- ``get_langfuse_callback()`` — LangChain callback handler (auto-inherits trace)
- ``get_prompt()`` — fetch a managed prompt from Langfuse by name
- ``flush()`` / ``shutdown()`` — lifecycle management
All functions gracefully degrade to no-ops when Langfuse is not configured,
so the service works identically with or without observability keys.
Requires ``langfuse >= 3.0.0`` (v4 / "Fast Preview" SDK).
"""
from __future__ import annotations
import logging
from contextlib import contextmanager
from typing import Any
from shared.config import settings
logger = logging.getLogger(__name__)
# ── State ────────────────────────────────────────────────────────────────
_initialised: bool = False
_disabled: bool = False
def _is_configured() -> bool:
return bool(settings.LANGFUSE_SECRET_KEY and settings.LANGFUSE_PUBLIC_KEY)
def init_langfuse() -> None:
"""Initialise the Langfuse singleton. Call once at startup."""
global _initialised, _disabled
if _initialised or _disabled:
return
if not _is_configured():
_disabled = True
logger.info("tracing: Langfuse keys not set — tracing disabled")
return
try:
from langfuse import Langfuse
Langfuse(
secret_key=settings.LANGFUSE_SECRET_KEY,
public_key=settings.LANGFUSE_PUBLIC_KEY,
host=settings.LANGFUSE_HOST,
)
_initialised = True
logger.info("tracing: Langfuse client initialised (host=%s)", settings.LANGFUSE_HOST)
except Exception as exc:
_disabled = True
logger.warning("tracing: failed to initialise Langfuse: %s", exc)
def _get_client() -> Any | None:
"""Return the singleton Langfuse client, or *None* if disabled."""
if _disabled:
return None
if not _initialised:
init_langfuse()
if _disabled:
return None
try:
from langfuse import get_client
return get_client()
except Exception:
return None
# ── Null span (no-op when Langfuse is disabled) ─────────────────────────
class _NullSpan:
"""Drop-in replacement when Langfuse is disabled."""
def update(self, **_: Any) -> None: ...
def set_trace_io(self, **_: Any) -> None: ...
def score_trace(self, **_: Any) -> None: ...
# ── Trace context manager ───────────────────────────────────────────────
@contextmanager
def trace_span(
*,
name: str,
user_id: str,
session_id: str | None = None,
trace_id: str | None = None,
input: Any = None,
metadata: dict[str, Any] | None = None,
tags: list[str] | None = None,
):
"""Context manager that creates a Langfuse trace/span.
Yields the span object (or a ``_NullSpan`` if Langfuse is disabled).
A ``CallbackHandler`` created inside this block auto-inherits the trace
context, so there is no need to pass trace IDs manually.
"""
lf = _get_client()
if lf is None:
yield _NullSpan()
return
try:
from langfuse import Langfuse, propagate_attributes
trace_ctx: dict[str, str] = {}
if trace_id is not None:
trace_ctx["trace_id"] = Langfuse.create_trace_id(seed=trace_id)
with lf.start_as_current_observation(
as_type="span",
name=name,
input=input,
metadata=metadata or {},
**({"trace_context": trace_ctx} if trace_ctx else {}),
) as span:
with propagate_attributes(
user_id=user_id,
session_id=session_id,
tags=tags or [],
):
yield span
except Exception as exc:
logger.warning("tracing: trace_span(%s) failed: %s", name, exc)
yield _NullSpan()
# ── LangChain callback handler ──────────────────────────────────────────
def get_langfuse_callback() -> Any | None:
"""Return a LangChain ``CallbackHandler`` that auto-inherits the current trace.
Must be called inside a ``trace_span()`` block for proper linking.
Returns *None* when Langfuse is disabled.
"""
if _disabled and not _initialised:
return None
try:
from langfuse.langchain import CallbackHandler
return CallbackHandler()
except Exception as exc:
logger.warning("tracing: get_langfuse_callback failed: %s", exc)
return None
# ── Prompt management ────────────────────────────────────────────────────
def get_prompt(
name: str,
*,
version: int | None = None,
label: str | None = None,
fallback: str | None = None,
cache_ttl_seconds: int = 300,
) -> str | None:
"""Fetch a managed prompt from Langfuse by name (without variable compilation).
Returns the raw prompt string, or *fallback* if the prompt is not
found or Langfuse is disabled.
"""
lf = _get_client()
if lf is None:
return fallback
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
return prompt.prompt
except Exception as exc:
logger.warning("tracing: get_prompt(%s) failed: %s", name, exc)
return fallback
def compile_prompt(
name: str,
*,
fallback: str,
variables: dict[str, str],
version: int | None = None,
label: str | None = None,
cache_ttl_seconds: int = 300,
) -> str:
"""Fetch a managed prompt from Langfuse and compile it with ``{{variables}}``.
If the prompt exists in Langfuse, uses the SDK's ``.compile(**variables)``
which replaces ``{{key}}`` placeholders. If Langfuse is disabled or the
prompt is not found, falls back to ``fallback.format(**variables)`` (Python
``{key}`` placeholders).
This means:
- Langfuse prompts use ``{{variable}}`` syntax.
- Hardcoded fallback strings use Python ``{variable}`` syntax.
"""
lf = _get_client()
if lf is None:
return fallback.format(**variables)
try:
kwargs: dict[str, Any] = {
"name": name,
"cache_ttl_seconds": cache_ttl_seconds,
}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
return prompt.compile(**variables)
except Exception as exc:
logger.warning("tracing: compile_prompt(%s) failed, using fallback: %s", name, exc)
return fallback.format(**variables)
def link_prompt_to_trace(
span: Any,
prompt_name: str,
*,
version: int | None = None,
label: str | None = None,
) -> None:
"""Attach prompt metadata to a span/trace."""
lf = _get_client()
if lf is None or isinstance(span, _NullSpan):
return
try:
kwargs: dict[str, Any] = {"name": prompt_name}
if version is not None:
kwargs["version"] = version
if label is not None:
kwargs["label"] = label
prompt = lf.get_prompt(**kwargs)
span.update(metadata={"prompt": {"name": prompt_name, "version": prompt.version}})
except Exception as exc:
logger.warning("tracing: link_prompt_to_trace(%s) failed: %s", prompt_name, exc)
# ── Scoring helper ───────────────────────────────────────────────────────
def score_trace(
trace_id: str,
name: str,
value: float,
*,
comment: str | None = None,
) -> None:
"""Post a score to a trace (e.g. user feedback, latency, quality)."""
lf = _get_client()
if lf is None:
return
try:
lf.create_score(trace_id=trace_id, name=name, value=value, comment=comment)
except Exception as exc:
logger.warning("tracing: score_trace failed: %s", exc)
# ── Shutdown ─────────────────────────────────────────────────────────────
def flush() -> None:
"""Flush pending Langfuse events."""
lf = _get_client()
if lf is not None:
try:
lf.flush()
except Exception as exc:
logger.warning("tracing: flush failed: %s", exc)
def shutdown() -> None:
"""Flush and close the Langfuse client."""
global _initialised, _disabled
lf = _get_client()
if lf is not None:
try:
lf.flush()
lf.shutdown()
except Exception as exc:
logger.warning("tracing: shutdown failed: %s", exc)
_initialised = False
_disabled = False

View File

@@ -0,0 +1,17 @@
fastapi>=0.115.0
uvicorn[standard]>=0.34.0
gunicorn>=22.0.0
pydantic>=2.10.0
pydantic-settings>=2.7.0
sqlalchemy>=2.0.0
asyncpg>=0.30.0
redis>=5.0.0
cryptography>=42.0.0
python-dotenv>=1.0.0
langchain-core>=0.3.0
langchain-openai>=0.3.0
langchain-litellm>=0.3.0
litellm>=1.50.0
openai>=1.50.0
httpx>=0.27.0
langfuse>=3.0.0

View File

@@ -0,0 +1,36 @@
# ── builder ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS builder
WORKDIR /build
COPY services/ws-gateway/requirements.txt ./requirements.txt
RUN pip install --upgrade pip && \
pip install --no-cache-dir --prefix=/install -r requirements.txt
# ── runtime ──────────────────────────────────────────────────────────────────
FROM python:3.12-slim AS runtime
RUN addgroup --system appgroup && adduser --system --ingroup appgroup appuser
WORKDIR /app
COPY --from=builder /install /usr/local
# Shared module
COPY shared/ shared/
# Service source
COPY services/ws-gateway/app/ app/
RUN chown -R appuser:appgroup /app
USER appuser
EXPOSE 8000
# Single worker — each instance handles many WS connections via asyncio
CMD ["gunicorn", "app.main:app", \
"-k", "uvicorn.workers.UvicornWorker", \
"--bind", "0.0.0.0:8000", \
"--workers", "1", \
"--timeout", "0"]

View File

@@ -0,0 +1,17 @@
# WS Gateway
Stateless WebSocket proxy. Accepts Electron connections, authenticates JWT,
routes frames to Chat/Batch services via Redis pub/sub.
## No business logic
This service does NOT know what tasks, notes, or agents are.
It only routes JSON frames between Electron and downstream services.
## Scaling
Sticky sessions on `user_id` (Traefik consistent hashing).
## Redis channels used
- Subscribe: `ws:out:{user_id}` (frames to send to client)
- Publish: `chat:request:{user_id}`, `batch:request:{user_id}`
- LPUSH: `tool:result:{call_id}` (from client tool_result frames)
- HSET/HDEL: `ws:devices:{user_id}` (device registry)

Some files were not shown because too many files have changed in this diff Show More