step 3.4 complete: agent run orchestrator — local/cloud runner + trigger_pending_runs + 23 tests

This commit is contained in:
2026-03-05 16:13:21 +01:00
parent 608d6c784f
commit 914f70bd85
6 changed files with 1228 additions and 14 deletions

View File

@@ -16,6 +16,7 @@ Endpoints:
from __future__ import annotations
import asyncio
from datetime import datetime
from typing import Any
@@ -26,6 +27,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES
from app.core.agent_runner import run_cloud_agent, run_local_agent
from app.core.device_manager import device_manager
from app.db import get_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from app.schemas import (
@@ -399,14 +402,19 @@ async def trigger_agent_run(
``DeviceConnectionManager`` and ``agent_runner`` are available.
"""
# Determine agent type by trying local first, then cloud.
agent_type: str
# Keep the full config object so we can pass it to the agent runner.
local_config: LocalAgentConfig | None = None
cloud_config: CloudAgentConfig | None = None
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.id == agent_id,
LocalAgentConfig.user_id == current_user.id,
)
)
if local_result.scalar_one_or_none() is not None:
local_config = local_result.scalar_one_or_none()
if local_config is not None:
agent_type = "local"
else:
cloud_result = await db.execute(
@@ -415,7 +423,8 @@ async def trigger_agent_run(
CloudAgentConfig.user_id == current_user.id,
)
)
if cloud_result.scalar_one_or_none() is not None:
cloud_config = cloud_result.scalar_one_or_none()
if cloud_config is not None:
agent_type = "cloud"
else:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
@@ -429,4 +438,15 @@ async def trigger_agent_run(
db.add(run_log)
await db.commit()
await db.refresh(run_log)
# Dispatch the run as a background task — returns 202 immediately.
if agent_type == "local" and local_config is not None:
asyncio.create_task(
run_local_agent(current_user.id, local_config, run_log, device_manager)
)
elif agent_type == "cloud" and cloud_config is not None:
asyncio.create_task(
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
)
return _to_run_log_response(run_log)

View File

@@ -39,6 +39,7 @@ from jose import JWTError, jwt
from sqlalchemy import select, update
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.core.device_manager import device_manager
from app.db import async_session
from app.models import AgentRunLog
@@ -100,8 +101,8 @@ async def device_ws(websocket: WebSocket) -> None:
agent_ids,
)
# Step 3.4 will replace this stub with a real call to agent_runner.
asyncio.create_task(_trigger_pending_runs_stub(user_id, device_id))
# Trigger any overdue agent runs now that the device is connected.
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
# ── 4. Concurrent message loop + heartbeat ────────────────────────
try:
@@ -217,10 +218,4 @@ async def _mark_runs_disconnected(user_id: str) -> None:
)
# ── Pending-run trigger stub (Step 3.4 will replace) ─────────────────
async def _trigger_pending_runs_stub(user_id: str, device_id: str) -> None:
"""No-op stub. Step 3.4 wires this to agent_runner.trigger_pending_runs."""
logger.debug(
"device_ws: _trigger_pending_runs stub user=%s device=%s", user_id, device_id
)

534
app/core/agent_runner.py Normal file
View File

@@ -0,0 +1,534 @@
"""Agent run orchestrator.
Drives two agent types:
* **Local directory agent** — sends an ``agent_run`` frame to the connected
Electron device, waits for the device to stream back file contents via
``agent_data`` frames, then calls the LLM to extract structured items from
each file and pushes inserts to Electron via tool-call round-trips.
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
Teams, Outlook) and pushes extracted items to Electron. **This path is
a stub** — provider integrations are implemented in Step 3.6.
Usage
-----
Background tasks are spawned with ``asyncio.create_task()``::
asyncio.create_task(run_local_agent(user_id, config, run_log, device_manager))
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
The ``trigger_pending_runs`` function is called by the device WS endpoint
when Electron sends ``device_hello``, so any overdue runs fire immediately
when the device reconnects.
"""
from __future__ import annotations
import asyncio
import json
import logging
import uuid
from datetime import datetime, timezone
from typing import Any
from croniter import croniter
from langchain_core.messages import HumanMessage, SystemMessage
from sqlalchemy import select
from app.core.device_manager import DeviceConnectionManager
from app.core.llm import get_llm
from app.db import async_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
logger = logging.getLogger(__name__)
# ── Timeouts ───────────────────────────────────────────────────────────────
# Max seconds to wait for Electron to finish streaming file data.
_FILE_READ_TIMEOUT: int = 120
# Max seconds to wait for Electron to acknowledge a single tool-call insert.
_INSERT_TIMEOUT: int = 30
# ── Allowed tables & extraction schema hints ───────────────────────────────
_ALLOWED_TABLES: frozenset[str] = frozenset(
{"tasks", "notes", "checkpoints", "projects", "taskComments"}
)
# Field descriptions fed to the extraction LLM as concise schema references.
_TABLE_SCHEMAS: dict[str, str] = {
"tasks": (
"title (str, required), description (str), "
"status (todo|in_progress|done, default todo), "
"priority (high|medium|low, default medium), "
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
),
"notes": "title (str, required), content (str, markdown), projectId (str)",
"checkpoints": (
"title (str, required), projectId (str, required), date (ms timestamp int)"
),
"projects": "name (str, required), clientId (str)",
"taskComments": "taskId (str, required), author (str), content (str, required)",
}
_EXTRACTION_SYSTEM_PROMPT = """\
You are a data extraction assistant for a freelance project management tool.
Given a document, extract structured records matching the user's instructions.
Output a JSON array (no markdown fences, no explanation) of objects shaped:
[{{"table": "<table_name>", "data": {{...fields}}}}, ...]
Allowed table names and their fields:
{table_schemas}
Rules:
- Only extract tables listed in the "data_types" instructions.
- Use camelCase field names exactly as shown above.
- Omit optional fields you cannot determine; do not invent data.
- Never include id, createdAt, updatedAt, isAiSuggested, or isApproved.
- If nothing relevant is found, return an empty JSON array: []
- Return ONLY the JSON array.
"""
# ── Cron helper ────────────────────────────────────────────────────────────
def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
"""Return ``True`` if the next scheduled run time has already passed.
Always validates the cron expression first — an invalid expression returns
``False`` (fail-safe: never trigger an unparseable schedule).
"""
try:
now = datetime.now(timezone.utc)
if last_run_at is None:
# Validate the expression before deciding this is overdue.
croniter(schedule_cron, now)
return True
ts = last_run_at
if ts.tzinfo is None:
ts = ts.replace(tzinfo=timezone.utc)
cron = croniter(schedule_cron, ts)
next_run: datetime = cron.get_next(datetime)
return now >= next_run
except Exception as exc:
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc)
return False # Fail-safe: don't trigger if expression is invalid.
# ── LLM extraction ─────────────────────────────────────────────────────────
async def _extract_items_from_content(
prompt_template: str,
file_content: str,
data_types: list[str],
) -> list[dict[str, Any]]:
"""Call the LLM to extract structured records from *file_content*.
Returns a validated list of ``{table: str, data: dict}`` objects.
Items referencing tables not in *data_types* are discarded.
"""
allowed = [t for t in data_types if t in _ALLOWED_TABLES]
if not allowed:
return []
schema_text = "\n".join(
f" {table}: {_TABLE_SCHEMAS.get(table, '(unknown)')}" for table in allowed
)
system_prompt = _EXTRACTION_SYSTEM_PROMPT.format(table_schemas=schema_text)
user_prompt = (
f"User instructions: {prompt_template}\n\n"
f"Extract these record types: {', '.join(allowed)}\n\n"
f"Document:\n{file_content[:8000]}"
)
llm = get_llm()
raw = ""
try:
response = await llm.ainvoke(
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
)
raw = str(response.content).strip()
items: list[dict] = json.loads(raw)
if not isinstance(items, list):
raise ValueError("LLM response is not a JSON array")
except json.JSONDecodeError as exc:
logger.warning(
"agent_runner: LLM extraction returned invalid JSON: %s — snippet: %.200r",
exc,
raw,
)
return []
# Other exceptions (LLM API errors, network errors) propagate to the
# caller (run_local_agent) which records them per-file in the run log.
validated: list[dict[str, Any]] = []
for item in items:
table = item.get("table")
data = item.get("data")
if not isinstance(table, str) or table not in allowed:
continue
if not isinstance(data, dict) or not data:
continue
# Strip any server-generated or forbidden fields.
for _field in ("id", "createdAt", "updatedAt", "isAiSuggested", "isApproved"):
data.pop(_field, None)
validated.append({"table": table, "data": data})
return validated
# ── Tool-call insert helper ─────────────────────────────────────────────────
async def _send_insert_to_client(
user_id: str,
table: str,
data: dict[str, Any],
device_mgr: DeviceConnectionManager,
) -> dict[str, Any]:
"""Send an ``insert`` tool_call frame to Electron and await the tool_result.
All inserts include ``isAiSuggested=1, isApproved=0`` so the user can
review AI-produced records before they are treated as confirmed.
Raises ``asyncio.TimeoutError`` if Electron does not respond within
``_INSERT_TIMEOUT`` seconds. Raises ``RuntimeError`` if the device
disconnects before the frame can be sent.
"""
call_id = str(uuid.uuid4())
payload: dict[str, Any] = {
"type": "tool_call",
"id": call_id,
"action": "insert",
"table": table,
"data": {**data, "isAiSuggested": 1, "isApproved": 0},
}
fut = device_mgr.create_pending_call(user_id, call_id)
await device_mgr.send_frame(user_id, payload)
return await asyncio.wait_for(fut, timeout=_INSERT_TIMEOUT)
# ── Local agent runner ──────────────────────────────────────────────────────
async def run_local_agent(
user_id: str,
config: LocalAgentConfig,
run_log: AgentRunLog,
device_mgr: DeviceConnectionManager,
) -> None:
"""Execute a local directory agent run end-to-end.
Steps:
1. Verify the device identified by ``config.device_id`` is currently online.
2. Pre-create the agent_data queue so no incoming frames are lost.
3. Send ``agent_run`` frame to Electron (paths, extensions, prompt, data_types).
4. Consume ``agent_data`` frames until the ``None`` sentinel from
``agent_complete``.
5. For each received file call the LLM to extract ``{table, data}`` items.
6. Push each item to Electron as an ``insert`` tool-call; include
``isAiSuggested=1, isApproved=0`` so users can review AI suggestions.
7. Persist the run outcome (status, counts, errors) and update
``config.last_run_at``.
"""
run_id = run_log.id
# ── 1. Device online check ─────────────────────────────────────────
if not device_mgr.is_online(user_id, config.device_id):
logger.info(
"agent_runner: skip run=%s — device %r offline for user=%s",
run_id,
config.device_id,
user_id,
)
await _finalize_run(
run_log,
status="error",
errors=[f"Device {config.device_id!r} is not connected"],
)
return
# ── 2. Pre-create agent_data queue ────────────────────────────────
try:
device_mgr.get_agent_data_queue(user_id, run_id)
except RuntimeError:
await _finalize_run(
run_log,
status="error",
errors=["Device disconnected before agent run could start"],
)
return
# ── 3. Send agent_run frame ────────────────────────────────────────
frame: dict[str, Any] = {
"type": "agent_run",
"run_id": run_id,
"agent_id": config.id,
"config": {
"paths": config.directory_paths,
"file_extensions": config.file_extensions,
"prompt_template": config.prompt_template,
"data_types": config.data_types,
},
}
try:
await device_mgr.send_frame(user_id, frame)
except RuntimeError as exc:
device_mgr.cleanup_agent_data_queue(user_id, run_id)
await _finalize_run(
run_log,
status="error",
errors=[f"Failed to send agent_run frame: {exc}"],
)
return
logger.info(
"agent_runner: sent agent_run run=%s agent=%s user=%s",
run_id,
config.id,
user_id,
)
# ── 4. Consume agent_data frames ──────────────────────────────────
files: list[dict[str, Any]] = []
errors: list[str] = []
try:
queue = device_mgr.get_agent_data_queue(user_id, run_id)
deadline = asyncio.get_event_loop().time() + _FILE_READ_TIMEOUT
while True:
remaining = deadline - asyncio.get_event_loop().time()
if remaining <= 0:
errors.append("Timed out waiting for file data from device")
break
try:
frame_data = await asyncio.wait_for(queue.get(), timeout=remaining)
except asyncio.TimeoutError:
errors.append("Timed out waiting for file data from device")
break
if frame_data is None:
# Sentinel from agent_complete — stream is done.
break
files.extend(frame_data.get("files", []))
except RuntimeError as exc:
errors.append(f"Queue error reading agent data: {exc}")
# ── 56. Extract + insert ─────────────────────────────────────────
items_processed = 0
items_created = 0
for file_info in files:
file_path: str = file_info.get("path", "<unknown>")
content: str = file_info.get("content", "")
if not content:
continue
items_processed += 1
try:
extracted = await _extract_items_from_content(
config.prompt_template, content, config.data_types
)
except Exception as exc:
errors.append(f"LLM extraction error for {file_path!r}: {exc}")
continue
for item in extracted:
try:
result = await _send_insert_to_client(
user_id, item["table"], item["data"], device_mgr
)
if result.get("error"):
errors.append(
f"Insert failed ({item['table']}, {file_path!r}): {result['error']}"
)
else:
items_created += 1
except asyncio.TimeoutError:
errors.append(
f"Timed out awaiting insert ack ({item['table']}, {file_path!r})"
)
except RuntimeError as exc:
errors.append(f"Insert error ({item['table']}, {file_path!r}): {exc}")
# ── 7. Finalise ────────────────────────────────────────────────────
device_mgr.cleanup_agent_data_queue(user_id, run_id)
if errors and items_created == 0:
final_status = "error"
elif errors:
final_status = "partial"
else:
final_status = "success"
await _finalize_run(
run_log,
status=final_status,
items_processed=items_processed,
items_created=items_created,
errors=errors,
update_config_last_run=True,
config_id=config.id,
config_type="local",
)
logger.info(
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
run_id,
final_status,
items_processed,
items_created,
len(errors),
)
# ── Cloud agent runner (stub) ───────────────────────────────────────────────
async def run_cloud_agent(
user_id: str,
config: CloudAgentConfig,
run_log: AgentRunLog,
device_mgr: DeviceConnectionManager,
) -> None:
"""Execute a cloud connector agent run.
.. note::
This is a **stub** — provider integrations (Gmail, Teams, Outlook)
are implemented in Step 3.6. The run is immediately marked as an
error with an informative message.
"""
logger.info(
"agent_runner: cloud agent %s (provider=%s) for user=%s — pending Step 3.6",
config.id,
config.provider,
user_id,
)
await _finalize_run(
run_log,
status="error",
errors=[
f"Cloud provider integrations for '{config.provider}' are not yet "
"implemented. This feature arrives in Step 3.6."
],
)
# ── Pending-run trigger ─────────────────────────────────────────────────────
async def trigger_pending_runs(
user_id: str,
device_id: str,
device_mgr: DeviceConnectionManager,
) -> None:
"""Dispatch any overdue agent runs after an Electron device connects.
Called as a background task from the device WS endpoint on ``device_hello``.
Scheduling rules:
* **Local agents**: only triggered when ``config.device_id == device_id``.
* **Cloud agents**: triggered on any connected device (no device binding).
* Runs execute **sequentially** to avoid flooding the WS connection.
"""
logger.info(
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id
)
async with async_session() as db:
local_result = await db.execute(
select(LocalAgentConfig).where(
LocalAgentConfig.user_id == user_id,
LocalAgentConfig.enabled == True, # noqa: E712
LocalAgentConfig.device_id == device_id,
)
)
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
cloud_result = await db.execute(
select(CloudAgentConfig).where(
CloudAgentConfig.user_id == user_id,
CloudAgentConfig.enabled == True, # noqa: E712
)
)
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
# Build ordered list of overdue (type, config) pairs.
pending: list[tuple[str, Any]] = []
for cfg in local_configs:
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
pending.append(("local", cfg))
for cfg in cloud_configs:
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
pending.append(("cloud", cfg))
if not pending:
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
return
logger.info(
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
)
for agent_type, cfg in pending:
# Create a fresh run log for this scheduled dispatch.
run_log = AgentRunLog(
agent_id=cfg.id,
agent_type=agent_type,
user_id=user_id,
status="running",
)
async with async_session() as db:
db.add(run_log)
await db.commit()
await db.refresh(run_log)
if agent_type == "local":
await run_local_agent(user_id, cfg, run_log, device_mgr)
else:
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
# ── Internal helper ─────────────────────────────────────────────────────────
async def _finalize_run(
run_log: AgentRunLog,
*,
status: str,
items_processed: int = 0,
items_created: int = 0,
errors: list[str] | None = None,
update_config_last_run: bool = False,
config_id: str | None = None,
config_type: str | None = None,
) -> None:
"""Persist the run outcome and optionally update ``LocalAgentConfig.last_run_at``.
Uses a fresh DB session so this is safe to call from background tasks
after the original request session has closed.
"""
now = datetime.now(timezone.utc)
try:
async with async_session() as db:
managed = await db.merge(run_log)
managed.status = status
managed.items_processed = items_processed
managed.items_created = items_created
managed.errors = errors or []
managed.completed_at = now
if update_config_last_run and config_id and config_type == "local":
cfg_result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
await db.commit()
except Exception as exc:
logger.error(
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc
)