diff --git a/alembic/versions/9a1f2d0b6c7e_deprecate_backend_agent_config_tables.py b/alembic/versions/9a1f2d0b6c7e_deprecate_backend_agent_config_tables.py new file mode 100644 index 0000000..549c11c --- /dev/null +++ b/alembic/versions/9a1f2d0b6c7e_deprecate_backend_agent_config_tables.py @@ -0,0 +1,92 @@ +"""Deprecate backend agent config tables. + +The Electron client is now the source of truth for agent configuration +(directory, extract targets, batch interval, custom prompt). Backend keeps +billing checks and trigger/run logs only. + +Revision ID: 9a1f2d0b6c7e +Revises: 818478c251dc +Create Date: 2026-03-16 +""" + +from __future__ import annotations + +from typing import Sequence, Union + +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects import postgresql + +revision: str = "9a1f2d0b6c7e" +down_revision: Union[str, None] = "818478c251dc" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + bind = op.get_bind() + inspector = sa.inspect(bind) + existing = set(inspector.get_table_names()) + + if "cloud_agent_configs" in existing: + op.drop_index("ix_cloud_agent_configs_user_id", table_name="cloud_agent_configs") + op.drop_table("cloud_agent_configs") + + if "local_agent_configs" in existing: + op.drop_index("ix_local_agent_configs_user_id", table_name="local_agent_configs") + op.drop_table("local_agent_configs") + + +def downgrade() -> None: + op.create_table( + "local_agent_configs", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("device_id", sa.String(255), nullable=False), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("directory_paths", sa.JSON, nullable=False, server_default="[]"), + sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"), + sa.Column("prompt_template", sa.Text, nullable=False, server_default=""), + sa.Column("file_extensions", sa.JSON, nullable=False, server_default="[]"), + sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_local_agent_configs_user_id", "local_agent_configs", ["user_id"]) + + op.execute( + """ + DO $$ BEGIN + CREATE TYPE cloud_provider AS ENUM ('gmail', 'teams', 'outlook'); + EXCEPTION WHEN duplicate_object THEN NULL; + END $$; + """ + ) + + op.create_table( + "cloud_agent_configs", + sa.Column("id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column("user_id", postgresql.UUID(as_uuid=False), nullable=False), + sa.Column( + "provider", + postgresql.ENUM("gmail", "teams", "outlook", name="cloud_provider", create_type=False), + nullable=False, + ), + sa.Column("name", sa.String(255), nullable=False), + sa.Column("data_types", sa.JSON, nullable=False, server_default="[]"), + sa.Column("prompt_template", sa.Text, nullable=False, server_default=""), + sa.Column("oauth_token_encrypted", sa.Text, nullable=True), + sa.Column("filter_config", sa.JSON, nullable=True), + sa.Column("schedule_cron", sa.String(100), nullable=False, server_default="0 */6 * * *"), + sa.Column("enabled", sa.Boolean, nullable=False, server_default=sa.true()), + sa.Column("last_run_at", sa.DateTime(timezone=True), nullable=True), + sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()")), + sa.PrimaryKeyConstraint("id"), + sa.ForeignKeyConstraint(["user_id"], ["users.id"], ondelete="CASCADE"), + ) + op.create_index("ix_cloud_agent_configs_user_id", "cloud_agent_configs", ["user_id"]) diff --git a/app/api/routes/agent_setup.py b/app/api/routes/agent_setup.py index e78bf75..ce71b72 100644 --- a/app/api/routes/agent_setup.py +++ b/app/api/routes/agent_setup.py @@ -16,9 +16,9 @@ Journey flow: delimited by ``PROMPT_TEMPLATE_START`` / ``PROMPT_TEMPLATE_END``. 5. Server parses the block, sets ``done=True``, and returns the template. -The ``prompt_template`` from the final response is meant to be stored in -``LocalAgentConfig.prompt_template`` or ``CloudAgentConfig.prompt_template`` -by the Electron client (via the agent CRUD endpoints). +The ``prompt_template`` from the final response is meant to be stored by +the Electron client in local agent settings and later sent to +``POST /agents/trigger`` when a run is executed. """ from __future__ import annotations diff --git a/app/api/routes/agents.py b/app/api/routes/agents.py index 6a17670..5e8fa47 100644 --- a/app/api/routes/agents.py +++ b/app/api/routes/agents.py @@ -1,45 +1,35 @@ -"""Agent CRUD routes: local directory agents and cloud connector agents. +"""Agent routes. -Endpoints: - GET /agents/catalog — hardcoded agent type catalog - GET /agents/local — list user's local agent configs - POST /agents/local — create local agent (tier-gated) - PUT /agents/local/{agent_id} — partial update (ownership check) - DELETE /agents/local/{agent_id} — delete + cascade run logs - GET /agents/cloud — list user's cloud agent configs - POST /agents/cloud — create cloud agent (tier-gated) - PUT /agents/cloud/{agent_id} — partial update (ownership check) - DELETE /agents/cloud/{agent_id} — delete + cascade run logs - GET /agents/runs — paginated run logs (agent_id, page, limit) - POST /agents/{agent_id}/run — manual trigger stub (dispatch in Step 3.4) +Backend responsibilities are intentionally minimal: + GET /agents/catalog — static catalog for UI display + POST /agents/can-create — billing eligibility check + POST /agents/trigger — trigger a local agent run + +Agent configuration is owned by the Electron app and is not persisted +in backend agent-config tables. """ from __future__ import annotations import asyncio +import uuid from datetime import datetime -from typing import Any -from fastapi import APIRouter, Depends, HTTPException, Query, status -from pydantic import BaseModel -from sqlalchemy import func, or_, select +from fastapi import APIRouter, Depends, HTTPException, status from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.billing.tier_manager import FEATURES -from app.core.agent_runner import run_cloud_agent, run_local_agent +from app.core.agent_runner import run_local_agent from app.core.device_manager import device_manager from app.db import get_session -from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig +from app.models import AgentRunLog, LocalAgentConfig from app.schemas import ( AgentCatalogItem, + AgentCreationCheckRequest, + AgentCreationCheckResponse, AgentRunLogResponse, - CloudAgentConfigCreate, - CloudAgentConfigResponse, - CloudAgentConfigUpdate, - LocalAgentConfigCreate, - LocalAgentConfigResponse, - LocalAgentConfigUpdate, + AgentTriggerRequest, UserProfile, ) @@ -56,39 +46,14 @@ def _dt_ms_opt(dt: datetime | None) -> int | None: return int(dt.timestamp() * 1000) if dt else None -# ── Model → schema converters ───────────────────────────────────────── - -def _to_local_response(a: LocalAgentConfig) -> LocalAgentConfigResponse: - return LocalAgentConfigResponse( - id=a.id, - name=a.name, - device_id=a.device_id, - directory_paths=a.directory_paths, - data_types=a.data_types, - prompt_template=a.prompt_template, - file_extensions=a.file_extensions, - schedule_cron=a.schedule_cron, - enabled=a.enabled, - last_run_at=_dt_ms_opt(a.last_run_at), - created_at=_dt_ms(a.created_at), - updated_at=_dt_ms(a.updated_at), - ) - - -def _to_cloud_response(a: CloudAgentConfig) -> CloudAgentConfigResponse: - return CloudAgentConfigResponse( - id=a.id, - provider=a.provider, # type: ignore[arg-type] - name=a.name, - data_types=a.data_types, - prompt_template=a.prompt_template, - schedule_cron=a.schedule_cron, - filter_config=a.filter_config, - enabled=a.enabled, - last_run_at=_dt_ms_opt(a.last_run_at), - created_at=_dt_ms(a.created_at), - updated_at=_dt_ms(a.updated_at), - ) +def _to_data_types(values: list[str]) -> list[str]: + normalize = { + "task": "tasks", + "note": "notes", + "timeline": "timelines", + "project": "projects", + } + return [normalize[v] for v in values if v in normalize] def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse: @@ -105,77 +70,14 @@ def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse: ) -# ── Ownership-checked lookups ───────────────────────────────────────── - -async def _get_local_agent_for_user( - agent_id: str, user_id: str, db: AsyncSession -) -> LocalAgentConfig: - result = await db.execute( - select(LocalAgentConfig).where( - LocalAgentConfig.id == agent_id, - LocalAgentConfig.user_id == user_id, - ) - ) - record = result.scalar_one_or_none() - if record is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found") - return record - - -async def _get_cloud_agent_for_user( - agent_id: str, user_id: str, db: AsyncSession -) -> CloudAgentConfig: - result = await db.execute( - select(CloudAgentConfig).where( - CloudAgentConfig.id == agent_id, - CloudAgentConfig.user_id == user_id, - ) - ) - record = result.scalar_one_or_none() - if record is None: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found") - return record - - -# ── Tier limit helper ───────────────────────────────────────────────── - -async def _count_enabled_agents(user_id: str, db: AsyncSession) -> int: - """Return combined enabled local + cloud agent count for the user.""" - local_count = ( - await db.execute( - select(func.count(LocalAgentConfig.id)).where( - LocalAgentConfig.user_id == user_id, - LocalAgentConfig.enabled == True, # noqa: E712 - ) - ) - ).scalar_one() - cloud_count = ( - await db.execute( - select(func.count(CloudAgentConfig.id)).where( - CloudAgentConfig.user_id == user_id, - CloudAgentConfig.enabled == True, # noqa: E712 - ) - ) - ).scalar_one() - return local_count + cloud_count - - -def _enforce_agent_limit(tier: str, current_count: int) -> None: +def _enforce_agent_limit(tier: str, current_count: int) -> int: limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"] if limit != -1 and current_count >= limit: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.", ) - - -# ── Local page schema (used by runs endpoint) ───────────────────────── - -class _RunsPage(BaseModel): - total: int - page: int - limit: int - items: list[AgentRunLogResponse] + return limit # ── Catalog ─────────────────────────────────────────────────────────── @@ -190,6 +92,24 @@ async def get_agent_catalog( type="local_directory", name="Local Directory Monitor", description="Watches local directories, extracts data from files using AI", + config_schema={ + "directory": {"type": "string", "required": True}, + "what_to_extract": { + "type": "array", + "items": ["task", "note", "timeline", "project"], + "required": True, + }, + "actions_by_type": { + "type": "object", + "example": { + "task": ["add", "update"], + "note": ["add", "update"], + }, + "required": False, + }, + "batch_interval": {"type": "string", "required": True}, + "custom_agent_prompt": {"type": "string", "required": True}, + }, ), AgentCatalogItem( type="gmail", @@ -209,229 +129,51 @@ async def get_agent_catalog( ] -# ── Local agent CRUD ────────────────────────────────────────────────── - -@router.get("/local", response_model=list[LocalAgentConfigResponse]) -async def list_local_agents( +@router.post("/can-create", response_model=AgentCreationCheckResponse) +async def can_create_agent( + body: AgentCreationCheckRequest, current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> list[LocalAgentConfigResponse]: - """List all local directory agent configs owned by the authenticated user.""" - result = await db.execute( - select(LocalAgentConfig).where(LocalAgentConfig.user_id == current_user.id) - ) - return [_to_local_response(a) for a in result.scalars().all()] +) -> AgentCreationCheckResponse: + """Check if the user can create one more agent based on billing tier. - -@router.post("/local", response_model=LocalAgentConfigResponse, status_code=status.HTTP_201_CREATED) -async def create_local_agent( - body: LocalAgentConfigCreate, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> LocalAgentConfigResponse: - """Create a new local directory agent config. - - The combined count of enabled local and cloud agents for the user is - checked against the ``batch_active`` limit for their billing tier. + Since configuration is client-owned, the Electron app sends its current + active agent count and the backend applies tier limits. """ - _enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db)) - agent = LocalAgentConfig( - user_id=current_user.id, - name=body.name, - device_id=body.device_id, - directory_paths=body.directory_paths, - data_types=body.data_types, - prompt_template=body.prompt_template, - file_extensions=body.file_extensions, - schedule_cron=body.schedule_cron, + limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"] + allowed = limit == -1 or body.active_agents < limit + return AgentCreationCheckResponse( + allowed=allowed, + tier=current_user.tier, + active_agents=body.active_agents, + limit=limit, ) - db.add(agent) - await db.commit() - await db.refresh(agent) - return _to_local_response(agent) -@router.put("/local/{agent_id}", response_model=LocalAgentConfigResponse) -async def update_local_agent( - agent_id: str, - body: LocalAgentConfigUpdate, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> LocalAgentConfigResponse: - """Partially update a local agent config. Only provided fields are changed.""" - agent = await _get_local_agent_for_user(agent_id, current_user.id, db) - for field, value in body.model_dump(exclude_unset=True).items(): - setattr(agent, field, value) - await db.commit() - await db.refresh(agent) - return _to_local_response(agent) - - -@router.delete("/local/{agent_id}", response_model=dict) -async def delete_local_agent( - agent_id: str, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, bool]: - """Delete a local agent config. Associated run logs are cascade-deleted.""" - agent = await _get_local_agent_for_user(agent_id, current_user.id, db) - await db.delete(agent) - await db.commit() - return {"ok": True} - - -# ── Cloud agent CRUD ────────────────────────────────────────────────── - -@router.get("/cloud", response_model=list[CloudAgentConfigResponse]) -async def list_cloud_agents( - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> list[CloudAgentConfigResponse]: - """List all cloud connector agent configs owned by the authenticated user.""" - result = await db.execute( - select(CloudAgentConfig).where(CloudAgentConfig.user_id == current_user.id) - ) - return [_to_cloud_response(a) for a in result.scalars().all()] - - -@router.post("/cloud", response_model=CloudAgentConfigResponse, status_code=status.HTTP_201_CREATED) -async def create_cloud_agent( - body: CloudAgentConfigCreate, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> CloudAgentConfigResponse: - """Create a new cloud connector agent config. - - The combined count of enabled local and cloud agents for the user is - checked against the ``batch_active`` limit for their billing tier. - """ - _enforce_agent_limit(current_user.tier, await _count_enabled_agents(current_user.id, db)) - agent = CloudAgentConfig( - user_id=current_user.id, - provider=body.provider, - name=body.name, - data_types=body.data_types, - prompt_template=body.prompt_template, - oauth_token_encrypted=body.oauth_token_encrypted, - schedule_cron=body.schedule_cron, - filter_config=body.filter_config, - ) - db.add(agent) - await db.commit() - await db.refresh(agent) - return _to_cloud_response(agent) - - -@router.put("/cloud/{agent_id}", response_model=CloudAgentConfigResponse) -async def update_cloud_agent( - agent_id: str, - body: CloudAgentConfigUpdate, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> CloudAgentConfigResponse: - """Partially update a cloud agent config. Only provided fields are changed.""" - agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db) - for field, value in body.model_dump(exclude_unset=True).items(): - setattr(agent, field, value) - await db.commit() - await db.refresh(agent) - return _to_cloud_response(agent) - - -@router.delete("/cloud/{agent_id}", response_model=dict) -async def delete_cloud_agent( - agent_id: str, - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> dict[str, bool]: - """Delete a cloud agent config. Associated run logs are cascade-deleted.""" - agent = await _get_cloud_agent_for_user(agent_id, current_user.id, db) - await db.delete(agent) - await db.commit() - return {"ok": True} - - -# ── Run logs ────────────────────────────────────────────────────────── - -@router.get("/runs", response_model=_RunsPage) -async def list_run_logs( - agent_id: str | None = Query(default=None), - page: int = Query(default=1, ge=1), - limit: int = Query(default=20, ge=1, le=100), - current_user: UserProfile = Depends(get_current_user), - db: AsyncSession = Depends(get_session), -) -> _RunsPage: - """Return paginated run logs for the authenticated user. - - Optionally filter by ``agent_id``. Results are ordered from newest to oldest. - """ - base_filter = [AgentRunLog.user_id == current_user.id] - if agent_id: - base_filter.append(AgentRunLog.agent_id == agent_id) - - total = ( - await db.execute(select(func.count(AgentRunLog.id)).where(*base_filter)) - ).scalar_one() - - result = await db.execute( - select(AgentRunLog) - .where(*base_filter) - .order_by(AgentRunLog.started_at.desc()) - .offset((page - 1) * limit) - .limit(limit) - ) - items = [_to_run_log_response(log) for log in result.scalars().all()] - - return _RunsPage(total=total, page=page, limit=limit, items=items) - - -# ── Manual trigger stub ─────────────────────────────────────────────── - -@router.post("/{agent_id}/run", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED) +@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED) async def trigger_agent_run( - agent_id: str, + body: AgentTriggerRequest, current_user: UserProfile = Depends(get_current_user), db: AsyncSession = Depends(get_session), ) -> AgentRunLogResponse: - """Manually trigger an agent run. + """Trigger a local agent run using client-provided configuration.""" + _enforce_agent_limit(current_user.tier, body.active_agents) - Looks up the agent config (local or cloud) by ID with ownership check, - creates a run log entry with ``status="running"``, and returns it. - - Actual dispatch to the agent runner is wired in Step 3.4 once - ``DeviceConnectionManager`` and ``agent_runner`` are available. - """ - # Determine agent type by trying local first, then cloud. - # Keep the full config object so we can pass it to the agent runner. - local_config: LocalAgentConfig | None = None - cloud_config: CloudAgentConfig | None = None - - local_result = await db.execute( - select(LocalAgentConfig).where( - LocalAgentConfig.id == agent_id, - LocalAgentConfig.user_id == current_user.id, - ) + config = LocalAgentConfig( + id=str(uuid.uuid4()), + user_id=current_user.id, + device_id="", + name="Local Directory Monitor", + directory_paths=[body.directory], + data_types=_to_data_types(body.what_to_extract), + prompt_template=body.custom_agent_prompt, + file_extensions=[], + schedule_cron=body.batch_interval, + enabled=True, ) - local_config = local_result.scalar_one_or_none() - - if local_config is not None: - agent_type = "local" - else: - cloud_result = await db.execute( - select(CloudAgentConfig).where( - CloudAgentConfig.id == agent_id, - CloudAgentConfig.user_id == current_user.id, - ) - ) - cloud_config = cloud_result.scalar_one_or_none() - if cloud_config is not None: - agent_type = "cloud" - else: - raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found") run_log = AgentRunLog( - agent_id=agent_id, - agent_type=agent_type, + agent_id=config.id, + agent_type="local", user_id=current_user.id, status="running", ) @@ -439,14 +181,8 @@ async def trigger_agent_run( await db.commit() await db.refresh(run_log) - # Dispatch the run as a background task — returns 202 immediately. - if agent_type == "local" and local_config is not None: - asyncio.create_task( - run_local_agent(current_user.id, local_config, run_log, device_manager) - ) - elif agent_type == "cloud" and cloud_config is not None: - asyncio.create_task( - run_cloud_agent(current_user.id, cloud_config, run_log, device_manager) - ) + asyncio.create_task( + run_local_agent(current_user.id, config, run_log, device_manager) + ) return _to_run_log_response(run_log) diff --git a/app/core/agent_runner.py b/app/core/agent_runner.py index 0d25f65..51d8745 100644 --- a/app/core/agent_runner.py +++ b/app/core/agent_runner.py @@ -238,17 +238,23 @@ async def run_local_agent( run_id = run_log.id # ── 1. Device online check ───────────────────────────────────────── - if not device_mgr.is_online(user_id, config.device_id): + target_device_id = config.device_id.strip() if isinstance(config.device_id, str) else "" + if target_device_id: + is_online = device_mgr.is_online(user_id, target_device_id) + else: + is_online = device_mgr.is_online(user_id) + + if not is_online: logger.info( "agent_runner: skip run=%s — device %r offline for user=%s", run_id, - config.device_id, + target_device_id or "", user_id, ) await _finalize_run( run_log, status="error", - errors=[f"Device {config.device_id!r} is not connected"], + errors=[f"Device {target_device_id or ''!r} is not connected"], ) return @@ -369,7 +375,7 @@ async def run_local_agent( items_processed=items_processed, items_created=items_created, errors=errors, - update_config_last_run=True, + update_config_last_run=False, config_id=config.id, config_type="local", ) @@ -610,60 +616,11 @@ async def trigger_pending_runs( * Runs execute **sequentially** to avoid flooding the WS connection. """ logger.info( - "agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id + "agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)", + user_id, + device_id, ) - async with async_session() as db: - local_result = await db.execute( - select(LocalAgentConfig).where( - LocalAgentConfig.user_id == user_id, - LocalAgentConfig.enabled == True, # noqa: E712 - LocalAgentConfig.device_id == device_id, - ) - ) - local_configs: list[LocalAgentConfig] = list(local_result.scalars().all()) - - cloud_result = await db.execute( - select(CloudAgentConfig).where( - CloudAgentConfig.user_id == user_id, - CloudAgentConfig.enabled == True, # noqa: E712 - ) - ) - cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all()) - - # Build ordered list of overdue (type, config) pairs. - pending: list[tuple[str, Any]] = [] - for cfg in local_configs: - if _is_overdue(cfg.schedule_cron, cfg.last_run_at): - pending.append(("local", cfg)) - for cfg in cloud_configs: - if _is_overdue(cfg.schedule_cron, cfg.last_run_at): - pending.append(("cloud", cfg)) - - if not pending: - logger.debug("agent_runner: no overdue runs for user=%s", user_id) - return - - logger.info( - "agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id - ) - - for agent_type, cfg in pending: - # Create a fresh run log for this scheduled dispatch. - run_log = AgentRunLog( - agent_id=cfg.id, - agent_type=agent_type, - user_id=user_id, - status="running", - ) - async with async_session() as db: - db.add(run_log) - await db.commit() - await db.refresh(run_log) - - if agent_type == "local": - await run_local_agent(user_id, cfg, run_log, device_mgr) - else: - await run_cloud_agent(user_id, cfg, run_log, device_mgr) + return # ── Internal helper ───────────────────────────────────────────────────────── diff --git a/app/schemas.py b/app/schemas.py index 3f0d227..33bf986 100644 --- a/app/schemas.py +++ b/app/schemas.py @@ -306,81 +306,27 @@ class AgentCatalogItem(BaseModel): config_schema: dict[str, Any] = Field(default_factory=dict) -# ── Local Agent Config ──────────────────────────────────────────────── - -class LocalAgentConfigCreate(BaseModel): - name: str - device_id: str - directory_paths: list[str] - data_types: list[str] - prompt_template: str - file_extensions: list[str] - schedule_cron: str +class AgentCreationCheckRequest(BaseModel): + active_agents: int = Field(ge=0, default=0) -class LocalAgentConfigUpdate(BaseModel): - name: str | None = None - device_id: str | None = None - directory_paths: list[str] | None = None - data_types: list[str] | None = None - prompt_template: str | None = None - file_extensions: list[str] | None = None - schedule_cron: str | None = None - enabled: bool | None = None +class AgentCreationCheckResponse(BaseModel): + allowed: bool + tier: BillingTier + active_agents: int + limit: int -class LocalAgentConfigResponse(BaseModel): - id: str - name: str - device_id: str - directory_paths: list[str] - data_types: list[str] - prompt_template: str - file_extensions: list[str] - schedule_cron: str - enabled: bool - last_run_at: int | None - created_at: int - updated_at: int - - -# ── Cloud Agent Config ──────────────────────────────────────────────── - -class CloudAgentConfigCreate(BaseModel): - provider: Literal["gmail", "teams", "outlook"] - name: str - data_types: list[str] - prompt_template: str - oauth_token_encrypted: str - schedule_cron: str - filter_config: dict[str, Any] | None = None - - -class CloudAgentConfigUpdate(BaseModel): - provider: Literal["gmail", "teams", "outlook"] | None = None - name: str | None = None - data_types: list[str] | None = None - prompt_template: str | None = None - oauth_token_encrypted: str | None = None - schedule_cron: str | None = None - filter_config: dict[str, Any] | None = None - enabled: bool | None = None - - -class CloudAgentConfigResponse(BaseModel): - """oauth_token_encrypted is intentionally excluded — never returned to clients.""" - - id: str - provider: Literal["gmail", "teams", "outlook"] - name: str - data_types: list[str] - prompt_template: str - schedule_cron: str - filter_config: dict[str, Any] | None - enabled: bool - last_run_at: int | None - created_at: int - updated_at: int +class AgentTriggerRequest(BaseModel): + directory: str = Field(min_length=1) + what_to_extract: list[Literal["task", "note", "timeline", "project"]] = Field(min_length=1) + actions_by_type: dict[ + Literal["task", "note", "timeline", "project"], + list[Literal["add", "update"]], + ] | None = None + batch_interval: str = Field(min_length=1) + custom_agent_prompt: str = Field(min_length=1) + active_agents: int = Field(ge=0, default=0) # ── Agent Run Log ───────────────────────────────────────────────────── diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py index d1d58d5..2764f77 100644 --- a/tests/test_agent_runner.py +++ b/tests/test_agent_runner.py @@ -10,13 +10,13 @@ Coverage: - run_local_agent — file-read timeout path - run_local_agent — LLM extraction error path - run_cloud_agent — stub returns error immediately - - trigger_pending_runs — overdue local + cloud dispatched + - trigger_pending_runs — skipped when config is client-owned - trigger_pending_runs — non-overdue skipped - trigger_pending_runs — device_id filter for local agents - Integration: - - POST /agents/{id}/run — 404 on unknown agent - - POST /agents/{id}/run — creates run log + dispatches background task + Integration: + - POST /agents/can-create — billing eligibility check + - POST /agents/trigger — creates run log + dispatches background task """ from __future__ import annotations @@ -373,7 +373,7 @@ async def test_run_local_agent_happy_path(): assert kwargs["items_processed"] == 1 assert kwargs["items_created"] == 1 assert kwargs["errors"] == [] - assert kwargs["update_config_last_run"] is True + assert kwargs["update_config_last_run"] is False # Verify agent_run frame was sent. agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"] @@ -690,31 +690,11 @@ async def test_finalize_run_updates_cloud_config_last_run_at(): @pytest.mark.asyncio async def test_trigger_pending_runs_no_overdue(): - """If no agents are overdue trigger_pending_runs does nothing.""" - from datetime import timedelta - - config = _make_local_config() - config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago - config.schedule_cron = "0 */6 * * *" # every 6h — not due yet - - mock_db_result_local = MagicMock() - mock_db_result_local.scalars.return_value.all.return_value = [config] - - mock_db_result_cloud = MagicMock() - mock_db_result_cloud.scalars.return_value.all.return_value = [] + """Pending-run scan is skipped because agent config is client-owned.""" mgr = _make_manager() - with patch("app.core.agent_runner.async_session") as mock_session_factory, \ - patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: - mock_ctx = AsyncMock() - mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_ctx.__aexit__ = AsyncMock(return_value=False) - mock_ctx.execute = AsyncMock( - side_effect=[mock_db_result_local, mock_db_result_cloud] - ) - mock_session_factory.return_value = mock_ctx - + with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: await trigger_pending_runs(_FREE_UID, "dev-001", mgr) mock_run.assert_not_called() @@ -722,31 +702,11 @@ async def test_trigger_pending_runs_no_overdue(): @pytest.mark.asyncio async def test_trigger_pending_runs_device_id_filter(): - """Local agents are only triggered for the matching device_id.""" - # The DB query already filters by device_id, so we verify the SELECT - # includes the device_id filter by checking that a config bound to a - # different device is never dispatched. - # - # Since trigger_pending_runs queries with device_id == "dev-001", - # simulate the DB returning an empty list (as it would for a mismatch). - mock_db_result_local = MagicMock() - mock_db_result_local.scalars.return_value.all.return_value = [] # no match - - mock_db_result_cloud = MagicMock() - mock_db_result_cloud.scalars.return_value.all.return_value = [] + """Device filtering is no longer backend-managed in pending runs.""" mgr = _make_manager(device_id="dev-001") - with patch("app.core.agent_runner.async_session") as mock_session_factory, \ - patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: - mock_ctx = AsyncMock() - mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) - mock_ctx.__aexit__ = AsyncMock(return_value=False) - mock_ctx.execute = AsyncMock( - side_effect=[mock_db_result_local, mock_db_result_cloud] - ) - mock_session_factory.return_value = mock_ctx - + with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: await trigger_pending_runs(_FREE_UID, "dev-001", mgr) mock_run.assert_not_called() @@ -754,56 +714,18 @@ async def test_trigger_pending_runs_device_id_filter(): @pytest.mark.asyncio async def test_trigger_pending_runs_dispatches_overdue(): - """Overdue local agent triggers run_local_agent sequentially.""" - config = _make_local_config() # last_run_at=None → always overdue - - mock_db_result_local = MagicMock() - mock_db_result_local.scalars.return_value.all.return_value = [config] - - mock_db_result_cloud = MagicMock() - mock_db_result_cloud.scalars.return_value.all.return_value = [] + """No pending runs are dispatched by backend after config deprecation.""" mgr = _make_manager() - call_order: list[str] = [] - - async def _mock_run_local(user_id, cfg, run_log, device_mgr): - call_order.append("run_local") - - with patch("app.core.agent_runner.async_session") as mock_session_factory, \ - patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local): - # First call: query configs. Subsequent calls: create run_log. - mock_query_ctx = AsyncMock() - mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx) - mock_query_ctx.__aexit__ = AsyncMock(return_value=False) - mock_query_ctx.execute = AsyncMock( - side_effect=[mock_db_result_local, mock_db_result_cloud] - ) - - run_log_obj = AgentRunLog( - id=str(uuid.uuid4()), - agent_id=config.id, - agent_type="local", - user_id=_FREE_UID, - status="running", - started_at=datetime.now(timezone.utc), - ) - mock_insert_ctx = AsyncMock() - mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx) - mock_insert_ctx.__aexit__ = AsyncMock(return_value=False) - mock_insert_ctx.add = MagicMock() - mock_insert_ctx.commit = AsyncMock() - mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None) - - mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx] - + with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: await trigger_pending_runs(_FREE_UID, "dev-001", mgr) - assert call_order == ["run_local"] + mock_run.assert_not_called() # --------------------------------------------------------------------------- -# Integration: POST /agents/{id}/run +# Integration: POST /agents/can-create and /agents/trigger # --------------------------------------------------------------------------- @@ -820,50 +742,67 @@ def _override_db(db_session): @pytest.mark.asyncio -async def test_trigger_run_unknown_agent(client): - """POST /agents/{id}/run returns 404 for unknown agent id.""" +async def test_can_create_agent_allows_when_under_limit(client): + """POST /agents/can-create returns allowed=True when under tier limit.""" resp = client.post( - f"/api/v1/agents/{uuid.uuid4()}/run", - headers=auth_header("power"), + "/api/v1/agents/can-create", + json={"active_agents": 0}, + headers=auth_header("free"), ) - assert resp.status_code == 404 + assert resp.status_code == 200 + body = resp.json() + assert body["allowed"] is True + assert body["tier"] == "free" + assert body["active_agents"] == 0 + assert body["limit"] == 2 + + +@pytest.mark.asyncio +async def test_can_create_agent_denies_when_at_limit(client): + """POST /agents/can-create returns allowed=False at free-tier limit.""" + resp = client.post( + "/api/v1/agents/can-create", + json={"active_agents": 2}, + headers=auth_header("free"), + ) + assert resp.status_code == 200 + body = resp.json() + assert body["allowed"] is False + assert body["limit"] == 2 @pytest.mark.asyncio async def test_trigger_run_local_agent_creates_run_log(client, db_session): - """POST /agents/{id}/run creates a run log and dispatches a background task.""" - # Create the local agent config in the DB. - config = LocalAgentConfig( - id=str(uuid.uuid4()), - user_id=TEST_USER_IDS["power"], - device_id="dev-001", - name="My Agent", - directory_paths=["/home/user/docs"], - data_types=["tasks"], - prompt_template="Extract tasks.", - file_extensions=[".txt"], - schedule_cron="0 */6 * * *", - enabled=True, - ) - db_session.add(config) - await db_session.commit() - - dispatched: list = [] + """POST /agents/trigger creates a local run log and dispatches background task.""" + dispatched: list[tuple[str, str]] = [] async def _fake_run(user_id, cfg, run_log, device_mgr): dispatched.append((user_id, cfg.id)) + def _fake_create_task(coro): + coro.close() + return MagicMock() + with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \ - patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \ patch("asyncio.create_task") as mock_create_task: + mock_create_task.side_effect = _fake_create_task resp = client.post( - f"/api/v1/agents/{config.id}/run", + "/api/v1/agents/trigger", + json={ + "directory": "/home/user/docs", + "what_to_extract": ["task", "note"], + "actions_by_type": {"task": ["add", "update"], "note": ["add"]}, + "batch_interval": "0 */6 * * *", + "custom_agent_prompt": "Extract tasks and notes.", + "active_agents": 0, + }, headers=auth_header("power"), ) assert resp.status_code == 202 data = resp.json() - assert data["agent_id"] == config.id + assert isinstance(data["agent_id"], str) + assert data["agent_id"] assert data["status"] == "running" assert data["agent_type"] == "local"