"""Agent REST routes — catalog, billing checks, trigger. Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas. Agent trigger dispatches via Redis to the consumer instead of spawning an in-process background task. """ from __future__ import annotations import json import uuid from datetime import datetime, timezone from fastapi import APIRouter, Header, HTTPException, status from sqlalchemy import func, select from sqlalchemy.ext.asyncio import AsyncSession from shared.db import async_session from shared.models import AgentRunLog from shared.redis import redis_client, batch_request_channel from app.agent_runner import is_agent_running router = APIRouter(prefix="/agents", tags=["agents"]) # ── Tier feature limits ─────────────────────────────────────────────── # Mirrors app/billing/tier_manager.py FEATURES dict. FEATURES: dict[str, dict] = { "free": {"batch_active": 1, "batch_runs_per_day": 3}, "pro": {"batch_active": 5, "batch_runs_per_day": 20}, "power": {"batch_active": 20, "batch_runs_per_day": 100}, "team": {"batch_active": -1, "batch_runs_per_day": -1}, } def _dt_ms(dt: datetime) -> int: return int(dt.timestamp() * 1000) def _dt_ms_opt(dt: datetime | None) -> int | None: return int(dt.timestamp() * 1000) if dt else None def _to_data_types(values: list[str]) -> list[str]: normalize = { "task": "tasks", "tasks": "tasks", "note": "notes", "notes": "notes", "timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines", "project": "projects", "projects": "projects", } seen: set[str] = set() result: list[str] = [] for v in values: mapped = normalize.get(v) if mapped and mapped not in seen: seen.add(mapped) result.append(mapped) return result def _enforce_agent_limit(tier: str, current_count: int) -> int: limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"] if limit != -1 and current_count >= limit: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.", ) return limit async def _enforce_run_frequency(tier: str, user_id: str) -> None: limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"] if limit == -1: return today_start = datetime.now(timezone.utc).replace( hour=0, minute=0, second=0, microsecond=0 ) async with async_session() as db: result = await db.execute( select(func.count(AgentRunLog.id)).where( AgentRunLog.user_id == user_id, AgentRunLog.started_at >= today_start, ) ) runs_today: int = result.scalar_one() if runs_today >= limit: raise HTTPException( status_code=status.HTTP_402_PAYMENT_REQUIRED, detail=f"Daily batch run limit ({limit}) reached for your tier.", ) # ── Catalog ─────────────────────────────────────────────────────────── @router.get("/catalog") async def get_agent_catalog( x_user_id: str = Header(..., alias="X-User-Id"), ) -> list[dict]: return [ { "type": "local_directory", "name": "Local Directory Monitor", "description": "Watches local directories, extracts data from files using AI", }, { "type": "gmail", "name": "Gmail Connector", "description": "Scans Gmail inbox, extracts tasks/notes from emails", }, { "type": "teams", "name": "Microsoft Teams Connector", "description": "Monitors Teams messages, extracts action items", }, { "type": "outlook", "name": "Outlook Connector", "description": "Scans Outlook inbox, extracts tasks/notes", }, ] # ── Can-create check ───────────────────────────────────────────────── @router.post("/can-create") async def can_create_agent( body: dict, x_user_id: str = Header(..., alias="X-User-Id"), x_user_tier: str = Header("free", alias="X-User-Tier"), ) -> dict: active_agents = body.get("active_agents", 0) limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"] allowed = limit == -1 or active_agents < limit return { "allowed": allowed, "tier": x_user_tier, "active_agents": active_agents, "limit": limit, } # ── Trigger ────────────────────────────────────────────────────────── @router.post("/trigger", status_code=status.HTTP_202_ACCEPTED) async def trigger_agent_run( body: dict, x_user_id: str = Header(..., alias="X-User-Id"), x_user_tier: str = Header("free", alias="X-User-Tier"), ) -> dict: """Trigger a local agent run — creates run log and dispatches via Redis.""" active_agents = body.get("active_agents", 0) _enforce_agent_limit(x_user_tier, active_agents) await _enforce_run_frequency(x_user_tier, x_user_id) stable_agent_id = body.get("agent_id") or str(uuid.uuid4()) if is_agent_running(stable_agent_id): raise HTTPException( status_code=status.HTTP_409_CONFLICT, detail="Agent is already running.", ) # Create run log in DB async with async_session() as db: run_log = AgentRunLog( agent_id=stable_agent_id, agent_type="local", user_id=x_user_id, status="running", ) db.add(run_log) await db.commit() await db.refresh(run_log) run_log_id = run_log.id run_context = { "type": "agent_batch", "run_id": run_log_id, "agent_id": stable_agent_id, } # Dispatch to the Redis consumer for processing trigger_data = { "type": "agent_trigger", "directory": body.get("directory", ""), "directory_paths": [body.get("directory", "")] if body.get("directory") else [], "data_types": _to_data_types(body.get("what_to_extract", [])), "file_extensions": body.get("file_extensions", []), "prompt_template": body.get("custom_agent_prompt", ""), "device_id": body.get("device_id", ""), "run_context": run_context, } channel = batch_request_channel(x_user_id) await redis_client.publish(channel, json.dumps(trigger_data)) return { "id": run_log_id, "agent_id": stable_agent_id, "agent_type": "local", "status": "running", "items_processed": 0, "items_created": 0, "errors": [], "started_at": _dt_ms(run_log.started_at), "completed_at": None, }