refactor agents to client-owned config flow

This commit is contained in:
2026-03-16 22:35:46 +01:00
parent 02a9684cd6
commit 5faa6b1d7c
6 changed files with 259 additions and 589 deletions

View File

@@ -16,9 +16,9 @@ Journey flow:
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
5. Server parses the block, sets ``done=True``, and returns the template.
The ``prompt_template`` from the final response is meant to be stored in
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
by the Electron client (via the agent CRUD endpoints).
The ``prompt_template`` from the final response is meant to be stored by
the Electron client in local agent settings and later sent to
``POST /agents/trigger`` when a run is executed.
"""
from __future__ import annotations

View File

@@ -1,45 +1,35 @@
"""Agent CRUD routes: local directory agents and cloud connector agents.
"""Agent routes.
Endpoints:
GET /agents/catalog — hardcoded agent type catalog
GET /agents/local — list user's local agent configs
POST /agents/local create local agent (tier-gated)
PUT /agents/local/{agent_id} — partial update (ownership check)
DELETE /agents/local/{agent_id} — delete + cascade run logs
GET /agents/cloud — list user's cloud agent configs
POST /agents/cloud — create cloud agent (tier-gated)
PUT /agents/cloud/{agent_id} — partial update (ownership check)
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
GET /agents/runs — paginated run logs (agent_id, page, limit)
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
Backend responsibilities are intentionally minimal:
GET /agents/catalog — static catalog for UI display
POST /agents/can-create — billing eligibility check
POST /agents/triggertrigger a local agent run
Agent configuration is owned by the Electron app and is not persisted
in backend agent-config tables.
"""
from __future__ import annotations
import asyncio
import uuid
from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status
from pydantic import BaseModel
from sqlalchemy import func, or_, select
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES
from app.core.agent_runner import run_cloud_agent, run_local_agent
from app.core.agent_runner import run_local_agent
from app.core.device_manager import device_manager
from app.db import get_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import (
AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse,
CloudAgentConfigCreate,
CloudAgentConfigResponse,
CloudAgentConfigUpdate,
LocalAgentConfigCreate,
LocalAgentConfigResponse,
LocalAgentConfigUpdate,
AgentTriggerRequest,
UserProfile,
)
@@ -56,39 +46,14 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
# ── Model → schema converters ─────────────────────────────────────────
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
return LocalAgentConfigResponse(
id=a.id,
name=a.name,
device_id=a.device_id,
directory_paths=a.directory_paths,
data_types=a.data_types,
prompt_template=a.prompt_template,
file_extensions=a.file_extensions,
schedule_cron=a.schedule_cron,
enabled=a.enabled,
last_run_at=_dt_ms_opt(a.last_run_at),
created_at=_dt_ms(a.created_at),
updated_at=_dt_ms(a.updated_at),
)
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
return CloudAgentConfigResponse(
id=a.id,
provider=a.provider, # type: ignore[arg-type]
name=a.name,
data_types=a.data_types,
prompt_template=a.prompt_template,
schedule_cron=a.schedule_cron,
filter_config=a.filter_config,
enabled=a.enabled,
last_run_at=_dt_ms_opt(a.last_run_at),
created_at=_dt_ms(a.created_at),
updated_at=_dt_ms(a.updated_at),
)
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks",
"note": "notes",
"timeline": "timelines",
"project": "projects",
}
return [normalize[v] for v in values if v in normalize]
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
@@ -105,77 +70,14 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
)
# ── Ownership-checked lookups ─────────────────────────────────────────
async def _get_local_agent_for_user(
agent_id: str, user_id: str, db: AsyncSession
) -> LocalAgentConfig:
result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == user_id,
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
return record
async def _get_cloud_agent_for_user(
agent_id: str, user_id: str, db: AsyncSession
) -> CloudAgentConfig:
result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == user_id,
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
return record
# ── Tier limit helper ─────────────────────────────────────────────────
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
"""Return combined enabled local + cloud agent count for the user."""
local_count = (
await db.execute(
select(func.count(LocalAgentConfig.id)).where(
LocalAgentConfig.user_id == user_id,
LocalAgentConfig.enabled == True, # noqa: E712
)
)
).scalar_one()
cloud_count = (
await db.execute(
select(func.count(CloudAgentConfig.id)).where(
CloudAgentConfig.user_id == user_id,
CloudAgentConfig.enabled == True, # noqa: E712
)
)
).scalar_one()
return local_count + cloud_count
def _enforce_agent_limit(tier: str, current_count: int) -> None:
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
# ── Local page schema (used by runs endpoint) ─────────────────────────
class _RunsPage(BaseModel):
total: int
page: int
limit: int
items: list[AgentRunLogResponse]
return limit
# ── Catalog ───────────────────────────────────────────────────────────
@@ -190,6 +92,24 @@ async def get_agent_catalog(
type="local_directory",
name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI",
config_schema={
"directory": {"type": "string", "required": True},
"what_to_extract": {
"type": "array",
"items": ["task", "note", "timeline", "project"],
"required": True,
},
"actions_by_type": {
"type": "object",
"example": {
"task": ["add", "update"],
"note": ["add", "update"],
},
"required": False,
},
"batch_interval": {"type": "string", "required": True},
"custom_agent_prompt": {"type": "string", "required": True},
},
),
AgentCatalogItem(
type="gmail",
@@ -209,229 +129,51 @@ async def get_agent_catalog(
]
# ── Local agent CRUD ──────────────────────────────────────────────────
@router.get("/local", response_model=list[LocalAgentConfigResponse])
async def list_local_agents(
@router.post("/can-create", response_model=AgentCreationCheckResponse)
async def can_create_agent(
body: AgentCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[LocalAgentConfigResponse]:
"""List all local directory agent configs owned by the authenticated user."""
result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
)
return [_to_local_response(a) for a in result.scalars().all()]
) -> AgentCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
async def create_local_agent(
body: LocalAgentConfigCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> LocalAgentConfigResponse:
"""Create a new local directory agent config.
The combined count of enabled local and cloud agents for the user is
checked against the ``batch_active`` limit for their billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
agent = LocalAgentConfig(
user_id=current_user.id,
name=body.name,
device_id=body.device_id,
directory_paths=body.directory_paths,
data_types=body.data_types,
prompt_template=body.prompt_template,
file_extensions=body.file_extensions,
schedule_cron=body.schedule_cron,
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return AgentCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return _to_local_response(agent)
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
async def update_local_agent(
agent_id: str,
body: LocalAgentConfigUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> LocalAgentConfigResponse:
"""Partially update a local agent config. Only provided fields are changed."""
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
for field, value in body.model_dump(exclude_unset=True).items():
setattr(agent, field, value)
await db.commit()
await db.refresh(agent)
return _to_local_response(agent)
@router.delete("/local/{agent_id}", response_model=dict)
async def delete_local_agent(
agent_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a local agent config. Associated run logs are cascade-deleted."""
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
await db.delete(agent)
await db.commit()
return {"ok": True}
# ── Cloud agent CRUD ──────────────────────────────────────────────────
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
async def list_cloud_agents(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[CloudAgentConfigResponse]:
"""List all cloud connector agent configs owned by the authenticated user."""
result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
)
return [_to_cloud_response(a) for a in result.scalars().all()]
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
async def create_cloud_agent(
body: CloudAgentConfigCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> CloudAgentConfigResponse:
"""Create a new cloud connector agent config.
The combined count of enabled local and cloud agents for the user is
checked against the ``batch_active`` limit for their billing tier.
"""
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
agent = CloudAgentConfig(
user_id=current_user.id,
provider=body.provider,
name=body.name,
data_types=body.data_types,
prompt_template=body.prompt_template,
oauth_token_encrypted=body.oauth_token_encrypted,
schedule_cron=body.schedule_cron,
filter_config=body.filter_config,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return _to_cloud_response(agent)
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
async def update_cloud_agent(
agent_id: str,
body: CloudAgentConfigUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> CloudAgentConfigResponse:
"""Partially update a cloud agent config. Only provided fields are changed."""
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
for field, value in body.model_dump(exclude_unset=True).items():
setattr(agent, field, value)
await db.commit()
await db.refresh(agent)
return _to_cloud_response(agent)
@router.delete("/cloud/{agent_id}", response_model=dict)
async def delete_cloud_agent(
agent_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
await db.delete(agent)
await db.commit()
return {"ok": True}
# ── Run logs ──────────────────────────────────────────────────────────
@router.get("/runs", response_model=_RunsPage)
async def list_run_logs(
agent_id: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=20, ge=1, le=100),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _RunsPage:
"""Return paginated run logs for the authenticated user.
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
"""
base_filter = [AgentRunLog.user_id == current_user.id]
if agent_id:
base_filter.append(AgentRunLog.agent_id == agent_id)
total = (
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
).scalar_one()
result = await db.execute(
select(AgentRunLog)
.where(*base_filter)
.order_by(AgentRunLog.started_at.desc())
.offset((page - 1) * limit)
.limit(limit)
)
items = [_to_run_log_response(log) for log in result.scalars().all()]
return _RunsPage(total=total, page=page, limit=limit, items=items)
# ── Manual trigger stub ───────────────────────────────────────────────
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
agent_id: str,
body: AgentTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse:
"""Manually trigger an agent run.
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
Looks up the agent config (local or cloud) by ID with ownership check,
creates a run log entry with ``status="running"``, and returns it.
Actual dispatch to the agent runner is wired in Step 3.4 once
``DeviceConnectionManager`` and ``agent_runner`` are available.
"""
# Determine agent type by trying local first, then cloud.
# Keep the full config object so we can pass it to the agent runner.
local_config: LocalAgentConfig | None = None
cloud_config: CloudAgentConfig | None = None
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == current_user.id,
)
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id="",
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
)
local_config = local_result.scalar_one_or_none()
if local_config is not None:
agent_type = "local"
else:
cloud_result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == current_user.id,
)
)
cloud_config = cloud_result.scalar_one_or_none()
if cloud_config is not None:
agent_type = "cloud"
else:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
run_log = AgentRunLog(
agent_id=agent_id,
agent_type=agent_type,
agent_id=config.id,
agent_type="local",
user_id=current_user.id,
status="running",
)
@@ -439,14 +181,8 @@ async def trigger_agent_run(
await db.commit()
await db.refresh(run_log)
# Dispatch the run as a background task — returns 202 immediately.
if agent_type == "local" and local_config is not None:
asyncio.create_task(
run_local_agent(current_user.id, local_config, run_log, device_manager)
)
elif agent_type == "cloud" and cloud_config is not None:
asyncio.create_task(
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
)
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager)
)
return _to_run_log_response(run_log)