- agent_runner: local directory + cloud agent orchestration via Redis - 5 domain agents: filesystem, task, note, project, timeline - integrations: Gmail, MS Graph (Outlook + Teams) - journey: guided chatbot conversation to build prompt_template - routes: REST endpoints (catalog, can-create, trigger) - redis_consumer: subscribes to batch:request:* pattern - ws_context: Redis-based execute_on_client for tool round-trip - Dockerfile with 300s timeout for long-running batch jobs
209 lines
7.1 KiB
Python
209 lines
7.1 KiB
Python
"""Agent REST routes — catalog, billing checks, trigger.
|
|
|
|
Adapted for Batch Agent Service: uses shared.db, shared.models, shared.schemas.
|
|
Agent trigger dispatches via Redis to the consumer instead of spawning
|
|
an in-process background task.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
|
|
from fastapi import APIRouter, Header, HTTPException, status
|
|
from sqlalchemy import func, select
|
|
from sqlalchemy.ext.asyncio import AsyncSession
|
|
|
|
from shared.db import async_session
|
|
from shared.models import AgentRunLog
|
|
from shared.redis import redis_client, batch_request_channel
|
|
|
|
from app.agent_runner import is_agent_running
|
|
|
|
router = APIRouter(prefix="/agents", tags=["agents"])
|
|
|
|
# ── Tier feature limits ───────────────────────────────────────────────
|
|
# Mirrors app/billing/tier_manager.py FEATURES dict.
|
|
FEATURES: dict[str, dict] = {
|
|
"free": {"batch_active": 1, "batch_runs_per_day": 3},
|
|
"pro": {"batch_active": 5, "batch_runs_per_day": 20},
|
|
"power": {"batch_active": 20, "batch_runs_per_day": 100},
|
|
"team": {"batch_active": -1, "batch_runs_per_day": -1},
|
|
}
|
|
|
|
|
|
def _dt_ms(dt: datetime) -> int:
|
|
return int(dt.timestamp() * 1000)
|
|
|
|
|
|
def _dt_ms_opt(dt: datetime | None) -> int | None:
|
|
return int(dt.timestamp() * 1000) if dt else None
|
|
|
|
|
|
def _to_data_types(values: list[str]) -> list[str]:
|
|
normalize = {
|
|
"task": "tasks", "tasks": "tasks",
|
|
"note": "notes", "notes": "notes",
|
|
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
|
|
"project": "projects", "projects": "projects",
|
|
}
|
|
seen: set[str] = set()
|
|
result: list[str] = []
|
|
for v in values:
|
|
mapped = normalize.get(v)
|
|
if mapped and mapped not in seen:
|
|
seen.add(mapped)
|
|
result.append(mapped)
|
|
return result
|
|
|
|
|
|
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
|
if limit != -1 and current_count >= limit:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
|
)
|
|
return limit
|
|
|
|
|
|
async def _enforce_run_frequency(tier: str, user_id: str) -> None:
|
|
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
|
|
if limit == -1:
|
|
return
|
|
today_start = datetime.now(timezone.utc).replace(
|
|
hour=0, minute=0, second=0, microsecond=0
|
|
)
|
|
async with async_session() as db:
|
|
result = await db.execute(
|
|
select(func.count(AgentRunLog.id)).where(
|
|
AgentRunLog.user_id == user_id,
|
|
AgentRunLog.started_at >= today_start,
|
|
)
|
|
)
|
|
runs_today: int = result.scalar_one()
|
|
|
|
if runs_today >= limit:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_402_PAYMENT_REQUIRED,
|
|
detail=f"Daily batch run limit ({limit}) reached for your tier.",
|
|
)
|
|
|
|
|
|
# ── Catalog ───────────────────────────────────────────────────────────
|
|
|
|
@router.get("/catalog")
|
|
async def get_agent_catalog(
|
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
|
) -> list[dict]:
|
|
return [
|
|
{
|
|
"type": "local_directory",
|
|
"name": "Local Directory Monitor",
|
|
"description": "Watches local directories, extracts data from files using AI",
|
|
},
|
|
{
|
|
"type": "gmail",
|
|
"name": "Gmail Connector",
|
|
"description": "Scans Gmail inbox, extracts tasks/notes from emails",
|
|
},
|
|
{
|
|
"type": "teams",
|
|
"name": "Microsoft Teams Connector",
|
|
"description": "Monitors Teams messages, extracts action items",
|
|
},
|
|
{
|
|
"type": "outlook",
|
|
"name": "Outlook Connector",
|
|
"description": "Scans Outlook inbox, extracts tasks/notes",
|
|
},
|
|
]
|
|
|
|
|
|
# ── Can-create check ─────────────────────────────────────────────────
|
|
|
|
@router.post("/can-create")
|
|
async def can_create_agent(
|
|
body: dict,
|
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
|
) -> dict:
|
|
active_agents = body.get("active_agents", 0)
|
|
limit: int = FEATURES.get(x_user_tier, FEATURES["free"])["batch_active"]
|
|
allowed = limit == -1 or active_agents < limit
|
|
return {
|
|
"allowed": allowed,
|
|
"tier": x_user_tier,
|
|
"active_agents": active_agents,
|
|
"limit": limit,
|
|
}
|
|
|
|
|
|
# ── Trigger ──────────────────────────────────────────────────────────
|
|
|
|
@router.post("/trigger", status_code=status.HTTP_202_ACCEPTED)
|
|
async def trigger_agent_run(
|
|
body: dict,
|
|
x_user_id: str = Header(..., alias="X-User-Id"),
|
|
x_user_tier: str = Header("free", alias="X-User-Tier"),
|
|
) -> dict:
|
|
"""Trigger a local agent run — creates run log and dispatches via Redis."""
|
|
active_agents = body.get("active_agents", 0)
|
|
_enforce_agent_limit(x_user_tier, active_agents)
|
|
await _enforce_run_frequency(x_user_tier, x_user_id)
|
|
|
|
stable_agent_id = body.get("agent_id") or str(uuid.uuid4())
|
|
|
|
if is_agent_running(stable_agent_id):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="Agent is already running.",
|
|
)
|
|
|
|
# Create run log in DB
|
|
async with async_session() as db:
|
|
run_log = AgentRunLog(
|
|
agent_id=stable_agent_id,
|
|
agent_type="local",
|
|
user_id=x_user_id,
|
|
status="running",
|
|
)
|
|
db.add(run_log)
|
|
await db.commit()
|
|
await db.refresh(run_log)
|
|
run_log_id = run_log.id
|
|
|
|
run_context = {
|
|
"type": "agent_batch",
|
|
"run_id": run_log_id,
|
|
"agent_id": stable_agent_id,
|
|
}
|
|
|
|
# Dispatch to the Redis consumer for processing
|
|
trigger_data = {
|
|
"type": "agent_trigger",
|
|
"directory": body.get("directory", ""),
|
|
"directory_paths": [body.get("directory", "")] if body.get("directory") else [],
|
|
"data_types": _to_data_types(body.get("what_to_extract", [])),
|
|
"file_extensions": body.get("file_extensions", []),
|
|
"prompt_template": body.get("custom_agent_prompt", ""),
|
|
"device_id": body.get("device_id", ""),
|
|
"run_context": run_context,
|
|
}
|
|
|
|
channel = batch_request_channel(x_user_id)
|
|
await redis_client.publish(channel, json.dumps(trigger_data))
|
|
|
|
return {
|
|
"id": run_log_id,
|
|
"agent_id": stable_agent_id,
|
|
"agent_type": "local",
|
|
"status": "running",
|
|
"items_processed": 0,
|
|
"items_created": 0,
|
|
"errors": [],
|
|
"started_at": _dt_ms(run_log.started_at),
|
|
"completed_at": None,
|
|
}
|