step 3.4 complete: agent run orchestrator — local/cloud runner + trigger_pending_runs + 23 tests
This commit is contained in:
@@ -16,6 +16,7 @@ Endpoints:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
@@ -26,6 +27,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import FEATURES
|
||||
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
||||
from app.core.device_manager import device_manager
|
||||
from app.db import get_session
|
||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||
from app.schemas import (
|
||||
@@ -399,14 +402,19 @@ async def trigger_agent_run(
|
||||
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
||||
"""
|
||||
# Determine agent type by trying local first, then cloud.
|
||||
agent_type: str
|
||||
# Keep the full config object so we can pass it to the agent runner.
|
||||
local_config: LocalAgentConfig | None = None
|
||||
cloud_config: CloudAgentConfig | None = None
|
||||
|
||||
local_result = await db.execute(
|
||||
select(LocalAgentConfig).where(
|
||||
LocalAgentConfig.id == agent_id,
|
||||
LocalAgentConfig.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
if local_result.scalar_one_or_none() is not None:
|
||||
local_config = local_result.scalar_one_or_none()
|
||||
|
||||
if local_config is not None:
|
||||
agent_type = "local"
|
||||
else:
|
||||
cloud_result = await db.execute(
|
||||
@@ -415,7 +423,8 @@ async def trigger_agent_run(
|
||||
CloudAgentConfig.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
if cloud_result.scalar_one_or_none() is not None:
|
||||
cloud_config = cloud_result.scalar_one_or_none()
|
||||
if cloud_config is not None:
|
||||
agent_type = "cloud"
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||
@@ -429,4 +438,15 @@ async def trigger_agent_run(
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
|
||||
# Dispatch the run as a background task — returns 202 immediately.
|
||||
if agent_type == "local" and local_config is not None:
|
||||
asyncio.create_task(
|
||||
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
||||
)
|
||||
elif agent_type == "cloud" and cloud_config is not None:
|
||||
asyncio.create_task(
|
||||
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
||||
)
|
||||
|
||||
return _to_run_log_response(run_log)
|
||||
|
||||
@@ -39,6 +39,7 @@ from jose import JWTError, jwt
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_runner import trigger_pending_runs
|
||||
from app.core.device_manager import device_manager
|
||||
from app.db import async_session
|
||||
from app.models import AgentRunLog
|
||||
@@ -100,8 +101,8 @@ async def device_ws(websocket: WebSocket) -> None:
|
||||
agent_ids,
|
||||
)
|
||||
|
||||
# Step 3.4 will replace this stub with a real call to agent_runner.
|
||||
asyncio.create_task(_trigger_pending_runs_stub(user_id, device_id))
|
||||
# Trigger any overdue agent runs now that the device is connected.
|
||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||
|
||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||
try:
|
||||
@@ -217,10 +218,4 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
# ── Pending-run trigger stub (Step 3.4 will replace) ─────────────────
|
||||
|
||||
async def _trigger_pending_runs_stub(user_id: str, device_id: str) -> None:
|
||||
"""No-op stub. Step 3.4 wires this to agent_runner.trigger_pending_runs."""
|
||||
logger.debug(
|
||||
"device_ws: _trigger_pending_runs stub user=%s device=%s", user_id, device_id
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user