refactor agents to client-owned config flow

This commit is contained in:
2026-03-16 22:35:46 +01:00
parent 02a9684cd6
commit 5faa6b1d7c
6 changed files with 259 additions and 589 deletions

View File

@@ -0,0 +1,92 @@
"""Deprecate backend agent config tables.
The Electron client is now the source of truth for agent configuration
(directory, extract targets, batch interval, custom prompt). Backend keeps
billing checks and trigger/run logs only.
Revision ID: 9a1f2d0b6c7e
Revises: 818478c251dc
Create Date: 2026-03-16
"""
from __future__ import annotations
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
revision: str = "9a1f2d0b6c7e"
down_revision: Union[str, None] = "818478c251dc"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
bind = op.get_bind()
inspector = sa.inspect(bind)
existing = set(inspector.get_table_names())
if "cloud_agent_configs" in existing:
op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs")
op.drop_table("cloud_agent_configs")
if "local_agent_configs" in existing:
op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs")
op.drop_table("local_agent_configs")
def downgrade() -> None:
op.create_table(
"local_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("device_id", sa.String(255), nullable=False),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"])
op.execute(
"""
DO $$ BEGIN
CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook');
EXCEPTION WHEN duplicate_object THEN NULL;
END $$;
"""
)
op.create_table(
"cloud_agent_configs",
sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False),
sa.Column(
"provider",
postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False),
nullable=False,
),
sa.Column("name", sa.String(255), nullable=False),
sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"),
sa.Column("prompt_template", sa.Text, nullable=False, server_default=""),
sa.Column("oauth_token_encrypted", sa.Text, nullable=True),
sa.Column("filter_config", sa.JSON, nullable=True),
sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"),
sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()),
sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"),
)
op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"])

View File

@@ -16,9 +16,9 @@ Journey flow:
delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``. delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``.
5. Server parses the block, sets ``done=True``, and returns the template. 5. Server parses the block, sets ``done=True``, and returns the template.
The ``prompt_template`` from the final response is meant to be stored in The ``prompt_template`` from the final response is meant to be stored by
``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template`` the Electron client in local agent settings and later sent to
by the Electron client (via the agent CRUD endpoints). ``POST /agents/trigger`` when a run is executed.
""" """
from __future__ import annotations from __future__ import annotations

View File

@@ -1,45 +1,35 @@
"""Agent CRUD routes: local directory agents and cloud connector agents. """Agent routes.
Endpoints: Backend responsibilities are intentionally minimal:
GET /agents/catalog — hardcoded agent type catalog GET /agents/catalog — static catalog for UI display
GET /agents/local — list user's local agent configs POST /agents/can-create — billing eligibility check
POST /agents/local create local agent (tier-gated) POST /agents/triggertrigger a local agent run
PUT /agents/local/{agent_id} — partial update (ownership check)
DELETE /agents/local/{agent_id} — delete + cascade run logs Agent configuration is owned by the Electron app and is not persisted
GET /agents/cloud — list user's cloud agent configs in backend agent-config tables.
POST /agents/cloud — create cloud agent (tier-gated)
PUT /agents/cloud/{agent_id} — partial update (ownership check)
DELETE /agents/cloud/{agent_id} — delete + cascade run logs
GET /agents/runs — paginated run logs (agent_id, page, limit)
POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4)
""" """
from __future__ import annotations from __future__ import annotations
import asyncio import asyncio
import uuid
from datetime import datetime from datetime import datetime
from typing import Any
from fastapi import APIRouter, Depends, HTTPException, Query, status from fastapi import APIRouter, Depends, HTTPException, status
from pydantic import BaseModel
from sqlalchemy import func, or_, select
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES from app.billing.tier_manager import FEATURES
from app.core.agent_runner import run_cloud_agent, run_local_agent from app.core.agent_runner import run_local_agent
from app.core.device_manager import device_manager from app.core.device_manager import device_manager
from app.db import get_session from app.db import get_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import ( from app.schemas import (
AgentCatalogItem, AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse, AgentRunLogResponse,
CloudAgentConfigCreate, AgentTriggerRequest,
CloudAgentConfigResponse,
CloudAgentConfigUpdate,
LocalAgentConfigCreate,
LocalAgentConfigResponse,
LocalAgentConfigUpdate,
UserProfile, UserProfile,
) )
@@ -56,39 +46,14 @@ def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None return int(dt.timestamp() * 1000) if dt else None
# ── Model → schema converters ───────────────────────────────────────── def _to_data_types(values: list[str]) -> list[str]:
normalize = {
def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse: "task": "tasks",
return LocalAgentConfigResponse( "note": "notes",
id=a.id, "timeline": "timelines",
name=a.name, "project": "projects",
device_id=a.device_id, }
directory_paths=a.directory_paths, return [normalize[v] for v in values if v in normalize]
data_types=a.data_types,
prompt_template=a.prompt_template,
file_extensions=a.file_extensions,
schedule_cron=a.schedule_cron,
enabled=a.enabled,
last_run_at=_dt_ms_opt(a.last_run_at),
created_at=_dt_ms(a.created_at),
updated_at=_dt_ms(a.updated_at),
)
def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse:
return CloudAgentConfigResponse(
id=a.id,
provider=a.provider, # type: ignore[arg-type]
name=a.name,
data_types=a.data_types,
prompt_template=a.prompt_template,
schedule_cron=a.schedule_cron,
filter_config=a.filter_config,
enabled=a.enabled,
last_run_at=_dt_ms_opt(a.last_run_at),
created_at=_dt_ms(a.created_at),
updated_at=_dt_ms(a.updated_at),
)
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse: def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
@@ -105,77 +70,14 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
) )
# ── Ownership-checked lookups ───────────────────────────────────────── def _enforce_agent_limit(tier: str, current_count: int) -> int:
async def _get_local_agent_for_user(
agent_id: str, user_id: str, db: AsyncSession
) -> LocalAgentConfig:
result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == user_id,
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
return record
async def _get_cloud_agent_for_user(
agent_id: str, user_id: str, db: AsyncSession
) -> CloudAgentConfig:
result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == user_id,
)
)
record = result.scalar_one_or_none()
if record is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
return record
# ── Tier limit helper ─────────────────────────────────────────────────
async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int:
"""Return combined enabled local + cloud agent count for the user."""
local_count = (
await db.execute(
select(func.count(LocalAgentConfig.id)).where(
LocalAgentConfig.user_id == user_id,
LocalAgentConfig.enabled == True, # noqa: E712
)
)
).scalar_one()
cloud_count = (
await db.execute(
select(func.count(CloudAgentConfig.id)).where(
CloudAgentConfig.user_id == user_id,
CloudAgentConfig.enabled == True, # noqa: E712
)
)
).scalar_one()
return local_count + cloud_count
def _enforce_agent_limit(tier: str, current_count: int) -> None:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"] limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit: if limit != -1 and current_count >= limit:
raise HTTPException( raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN, status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.", detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
) )
return limit
# ── Local page schema (used by runs endpoint) ─────────────────────────
class _RunsPage(BaseModel):
total: int
page: int
limit: int
items: list[AgentRunLogResponse]
# ── Catalog ─────────────────────────────────────────────────────────── # ── Catalog ───────────────────────────────────────────────────────────
@@ -190,6 +92,24 @@ async def get_agent_catalog(
type="local_directory", type="local_directory",
name="Local Directory Monitor", name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI", description="Watches local directories, extracts data from files using AI",
config_schema={
"directory": {"type": "string", "required": True},
"what_to_extract": {
"type": "array",
"items": ["task", "note", "timeline", "project"],
"required": True,
},
"actions_by_type": {
"type": "object",
"example": {
"task": ["add", "update"],
"note": ["add", "update"],
},
"required": False,
},
"batch_interval": {"type": "string", "required": True},
"custom_agent_prompt": {"type": "string", "required": True},
},
), ),
AgentCatalogItem( AgentCatalogItem(
type="gmail", type="gmail",
@@ -209,229 +129,51 @@ async def get_agent_catalog(
] ]
# ── Local agent CRUD ────────────────────────────────────────────────── @router.post("/can-create", response_model=AgentCreationCheckResponse)
async def can_create_agent(
@router.get("/local", response_model=list[LocalAgentConfigResponse]) body: AgentCreationCheckRequest,
async def list_local_agents(
current_user: UserProfile = Depends(get_current_user), current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session), ) -> AgentCreationCheckResponse:
) -> list[LocalAgentConfigResponse]: """Check if the user can create one more agent based on billing tier.
"""List all local directory agent configs owned by the authenticated user."""
result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id)
)
return [_to_local_response(a) for a in result.scalars().all()]
Since configuration is client-owned, the Electron app sends its current
@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED) active agent count and the backend applies tier limits.
async def create_local_agent(
body: LocalAgentConfigCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> LocalAgentConfigResponse:
"""Create a new local directory agent config.
The combined count of enabled local and cloud agents for the user is
checked against the ``batch_active`` limit for their billing tier.
""" """
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db)) limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
agent = LocalAgentConfig( allowed = limit == -1 or body.active_agents < limit
user_id=current_user.id, return AgentCreationCheckResponse(
name=body.name, allowed=allowed,
device_id=body.device_id, tier=current_user.tier,
directory_paths=body.directory_paths, active_agents=body.active_agents,
data_types=body.data_types, limit=limit,
prompt_template=body.prompt_template,
file_extensions=body.file_extensions,
schedule_cron=body.schedule_cron,
) )
db.add(agent)
await db.commit()
await db.refresh(agent)
return _to_local_response(agent)
@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse) @router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def update_local_agent(
agent_id: str,
body: LocalAgentConfigUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> LocalAgentConfigResponse:
"""Partially update a local agent config. Only provided fields are changed."""
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
for field, value in body.model_dump(exclude_unset=True).items():
setattr(agent, field, value)
await db.commit()
await db.refresh(agent)
return _to_local_response(agent)
@router.delete("/local/{agent_id}", response_model=dict)
async def delete_local_agent(
agent_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a local agent config. Associated run logs are cascade-deleted."""
agent = await _get_local_agent_for_user(agent_id, current_user.id, db)
await db.delete(agent)
await db.commit()
return {"ok": True}
# ── Cloud agent CRUD ──────────────────────────────────────────────────
@router.get("/cloud", response_model=list[CloudAgentConfigResponse])
async def list_cloud_agents(
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> list[CloudAgentConfigResponse]:
"""List all cloud connector agent configs owned by the authenticated user."""
result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id)
)
return [_to_cloud_response(a) for a in result.scalars().all()]
@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED)
async def create_cloud_agent(
body: CloudAgentConfigCreate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> CloudAgentConfigResponse:
"""Create a new cloud connector agent config.
The combined count of enabled local and cloud agents for the user is
checked against the ``batch_active`` limit for their billing tier.
"""
_enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db))
agent = CloudAgentConfig(
user_id=current_user.id,
provider=body.provider,
name=body.name,
data_types=body.data_types,
prompt_template=body.prompt_template,
oauth_token_encrypted=body.oauth_token_encrypted,
schedule_cron=body.schedule_cron,
filter_config=body.filter_config,
)
db.add(agent)
await db.commit()
await db.refresh(agent)
return _to_cloud_response(agent)
@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse)
async def update_cloud_agent(
agent_id: str,
body: CloudAgentConfigUpdate,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> CloudAgentConfigResponse:
"""Partially update a cloud agent config. Only provided fields are changed."""
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
for field, value in body.model_dump(exclude_unset=True).items():
setattr(agent, field, value)
await db.commit()
await db.refresh(agent)
return _to_cloud_response(agent)
@router.delete("/cloud/{agent_id}", response_model=dict)
async def delete_cloud_agent(
agent_id: str,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> dict[str, bool]:
"""Delete a cloud agent config. Associated run logs are cascade-deleted."""
agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db)
await db.delete(agent)
await db.commit()
return {"ok": True}
# ── Run logs ──────────────────────────────────────────────────────────
@router.get("/runs", response_model=_RunsPage)
async def list_run_logs(
agent_id: str | None = Query(default=None),
page: int = Query(default=1, ge=1),
limit: int = Query(default=20, ge=1, le=100),
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> _RunsPage:
"""Return paginated run logs for the authenticated user.
Optionally filter by ``agent_id``. Results are ordered from newest to oldest.
"""
base_filter = [AgentRunLog.user_id == current_user.id]
if agent_id:
base_filter.append(AgentRunLog.agent_id == agent_id)
total = (
await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter))
).scalar_one()
result = await db.execute(
select(AgentRunLog)
.where(*base_filter)
.order_by(AgentRunLog.started_at.desc())
.offset((page - 1) * limit)
.limit(limit)
)
items = [_to_run_log_response(log) for log in result.scalars().all()]
return _RunsPage(total=total, page=page, limit=limit, items=items)
# ── Manual trigger stub ───────────────────────────────────────────────
@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run( async def trigger_agent_run(
agent_id: str, body: AgentTriggerRequest,
current_user: UserProfile = Depends(get_current_user), current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session), db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse: ) -> AgentRunLogResponse:
"""Manually trigger an agent run. """Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
Looks up the agent config (local or cloud) by ID with ownership check, config = LocalAgentConfig(
creates a run log entry with ``status="running"``, and returns it. id=str(uuid.uuid4()),
user_id=current_user.id,
Actual dispatch to the agent runner is wired in Step 3.4 once device_id="",
``DeviceConnectionManager`` and ``agent_runner`` are available. name="Local Directory Monitor",
""" directory_paths=[body.directory],
# Determine agent type by trying local first, then cloud. data_types=_to_data_types(body.what_to_extract),
# Keep the full config object so we can pass it to the agent runner. prompt_template=body.custom_agent_prompt,
local_config: LocalAgentConfig | None = None file_extensions=[],
cloud_config: CloudAgentConfig | None = None schedule_cron=body.batch_interval,
enabled=True,
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == current_user.id,
) )
)
local_config = local_result.scalar_one_or_none()
if local_config is not None:
agent_type = "local"
else:
cloud_result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.id == agent_id,
CloudAgentConfig.user_id == current_user.id,
)
)
cloud_config = cloud_result.scalar_one_or_none()
if cloud_config is not None:
agent_type = "cloud"
else:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
run_log = AgentRunLog( run_log = AgentRunLog(
agent_id=agent_id, agent_id=config.id,
agent_type=agent_type, agent_type="local",
user_id=current_user.id, user_id=current_user.id,
status="running", status="running",
) )
@@ -439,14 +181,8 @@ async def trigger_agent_run(
await db.commit() await db.commit()
await db.refresh(run_log) await db.refresh(run_log)
# Dispatch the run as a background task — returns 202 immediately.
if agent_type == "local" and local_config is not None:
asyncio.create_task( asyncio.create_task(
run_local_agent(current_user.id, local_config, run_log, device_manager) run_local_agent(current_user.id, config, run_log, device_manager)
)
elif agent_type == "cloud" and cloud_config is not None:
asyncio.create_task(
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
) )
return _to_run_log_response(run_log) return _to_run_log_response(run_log)

View File

@@ -238,17 +238,23 @@ async def run_local_agent(
run_id = run_log.id run_id = run_log.id
# ── 1. Device online check ───────────────────────────────────────── # ── 1. Device online check ─────────────────────────────────────────
if not device_mgr.is_online(user_id, config.device_id): target_device_id = config.device_id.strip() if isinstance(config.device_id, str) else ""
if target_device_id:
is_online = device_mgr.is_online(user_id, target_device_id)
else:
is_online = device_mgr.is_online(user_id)
if not is_online:
logger.info( logger.info(
"agent_runner: skip run=%s — device %r offline for user=%s", "agent_runner: skip run=%s — device %r offline for user=%s",
run_id, run_id,
config.device_id, target_device_id or "<any>",
user_id, user_id,
) )
await _finalize_run( await _finalize_run(
run_log, run_log,
status="error", status="error",
errors=[f"Device {config.device_id!r} is not connected"], errors=[f"Device {target_device_id or '<any>'!r} is not connected"],
) )
return return
@@ -369,7 +375,7 @@ async def run_local_agent(
items_processed=items_processed, items_processed=items_processed,
items_created=items_created, items_created=items_created,
errors=errors, errors=errors,
update_config_last_run=True, update_config_last_run=False,
config_id=config.id, config_id=config.id,
config_type="local", config_type="local",
) )
@@ -610,61 +616,12 @@ async def trigger_pending_runs(
* Runs execute **sequentially** to avoid flooding the WS connection. * Runs execute **sequentially** to avoid flooding the WS connection.
""" """
logger.info( logger.info(
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id "agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
user_id,
device_id,
) )
async with async_session() as db:
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.user_id == user_id,
LocalAgentConfig.enabled == True, # noqa: E712
LocalAgentConfig.device_id == device_id,
)
)
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
cloud_result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.user_id == user_id,
CloudAgentConfig.enabled == True, # noqa: E712
)
)
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
# Build ordered list of overdue (type, config) pairs.
pending: list[tuple[str, Any]] = []
for cfg in local_configs:
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
pending.append(("local", cfg))
for cfg in cloud_configs:
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
pending.append(("cloud", cfg))
if not pending:
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
return return
logger.info(
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
)
for agent_type, cfg in pending:
# Create a fresh run log for this scheduled dispatch.
run_log = AgentRunLog(
agent_id=cfg.id,
agent_type=agent_type,
user_id=user_id,
status="running",
)
async with async_session() as db:
db.add(run_log)
await db.commit()
await db.refresh(run_log)
if agent_type == "local":
await run_local_agent(user_id, cfg, run_log, device_mgr)
else:
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
# ── Internal helper ───────────────────────────────────────────────────────── # ── Internal helper ─────────────────────────────────────────────────────────

View File

@@ -306,81 +306,27 @@ class AgentCatalogItem(BaseModel):
config_schema: dict[str, Any] = Field(default_factory=dict) config_schema: dict[str, Any] = Field(default_factory=dict)
# ── Local Agent Config ──────────────────────────────────────────────── class AgentCreationCheckRequest(BaseModel):
active_agents: int = Field(ge=0, default=0)
class LocalAgentConfigCreate(BaseModel):
name: str
device_id: str
directory_paths: list[str]
data_types: list[str]
prompt_template: str
file_extensions: list[str]
schedule_cron: str
class LocalAgentConfigUpdate(BaseModel): class AgentCreationCheckResponse(BaseModel):
name: str | None = None allowed: bool
device_id: str | None = None tier: BillingTier
directory_paths: list[str] | None = None active_agents: int
data_types: list[str] | None = None limit: int
prompt_template: str | None = None
file_extensions: list[str] | None = None
schedule_cron: str | None = None
enabled: bool | None = None
class LocalAgentConfigResponse(BaseModel): class AgentTriggerRequest(BaseModel):
id: str directory: str = Field(min_length=1)
name: str what_to_extract: list[Literal["task", "note", "timeline", "project"]] = Field(min_length=1)
device_id: str actions_by_type: dict[
directory_paths: list[str] Literal["task", "note", "timeline", "project"],
data_types: list[str] list[Literal["add", "update"]],
prompt_template: str ] | None = None
file_extensions: list[str] batch_interval: str = Field(min_length=1)
schedule_cron: str custom_agent_prompt: str = Field(min_length=1)
enabled: bool active_agents: int = Field(ge=0, default=0)
last_run_at: int | None
created_at: int
updated_at: int
# ── Cloud Agent Config ────────────────────────────────────────────────
class CloudAgentConfigCreate(BaseModel):
provider: Literal["gmail", "teams", "outlook"]
name: str
data_types: list[str]
prompt_template: str
oauth_token_encrypted: str
schedule_cron: str
filter_config: dict[str, Any] | None = None
class CloudAgentConfigUpdate(BaseModel):
provider: Literal["gmail", "teams", "outlook"] | None = None
name: str | None = None
data_types: list[str] | None = None
prompt_template: str | None = None
oauth_token_encrypted: str | None = None
schedule_cron: str | None = None
filter_config: dict[str, Any] | None = None
enabled: bool | None = None
class CloudAgentConfigResponse(BaseModel):
"""oauth_token_encrypted is intentionally excluded — never returned to clients."""
id: str
provider: Literal["gmail", "teams", "outlook"]
name: str
data_types: list[str]
prompt_template: str
schedule_cron: str
filter_config: dict[str, Any] | None
enabled: bool
last_run_at: int | None
created_at: int
updated_at: int
# ── Agent Run Log ───────────────────────────────────────────────────── # ── Agent Run Log ─────────────────────────────────────────────────────

View File

@@ -10,13 +10,13 @@ Coverage:
- run_local_agent — file-read timeout path - run_local_agent — file-read timeout path
- run_local_agent — LLM extraction error path - run_local_agent — LLM extraction error path
- run_cloud_agent — stub returns error immediately - run_cloud_agent — stub returns error immediately
- trigger_pending_runs — overdue local + cloud dispatched - trigger_pending_runs — skipped when config is client-owned
- trigger_pending_runs — non-overdue skipped - trigger_pending_runs — non-overdue skipped
- trigger_pending_runs — device_id filter for local agents - trigger_pending_runs — device_id filter for local agents
Integration: Integration:
- POST /agents/{id}/run — 404 on unknown agent - POST /agents/can-create — billing eligibility check
- POST /agents/{id}/run — creates run log + dispatches background task - POST /agents/trigger — creates run log + dispatches background task
""" """
from __future__ import annotations from __future__ import annotations
@@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path():
assert kwargs["items_processed"] == 1 assert kwargs["items_processed"] == 1
assert kwargs["items_created"] == 1 assert kwargs["items_created"] == 1
assert kwargs["errors"] == [] assert kwargs["errors"] == []
assert kwargs["update_config_last_run"] is True assert kwargs["update_config_last_run"] is False
# Verify agent_run frame was sent. # Verify agent_run frame was sent.
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"] agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
@@ -690,31 +690,11 @@ async def test_finalize_run_updates_cloud_config_last_run_at():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_pending_runs_no_overdue(): async def test_trigger_pending_runs_no_overdue():
"""If no agents are overdue trigger_pending_runs does nothing.""" """Pending-run scan is skipped because agent config is client-owned."""
from datetime import timedelta
config = _make_local_config()
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
mock_db_result_local = MagicMock()
mock_db_result_local.scalars.return_value.all.return_value = [config]
mock_db_result_cloud = MagicMock()
mock_db_result_cloud.scalars.return_value.all.return_value = []
mgr = _make_manager() mgr = _make_manager()
with patch("app.core.agent_runner.async_session") as mock_session_factory, \ with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.execute = AsyncMock(
side_effect=[mock_db_result_local, mock_db_result_cloud]
)
mock_session_factory.return_value = mock_ctx
await trigger_pending_runs(_FREE_UID, "dev-001", mgr) await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called() mock_run.assert_not_called()
@@ -722,31 +702,11 @@ async def test_trigger_pending_runs_no_overdue():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_pending_runs_device_id_filter(): async def test_trigger_pending_runs_device_id_filter():
"""Local agents are only triggered for the matching device_id.""" """Device filtering is no longer backend-managed in pending runs."""
# The DB query already filters by device_id, so we verify the SELECT
# includes the device_id filter by checking that a config bound to a
# different device is never dispatched.
#
# Since trigger_pending_runs queries with device_id == "dev-001",
# simulate the DB returning an empty list (as it would for a mismatch).
mock_db_result_local = MagicMock()
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
mock_db_result_cloud = MagicMock()
mock_db_result_cloud.scalars.return_value.all.return_value = []
mgr = _make_manager(device_id="dev-001") mgr = _make_manager(device_id="dev-001")
with patch("app.core.agent_runner.async_session") as mock_session_factory, \ with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
mock_ctx = AsyncMock()
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
mock_ctx.__aexit__ = AsyncMock(return_value=False)
mock_ctx.execute = AsyncMock(
side_effect=[mock_db_result_local, mock_db_result_cloud]
)
mock_session_factory.return_value = mock_ctx
await trigger_pending_runs(_FREE_UID, "dev-001", mgr) await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
mock_run.assert_not_called() mock_run.assert_not_called()
@@ -754,56 +714,18 @@ async def test_trigger_pending_runs_device_id_filter():
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_pending_runs_dispatches_overdue(): async def test_trigger_pending_runs_dispatches_overdue():
"""Overdue local agent triggers run_local_agent sequentially.""" """No pending runs are dispatched by backend after config deprecation."""
config = _make_local_config() # last_run_at=None → always overdue
mock_db_result_local = MagicMock()
mock_db_result_local.scalars.return_value.all.return_value = [config]
mock_db_result_cloud = MagicMock()
mock_db_result_cloud.scalars.return_value.all.return_value = []
mgr = _make_manager() mgr = _make_manager()
call_order: list[str] = [] with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
call_order.append("run_local")
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
# First call: query configs. Subsequent calls: create run_log.
mock_query_ctx = AsyncMock()
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
mock_query_ctx.execute = AsyncMock(
side_effect=[mock_db_result_local, mock_db_result_cloud]
)
run_log_obj = AgentRunLog(
id=str(uuid.uuid4()),
agent_id=config.id,
agent_type="local",
user_id=_FREE_UID,
status="running",
started_at=datetime.now(timezone.utc),
)
mock_insert_ctx = AsyncMock()
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
mock_insert_ctx.add = MagicMock()
mock_insert_ctx.commit = AsyncMock()
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
await trigger_pending_runs(_FREE_UID, "dev-001", mgr) await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
assert call_order == ["run_local"] mock_run.assert_not_called()
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# Integration: POST /agents/{id}/run # Integration: POST /agents/can-create and /agents/trigger
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -820,50 +742,67 @@ def _override_db(db_session):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_run_unknown_agent(client): async def test_can_create_agent_allows_when_under_limit(client):
"""POST /agents/{id}/run returns 404 for unknown agent id.""" """POST /agents/can-create returns allowed=True when under tier limit."""
resp = client.post( resp = client.post(
f"/api/v1/agents/{uuid.uuid4()}/run", "/api/v1/agents/can-create",
headers=auth_header("power"), json={"active_agents": 0},
headers=auth_header("free"),
) )
assert resp.status_code == 404 assert resp.status_code == 200
body = resp.json()
assert body["allowed"] is True
assert body["tier"] == "free"
assert body["active_agents"] == 0
assert body["limit"] == 2
@pytest.mark.asyncio
async def test_can_create_agent_denies_when_at_limit(client):
"""POST /agents/can-create returns allowed=False at free-tier limit."""
resp = client.post(
"/api/v1/agents/can-create",
json={"active_agents": 2},
headers=auth_header("free"),
)
assert resp.status_code == 200
body = resp.json()
assert body["allowed"] is False
assert body["limit"] == 2
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_trigger_run_local_agent_creates_run_log(client, db_session): async def test_trigger_run_local_agent_creates_run_log(client, db_session):
"""POST /agents/{id}/run creates a run log and dispatches a background task.""" """POST /agents/trigger creates a local run log and dispatches background task."""
# Create the local agent config in the DB. dispatched: list[tuple[str, str]] = []
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=TEST_USER_IDS["power"],
device_id="dev-001",
name="My Agent",
directory_paths=["/home/user/docs"],
data_types=["tasks"],
prompt_template="Extract tasks.",
file_extensions=[".txt"],
schedule_cron="0 */6 * * *",
enabled=True,
)
db_session.add(config)
await db_session.commit()
dispatched: list = []
async def _fake_run(user_id, cfg, run_log, device_mgr): async def _fake_run(user_id, cfg, run_log, device_mgr):
dispatched.append((user_id, cfg.id)) dispatched.append((user_id, cfg.id))
def _fake_create_task(coro):
coro.close()
return MagicMock()
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \ with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
patch("asyncio.create_task") as mock_create_task: patch("asyncio.create_task") as mock_create_task:
mock_create_task.side_effect = _fake_create_task
resp = client.post( resp = client.post(
f"/api/v1/agents/{config.id}/run", "/api/v1/agents/trigger",
json={
"directory": "/home/user/docs",
"what_to_extract": ["task", "note"],
"actions_by_type": {"task": ["add", "update"], "note": ["add"]},
"batch_interval": "0 */6 * * *",
"custom_agent_prompt": "Extract tasks and notes.",
"active_agents": 0,
},
headers=auth_header("power"), headers=auth_header("power"),
) )
assert resp.status_code == 202 assert resp.status_code == 202
data = resp.json() data = resp.json()
assert data["agent_id"] == config.id assert isinstance(data["agent_id"], str)
assert data["agent_id"]
assert data["status"] == "running" assert data["status"] == "running"
assert data["agent_type"] == "local" assert data["agent_type"] == "local"