refactor agents to client-owned config flow
This commit is contained in:
@@ -0,0 +1,92 @@
|
||||
"""Deprecate backend agent config tables.
|
||||
|
||||
The Electron client is now the source of truth for agent configuration
|
||||
(directory, extract targets, batch interval, custom prompt). Backend keeps
|
||||
billing checks and trigger/run logs only.
|
||||
|
||||
Revision ID: 9a1f2d0b6c7e
|
||||
Revises: 818478c251dc
|
||||
Create Date: 2026-03-16
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
revision: str = "9a1f2d0b6c7e"
|
||||
down_revision: Union[str, None] = "818478c251dc"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
bind = op.get_bind()
|
||||
inspector = sa.inspect(bind)
|
||||
existing = set(inspector.get_table_names())
|
||||
|
||||
if "cloud_agent_configs" in existing:
|
||||
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
|
||||
op.drop_table("cloud_agent_configs")
|
||||
|
||||
if "local_agent_configs" in existing:
|
||||
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
|
||||
op.drop_table("local_agent_configs")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.create_table(
|
||||
"local_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("device_id", sa.String(255), nullable=False),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
|
||||
EXCEPTION WHEN duplicate_object THEN NULL;
|
||||
END $$;
|
||||
"""
|
||||
)
|
||||
|
||||
op.create_table(
|
||||
"cloud_agent_configs",
|
||||
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
|
||||
sa.Column(
|
||||
"provider",
|
||||
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("name", sa.String(255), nullable=False),
|
||||
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
|
||||
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
|
||||
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
|
||||
sa.Column("filter_config", sa.JSON, nullable=True),
|
||||
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
|
||||
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
|
||||
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
|
||||
)
|
||||
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])
|
||||
@@ -16,9 +16,9 @@ Journey flow:
|
||||
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
|
||||
5. Server parses the block, sets ``done=True``, and returns the template.
|
||||
|
||||
The ``prompt_template`` from the final response is meant to be stored in
|
||||
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template``
|
||||
by the Electron client (via the agent CRUD endpoints).
|
||||
The ``prompt_template`` from the final response is meant to be stored by
|
||||
the Electron client in local agent settings and later sent to
|
||||
``POST /agents/trigger`` when a run is executed.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
@@ -1,45 +1,35 @@
|
||||
"""Agent CRUD routes: local directory agents and cloud connector agents.
|
||||
"""Agent routes.
|
||||
|
||||
Endpoints:
|
||||
GET /agents/catalog — hardcoded agent type catalog
|
||||
GET /agents/local — list user's local agent configs
|
||||
POST /agents/local — create local agent (tier-gated)
|
||||
PUT /agents/local/{agent_id} — partial update (ownership check)
|
||||
DELETE /agents/local/{agent_id} — delete + cascade run logs
|
||||
GET /agents/cloud — list user's cloud agent configs
|
||||
POST /agents/cloud — create cloud agent (tier-gated)
|
||||
PUT /agents/cloud/{agent_id} — partial update (ownership check)
|
||||
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
|
||||
GET /agents/runs — paginated run logs (agent_id, page, limit)
|
||||
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
|
||||
Backend responsibilities are intentionally minimal:
|
||||
GET /agents/catalog — static catalog for UI display
|
||||
POST /agents/can-create — billing eligibility check
|
||||
POST /agents/trigger — trigger a local agent run
|
||||
|
||||
Agent configuration is owned by the Electron app and is not persisted
|
||||
in backend agent-config tables.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func, or_, select
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import FEATURES
|
||||
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
||||
from app.core.agent_runner import run_local_agent
|
||||
from app.core.device_manager import device_manager
|
||||
from app.db import get_session
|
||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||
from app.models import AgentRunLog, LocalAgentConfig
|
||||
from app.schemas import (
|
||||
AgentCatalogItem,
|
||||
AgentCreationCheckRequest,
|
||||
AgentCreationCheckResponse,
|
||||
AgentRunLogResponse,
|
||||
CloudAgentConfigCreate,
|
||||
CloudAgentConfigResponse,
|
||||
CloudAgentConfigUpdate,
|
||||
LocalAgentConfigCreate,
|
||||
LocalAgentConfigResponse,
|
||||
LocalAgentConfigUpdate,
|
||||
AgentTriggerRequest,
|
||||
UserProfile,
|
||||
)
|
||||
|
||||
@@ -56,39 +46,14 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
|
||||
return int(dt.timestamp() * 1000) if dt else None
|
||||
|
||||
|
||||
# ── Model → schema converters ─────────────────────────────────────────
|
||||
|
||||
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse:
|
||||
return LocalAgentConfigResponse(
|
||||
id=a.id,
|
||||
name=a.name,
|
||||
device_id=a.device_id,
|
||||
directory_paths=a.directory_paths,
|
||||
data_types=a.data_types,
|
||||
prompt_template=a.prompt_template,
|
||||
file_extensions=a.file_extensions,
|
||||
schedule_cron=a.schedule_cron,
|
||||
enabled=a.enabled,
|
||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||
created_at=_dt_ms(a.created_at),
|
||||
updated_at=_dt_ms(a.updated_at),
|
||||
)
|
||||
|
||||
|
||||
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
|
||||
return CloudAgentConfigResponse(
|
||||
id=a.id,
|
||||
provider=a.provider, # type: ignore[arg-type]
|
||||
name=a.name,
|
||||
data_types=a.data_types,
|
||||
prompt_template=a.prompt_template,
|
||||
schedule_cron=a.schedule_cron,
|
||||
filter_config=a.filter_config,
|
||||
enabled=a.enabled,
|
||||
last_run_at=_dt_ms_opt(a.last_run_at),
|
||||
created_at=_dt_ms(a.created_at),
|
||||
updated_at=_dt_ms(a.updated_at),
|
||||
)
|
||||
def _to_data_types(values: list[str]) -> list[str]:
|
||||
normalize = {
|
||||
"task": "tasks",
|
||||
"note": "notes",
|
||||
"timeline": "timelines",
|
||||
"project": "projects",
|
||||
}
|
||||
return [normalize[v] for v in values if v in normalize]
|
||||
|
||||
|
||||
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||
@@ -105,77 +70,14 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
|
||||
)
|
||||
|
||||
|
||||
# ── Ownership-checked lookups ─────────────────────────────────────────
|
||||
|
||||
async def _get_local_agent_for_user(
|
||||
agent_id: str, user_id: str, db: AsyncSession
|
||||
) -> LocalAgentConfig:
|
||||
result = await db.execute(
|
||||
select(LocalAgentConfig).where(
|
||||
LocalAgentConfig.id == agent_id,
|
||||
LocalAgentConfig.user_id == user_id,
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if record is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||
return record
|
||||
|
||||
|
||||
async def _get_cloud_agent_for_user(
|
||||
agent_id: str, user_id: str, db: AsyncSession
|
||||
) -> CloudAgentConfig:
|
||||
result = await db.execute(
|
||||
select(CloudAgentConfig).where(
|
||||
CloudAgentConfig.id == agent_id,
|
||||
CloudAgentConfig.user_id == user_id,
|
||||
)
|
||||
)
|
||||
record = result.scalar_one_or_none()
|
||||
if record is None:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||
return record
|
||||
|
||||
|
||||
# ── Tier limit helper ─────────────────────────────────────────────────
|
||||
|
||||
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
|
||||
"""Return combined enabled local + cloud agent count for the user."""
|
||||
local_count = (
|
||||
await db.execute(
|
||||
select(func.count(LocalAgentConfig.id)).where(
|
||||
LocalAgentConfig.user_id == user_id,
|
||||
LocalAgentConfig.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
cloud_count = (
|
||||
await db.execute(
|
||||
select(func.count(CloudAgentConfig.id)).where(
|
||||
CloudAgentConfig.user_id == user_id,
|
||||
CloudAgentConfig.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
).scalar_one()
|
||||
return local_count + cloud_count
|
||||
|
||||
|
||||
def _enforce_agent_limit(tier: str, current_count: int) -> None:
|
||||
def _enforce_agent_limit(tier: str, current_count: int) -> int:
|
||||
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
|
||||
if limit != -1 and current_count >= limit:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
|
||||
)
|
||||
|
||||
|
||||
# ── Local page schema (used by runs endpoint) ─────────────────────────
|
||||
|
||||
class _RunsPage(BaseModel):
|
||||
total: int
|
||||
page: int
|
||||
limit: int
|
||||
items: list[AgentRunLogResponse]
|
||||
return limit
|
||||
|
||||
|
||||
# ── Catalog ───────────────────────────────────────────────────────────
|
||||
@@ -190,6 +92,24 @@ async def get_agent_catalog(
|
||||
type="local_directory",
|
||||
name="Local Directory Monitor",
|
||||
description="Watches local directories, extracts data from files using AI",
|
||||
config_schema={
|
||||
"directory": {"type": "string", "required": True},
|
||||
"what_to_extract": {
|
||||
"type": "array",
|
||||
"items": ["task", "note", "timeline", "project"],
|
||||
"required": True,
|
||||
},
|
||||
"actions_by_type": {
|
||||
"type": "object",
|
||||
"example": {
|
||||
"task": ["add", "update"],
|
||||
"note": ["add", "update"],
|
||||
},
|
||||
"required": False,
|
||||
},
|
||||
"batch_interval": {"type": "string", "required": True},
|
||||
"custom_agent_prompt": {"type": "string", "required": True},
|
||||
},
|
||||
),
|
||||
AgentCatalogItem(
|
||||
type="gmail",
|
||||
@@ -209,229 +129,51 @@ async def get_agent_catalog(
|
||||
]
|
||||
|
||||
|
||||
# ── Local agent CRUD ──────────────────────────────────────────────────
|
||||
|
||||
@router.get("/local", response_model=list[LocalAgentConfigResponse])
|
||||
async def list_local_agents(
|
||||
@router.post("/can-create", response_model=AgentCreationCheckResponse)
|
||||
async def can_create_agent(
|
||||
body: AgentCreationCheckRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[LocalAgentConfigResponse]:
|
||||
"""List all local directory agent configs owned by the authenticated user."""
|
||||
result = await db.execute(
|
||||
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
|
||||
)
|
||||
return [_to_local_response(a) for a in result.scalars().all()]
|
||||
) -> AgentCreationCheckResponse:
|
||||
"""Check if the user can create one more agent based on billing tier.
|
||||
|
||||
|
||||
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_local_agent(
|
||||
body: LocalAgentConfigCreate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> LocalAgentConfigResponse:
|
||||
"""Create a new local directory agent config.
|
||||
|
||||
The combined count of enabled local and cloud agents for the user is
|
||||
checked against the ``batch_active`` limit for their billing tier.
|
||||
Since configuration is client-owned, the Electron app sends its current
|
||||
active agent count and the backend applies tier limits.
|
||||
"""
|
||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||
agent = LocalAgentConfig(
|
||||
user_id=current_user.id,
|
||||
name=body.name,
|
||||
device_id=body.device_id,
|
||||
directory_paths=body.directory_paths,
|
||||
data_types=body.data_types,
|
||||
prompt_template=body.prompt_template,
|
||||
file_extensions=body.file_extensions,
|
||||
schedule_cron=body.schedule_cron,
|
||||
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
|
||||
allowed = limit == -1 or body.active_agents < limit
|
||||
return AgentCreationCheckResponse(
|
||||
allowed=allowed,
|
||||
tier=current_user.tier,
|
||||
active_agents=body.active_agents,
|
||||
limit=limit,
|
||||
)
|
||||
db.add(agent)
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return _to_local_response(agent)
|
||||
|
||||
|
||||
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse)
|
||||
async def update_local_agent(
|
||||
agent_id: str,
|
||||
body: LocalAgentConfigUpdate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> LocalAgentConfigResponse:
|
||||
"""Partially update a local agent config. Only provided fields are changed."""
|
||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
setattr(agent, field, value)
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return _to_local_response(agent)
|
||||
|
||||
|
||||
@router.delete("/local/{agent_id}", response_model=dict)
|
||||
async def delete_local_agent(
|
||||
agent_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a local agent config. Associated run logs are cascade-deleted."""
|
||||
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
|
||||
await db.delete(agent)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Cloud agent CRUD ──────────────────────────────────────────────────
|
||||
|
||||
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
|
||||
async def list_cloud_agents(
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> list[CloudAgentConfigResponse]:
|
||||
"""List all cloud connector agent configs owned by the authenticated user."""
|
||||
result = await db.execute(
|
||||
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
|
||||
)
|
||||
return [_to_cloud_response(a) for a in result.scalars().all()]
|
||||
|
||||
|
||||
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
|
||||
async def create_cloud_agent(
|
||||
body: CloudAgentConfigCreate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> CloudAgentConfigResponse:
|
||||
"""Create a new cloud connector agent config.
|
||||
|
||||
The combined count of enabled local and cloud agents for the user is
|
||||
checked against the ``batch_active`` limit for their billing tier.
|
||||
"""
|
||||
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
|
||||
agent = CloudAgentConfig(
|
||||
user_id=current_user.id,
|
||||
provider=body.provider,
|
||||
name=body.name,
|
||||
data_types=body.data_types,
|
||||
prompt_template=body.prompt_template,
|
||||
oauth_token_encrypted=body.oauth_token_encrypted,
|
||||
schedule_cron=body.schedule_cron,
|
||||
filter_config=body.filter_config,
|
||||
)
|
||||
db.add(agent)
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return _to_cloud_response(agent)
|
||||
|
||||
|
||||
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
|
||||
async def update_cloud_agent(
|
||||
agent_id: str,
|
||||
body: CloudAgentConfigUpdate,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> CloudAgentConfigResponse:
|
||||
"""Partially update a cloud agent config. Only provided fields are changed."""
|
||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||
for field, value in body.model_dump(exclude_unset=True).items():
|
||||
setattr(agent, field, value)
|
||||
await db.commit()
|
||||
await db.refresh(agent)
|
||||
return _to_cloud_response(agent)
|
||||
|
||||
|
||||
@router.delete("/cloud/{agent_id}", response_model=dict)
|
||||
async def delete_cloud_agent(
|
||||
agent_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> dict[str, bool]:
|
||||
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
|
||||
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
|
||||
await db.delete(agent)
|
||||
await db.commit()
|
||||
return {"ok": True}
|
||||
|
||||
|
||||
# ── Run logs ──────────────────────────────────────────────────────────
|
||||
|
||||
@router.get("/runs", response_model=_RunsPage)
|
||||
async def list_run_logs(
|
||||
agent_id: str | None = Query(default=None),
|
||||
page: int = Query(default=1, ge=1),
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> _RunsPage:
|
||||
"""Return paginated run logs for the authenticated user.
|
||||
|
||||
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
|
||||
"""
|
||||
base_filter = [AgentRunLog.user_id == current_user.id]
|
||||
if agent_id:
|
||||
base_filter.append(AgentRunLog.agent_id == agent_id)
|
||||
|
||||
total = (
|
||||
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
|
||||
).scalar_one()
|
||||
|
||||
result = await db.execute(
|
||||
select(AgentRunLog)
|
||||
.where(*base_filter)
|
||||
.order_by(AgentRunLog.started_at.desc())
|
||||
.offset((page - 1) * limit)
|
||||
.limit(limit)
|
||||
)
|
||||
items = [_to_run_log_response(log) for log in result.scalars().all()]
|
||||
|
||||
return _RunsPage(total=total, page=page, limit=limit, items=items)
|
||||
|
||||
|
||||
# ── Manual trigger stub ───────────────────────────────────────────────
|
||||
|
||||
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
|
||||
async def trigger_agent_run(
|
||||
agent_id: str,
|
||||
body: AgentTriggerRequest,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
db: AsyncSession = Depends(get_session),
|
||||
) -> AgentRunLogResponse:
|
||||
"""Manually trigger an agent run.
|
||||
"""Trigger a local agent run using client-provided configuration."""
|
||||
_enforce_agent_limit(current_user.tier, body.active_agents)
|
||||
|
||||
Looks up the agent config (local or cloud) by ID with ownership check,
|
||||
creates a run log entry with ``status="running"``, and returns it.
|
||||
|
||||
Actual dispatch to the agent runner is wired in Step 3.4 once
|
||||
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
||||
"""
|
||||
# Determine agent type by trying local first, then cloud.
|
||||
# Keep the full config object so we can pass it to the agent runner.
|
||||
local_config: LocalAgentConfig | None = None
|
||||
cloud_config: CloudAgentConfig | None = None
|
||||
|
||||
local_result = await db.execute(
|
||||
select(LocalAgentConfig).where(
|
||||
LocalAgentConfig.id == agent_id,
|
||||
LocalAgentConfig.user_id == current_user.id,
|
||||
config = LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=current_user.id,
|
||||
device_id="",
|
||||
name="Local Directory Monitor",
|
||||
directory_paths=[body.directory],
|
||||
data_types=_to_data_types(body.what_to_extract),
|
||||
prompt_template=body.custom_agent_prompt,
|
||||
file_extensions=[],
|
||||
schedule_cron=body.batch_interval,
|
||||
enabled=True,
|
||||
)
|
||||
)
|
||||
local_config = local_result.scalar_one_or_none()
|
||||
|
||||
if local_config is not None:
|
||||
agent_type = "local"
|
||||
else:
|
||||
cloud_result = await db.execute(
|
||||
select(CloudAgentConfig).where(
|
||||
CloudAgentConfig.id == agent_id,
|
||||
CloudAgentConfig.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
cloud_config = cloud_result.scalar_one_or_none()
|
||||
if cloud_config is not None:
|
||||
agent_type = "cloud"
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||
|
||||
run_log = AgentRunLog(
|
||||
agent_id=agent_id,
|
||||
agent_type=agent_type,
|
||||
agent_id=config.id,
|
||||
agent_type="local",
|
||||
user_id=current_user.id,
|
||||
status="running",
|
||||
)
|
||||
@@ -439,14 +181,8 @@ async def trigger_agent_run(
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
|
||||
# Dispatch the run as a background task — returns 202 immediately.
|
||||
if agent_type == "local" and local_config is not None:
|
||||
asyncio.create_task(
|
||||
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
||||
)
|
||||
elif agent_type == "cloud" and cloud_config is not None:
|
||||
asyncio.create_task(
|
||||
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
||||
run_local_agent(current_user.id, config, run_log, device_manager)
|
||||
)
|
||||
|
||||
return _to_run_log_response(run_log)
|
||||
|
||||
@@ -238,17 +238,23 @@ async def run_local_agent(
|
||||
run_id = run_log.id
|
||||
|
||||
# ── 1. Device online check ─────────────────────────────────────────
|
||||
if not device_mgr.is_online(user_id, config.device_id):
|
||||
target_device_id = config.device_id.strip() if isinstance(config.device_id, str) else ""
|
||||
if target_device_id:
|
||||
is_online = device_mgr.is_online(user_id, target_device_id)
|
||||
else:
|
||||
is_online = device_mgr.is_online(user_id)
|
||||
|
||||
if not is_online:
|
||||
logger.info(
|
||||
"agent_runner: skip run=%s — device %r offline for user=%s",
|
||||
run_id,
|
||||
config.device_id,
|
||||
target_device_id or "<any>",
|
||||
user_id,
|
||||
)
|
||||
await _finalize_run(
|
||||
run_log,
|
||||
status="error",
|
||||
errors=[f"Device {config.device_id!r} is not connected"],
|
||||
errors=[f"Device {target_device_id or '<any>'!r} is not connected"],
|
||||
)
|
||||
return
|
||||
|
||||
@@ -369,7 +375,7 @@ async def run_local_agent(
|
||||
items_processed=items_processed,
|
||||
items_created=items_created,
|
||||
errors=errors,
|
||||
update_config_last_run=True,
|
||||
update_config_last_run=False,
|
||||
config_id=config.id,
|
||||
config_type="local",
|
||||
)
|
||||
@@ -610,61 +616,12 @@ async def trigger_pending_runs(
|
||||
* Runs execute **sequentially** to avoid flooding the WS connection.
|
||||
"""
|
||||
logger.info(
|
||||
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id
|
||||
"agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
|
||||
user_id,
|
||||
device_id,
|
||||
)
|
||||
async with async_session() as db:
|
||||
local_result = await db.execute(
|
||||
select(LocalAgentConfig).where(
|
||||
LocalAgentConfig.user_id == user_id,
|
||||
LocalAgentConfig.enabled == True, # noqa: E712
|
||||
LocalAgentConfig.device_id == device_id,
|
||||
)
|
||||
)
|
||||
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
|
||||
|
||||
cloud_result = await db.execute(
|
||||
select(CloudAgentConfig).where(
|
||||
CloudAgentConfig.user_id == user_id,
|
||||
CloudAgentConfig.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
|
||||
|
||||
# Build ordered list of overdue (type, config) pairs.
|
||||
pending: list[tuple[str, Any]] = []
|
||||
for cfg in local_configs:
|
||||
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||
pending.append(("local", cfg))
|
||||
for cfg in cloud_configs:
|
||||
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||
pending.append(("cloud", cfg))
|
||||
|
||||
if not pending:
|
||||
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
|
||||
)
|
||||
|
||||
for agent_type, cfg in pending:
|
||||
# Create a fresh run log for this scheduled dispatch.
|
||||
run_log = AgentRunLog(
|
||||
agent_id=cfg.id,
|
||||
agent_type=agent_type,
|
||||
user_id=user_id,
|
||||
status="running",
|
||||
)
|
||||
async with async_session() as db:
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
|
||||
if agent_type == "local":
|
||||
await run_local_agent(user_id, cfg, run_log, device_mgr)
|
||||
else:
|
||||
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
|
||||
|
||||
|
||||
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@@ -306,81 +306,27 @@ class AgentCatalogItem(BaseModel):
|
||||
config_schema: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
# ── Local Agent Config ────────────────────────────────────────────────
|
||||
|
||||
class LocalAgentConfigCreate(BaseModel):
|
||||
name: str
|
||||
device_id: str
|
||||
directory_paths: list[str]
|
||||
data_types: list[str]
|
||||
prompt_template: str
|
||||
file_extensions: list[str]
|
||||
schedule_cron: str
|
||||
class AgentCreationCheckRequest(BaseModel):
|
||||
active_agents: int = Field(ge=0, default=0)
|
||||
|
||||
|
||||
class LocalAgentConfigUpdate(BaseModel):
|
||||
name: str | None = None
|
||||
device_id: str | None = None
|
||||
directory_paths: list[str] | None = None
|
||||
data_types: list[str] | None = None
|
||||
prompt_template: str | None = None
|
||||
file_extensions: list[str] | None = None
|
||||
schedule_cron: str | None = None
|
||||
enabled: bool | None = None
|
||||
class AgentCreationCheckResponse(BaseModel):
|
||||
allowed: bool
|
||||
tier: BillingTier
|
||||
active_agents: int
|
||||
limit: int
|
||||
|
||||
|
||||
class LocalAgentConfigResponse(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
device_id: str
|
||||
directory_paths: list[str]
|
||||
data_types: list[str]
|
||||
prompt_template: str
|
||||
file_extensions: list[str]
|
||||
schedule_cron: str
|
||||
enabled: bool
|
||||
last_run_at: int | None
|
||||
created_at: int
|
||||
updated_at: int
|
||||
|
||||
|
||||
# ── Cloud Agent Config ────────────────────────────────────────────────
|
||||
|
||||
class CloudAgentConfigCreate(BaseModel):
|
||||
provider: Literal["gmail", "teams", "outlook"]
|
||||
name: str
|
||||
data_types: list[str]
|
||||
prompt_template: str
|
||||
oauth_token_encrypted: str
|
||||
schedule_cron: str
|
||||
filter_config: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class CloudAgentConfigUpdate(BaseModel):
|
||||
provider: Literal["gmail", "teams", "outlook"] | None = None
|
||||
name: str | None = None
|
||||
data_types: list[str] | None = None
|
||||
prompt_template: str | None = None
|
||||
oauth_token_encrypted: str | None = None
|
||||
schedule_cron: str | None = None
|
||||
filter_config: dict[str, Any] | None = None
|
||||
enabled: bool | None = None
|
||||
|
||||
|
||||
class CloudAgentConfigResponse(BaseModel):
|
||||
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
|
||||
|
||||
id: str
|
||||
provider: Literal["gmail", "teams", "outlook"]
|
||||
name: str
|
||||
data_types: list[str]
|
||||
prompt_template: str
|
||||
schedule_cron: str
|
||||
filter_config: dict[str, Any] | None
|
||||
enabled: bool
|
||||
last_run_at: int | None
|
||||
created_at: int
|
||||
updated_at: int
|
||||
class AgentTriggerRequest(BaseModel):
|
||||
directory: str = Field(min_length=1)
|
||||
what_to_extract: list[Literal["task", "note", "timeline", "project"]] = Field(min_length=1)
|
||||
actions_by_type: dict[
|
||||
Literal["task", "note", "timeline", "project"],
|
||||
list[Literal["add", "update"]],
|
||||
] | None = None
|
||||
batch_interval: str = Field(min_length=1)
|
||||
custom_agent_prompt: str = Field(min_length=1)
|
||||
active_agents: int = Field(ge=0, default=0)
|
||||
|
||||
|
||||
# ── Agent Run Log ─────────────────────────────────────────────────────
|
||||
|
||||
@@ -10,13 +10,13 @@ Coverage:
|
||||
- run_local_agent — file-read timeout path
|
||||
- run_local_agent — LLM extraction error path
|
||||
- run_cloud_agent — stub returns error immediately
|
||||
- trigger_pending_runs — overdue local + cloud dispatched
|
||||
- trigger_pending_runs — skipped when config is client-owned
|
||||
- trigger_pending_runs — non-overdue skipped
|
||||
- trigger_pending_runs — device_id filter for local agents
|
||||
|
||||
Integration:
|
||||
- POST /agents/{id}/run — 404 on unknown agent
|
||||
- POST /agents/{id}/run — creates run log + dispatches background task
|
||||
- POST /agents/can-create — billing eligibility check
|
||||
- POST /agents/trigger — creates run log + dispatches background task
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path():
|
||||
assert kwargs["items_processed"] == 1
|
||||
assert kwargs["items_created"] == 1
|
||||
assert kwargs["errors"] == []
|
||||
assert kwargs["update_config_last_run"] is True
|
||||
assert kwargs["update_config_last_run"] is False
|
||||
|
||||
# Verify agent_run frame was sent.
|
||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||
@@ -690,31 +690,11 @@ async def test_finalize_run_updates_cloud_config_last_run_at():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_no_overdue():
|
||||
"""If no agents are overdue trigger_pending_runs does nothing."""
|
||||
from datetime import timedelta
|
||||
|
||||
config = _make_local_config()
|
||||
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
||||
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
||||
|
||||
mock_db_result_local = MagicMock()
|
||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||
|
||||
mock_db_result_cloud = MagicMock()
|
||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||
"""Pending-run scan is skipped because agent config is client-owned."""
|
||||
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.execute = AsyncMock(
|
||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||
)
|
||||
mock_session_factory.return_value = mock_ctx
|
||||
|
||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
@@ -722,31 +702,11 @@ async def test_trigger_pending_runs_no_overdue():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_device_id_filter():
|
||||
"""Local agents are only triggered for the matching device_id."""
|
||||
# The DB query already filters by device_id, so we verify the SELECT
|
||||
# includes the device_id filter by checking that a config bound to a
|
||||
# different device is never dispatched.
|
||||
#
|
||||
# Since trigger_pending_runs queries with device_id == "dev-001",
|
||||
# simulate the DB returning an empty list (as it would for a mismatch).
|
||||
mock_db_result_local = MagicMock()
|
||||
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
||||
|
||||
mock_db_result_cloud = MagicMock()
|
||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||
"""Device filtering is no longer backend-managed in pending runs."""
|
||||
|
||||
mgr = _make_manager(device_id="dev-001")
|
||||
|
||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.execute = AsyncMock(
|
||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||
)
|
||||
mock_session_factory.return_value = mock_ctx
|
||||
|
||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
@@ -754,56 +714,18 @@ async def test_trigger_pending_runs_device_id_filter():
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_dispatches_overdue():
|
||||
"""Overdue local agent triggers run_local_agent sequentially."""
|
||||
config = _make_local_config() # last_run_at=None → always overdue
|
||||
|
||||
mock_db_result_local = MagicMock()
|
||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||
|
||||
mock_db_result_cloud = MagicMock()
|
||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||
"""No pending runs are dispatched by backend after config deprecation."""
|
||||
|
||||
mgr = _make_manager()
|
||||
|
||||
call_order: list[str] = []
|
||||
|
||||
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
||||
call_order.append("run_local")
|
||||
|
||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
||||
# First call: query configs. Subsequent calls: create run_log.
|
||||
mock_query_ctx = AsyncMock()
|
||||
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
||||
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_query_ctx.execute = AsyncMock(
|
||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||
)
|
||||
|
||||
run_log_obj = AgentRunLog(
|
||||
id=str(uuid.uuid4()),
|
||||
agent_id=config.id,
|
||||
agent_type="local",
|
||||
user_id=_FREE_UID,
|
||||
status="running",
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
mock_insert_ctx = AsyncMock()
|
||||
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
||||
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_insert_ctx.add = MagicMock()
|
||||
mock_insert_ctx.commit = AsyncMock()
|
||||
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
||||
|
||||
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
||||
|
||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
assert call_order == ["run_local"]
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: POST /agents/{id}/run
|
||||
# Integration: POST /agents/can-create and /agents/trigger
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -820,50 +742,67 @@ def _override_db(db_session):
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_run_unknown_agent(client):
|
||||
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
||||
async def test_can_create_agent_allows_when_under_limit(client):
|
||||
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
||||
resp = client.post(
|
||||
f"/api/v1/agents/{uuid.uuid4()}/run",
|
||||
headers=auth_header("power"),
|
||||
"/api/v1/agents/can-create",
|
||||
json={"active_agents": 0},
|
||||
headers=auth_header("free"),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["allowed"] is True
|
||||
assert body["tier"] == "free"
|
||||
assert body["active_agents"] == 0
|
||||
assert body["limit"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_create_agent_denies_when_at_limit(client):
|
||||
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
||||
resp = client.post(
|
||||
"/api/v1/agents/can-create",
|
||||
json={"active_agents": 2},
|
||||
headers=auth_header("free"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["allowed"] is False
|
||||
assert body["limit"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
||||
# Create the local agent config in the DB.
|
||||
config = LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=TEST_USER_IDS["power"],
|
||||
device_id="dev-001",
|
||||
name="My Agent",
|
||||
directory_paths=["/home/user/docs"],
|
||||
data_types=["tasks"],
|
||||
prompt_template="Extract tasks.",
|
||||
file_extensions=[".txt"],
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
)
|
||||
db_session.add(config)
|
||||
await db_session.commit()
|
||||
|
||||
dispatched: list = []
|
||||
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
||||
dispatched: list[tuple[str, str]] = []
|
||||
|
||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||
dispatched.append((user_id, cfg.id))
|
||||
|
||||
def _fake_create_task(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
||||
patch("asyncio.create_task") as mock_create_task:
|
||||
mock_create_task.side_effect = _fake_create_task
|
||||
resp = client.post(
|
||||
f"/api/v1/agents/{config.id}/run",
|
||||
"/api/v1/agents/trigger",
|
||||
json={
|
||||
"directory": "/home/user/docs",
|
||||
"what_to_extract": ["task", "note"],
|
||||
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
|
||||
"batch_interval": "0 */6 * * *",
|
||||
"custom_agent_prompt": "Extract tasks and notes.",
|
||||
"active_agents": 0,
|
||||
},
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
assert data["agent_id"] == config.id
|
||||
assert isinstance(data["agent_id"], str)
|
||||
assert data["agent_id"]
|
||||
assert data["status"] == "running"
|
||||
assert data["agent_type"] == "local"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user