diff --git a/AI_REFACTOR_PLAN.md b/AI_REFACTOR_PLAN.md index 9517a11..975b93c 100644 --- a/AI_REFACTOR_PLAN.md +++ b/AI_REFACTOR_PLAN.md @@ -322,7 +322,7 @@ Cloud Agent: - **Outcome:** Agent config and run tracking tables in PostgreSQL. ### Step 3.2 — Agent CRUD API routes -- [ ] Create `app/api/routes/agents.py`: +- [x] Create `app/api/routes/agents.py`: - `GET /api/v1/agents/catalog` — returns hardcoded agent type catalog: - `local_directory`: "Watches local directories, extracts data from files using AI" - `gmail`: "Scans Gmail inbox, extracts tasks/notes from emails" @@ -343,7 +343,7 @@ Cloud Agent: - `GET /api/v1/agents/runs` — query params: `agent_id`, `page`, `limit` → paginated run logs - `POST /api/v1/agents/{id}/run` — manual trigger (dispatches to agent runner) - All routes require JWT auth; ownership enforced on all mutations -- [ ] Register router in `app/main.py` +- [x] Register router in `app/main.py` - **Files:** `app/api/routes/agents.py`, `app/main.py` - **Outcome:** Full CRUD for agent configs with tier-gated creation limits. diff --git a/app/agents/task_agent.py b/app/agents/task_agent.py index 6d932a7..1d6e32d 100644 --- a/app/agents/task_agent.py +++ b/app/agents/task_agent.py @@ -2,6 +2,7 @@ from __future__ import annotations +import json from datetime import datetime, timezone from typing import Any diff --git a/app/api/routes/agents.py b/app/api/routes/agents.py new file mode 100644 index 0000000..748ffc9 --- /dev/null +++ b/app/api/routes/agents.py @@ -0,0 +1,432 @@ +"""Agent CRUD routes: local directory agents and cloud connector agents. + +Endpoints: + GET /agents/catalog — hardcoded agent type catalog + GET /agents/local — list user's local agent configs + POST /agents/local — create local agent (tier-gated) + PUT /agents/local/{agent_id} — partial update (ownership check) + DELETE /agents/local/{agent_id} — delete + cascade run logs + GET /agents/cloud — list user's cloud agent configs + POST /agents/cloud — create cloud agent (tier-gated) + PUT /agents/cloud/{agent_id} — partial update (ownership check) + DELETE /agents/cloud/{agent_id} — delete + cascade run logs + GET /agents/runs — paginated run logs (agent_id, page, limit) + POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4) +""" + +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from fastapi import APIRouter, Depends, HTTPException, Query, status +from pydantic import BaseModel +from sqlalchemy import func, or_, select +from sqlalchemy.ext.asyncio import AsyncSession + +from app.api.deps import get_current_user +from app.billing.tier_manager import FEATURES +from app.db import get_session +from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig +from app.schemas import ( + AgentCatalogItem, + AgentRunLogResponse, + CloudAgentConfigCreate, + CloudAgentConfigResponse, + CloudAgentConfigUpdate, + LocalAgentConfigCreate, + LocalAgentConfigResponse, + LocalAgentConfigUpdate, + UserProfile, +) + +router = APIRouter(prefix="/agents", tags=["agents"]) + + +# ── Datetime helpers ────────────────────────────────────────────────── + +def _dt_ms(dt: datetime) -> int: + return int(dt.timestamp() * 1000) + + +def _dt_ms_opt(dt: datetime | None) -> int | None: + return int(dt.timestamp() * 1000) if dt else None + + +# ── Model → schema converters ───────────────────────────────────────── + +def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse: + return LocalAgentConfigResponse( + id=a.id, + name=a.name, + device_id=a.device_id, + directory_paths=a.directory_paths, + data_types=a.data_types, + prompt_template=a.prompt_template, + file_extensions=a.file_extensions, + schedule_cron=a.schedule_cron, + enabled=a.enabled, + last_run_at=_dt_ms_opt(a.last_run_at), + created_at=_dt_ms(a.created_at), + updated_at=_dt_ms(a.updated_at), + ) + + +def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse: + return CloudAgentConfigResponse( + id=a.id, + provider=a.provider, # type: ignore[arg-type] + name=a.name, + data_types=a.data_types, + prompt_template=a.prompt_template, + schedule_cron=a.schedule_cron, + filter_config=a.filter_config, + enabled=a.enabled, + last_run_at=_dt_ms_opt(a.last_run_at), + created_at=_dt_ms(a.created_at), + updated_at=_dt_ms(a.updated_at), + ) + + +def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse: + return AgentRunLogResponse( + id=log.id, + agent_id=log.agent_id, + agent_type=log.agent_type, # type: ignore[arg-type] + status=log.status, # type: ignore[arg-type] + items_processed=log.items_processed, + items_created=log.items_created, + errors=log.errors or [], + started_at=_dt_ms(log.started_at), + completed_at=_dt_ms_opt(log.completed_at), + ) + + +# ── Ownership-checked lookups ───────────────────────────────────────── + +async def _get_local_agent_for_user( + agent_id: str, user_id: str, db: AsyncSession +) -> LocalAgentConfig: + result = await db.execute( + select(LocalAgentConfig).where( + LocalAgentConfig.id == agent_id, + LocalAgentConfig.user_id == user_id, + ) + ) + record = result.scalar_one_or_none() + if record is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found") + return record + + +async def _get_cloud_agent_for_user( + agent_id: str, user_id: str, db: AsyncSession +) -> CloudAgentConfig: + result = await db.execute( + select(CloudAgentConfig).where( + CloudAgentConfig.id == agent_id, + CloudAgentConfig.user_id == user_id, + ) + ) + record = result.scalar_one_or_none() + if record is None: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found") + return record + + +# ── Tier limit helper ───────────────────────────────────────────────── + +async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int: + """Return combined enabled local + cloud agent count for the user.""" + local_count = ( + await db.execute( + select(func.count(LocalAgentConfig.id)).where( + LocalAgentConfig.user_id == user_id, + LocalAgentConfig.enabled == True, # noqa: E712 + ) + ) + ).scalar_one() + cloud_count = ( + await db.execute( + select(func.count(CloudAgentConfig.id)).where( + CloudAgentConfig.user_id == user_id, + CloudAgentConfig.enabled == True, # noqa: E712 + ) + ) + ).scalar_one() + return local_count + cloud_count + + +def _enforce_agent_limit(tier: str, current_count: int) -> None: + limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"] + if limit != -1 and current_count >= limit: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.", + ) + + +# ── Local page schema (used by runs endpoint) ───────────────────────── + +class _RunsPage(BaseModel): + total: int + page: int + limit: int + items: list[AgentRunLogResponse] + + +# ── Catalog ─────────────────────────────────────────────────────────── + +@router.get("/catalog", response_model=list[AgentCatalogItem]) +async def get_agent_catalog( + current_user: UserProfile = Depends(get_current_user), +) -> list[AgentCatalogItem]: + """Return the static list of available agent types and their descriptions.""" + return [ + AgentCatalogItem( + type="local_directory", + name="Local Directory Monitor", + description="Watches local directories, extracts data from files using AI", + ), + AgentCatalogItem( + type="gmail", + name="Gmail Connector", + description="Scans Gmail inbox, extracts tasks/notes from emails", + ), + AgentCatalogItem( + type="teams", + name="Microsoft Teams Connector", + description="Monitors Teams messages, extracts action items", + ), + AgentCatalogItem( + type="outlook", + name="Outlook Connector", + description="Scans Outlook inbox, extracts tasks/notes", + ), + ] + + +# ── Local agent CRUD ────────────────────────────────────────────────── + +@router.get("/local", response_model=list[LocalAgentConfigResponse]) +async def list_local_agents( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> list[LocalAgentConfigResponse]: + """List all local directory agent configs owned by the authenticated user.""" + result = await db.execute( + select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id) + ) + return [_to_local_response(a) for a in result.scalars().all()] + + +@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED) +async def create_local_agent( + body: LocalAgentConfigCreate, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> LocalAgentConfigResponse: + """Create a new local directory agent config. + + The combined count of enabled local and cloud agents for the user is + checked against the ``batch_active`` limit for their billing tier. + """ + _enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db)) + agent = LocalAgentConfig( + user_id=current_user.id, + name=body.name, + device_id=body.device_id, + directory_paths=body.directory_paths, + data_types=body.data_types, + prompt_template=body.prompt_template, + file_extensions=body.file_extensions, + schedule_cron=body.schedule_cron, + ) + db.add(agent) + await db.commit() + await db.refresh(agent) + return _to_local_response(agent) + + +@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse) +async def update_local_agent( + agent_id: str, + body: LocalAgentConfigUpdate, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> LocalAgentConfigResponse: + """Partially update a local agent config. Only provided fields are changed.""" + agent = await _get_local_agent_for_user(agent_id, current_user.id, db) + for field, value in body.model_dump(exclude_unset=True).items(): + setattr(agent, field, value) + await db.commit() + await db.refresh(agent) + return _to_local_response(agent) + + +@router.delete("/local/{agent_id}", response_model=dict) +async def delete_local_agent( + agent_id: str, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, bool]: + """Delete a local agent config. Associated run logs are cascade-deleted.""" + agent = await _get_local_agent_for_user(agent_id, current_user.id, db) + await db.delete(agent) + await db.commit() + return {"ok": True} + + +# ── Cloud agent CRUD ────────────────────────────────────────────────── + +@router.get("/cloud", response_model=list[CloudAgentConfigResponse]) +async def list_cloud_agents( + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> list[CloudAgentConfigResponse]: + """List all cloud connector agent configs owned by the authenticated user.""" + result = await db.execute( + select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id) + ) + return [_to_cloud_response(a) for a in result.scalars().all()] + + +@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED) +async def create_cloud_agent( + body: CloudAgentConfigCreate, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> CloudAgentConfigResponse: + """Create a new cloud connector agent config. + + The combined count of enabled local and cloud agents for the user is + checked against the ``batch_active`` limit for their billing tier. + """ + _enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db)) + agent = CloudAgentConfig( + user_id=current_user.id, + provider=body.provider, + name=body.name, + data_types=body.data_types, + prompt_template=body.prompt_template, + oauth_token_encrypted=body.oauth_token_encrypted, + schedule_cron=body.schedule_cron, + filter_config=body.filter_config, + ) + db.add(agent) + await db.commit() + await db.refresh(agent) + return _to_cloud_response(agent) + + +@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse) +async def update_cloud_agent( + agent_id: str, + body: CloudAgentConfigUpdate, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> CloudAgentConfigResponse: + """Partially update a cloud agent config. Only provided fields are changed.""" + agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db) + for field, value in body.model_dump(exclude_unset=True).items(): + setattr(agent, field, value) + await db.commit() + await db.refresh(agent) + return _to_cloud_response(agent) + + +@router.delete("/cloud/{agent_id}", response_model=dict) +async def delete_cloud_agent( + agent_id: str, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> dict[str, bool]: + """Delete a cloud agent config. Associated run logs are cascade-deleted.""" + agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db) + await db.delete(agent) + await db.commit() + return {"ok": True} + + +# ── Run logs ────────────────────────────────────────────────────────── + +@router.get("/runs", response_model=_RunsPage) +async def list_run_logs( + agent_id: str | None = Query(default=None), + page: int = Query(default=1, ge=1), + limit: int = Query(default=20, ge=1, le=100), + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> _RunsPage: + """Return paginated run logs for the authenticated user. + + Optionally filter by ``agent_id``. Results are ordered from newest to oldest. + """ + base_filter = [AgentRunLog.user_id == current_user.id] + if agent_id: + base_filter.append(AgentRunLog.agent_id == agent_id) + + total = ( + await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter)) + ).scalar_one() + + result = await db.execute( + select(AgentRunLog) + .where(*base_filter) + .order_by(AgentRunLog.started_at.desc()) + .offset((page - 1) * limit) + .limit(limit) + ) + items = [_to_run_log_response(log) for log in result.scalars().all()] + + return _RunsPage(total=total, page=page, limit=limit, items=items) + + +# ── Manual trigger stub ─────────────────────────────────────────────── + +@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED) +async def trigger_agent_run( + agent_id: str, + current_user: UserProfile = Depends(get_current_user), + db: AsyncSession = Depends(get_session), +) -> AgentRunLogResponse: + """Manually trigger an agent run. + + Looks up the agent config (local or cloud) by ID with ownership check, + creates a run log entry with ``status="running"``, and returns it. + + Actual dispatch to the agent runner is wired in Step 3.4 once + ``DeviceConnectionManager`` and ``agent_runner`` are available. + """ + # Determine agent type by trying local first, then cloud. + agent_type: str + local_result = await db.execute( + select(LocalAgentConfig).where( + LocalAgentConfig.id == agent_id, + LocalAgentConfig.user_id == current_user.id, + ) + ) + if local_result.scalar_one_or_none() is not None: + agent_type = "local" + else: + cloud_result = await db.execute( + select(CloudAgentConfig).where( + CloudAgentConfig.id == agent_id, + CloudAgentConfig.user_id == current_user.id, + ) + ) + if cloud_result.scalar_one_or_none() is not None: + agent_type = "cloud" + else: + raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found") + + run_log = AgentRunLog( + agent_id=agent_id, + agent_type=agent_type, + user_id=current_user.id, + status="running", + ) + db.add(run_log) + await db.commit() + await db.refresh(run_log) + return _to_run_log_response(run_log) diff --git a/app/main.py b/app/main.py index 29d7230..31a9822 100644 --- a/app/main.py +++ b/app/main.py @@ -43,7 +43,7 @@ def create_app() -> FastAPI: app.add_middleware(SanitizerMiddleware) app.add_middleware(TierRateLimitMiddleware) - from app.api.routes import auth, backup, billing, chat, plans, plugins, storage, vectors + from app.api.routes import agents, auth, backup, billing, chat, plans, plugins, storage, vectors app.include_router(auth.router, prefix="/api/v1") app.include_router(chat.router, prefix="/api/v1") @@ -53,6 +53,7 @@ def create_app() -> FastAPI: app.include_router(backup.router, prefix="/api/v1") app.include_router(plugins.router, prefix="/api/v1") app.include_router(billing.router, prefix="/api/v1") + app.include_router(agents.router, prefix="/api/v1") @app.get("/api/v1/health", tags=["health"]) async def health() -> dict: