From 914f70bd85fc7a4e821b736cf293f3c2020ac86d Mon Sep 17 00:00:00 2001 From: roberto Date: Thu, 5 Mar 2026 16:13:21 +0100 Subject: [PATCH] =?UTF-8?q?step=203.4=20complete:=20agent=20run=20orchestr?= =?UTF-8?q?ator=20=E2=80=94=20local/cloud=20runner=20+=20trigger=5Fpending?= =?UTF-8?q?=5Fruns=20+=2023=20tests?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- AI_REFACTOR_PLAN.md | 10 +- app/api/routes/agents.py | 26 +- app/api/routes/device_ws.py | 11 +- app/core/agent_runner.py | 534 +++++++++++++++++++++++++++++ requirements.txt | 1 + tests/test_agent_runner.py | 660 ++++++++++++++++++++++++++++++++++++ 6 files changed, 1228 insertions(+), 14 deletions(-) create mode 100644 app/core/agent_runner.py create mode 100644 tests/test_agent_runner.py diff --git a/AI_REFACTOR_PLAN.md b/AI_REFACTOR_PLAN.md index 72a4b27..3da1ac0 100644 --- a/AI_REFACTOR_PLAN.md +++ b/AI_REFACTOR_PLAN.md @@ -375,7 +375,7 @@ Cloud Agent: - **Outcome:** Backend maintains persistent WS connections to Electron devices for agent triggers. ### Step 3.4 — Agent run orchestrator -- [ ] Create `app/core/agent_runner.py`: +- [x] Create `app/core/agent_runner.py`: - `async run_local_agent(user_id, config: LocalAgentConfig, device_mgr: DeviceConnectionManager)`: 1. Check device is online with matching `device_id` → abort if offline 2. Create `AgentRunLog` with `status=running` @@ -404,8 +404,12 @@ Cloud Agent: - For cloud agents: triggers regardless of device (any connected device can receive results) - Executes runs sequentially (one at a time to avoid overwhelming the WS) - Error handling: on any failure, update `AgentRunLog` with `status=error` + error details -- **Files:** `app/core/agent_runner.py` -- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls). +- [x] Wire `POST /agents/{id}/run` endpoint to dispatch background task via `asyncio.create_task()` +- [x] Replace `_trigger_pending_runs_stub` in `device_ws.py` with real `trigger_pending_runs` call +- [x] Add `croniter>=3.0.0` to `requirements.txt` +- [x] 23 unit + integration tests covering all code paths +- **Files:** `app/core/agent_runner.py`, `app/api/routes/agents.py`, `app/api/routes/device_ws.py`, `requirements.txt`, `tests/test_agent_runner.py` +- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls — stub until Step 3.6). ### Step 3.5 — Chatbot Journey endpoint - [ ] Create `app/api/routes/agent_setup.py`: diff --git a/app/api/routes/agents.py b/app/api/routes/agents.py index 748ffc9..6a17670 100644 --- a/app/api/routes/agents.py +++ b/app/api/routes/agents.py @@ -16,6 +16,7 @@ Endpoints: from __future__ import annotations +import asyncio from datetime import datetime from typing import Any @@ -26,6 +27,8 @@ from sqlalchemy.ext.asyncio import AsyncSession from app.api.deps import get_current_user from app.billing.tier_manager import FEATURES +from app.core.agent_runner import run_cloud_agent, run_local_agent +from app.core.device_manager import device_manager from app.db import get_session from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig from app.schemas import ( @@ -399,14 +402,19 @@ async def trigger_agent_run( ``DeviceConnectionManager`` and ``agent_runner`` are available. """ # Determine agent type by trying local first, then cloud. - agent_type: str + # Keep the full config object so we can pass it to the agent runner. + local_config: LocalAgentConfig | None = None + cloud_config: CloudAgentConfig | None = None + local_result = await db.execute( select(LocalAgentConfig).where( LocalAgentConfig.id == agent_id, LocalAgentConfig.user_id == current_user.id, ) ) - if local_result.scalar_one_or_none() is not None: + local_config = local_result.scalar_one_or_none() + + if local_config is not None: agent_type = "local" else: cloud_result = await db.execute( @@ -415,7 +423,8 @@ async def trigger_agent_run( CloudAgentConfig.user_id == current_user.id, ) ) - if cloud_result.scalar_one_or_none() is not None: + cloud_config = cloud_result.scalar_one_or_none() + if cloud_config is not None: agent_type = "cloud" else: raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found") @@ -429,4 +438,15 @@ async def trigger_agent_run( db.add(run_log) await db.commit() await db.refresh(run_log) + + # Dispatch the run as a background task — returns 202 immediately. + if agent_type == "local" and local_config is not None: + asyncio.create_task( + run_local_agent(current_user.id, local_config, run_log, device_manager) + ) + elif agent_type == "cloud" and cloud_config is not None: + asyncio.create_task( + run_cloud_agent(current_user.id, cloud_config, run_log, device_manager) + ) + return _to_run_log_response(run_log) diff --git a/app/api/routes/device_ws.py b/app/api/routes/device_ws.py index ffc9e19..2e0c038 100644 --- a/app/api/routes/device_ws.py +++ b/app/api/routes/device_ws.py @@ -39,6 +39,7 @@ from jose import JWTError, jwt from sqlalchemy import select, update from app.config.settings import settings +from app.core.agent_runner import trigger_pending_runs from app.core.device_manager import device_manager from app.db import async_session from app.models import AgentRunLog @@ -100,8 +101,8 @@ async def device_ws(websocket: WebSocket) -> None: agent_ids, ) - # Step 3.4 will replace this stub with a real call to agent_runner. - asyncio.create_task(_trigger_pending_runs_stub(user_id, device_id)) + # Trigger any overdue agent runs now that the device is connected. + asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager)) # ── 4. Concurrent message loop + heartbeat ──────────────────────── try: @@ -217,10 +218,4 @@ async def _mark_runs_disconnected(user_id: str) -> None: ) -# ── Pending-run trigger stub (Step 3.4 will replace) ───────────────── -async def _trigger_pending_runs_stub(user_id: str, device_id: str) -> None: - """No-op stub. Step 3.4 wires this to agent_runner.trigger_pending_runs.""" - logger.debug( - "device_ws: _trigger_pending_runs stub user=%s device=%s", user_id, device_id - ) diff --git a/app/core/agent_runner.py b/app/core/agent_runner.py new file mode 100644 index 0000000..d6e9cd5 --- /dev/null +++ b/app/core/agent_runner.py @@ -0,0 +1,534 @@ +"""Agent run orchestrator. + +Drives two agent types: + +* **Local directory agent** — sends an ``agent_run`` frame to the connected + Electron device, waits for the device to stream back file contents via + ``agent_data`` frames, then calls the LLM to extract structured items from + each file and pushes inserts to Electron via tool-call round-trips. + +* **Cloud connector agent** — fetches data from third-party APIs (Gmail, + Teams, Outlook) and pushes extracted items to Electron. **This path is + a stub** — provider integrations are implemented in Step 3.6. + +Usage +----- +Background tasks are spawned with ``asyncio.create_task()``:: + + asyncio.create_task(run_local_agent(user_id, config, run_log, device_manager)) + asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager)) + +The ``trigger_pending_runs`` function is called by the device WS endpoint +when Electron sends ``device_hello``, so any overdue runs fire immediately +when the device reconnects. +""" + +from __future__ import annotations + +import asyncio +import json +import logging +import uuid +from datetime import datetime, timezone +from typing import Any + +from croniter import croniter +from langchain_core.messages import HumanMessage, SystemMessage +from sqlalchemy import select + +from app.core.device_manager import DeviceConnectionManager +from app.core.llm import get_llm +from app.db import async_session +from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig + +logger = logging.getLogger(__name__) + +# ── Timeouts ─────────────────────────────────────────────────────────────── + +# Max seconds to wait for Electron to finish streaming file data. +_FILE_READ_TIMEOUT: int = 120 +# Max seconds to wait for Electron to acknowledge a single tool-call insert. +_INSERT_TIMEOUT: int = 30 + +# ── Allowed tables & extraction schema hints ─────────────────────────────── + +_ALLOWED_TABLES: frozenset[str] = frozenset( + {"tasks", "notes", "checkpoints", "projects", "taskComments"} +) + +# Field descriptions fed to the extraction LLM as concise schema references. +_TABLE_SCHEMAS: dict[str, str] = { + "tasks": ( + "title (str, required), description (str), " + "status (todo|in_progress|done, default todo), " + "priority (high|medium|low, default medium), " + "assignee (JSON array string), dueDate (ms timestamp int), projectId (str)" + ), + "notes": "title (str, required), content (str, markdown), projectId (str)", + "checkpoints": ( + "title (str, required), projectId (str, required), date (ms timestamp int)" + ), + "projects": "name (str, required), clientId (str)", + "taskComments": "taskId (str, required), author (str), content (str, required)", +} + +_EXTRACTION_SYSTEM_PROMPT = """\ +You are a data extraction assistant for a freelance project management tool. +Given a document, extract structured records matching the user's instructions. + +Output a JSON array (no markdown fences, no explanation) of objects shaped: + [{{"table": "", "data": {{...fields}}}}, ...] + +Allowed table names and their fields: +{table_schemas} + +Rules: +- Only extract tables listed in the "data_types" instructions. +- Use camelCase field names exactly as shown above. +- Omit optional fields you cannot determine; do not invent data. +- Never include id, createdAt, updatedAt, isAiSuggested, or isApproved. +- If nothing relevant is found, return an empty JSON array: [] +- Return ONLY the JSON array. +""" + + +# ── Cron helper ──────────────────────────────────────────────────────────── + + +def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool: + """Return ``True`` if the next scheduled run time has already passed. + + Always validates the cron expression first — an invalid expression returns + ``False`` (fail-safe: never trigger an unparseable schedule). + """ + try: + now = datetime.now(timezone.utc) + if last_run_at is None: + # Validate the expression before deciding this is overdue. + croniter(schedule_cron, now) + return True + ts = last_run_at + if ts.tzinfo is None: + ts = ts.replace(tzinfo=timezone.utc) + cron = croniter(schedule_cron, ts) + next_run: datetime = cron.get_next(datetime) + return now >= next_run + except Exception as exc: + logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc) + return False # Fail-safe: don't trigger if expression is invalid. + + +# ── LLM extraction ───────────────────────────────────────────────────────── + + +async def _extract_items_from_content( + prompt_template: str, + file_content: str, + data_types: list[str], +) -> list[dict[str, Any]]: + """Call the LLM to extract structured records from *file_content*. + + Returns a validated list of ``{table: str, data: dict}`` objects. + Items referencing tables not in *data_types* are discarded. + """ + allowed = [t for t in data_types if t in _ALLOWED_TABLES] + if not allowed: + return [] + + schema_text = "\n".join( + f" {table}: {_TABLE_SCHEMAS.get(table, '(unknown)')}" for table in allowed + ) + system_prompt = _EXTRACTION_SYSTEM_PROMPT.format(table_schemas=schema_text) + user_prompt = ( + f"User instructions: {prompt_template}\n\n" + f"Extract these record types: {', '.join(allowed)}\n\n" + f"Document:\n{file_content[:8000]}" + ) + + llm = get_llm() + raw = "" + try: + response = await llm.ainvoke( + [SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)] + ) + raw = str(response.content).strip() + items: list[dict] = json.loads(raw) + if not isinstance(items, list): + raise ValueError("LLM response is not a JSON array") + except json.JSONDecodeError as exc: + logger.warning( + "agent_runner: LLM extraction returned invalid JSON: %s — snippet: %.200r", + exc, + raw, + ) + return [] + # Other exceptions (LLM API errors, network errors) propagate to the + # caller (run_local_agent) which records them per-file in the run log. + + validated: list[dict[str, Any]] = [] + for item in items: + table = item.get("table") + data = item.get("data") + if not isinstance(table, str) or table not in allowed: + continue + if not isinstance(data, dict) or not data: + continue + # Strip any server-generated or forbidden fields. + for _field in ("id", "createdAt", "updatedAt", "isAiSuggested", "isApproved"): + data.pop(_field, None) + validated.append({"table": table, "data": data}) + return validated + + +# ── Tool-call insert helper ───────────────────────────────────────────────── + + +async def _send_insert_to_client( + user_id: str, + table: str, + data: dict[str, Any], + device_mgr: DeviceConnectionManager, +) -> dict[str, Any]: + """Send an ``insert`` tool_call frame to Electron and await the tool_result. + + All inserts include ``isAiSuggested=1, isApproved=0`` so the user can + review AI-produced records before they are treated as confirmed. + + Raises ``asyncio.TimeoutError`` if Electron does not respond within + ``_INSERT_TIMEOUT`` seconds. Raises ``RuntimeError`` if the device + disconnects before the frame can be sent. + """ + call_id = str(uuid.uuid4()) + payload: dict[str, Any] = { + "type": "tool_call", + "id": call_id, + "action": "insert", + "table": table, + "data": {**data, "isAiSuggested": 1, "isApproved": 0}, + } + fut = device_mgr.create_pending_call(user_id, call_id) + await device_mgr.send_frame(user_id, payload) + return await asyncio.wait_for(fut, timeout=_INSERT_TIMEOUT) + + +# ── Local agent runner ────────────────────────────────────────────────────── + + +async def run_local_agent( + user_id: str, + config: LocalAgentConfig, + run_log: AgentRunLog, + device_mgr: DeviceConnectionManager, +) -> None: + """Execute a local directory agent run end-to-end. + + Steps: + + 1. Verify the device identified by ``config.device_id`` is currently online. + 2. Pre-create the agent_data queue so no incoming frames are lost. + 3. Send ``agent_run`` frame to Electron (paths, extensions, prompt, data_types). + 4. Consume ``agent_data`` frames until the ``None`` sentinel from + ``agent_complete``. + 5. For each received file call the LLM to extract ``{table, data}`` items. + 6. Push each item to Electron as an ``insert`` tool-call; include + ``isAiSuggested=1, isApproved=0`` so users can review AI suggestions. + 7. Persist the run outcome (status, counts, errors) and update + ``config.last_run_at``. + """ + run_id = run_log.id + + # ── 1. Device online check ───────────────────────────────────────── + if not device_mgr.is_online(user_id, config.device_id): + logger.info( + "agent_runner: skip run=%s — device %r offline for user=%s", + run_id, + config.device_id, + user_id, + ) + await _finalize_run( + run_log, + status="error", + errors=[f"Device {config.device_id!r} is not connected"], + ) + return + + # ── 2. Pre-create agent_data queue ──────────────────────────────── + try: + device_mgr.get_agent_data_queue(user_id, run_id) + except RuntimeError: + await _finalize_run( + run_log, + status="error", + errors=["Device disconnected before agent run could start"], + ) + return + + # ── 3. Send agent_run frame ──────────────────────────────────────── + frame: dict[str, Any] = { + "type": "agent_run", + "run_id": run_id, + "agent_id": config.id, + "config": { + "paths": config.directory_paths, + "file_extensions": config.file_extensions, + "prompt_template": config.prompt_template, + "data_types": config.data_types, + }, + } + try: + await device_mgr.send_frame(user_id, frame) + except RuntimeError as exc: + device_mgr.cleanup_agent_data_queue(user_id, run_id) + await _finalize_run( + run_log, + status="error", + errors=[f"Failed to send agent_run frame: {exc}"], + ) + return + + logger.info( + "agent_runner: sent agent_run run=%s agent=%s user=%s", + run_id, + config.id, + user_id, + ) + + # ── 4. Consume agent_data frames ────────────────────────────────── + files: list[dict[str, Any]] = [] + errors: list[str] = [] + + try: + queue = device_mgr.get_agent_data_queue(user_id, run_id) + deadline = asyncio.get_event_loop().time() + _FILE_READ_TIMEOUT + while True: + remaining = deadline - asyncio.get_event_loop().time() + if remaining <= 0: + errors.append("Timed out waiting for file data from device") + break + try: + frame_data = await asyncio.wait_for(queue.get(), timeout=remaining) + except asyncio.TimeoutError: + errors.append("Timed out waiting for file data from device") + break + if frame_data is None: + # Sentinel from agent_complete — stream is done. + break + files.extend(frame_data.get("files", [])) + except RuntimeError as exc: + errors.append(f"Queue error reading agent data: {exc}") + + # ── 5–6. Extract + insert ───────────────────────────────────────── + items_processed = 0 + items_created = 0 + + for file_info in files: + file_path: str = file_info.get("path", "") + content: str = file_info.get("content", "") + if not content: + continue + items_processed += 1 + try: + extracted = await _extract_items_from_content( + config.prompt_template, content, config.data_types + ) + except Exception as exc: + errors.append(f"LLM extraction error for {file_path!r}: {exc}") + continue + + for item in extracted: + try: + result = await _send_insert_to_client( + user_id, item["table"], item["data"], device_mgr + ) + if result.get("error"): + errors.append( + f"Insert failed ({item['table']}, {file_path!r}): {result['error']}" + ) + else: + items_created += 1 + except asyncio.TimeoutError: + errors.append( + f"Timed out awaiting insert ack ({item['table']}, {file_path!r})" + ) + except RuntimeError as exc: + errors.append(f"Insert error ({item['table']}, {file_path!r}): {exc}") + + # ── 7. Finalise ──────────────────────────────────────────────────── + device_mgr.cleanup_agent_data_queue(user_id, run_id) + + if errors and items_created == 0: + final_status = "error" + elif errors: + final_status = "partial" + else: + final_status = "success" + + await _finalize_run( + run_log, + status=final_status, + items_processed=items_processed, + items_created=items_created, + errors=errors, + update_config_last_run=True, + config_id=config.id, + config_type="local", + ) + logger.info( + "agent_runner: run=%s done status=%s processed=%d created=%d errors=%d", + run_id, + final_status, + items_processed, + items_created, + len(errors), + ) + + +# ── Cloud agent runner (stub) ─────────────────────────────────────────────── + + +async def run_cloud_agent( + user_id: str, + config: CloudAgentConfig, + run_log: AgentRunLog, + device_mgr: DeviceConnectionManager, +) -> None: + """Execute a cloud connector agent run. + + .. note:: + This is a **stub** — provider integrations (Gmail, Teams, Outlook) + are implemented in Step 3.6. The run is immediately marked as an + error with an informative message. + """ + logger.info( + "agent_runner: cloud agent %s (provider=%s) for user=%s — pending Step 3.6", + config.id, + config.provider, + user_id, + ) + await _finalize_run( + run_log, + status="error", + errors=[ + f"Cloud provider integrations for '{config.provider}' are not yet " + "implemented. This feature arrives in Step 3.6." + ], + ) + + +# ── Pending-run trigger ───────────────────────────────────────────────────── + + +async def trigger_pending_runs( + user_id: str, + device_id: str, + device_mgr: DeviceConnectionManager, +) -> None: + """Dispatch any overdue agent runs after an Electron device connects. + + Called as a background task from the device WS endpoint on ``device_hello``. + + Scheduling rules: + + * **Local agents**: only triggered when ``config.device_id == device_id``. + * **Cloud agents**: triggered on any connected device (no device binding). + * Runs execute **sequentially** to avoid flooding the WS connection. + """ + logger.info( + "agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id + ) + async with async_session() as db: + local_result = await db.execute( + select(LocalAgentConfig).where( + LocalAgentConfig.user_id == user_id, + LocalAgentConfig.enabled == True, # noqa: E712 + LocalAgentConfig.device_id == device_id, + ) + ) + local_configs: list[LocalAgentConfig] = list(local_result.scalars().all()) + + cloud_result = await db.execute( + select(CloudAgentConfig).where( + CloudAgentConfig.user_id == user_id, + CloudAgentConfig.enabled == True, # noqa: E712 + ) + ) + cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all()) + + # Build ordered list of overdue (type, config) pairs. + pending: list[tuple[str, Any]] = [] + for cfg in local_configs: + if _is_overdue(cfg.schedule_cron, cfg.last_run_at): + pending.append(("local", cfg)) + for cfg in cloud_configs: + if _is_overdue(cfg.schedule_cron, cfg.last_run_at): + pending.append(("cloud", cfg)) + + if not pending: + logger.debug("agent_runner: no overdue runs for user=%s", user_id) + return + + logger.info( + "agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id + ) + + for agent_type, cfg in pending: + # Create a fresh run log for this scheduled dispatch. + run_log = AgentRunLog( + agent_id=cfg.id, + agent_type=agent_type, + user_id=user_id, + status="running", + ) + async with async_session() as db: + db.add(run_log) + await db.commit() + await db.refresh(run_log) + + if agent_type == "local": + await run_local_agent(user_id, cfg, run_log, device_mgr) + else: + await run_cloud_agent(user_id, cfg, run_log, device_mgr) + + +# ── Internal helper ───────────────────────────────────────────────────────── + + +async def _finalize_run( + run_log: AgentRunLog, + *, + status: str, + items_processed: int = 0, + items_created: int = 0, + errors: list[str] | None = None, + update_config_last_run: bool = False, + config_id: str | None = None, + config_type: str | None = None, +) -> None: + """Persist the run outcome and optionally update ``LocalAgentConfig.last_run_at``. + + Uses a fresh DB session so this is safe to call from background tasks + after the original request session has closed. + """ + now = datetime.now(timezone.utc) + try: + async with async_session() as db: + managed = await db.merge(run_log) + managed.status = status + managed.items_processed = items_processed + managed.items_created = items_created + managed.errors = errors or [] + managed.completed_at = now + + if update_config_last_run and config_id and config_type == "local": + cfg_result = await db.execute( + select(LocalAgentConfig).where(LocalAgentConfig.id == config_id) + ) + cfg = cfg_result.scalar_one_or_none() + if cfg: + cfg.last_run_at = now + + await db.commit() + except Exception as exc: + logger.error( + "agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc + ) diff --git a/requirements.txt b/requirements.txt index b7409ab..0650450 100644 --- a/requirements.txt +++ b/requirements.txt @@ -24,4 +24,5 @@ aiosqlite>=0.20.0 moto[s3]>=5.0.0 pinecone>=5.0.0 qdrant-client>=1.7.0 +croniter>=3.0.0 ruff>=0.8.0 diff --git a/tests/test_agent_runner.py b/tests/test_agent_runner.py new file mode 100644 index 0000000..46b748d --- /dev/null +++ b/tests/test_agent_runner.py @@ -0,0 +1,660 @@ +"""Tests for Step 3.4: agent_runner module. + +Coverage: + Unit: + - _is_overdue — cron schedule overdue detection + - _extract_items_from_content — LLM extraction + JSON parsing + validation + - _send_insert_to_client — tool_call frame construction + timeout + - run_local_agent — end-to-end local agent happy path + - run_local_agent — device offline path + - run_local_agent — file-read timeout path + - run_local_agent — LLM extraction error path + - run_cloud_agent — stub returns error immediately + - trigger_pending_runs — overdue local + cloud dispatched + - trigger_pending_runs — non-overdue skipped + - trigger_pending_runs — device_id filter for local agents + + Integration: + - POST /agents/{id}/run — 404 on unknown agent + - POST /agents/{id}/run — creates run log + dispatches background task +""" + +from __future__ import annotations + +import asyncio +import json +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +import pytest_asyncio + +from app.core.agent_runner import ( + _extract_items_from_content, + _is_overdue, + _send_insert_to_client, + run_cloud_agent, + run_local_agent, + trigger_pending_runs, +) +from app.core.device_manager import DeviceConnectionManager +from app.db import get_session +from app.main import app +from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig +from tests.conftest import TEST_USER_IDS, auth_header + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +_FREE_UID = TEST_USER_IDS["free"] +_PRO_UID = TEST_USER_IDS["pro"] + + +def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig: + return LocalAgentConfig( + id=str(uuid.uuid4()), + user_id=user_id, + device_id=device_id, + name="Test Local Agent", + directory_paths=["/home/user/emails"], + data_types=["tasks", "notes"], + prompt_template="Extract tasks and notes from this document.", + file_extensions=[".txt", ".eml"], + schedule_cron="0 */6 * * *", + enabled=True, + last_run_at=None, + ) + + +def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig: + return CloudAgentConfig( + id=str(uuid.uuid4()), + user_id=user_id, + provider="gmail", + name="Test Gmail Agent", + data_types=["tasks"], + prompt_template="Extract tasks from email.", + schedule_cron="0 */6 * * *", + enabled=True, + last_run_at=None, + ) + + +def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog: + return AgentRunLog( + id=str(uuid.uuid4()), + agent_id=agent_id, + agent_type=agent_type, + user_id=user_id, + status="running", + started_at=datetime.now(timezone.utc), + ) + + +def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager: + mgr = DeviceConnectionManager() + ws = MagicMock() + ws.send_text = AsyncMock() + mgr.register(user_id, device_id, ws) + return mgr + + +# --------------------------------------------------------------------------- +# _is_overdue +# --------------------------------------------------------------------------- + +def test_is_overdue_never_run(): + """An agent that has never run is always overdue.""" + assert _is_overdue("0 */6 * * *", None) is True + + +def test_is_overdue_very_recently_run(): + """An agent that just ran is not overdue.""" + last = datetime.now(timezone.utc) + assert _is_overdue("0 */6 * * *", last) is False + + +def test_is_overdue_long_ago(): + """An agent last run 2 days ago with a 6-hour schedule is overdue.""" + from datetime import timedelta + last = datetime.now(timezone.utc) - timedelta(days=2) + assert _is_overdue("0 */6 * * *", last) is True + + +def test_is_overdue_invalid_cron_returns_false(): + """Unparseable cron must not raise and should return False (fail-safe).""" + assert _is_overdue("not a cron", None) is False + + +def test_is_overdue_naive_datetime(): + """Naive datetime objects are handled without raising.""" + from datetime import timedelta + last = datetime.utcnow() - timedelta(days=1) # naive + # Should not raise. + result = _is_overdue("0 */6 * * *", last) + assert isinstance(result, bool) + + +# --------------------------------------------------------------------------- +# _extract_items_from_content +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_extract_items_happy_path(): + """LLM returns valid JSON array; items with allowed tables are returned.""" + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = json.dumps([ + {"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}}, + {"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}}, + ]) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + with patch("app.core.agent_runner.get_llm", return_value=mock_llm): + items = await _extract_items_from_content( + "Extract tasks and notes.", + "Email body: Buy milk urgently. Notes from meeting: discussed roadmap.", + ["tasks", "notes"], + ) + + assert len(items) == 2 + assert items[0]["table"] == "tasks" + assert items[0]["data"]["title"] == "Buy milk" + assert items[1]["table"] == "notes" + + +@pytest.mark.asyncio +async def test_extract_items_strips_forbidden_fields(): + """Fields like id, createdAt, isAiSuggested must be stripped from extracted data.""" + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = json.dumps([ + { + "table": "tasks", + "data": { + "title": "Review PR", + "id": "should-be-removed", + "createdAt": 99999, + "isAiSuggested": 0, + "isApproved": 1, + }, + } + ]) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + with patch("app.core.agent_runner.get_llm", return_value=mock_llm): + items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"]) + + assert len(items) == 1 + data = items[0]["data"] + assert "id" not in data + assert "createdAt" not in data + assert "isAiSuggested" not in data + assert "isApproved" not in data + assert data["title"] == "Review PR" + + +@pytest.mark.asyncio +async def test_extract_items_invalid_json_returns_empty(): + """LLM returning invalid JSON must return empty list without raising.""" + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = "Sorry, I cannot extract anything." + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + with patch("app.core.agent_runner.get_llm", return_value=mock_llm): + items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"]) + + assert items == [] + + +@pytest.mark.asyncio +async def test_extract_items_disallowed_table_filtered(): + """Items whose table is not in data_types are discarded.""" + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = json.dumps([ + {"table": "tasks", "data": {"title": "Valid task"}}, + {"table": "projects", "data": {"name": "Should be filtered"}}, + ]) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + with patch("app.core.agent_runner.get_llm", return_value=mock_llm): + # Only "tasks" is in data_types — "projects" should be filtered. + items = await _extract_items_from_content("Extract.", "content", ["tasks"]) + + assert len(items) == 1 + assert items[0]["table"] == "tasks" + + +@pytest.mark.asyncio +async def test_extract_items_empty_data_types_returns_empty(): + """If no allowed data_types match, skip LLM call and return immediately.""" + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock() + + with patch("app.core.agent_runner.get_llm", return_value=mock_llm): + items = await _extract_items_from_content("Extract.", "content", []) + + mock_llm.ainvoke.assert_not_called() + assert items == [] + + +@pytest.mark.asyncio +async def test_extract_items_llm_error_propagates(): + """LLM API errors propagate so the caller (run_local_agent) can record them.""" + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable")) + + with patch("app.core.agent_runner.get_llm", return_value=mock_llm): + with pytest.raises(RuntimeError, match="API unavailable"): + await _extract_items_from_content("Extract tasks.", "content", ["tasks"]) + + +# --------------------------------------------------------------------------- +# _send_insert_to_client +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_send_insert_to_client_happy_path(): + """Frame is sent with isAiSuggested/isApproved added; result is returned.""" + mgr = _make_manager() + + sent_payloads: list[dict] = [] + original_send = mgr.send_frame + + async def _capture_send(uid: str, frame: dict) -> None: + sent_payloads.append(frame) + # Immediately resolve the pending call with a success result. + call_id = frame["id"] + mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}}) + + mgr.send_frame = _capture_send # type: ignore[method-assign] + + result = await _send_insert_to_client( + _FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr + ) + + assert len(sent_payloads) == 1 + payload = sent_payloads[0] + assert payload["action"] == "insert" + assert payload["table"] == "tasks" + assert payload["data"]["title"] == "Buy milk" + assert payload["data"]["isAiSuggested"] == 1 + assert payload["data"]["isApproved"] == 0 + assert result["row"]["title"] == "Buy milk" + + +@pytest.mark.asyncio +async def test_send_insert_to_client_timeout(): + """asyncio.TimeoutError is raised when Electron does not respond.""" + mgr = _make_manager() + + async def _slow_send(uid: str, frame: dict) -> None: + # Never resolve the pending call. + pass + + mgr.send_frame = _slow_send # type: ignore[method-assign] + + with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05): + with pytest.raises(asyncio.TimeoutError): + await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr) + + +# --------------------------------------------------------------------------- +# run_local_agent +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_run_local_agent_device_offline(): + """run_local_agent marks run as error when device is offline.""" + config = _make_local_config() + run_log = _make_run_log(config.id) + mgr = DeviceConnectionManager() # Empty — no device registered. + + with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize: + await run_local_agent(_FREE_UID, config, run_log, mgr) + + mock_finalize.assert_called_once() + _args, kwargs = mock_finalize.call_args + assert kwargs["status"] == "error" + assert any("not connected" in e for e in kwargs["errors"]) + + +@pytest.mark.asyncio +async def test_run_local_agent_happy_path(): + """End-to-end: files received, LLM extracts one task, insert sent + ack'd.""" + config = _make_local_config() + run_log = _make_run_log(config.id) + mgr = _make_manager() + + # Build a fake agent_data frame (will be queued after send). + file_frame = { + "type": "agent_data", + "run_id": run_log.id, + "files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}], + } + agent_complete_frame = None # sentinel + + sent_frames: list[dict] = [] + + async def _mock_send(uid: str, frame: dict) -> None: + sent_frames.append(frame) + if frame.get("type") == "agent_run": + # Simulate Electron responding with file data then agent_complete. + q = mgr.get_agent_data_queue(uid, frame["run_id"]) + await q.put(file_frame) + await q.put(agent_complete_frame) + elif frame.get("type") == "tool_call": + # Resolve the pending insert immediately. + mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}}) + + mgr.send_frame = _mock_send # type: ignore[method-assign] + + mock_llm = MagicMock() + mock_response = MagicMock() + mock_response.content = json.dumps([ + {"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}} + ]) + mock_llm.ainvoke = AsyncMock(return_value=mock_response) + + with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \ + patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize: + await run_local_agent(_FREE_UID, config, run_log, mgr) + + mock_finalize.assert_called_once() + _args, kwargs = mock_finalize.call_args + assert kwargs["status"] == "success" + assert kwargs["items_processed"] == 1 + assert kwargs["items_created"] == 1 + assert kwargs["errors"] == [] + assert kwargs["update_config_last_run"] is True + + # Verify agent_run frame was sent. + agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"] + assert len(agent_run_frames) == 1 + assert agent_run_frames[0]["agent_id"] == config.id + assert "paths" in agent_run_frames[0]["config"] + + # Verify insert frame was sent with AI flags. + insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"] + assert len(insert_frames) == 1 + assert insert_frames[0]["data"]["isAiSuggested"] == 1 + assert insert_frames[0]["data"]["isApproved"] == 0 + + +@pytest.mark.asyncio +async def test_run_local_agent_file_read_timeout(): + """run_local_agent marks run as partial/error when device stops sending files.""" + config = _make_local_config() + run_log = _make_run_log(config.id) + mgr = _make_manager() + + async def _mock_send(uid: str, frame: dict) -> None: + # Don't put anything in the queue — simulate stalled device. + pass + + mgr.send_frame = _mock_send # type: ignore[method-assign] + + with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \ + patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize: + await run_local_agent(_FREE_UID, config, run_log, mgr) + + mock_finalize.assert_called_once() + _args, kwargs = mock_finalize.call_args + assert kwargs["status"] == "error" # No items created, so error (not partial). + assert any("timed out" in e.lower() for e in kwargs["errors"]) + + +@pytest.mark.asyncio +async def test_run_local_agent_llm_extraction_error(): + """LLM errors per-file are recorded; run continues for remaining files.""" + config = _make_local_config() + run_log = _make_run_log(config.id) + mgr = _make_manager() + + file_frame = { + "type": "agent_data", + "run_id": run_log.id, + "files": [ + {"path": "/file1.eml", "content": "Email one."}, + {"path": "/file2.eml", "content": "Email two."}, + ], + } + + async def _mock_send(uid: str, frame: dict) -> None: + if frame.get("type") == "agent_run": + q = mgr.get_agent_data_queue(uid, frame["run_id"]) + await q.put(file_frame) + await q.put(None) # agent_complete sentinel + + mgr.send_frame = _mock_send # type: ignore[method-assign] + + mock_llm = MagicMock() + mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom")) + + with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \ + patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize: + await run_local_agent(_FREE_UID, config, run_log, mgr) + + _args, kwargs = mock_finalize.call_args + assert kwargs["status"] == "error" + assert kwargs["items_processed"] == 2 # Both files attempted. + assert kwargs["items_created"] == 0 + assert len(kwargs["errors"]) == 2 # One error per file. + + +# --------------------------------------------------------------------------- +# run_cloud_agent (stub) +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_run_cloud_agent_stub_returns_error(): + """Cloud agent stub immediately marks run as error with informative message.""" + config = _make_cloud_config() + run_log = _make_run_log(config.id, agent_type="cloud") + mgr = _make_manager() + + with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize: + await run_cloud_agent(_FREE_UID, config, run_log, mgr) + + mock_finalize.assert_called_once() + _args, kwargs = mock_finalize.call_args + assert kwargs["status"] == "error" + assert len(kwargs["errors"]) == 1 + assert "gmail" in kwargs["errors"][0].lower() + assert "3.6" in kwargs["errors"][0] + + +# --------------------------------------------------------------------------- +# trigger_pending_runs +# --------------------------------------------------------------------------- + + +@pytest.mark.asyncio +async def test_trigger_pending_runs_no_overdue(): + """If no agents are overdue trigger_pending_runs does nothing.""" + from datetime import timedelta + + config = _make_local_config() + config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago + config.schedule_cron = "0 */6 * * *" # every 6h — not due yet + + mock_db_result_local = MagicMock() + mock_db_result_local.scalars.return_value.all.return_value = [config] + + mock_db_result_cloud = MagicMock() + mock_db_result_cloud.scalars.return_value.all.return_value = [] + + mgr = _make_manager() + + with patch("app.core.agent_runner.async_session") as mock_session_factory, \ + patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + mock_ctx.execute = AsyncMock( + side_effect=[mock_db_result_local, mock_db_result_cloud] + ) + mock_session_factory.return_value = mock_ctx + + await trigger_pending_runs(_FREE_UID, "dev-001", mgr) + + mock_run.assert_not_called() + + +@pytest.mark.asyncio +async def test_trigger_pending_runs_device_id_filter(): + """Local agents are only triggered for the matching device_id.""" + # The DB query already filters by device_id, so we verify the SELECT + # includes the device_id filter by checking that a config bound to a + # different device is never dispatched. + # + # Since trigger_pending_runs queries with device_id == "dev-001", + # simulate the DB returning an empty list (as it would for a mismatch). + mock_db_result_local = MagicMock() + mock_db_result_local.scalars.return_value.all.return_value = [] # no match + + mock_db_result_cloud = MagicMock() + mock_db_result_cloud.scalars.return_value.all.return_value = [] + + mgr = _make_manager(device_id="dev-001") + + with patch("app.core.agent_runner.async_session") as mock_session_factory, \ + patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run: + mock_ctx = AsyncMock() + mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx) + mock_ctx.__aexit__ = AsyncMock(return_value=False) + mock_ctx.execute = AsyncMock( + side_effect=[mock_db_result_local, mock_db_result_cloud] + ) + mock_session_factory.return_value = mock_ctx + + await trigger_pending_runs(_FREE_UID, "dev-001", mgr) + + mock_run.assert_not_called() + + +@pytest.mark.asyncio +async def test_trigger_pending_runs_dispatches_overdue(): + """Overdue local agent triggers run_local_agent sequentially.""" + config = _make_local_config() # last_run_at=None → always overdue + + mock_db_result_local = MagicMock() + mock_db_result_local.scalars.return_value.all.return_value = [config] + + mock_db_result_cloud = MagicMock() + mock_db_result_cloud.scalars.return_value.all.return_value = [] + + mgr = _make_manager() + + call_order: list[str] = [] + + async def _mock_run_local(user_id, cfg, run_log, device_mgr): + call_order.append("run_local") + + with patch("app.core.agent_runner.async_session") as mock_session_factory, \ + patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local): + # First call: query configs. Subsequent calls: create run_log. + mock_query_ctx = AsyncMock() + mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx) + mock_query_ctx.__aexit__ = AsyncMock(return_value=False) + mock_query_ctx.execute = AsyncMock( + side_effect=[mock_db_result_local, mock_db_result_cloud] + ) + + run_log_obj = AgentRunLog( + id=str(uuid.uuid4()), + agent_id=config.id, + agent_type="local", + user_id=_FREE_UID, + status="running", + started_at=datetime.now(timezone.utc), + ) + mock_insert_ctx = AsyncMock() + mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx) + mock_insert_ctx.__aexit__ = AsyncMock(return_value=False) + mock_insert_ctx.add = MagicMock() + mock_insert_ctx.commit = AsyncMock() + mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None) + + mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx] + + await trigger_pending_runs(_FREE_UID, "dev-001", mgr) + + assert call_order == ["run_local"] + + +# --------------------------------------------------------------------------- +# Integration: POST /agents/{id}/run +# --------------------------------------------------------------------------- + + +@pytest.fixture(autouse=True) +def _override_db(db_session): + """Route all get_session calls to the test SQLite session.""" + + async def _gen(): + yield db_session + + app.dependency_overrides[get_session] = _gen + yield + app.dependency_overrides.pop(get_session, None) + + +@pytest.mark.asyncio +async def test_trigger_run_unknown_agent(client): + """POST /agents/{id}/run returns 404 for unknown agent id.""" + resp = client.post( + f"/api/v1/agents/{uuid.uuid4()}/run", + headers=auth_header("power"), + ) + assert resp.status_code == 404 + + +@pytest.mark.asyncio +async def test_trigger_run_local_agent_creates_run_log(client, db_session): + """POST /agents/{id}/run creates a run log and dispatches a background task.""" + # Create the local agent config in the DB. + config = LocalAgentConfig( + id=str(uuid.uuid4()), + user_id=TEST_USER_IDS["power"], + device_id="dev-001", + name="My Agent", + directory_paths=["/home/user/docs"], + data_types=["tasks"], + prompt_template="Extract tasks.", + file_extensions=[".txt"], + schedule_cron="0 */6 * * *", + enabled=True, + ) + db_session.add(config) + await db_session.commit() + + dispatched: list = [] + + async def _fake_run(user_id, cfg, run_log, device_mgr): + dispatched.append((user_id, cfg.id)) + + with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \ + patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \ + patch("asyncio.create_task") as mock_create_task: + resp = client.post( + f"/api/v1/agents/{config.id}/run", + headers=auth_header("power"), + ) + + assert resp.status_code == 202 + data = resp.json() + assert data["agent_id"] == config.id + assert data["status"] == "running" + assert data["agent_type"] == "local" + + # Verify create_task was called (dispatching background run). + mock_create_task.assert_called_once()