step 3.4 complete: agent run orchestrator — local/cloud runner + trigger_pending_runs + 23 tests
This commit is contained in:
@@ -375,7 +375,7 @@ Cloud Agent:
|
||||
- **Outcome:** Backend maintains persistent WS connections to Electron devices for agent triggers.
|
||||
|
||||
### Step 3.4 — Agent run orchestrator
|
||||
- [ ] Create `app/core/agent_runner.py`:
|
||||
- [x] Create `app/core/agent_runner.py`:
|
||||
- `async run_local_agent(user_id, config: LocalAgentConfig, device_mgr: DeviceConnectionManager)`:
|
||||
1. Check device is online with matching `device_id` → abort if offline
|
||||
2. Create `AgentRunLog` with `status=running`
|
||||
@@ -404,8 +404,12 @@ Cloud Agent:
|
||||
- For cloud agents: triggers regardless of device (any connected device can receive results)
|
||||
- Executes runs sequentially (one at a time to avoid overwhelming the WS)
|
||||
- Error handling: on any failure, update `AgentRunLog` with `status=error` + error details
|
||||
- **Files:** `app/core/agent_runner.py`
|
||||
- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls).
|
||||
- [x] Wire `POST /agents/{id}/run` endpoint to dispatch background task via `asyncio.create_task()`
|
||||
- [x] Replace `_trigger_pending_runs_stub` in `device_ws.py` with real `trigger_pending_runs` call
|
||||
- [x] Add `croniter>=3.0.0` to `requirements.txt`
|
||||
- [x] 23 unit + integration tests covering all code paths
|
||||
- **Files:** `app/core/agent_runner.py`, `app/api/routes/agents.py`, `app/api/routes/device_ws.py`, `requirements.txt`, `tests/test_agent_runner.py`
|
||||
- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls — stub until Step 3.6).
|
||||
|
||||
### Step 3.5 — Chatbot Journey endpoint
|
||||
- [ ] Create `app/api/routes/agent_setup.py`:
|
||||
|
||||
@@ -16,6 +16,7 @@ Endpoints:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
@@ -26,6 +27,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.billing.tier_manager import FEATURES
|
||||
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
||||
from app.core.device_manager import device_manager
|
||||
from app.db import get_session
|
||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||
from app.schemas import (
|
||||
@@ -399,14 +402,19 @@ async def trigger_agent_run(
|
||||
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
||||
"""
|
||||
# Determine agent type by trying local first, then cloud.
|
||||
agent_type: str
|
||||
# Keep the full config object so we can pass it to the agent runner.
|
||||
local_config: LocalAgentConfig | None = None
|
||||
cloud_config: CloudAgentConfig | None = None
|
||||
|
||||
local_result = await db.execute(
|
||||
select(LocalAgentConfig).where(
|
||||
LocalAgentConfig.id == agent_id,
|
||||
LocalAgentConfig.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
if local_result.scalar_one_or_none() is not None:
|
||||
local_config = local_result.scalar_one_or_none()
|
||||
|
||||
if local_config is not None:
|
||||
agent_type = "local"
|
||||
else:
|
||||
cloud_result = await db.execute(
|
||||
@@ -415,7 +423,8 @@ async def trigger_agent_run(
|
||||
CloudAgentConfig.user_id == current_user.id,
|
||||
)
|
||||
)
|
||||
if cloud_result.scalar_one_or_none() is not None:
|
||||
cloud_config = cloud_result.scalar_one_or_none()
|
||||
if cloud_config is not None:
|
||||
agent_type = "cloud"
|
||||
else:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||
@@ -429,4 +438,15 @@ async def trigger_agent_run(
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
|
||||
# Dispatch the run as a background task — returns 202 immediately.
|
||||
if agent_type == "local" and local_config is not None:
|
||||
asyncio.create_task(
|
||||
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
||||
)
|
||||
elif agent_type == "cloud" and cloud_config is not None:
|
||||
asyncio.create_task(
|
||||
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
||||
)
|
||||
|
||||
return _to_run_log_response(run_log)
|
||||
|
||||
@@ -39,6 +39,7 @@ from jose import JWTError, jwt
|
||||
from sqlalchemy import select, update
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.core.agent_runner import trigger_pending_runs
|
||||
from app.core.device_manager import device_manager
|
||||
from app.db import async_session
|
||||
from app.models import AgentRunLog
|
||||
@@ -100,8 +101,8 @@ async def device_ws(websocket: WebSocket) -> None:
|
||||
agent_ids,
|
||||
)
|
||||
|
||||
# Step 3.4 will replace this stub with a real call to agent_runner.
|
||||
asyncio.create_task(_trigger_pending_runs_stub(user_id, device_id))
|
||||
# Trigger any overdue agent runs now that the device is connected.
|
||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||
|
||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||
try:
|
||||
@@ -217,10 +218,4 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
# ── Pending-run trigger stub (Step 3.4 will replace) ─────────────────
|
||||
|
||||
async def _trigger_pending_runs_stub(user_id: str, device_id: str) -> None:
|
||||
"""No-op stub. Step 3.4 wires this to agent_runner.trigger_pending_runs."""
|
||||
logger.debug(
|
||||
"device_ws: _trigger_pending_runs stub user=%s device=%s", user_id, device_id
|
||||
)
|
||||
|
||||
534
app/core/agent_runner.py
Normal file
534
app/core/agent_runner.py
Normal file
@@ -0,0 +1,534 @@
|
||||
"""Agent run orchestrator.
|
||||
|
||||
Drives two agent types:
|
||||
|
||||
* **Local directory agent** — sends an ``agent_run`` frame to the connected
|
||||
Electron device, waits for the device to stream back file contents via
|
||||
``agent_data`` frames, then calls the LLM to extract structured items from
|
||||
each file and pushes inserts to Electron via tool-call round-trips.
|
||||
|
||||
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
||||
Teams, Outlook) and pushes extracted items to Electron. **This path is
|
||||
a stub** — provider integrations are implemented in Step 3.6.
|
||||
|
||||
Usage
|
||||
-----
|
||||
Background tasks are spawned with ``asyncio.create_task()``::
|
||||
|
||||
asyncio.create_task(run_local_agent(user_id, config, run_log, device_manager))
|
||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||
|
||||
The ``trigger_pending_runs`` function is called by the device WS endpoint
|
||||
when Electron sends ``device_hello``, so any overdue runs fire immediately
|
||||
when the device reconnects.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
|
||||
from croniter import croniter
|
||||
from langchain_core.messages import HumanMessage, SystemMessage
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
from app.core.llm import get_llm
|
||||
from app.db import async_session
|
||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||
|
||||
# Max seconds to wait for Electron to finish streaming file data.
|
||||
_FILE_READ_TIMEOUT: int = 120
|
||||
# Max seconds to wait for Electron to acknowledge a single tool-call insert.
|
||||
_INSERT_TIMEOUT: int = 30
|
||||
|
||||
# ── Allowed tables & extraction schema hints ───────────────────────────────
|
||||
|
||||
_ALLOWED_TABLES: frozenset[str] = frozenset(
|
||||
{"tasks", "notes", "checkpoints", "projects", "taskComments"}
|
||||
)
|
||||
|
||||
# Field descriptions fed to the extraction LLM as concise schema references.
|
||||
_TABLE_SCHEMAS: dict[str, str] = {
|
||||
"tasks": (
|
||||
"title (str, required), description (str), "
|
||||
"status (todo|in_progress|done, default todo), "
|
||||
"priority (high|medium|low, default medium), "
|
||||
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
|
||||
),
|
||||
"notes": "title (str, required), content (str, markdown), projectId (str)",
|
||||
"checkpoints": (
|
||||
"title (str, required), projectId (str, required), date (ms timestamp int)"
|
||||
),
|
||||
"projects": "name (str, required), clientId (str)",
|
||||
"taskComments": "taskId (str, required), author (str), content (str, required)",
|
||||
}
|
||||
|
||||
_EXTRACTION_SYSTEM_PROMPT = """\
|
||||
You are a data extraction assistant for a freelance project management tool.
|
||||
Given a document, extract structured records matching the user's instructions.
|
||||
|
||||
Output a JSON array (no markdown fences, no explanation) of objects shaped:
|
||||
[{{"table": "<table_name>", "data": {{...fields}}}}, ...]
|
||||
|
||||
Allowed table names and their fields:
|
||||
{table_schemas}
|
||||
|
||||
Rules:
|
||||
- Only extract tables listed in the "data_types" instructions.
|
||||
- Use camelCase field names exactly as shown above.
|
||||
- Omit optional fields you cannot determine; do not invent data.
|
||||
- Never include id, createdAt, updatedAt, isAiSuggested, or isApproved.
|
||||
- If nothing relevant is found, return an empty JSON array: []
|
||||
- Return ONLY the JSON array.
|
||||
"""
|
||||
|
||||
|
||||
# ── Cron helper ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
|
||||
"""Return ``True`` if the next scheduled run time has already passed.
|
||||
|
||||
Always validates the cron expression first — an invalid expression returns
|
||||
``False`` (fail-safe: never trigger an unparseable schedule).
|
||||
"""
|
||||
try:
|
||||
now = datetime.now(timezone.utc)
|
||||
if last_run_at is None:
|
||||
# Validate the expression before deciding this is overdue.
|
||||
croniter(schedule_cron, now)
|
||||
return True
|
||||
ts = last_run_at
|
||||
if ts.tzinfo is None:
|
||||
ts = ts.replace(tzinfo=timezone.utc)
|
||||
cron = croniter(schedule_cron, ts)
|
||||
next_run: datetime = cron.get_next(datetime)
|
||||
return now >= next_run
|
||||
except Exception as exc:
|
||||
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc)
|
||||
return False # Fail-safe: don't trigger if expression is invalid.
|
||||
|
||||
|
||||
# ── LLM extraction ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _extract_items_from_content(
|
||||
prompt_template: str,
|
||||
file_content: str,
|
||||
data_types: list[str],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Call the LLM to extract structured records from *file_content*.
|
||||
|
||||
Returns a validated list of ``{table: str, data: dict}`` objects.
|
||||
Items referencing tables not in *data_types* are discarded.
|
||||
"""
|
||||
allowed = [t for t in data_types if t in _ALLOWED_TABLES]
|
||||
if not allowed:
|
||||
return []
|
||||
|
||||
schema_text = "\n".join(
|
||||
f" {table}: {_TABLE_SCHEMAS.get(table, '(unknown)')}" for table in allowed
|
||||
)
|
||||
system_prompt = _EXTRACTION_SYSTEM_PROMPT.format(table_schemas=schema_text)
|
||||
user_prompt = (
|
||||
f"User instructions: {prompt_template}\n\n"
|
||||
f"Extract these record types: {', '.join(allowed)}\n\n"
|
||||
f"Document:\n{file_content[:8000]}"
|
||||
)
|
||||
|
||||
llm = get_llm()
|
||||
raw = ""
|
||||
try:
|
||||
response = await llm.ainvoke(
|
||||
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
|
||||
)
|
||||
raw = str(response.content).strip()
|
||||
items: list[dict] = json.loads(raw)
|
||||
if not isinstance(items, list):
|
||||
raise ValueError("LLM response is not a JSON array")
|
||||
except json.JSONDecodeError as exc:
|
||||
logger.warning(
|
||||
"agent_runner: LLM extraction returned invalid JSON: %s — snippet: %.200r",
|
||||
exc,
|
||||
raw,
|
||||
)
|
||||
return []
|
||||
# Other exceptions (LLM API errors, network errors) propagate to the
|
||||
# caller (run_local_agent) which records them per-file in the run log.
|
||||
|
||||
validated: list[dict[str, Any]] = []
|
||||
for item in items:
|
||||
table = item.get("table")
|
||||
data = item.get("data")
|
||||
if not isinstance(table, str) or table not in allowed:
|
||||
continue
|
||||
if not isinstance(data, dict) or not data:
|
||||
continue
|
||||
# Strip any server-generated or forbidden fields.
|
||||
for _field in ("id", "createdAt", "updatedAt", "isAiSuggested", "isApproved"):
|
||||
data.pop(_field, None)
|
||||
validated.append({"table": table, "data": data})
|
||||
return validated
|
||||
|
||||
|
||||
# ── Tool-call insert helper ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _send_insert_to_client(
|
||||
user_id: str,
|
||||
table: str,
|
||||
data: dict[str, Any],
|
||||
device_mgr: DeviceConnectionManager,
|
||||
) -> dict[str, Any]:
|
||||
"""Send an ``insert`` tool_call frame to Electron and await the tool_result.
|
||||
|
||||
All inserts include ``isAiSuggested=1, isApproved=0`` so the user can
|
||||
review AI-produced records before they are treated as confirmed.
|
||||
|
||||
Raises ``asyncio.TimeoutError`` if Electron does not respond within
|
||||
``_INSERT_TIMEOUT`` seconds. Raises ``RuntimeError`` if the device
|
||||
disconnects before the frame can be sent.
|
||||
"""
|
||||
call_id = str(uuid.uuid4())
|
||||
payload: dict[str, Any] = {
|
||||
"type": "tool_call",
|
||||
"id": call_id,
|
||||
"action": "insert",
|
||||
"table": table,
|
||||
"data": {**data, "isAiSuggested": 1, "isApproved": 0},
|
||||
}
|
||||
fut = device_mgr.create_pending_call(user_id, call_id)
|
||||
await device_mgr.send_frame(user_id, payload)
|
||||
return await asyncio.wait_for(fut, timeout=_INSERT_TIMEOUT)
|
||||
|
||||
|
||||
# ── Local agent runner ──────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def run_local_agent(
|
||||
user_id: str,
|
||||
config: LocalAgentConfig,
|
||||
run_log: AgentRunLog,
|
||||
device_mgr: DeviceConnectionManager,
|
||||
) -> None:
|
||||
"""Execute a local directory agent run end-to-end.
|
||||
|
||||
Steps:
|
||||
|
||||
1. Verify the device identified by ``config.device_id`` is currently online.
|
||||
2. Pre-create the agent_data queue so no incoming frames are lost.
|
||||
3. Send ``agent_run`` frame to Electron (paths, extensions, prompt, data_types).
|
||||
4. Consume ``agent_data`` frames until the ``None`` sentinel from
|
||||
``agent_complete``.
|
||||
5. For each received file call the LLM to extract ``{table, data}`` items.
|
||||
6. Push each item to Electron as an ``insert`` tool-call; include
|
||||
``isAiSuggested=1, isApproved=0`` so users can review AI suggestions.
|
||||
7. Persist the run outcome (status, counts, errors) and update
|
||||
``config.last_run_at``.
|
||||
"""
|
||||
run_id = run_log.id
|
||||
|
||||
# ── 1. Device online check ─────────────────────────────────────────
|
||||
if not device_mgr.is_online(user_id, config.device_id):
|
||||
logger.info(
|
||||
"agent_runner: skip run=%s — device %r offline for user=%s",
|
||||
run_id,
|
||||
config.device_id,
|
||||
user_id,
|
||||
)
|
||||
await _finalize_run(
|
||||
run_log,
|
||||
status="error",
|
||||
errors=[f"Device {config.device_id!r} is not connected"],
|
||||
)
|
||||
return
|
||||
|
||||
# ── 2. Pre-create agent_data queue ────────────────────────────────
|
||||
try:
|
||||
device_mgr.get_agent_data_queue(user_id, run_id)
|
||||
except RuntimeError:
|
||||
await _finalize_run(
|
||||
run_log,
|
||||
status="error",
|
||||
errors=["Device disconnected before agent run could start"],
|
||||
)
|
||||
return
|
||||
|
||||
# ── 3. Send agent_run frame ────────────────────────────────────────
|
||||
frame: dict[str, Any] = {
|
||||
"type": "agent_run",
|
||||
"run_id": run_id,
|
||||
"agent_id": config.id,
|
||||
"config": {
|
||||
"paths": config.directory_paths,
|
||||
"file_extensions": config.file_extensions,
|
||||
"prompt_template": config.prompt_template,
|
||||
"data_types": config.data_types,
|
||||
},
|
||||
}
|
||||
try:
|
||||
await device_mgr.send_frame(user_id, frame)
|
||||
except RuntimeError as exc:
|
||||
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
||||
await _finalize_run(
|
||||
run_log,
|
||||
status="error",
|
||||
errors=[f"Failed to send agent_run frame: {exc}"],
|
||||
)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"agent_runner: sent agent_run run=%s agent=%s user=%s",
|
||||
run_id,
|
||||
config.id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
# ── 4. Consume agent_data frames ──────────────────────────────────
|
||||
files: list[dict[str, Any]] = []
|
||||
errors: list[str] = []
|
||||
|
||||
try:
|
||||
queue = device_mgr.get_agent_data_queue(user_id, run_id)
|
||||
deadline = asyncio.get_event_loop().time() + _FILE_READ_TIMEOUT
|
||||
while True:
|
||||
remaining = deadline - asyncio.get_event_loop().time()
|
||||
if remaining <= 0:
|
||||
errors.append("Timed out waiting for file data from device")
|
||||
break
|
||||
try:
|
||||
frame_data = await asyncio.wait_for(queue.get(), timeout=remaining)
|
||||
except asyncio.TimeoutError:
|
||||
errors.append("Timed out waiting for file data from device")
|
||||
break
|
||||
if frame_data is None:
|
||||
# Sentinel from agent_complete — stream is done.
|
||||
break
|
||||
files.extend(frame_data.get("files", []))
|
||||
except RuntimeError as exc:
|
||||
errors.append(f"Queue error reading agent data: {exc}")
|
||||
|
||||
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
||||
items_processed = 0
|
||||
items_created = 0
|
||||
|
||||
for file_info in files:
|
||||
file_path: str = file_info.get("path", "<unknown>")
|
||||
content: str = file_info.get("content", "")
|
||||
if not content:
|
||||
continue
|
||||
items_processed += 1
|
||||
try:
|
||||
extracted = await _extract_items_from_content(
|
||||
config.prompt_template, content, config.data_types
|
||||
)
|
||||
except Exception as exc:
|
||||
errors.append(f"LLM extraction error for {file_path!r}: {exc}")
|
||||
continue
|
||||
|
||||
for item in extracted:
|
||||
try:
|
||||
result = await _send_insert_to_client(
|
||||
user_id, item["table"], item["data"], device_mgr
|
||||
)
|
||||
if result.get("error"):
|
||||
errors.append(
|
||||
f"Insert failed ({item['table']}, {file_path!r}): {result['error']}"
|
||||
)
|
||||
else:
|
||||
items_created += 1
|
||||
except asyncio.TimeoutError:
|
||||
errors.append(
|
||||
f"Timed out awaiting insert ack ({item['table']}, {file_path!r})"
|
||||
)
|
||||
except RuntimeError as exc:
|
||||
errors.append(f"Insert error ({item['table']}, {file_path!r}): {exc}")
|
||||
|
||||
# ── 7. Finalise ────────────────────────────────────────────────────
|
||||
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
||||
|
||||
if errors and items_created == 0:
|
||||
final_status = "error"
|
||||
elif errors:
|
||||
final_status = "partial"
|
||||
else:
|
||||
final_status = "success"
|
||||
|
||||
await _finalize_run(
|
||||
run_log,
|
||||
status=final_status,
|
||||
items_processed=items_processed,
|
||||
items_created=items_created,
|
||||
errors=errors,
|
||||
update_config_last_run=True,
|
||||
config_id=config.id,
|
||||
config_type="local",
|
||||
)
|
||||
logger.info(
|
||||
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
||||
run_id,
|
||||
final_status,
|
||||
items_processed,
|
||||
items_created,
|
||||
len(errors),
|
||||
)
|
||||
|
||||
|
||||
# ── Cloud agent runner (stub) ───────────────────────────────────────────────
|
||||
|
||||
|
||||
async def run_cloud_agent(
|
||||
user_id: str,
|
||||
config: CloudAgentConfig,
|
||||
run_log: AgentRunLog,
|
||||
device_mgr: DeviceConnectionManager,
|
||||
) -> None:
|
||||
"""Execute a cloud connector agent run.
|
||||
|
||||
.. note::
|
||||
This is a **stub** — provider integrations (Gmail, Teams, Outlook)
|
||||
are implemented in Step 3.6. The run is immediately marked as an
|
||||
error with an informative message.
|
||||
"""
|
||||
logger.info(
|
||||
"agent_runner: cloud agent %s (provider=%s) for user=%s — pending Step 3.6",
|
||||
config.id,
|
||||
config.provider,
|
||||
user_id,
|
||||
)
|
||||
await _finalize_run(
|
||||
run_log,
|
||||
status="error",
|
||||
errors=[
|
||||
f"Cloud provider integrations for '{config.provider}' are not yet "
|
||||
"implemented. This feature arrives in Step 3.6."
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
# ── Pending-run trigger ─────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def trigger_pending_runs(
|
||||
user_id: str,
|
||||
device_id: str,
|
||||
device_mgr: DeviceConnectionManager,
|
||||
) -> None:
|
||||
"""Dispatch any overdue agent runs after an Electron device connects.
|
||||
|
||||
Called as a background task from the device WS endpoint on ``device_hello``.
|
||||
|
||||
Scheduling rules:
|
||||
|
||||
* **Local agents**: only triggered when ``config.device_id == device_id``.
|
||||
* **Cloud agents**: triggered on any connected device (no device binding).
|
||||
* Runs execute **sequentially** to avoid flooding the WS connection.
|
||||
"""
|
||||
logger.info(
|
||||
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id
|
||||
)
|
||||
async with async_session() as db:
|
||||
local_result = await db.execute(
|
||||
select(LocalAgentConfig).where(
|
||||
LocalAgentConfig.user_id == user_id,
|
||||
LocalAgentConfig.enabled == True, # noqa: E712
|
||||
LocalAgentConfig.device_id == device_id,
|
||||
)
|
||||
)
|
||||
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
|
||||
|
||||
cloud_result = await db.execute(
|
||||
select(CloudAgentConfig).where(
|
||||
CloudAgentConfig.user_id == user_id,
|
||||
CloudAgentConfig.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
|
||||
|
||||
# Build ordered list of overdue (type, config) pairs.
|
||||
pending: list[tuple[str, Any]] = []
|
||||
for cfg in local_configs:
|
||||
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||
pending.append(("local", cfg))
|
||||
for cfg in cloud_configs:
|
||||
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||
pending.append(("cloud", cfg))
|
||||
|
||||
if not pending:
|
||||
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
|
||||
return
|
||||
|
||||
logger.info(
|
||||
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
|
||||
)
|
||||
|
||||
for agent_type, cfg in pending:
|
||||
# Create a fresh run log for this scheduled dispatch.
|
||||
run_log = AgentRunLog(
|
||||
agent_id=cfg.id,
|
||||
agent_type=agent_type,
|
||||
user_id=user_id,
|
||||
status="running",
|
||||
)
|
||||
async with async_session() as db:
|
||||
db.add(run_log)
|
||||
await db.commit()
|
||||
await db.refresh(run_log)
|
||||
|
||||
if agent_type == "local":
|
||||
await run_local_agent(user_id, cfg, run_log, device_mgr)
|
||||
else:
|
||||
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
|
||||
|
||||
|
||||
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _finalize_run(
|
||||
run_log: AgentRunLog,
|
||||
*,
|
||||
status: str,
|
||||
items_processed: int = 0,
|
||||
items_created: int = 0,
|
||||
errors: list[str] | None = None,
|
||||
update_config_last_run: bool = False,
|
||||
config_id: str | None = None,
|
||||
config_type: str | None = None,
|
||||
) -> None:
|
||||
"""Persist the run outcome and optionally update ``LocalAgentConfig.last_run_at``.
|
||||
|
||||
Uses a fresh DB session so this is safe to call from background tasks
|
||||
after the original request session has closed.
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
try:
|
||||
async with async_session() as db:
|
||||
managed = await db.merge(run_log)
|
||||
managed.status = status
|
||||
managed.items_processed = items_processed
|
||||
managed.items_created = items_created
|
||||
managed.errors = errors or []
|
||||
managed.completed_at = now
|
||||
|
||||
if update_config_last_run and config_id and config_type == "local":
|
||||
cfg_result = await db.execute(
|
||||
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||
)
|
||||
cfg = cfg_result.scalar_one_or_none()
|
||||
if cfg:
|
||||
cfg.last_run_at = now
|
||||
|
||||
await db.commit()
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc
|
||||
)
|
||||
@@ -24,4 +24,5 @@ aiosqlite>=0.20.0
|
||||
moto[s3]>=5.0.0
|
||||
pinecone>=5.0.0
|
||||
qdrant-client>=1.7.0
|
||||
croniter>=3.0.0
|
||||
ruff>=0.8.0
|
||||
|
||||
660
tests/test_agent_runner.py
Normal file
660
tests/test_agent_runner.py
Normal file
@@ -0,0 +1,660 @@
|
||||
"""Tests for Step 3.4: agent_runner module.
|
||||
|
||||
Coverage:
|
||||
Unit:
|
||||
- _is_overdue — cron schedule overdue detection
|
||||
- _extract_items_from_content — LLM extraction + JSON parsing + validation
|
||||
- _send_insert_to_client — tool_call frame construction + timeout
|
||||
- run_local_agent — end-to-end local agent happy path
|
||||
- run_local_agent — device offline path
|
||||
- run_local_agent — file-read timeout path
|
||||
- run_local_agent — LLM extraction error path
|
||||
- run_cloud_agent — stub returns error immediately
|
||||
- trigger_pending_runs — overdue local + cloud dispatched
|
||||
- trigger_pending_runs — non-overdue skipped
|
||||
- trigger_pending_runs — device_id filter for local agents
|
||||
|
||||
Integration:
|
||||
- POST /agents/{id}/run — 404 on unknown agent
|
||||
- POST /agents/{id}/run — creates run log + dispatches background task
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from app.core.agent_runner import (
|
||||
_extract_items_from_content,
|
||||
_is_overdue,
|
||||
_send_insert_to_client,
|
||||
run_cloud_agent,
|
||||
run_local_agent,
|
||||
trigger_pending_runs,
|
||||
)
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||
from tests.conftest import TEST_USER_IDS, auth_header
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FREE_UID = TEST_USER_IDS["free"]
|
||||
_PRO_UID = TEST_USER_IDS["pro"]
|
||||
|
||||
|
||||
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
|
||||
return LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
name="Test Local Agent",
|
||||
directory_paths=["/home/user/emails"],
|
||||
data_types=["tasks", "notes"],
|
||||
prompt_template="Extract tasks and notes from this document.",
|
||||
file_extensions=[".txt", ".eml"],
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
last_run_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
|
||||
return CloudAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
provider="gmail",
|
||||
name="Test Gmail Agent",
|
||||
data_types=["tasks"],
|
||||
prompt_template="Extract tasks from email.",
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
last_run_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
|
||||
return AgentRunLog(
|
||||
id=str(uuid.uuid4()),
|
||||
agent_id=agent_id,
|
||||
agent_type=agent_type,
|
||||
user_id=user_id,
|
||||
status="running",
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
|
||||
mgr = DeviceConnectionManager()
|
||||
ws = MagicMock()
|
||||
ws.send_text = AsyncMock()
|
||||
mgr.register(user_id, device_id, ws)
|
||||
return mgr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_overdue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_is_overdue_never_run():
|
||||
"""An agent that has never run is always overdue."""
|
||||
assert _is_overdue("0 */6 * * *", None) is True
|
||||
|
||||
|
||||
def test_is_overdue_very_recently_run():
|
||||
"""An agent that just ran is not overdue."""
|
||||
last = datetime.now(timezone.utc)
|
||||
assert _is_overdue("0 */6 * * *", last) is False
|
||||
|
||||
|
||||
def test_is_overdue_long_ago():
|
||||
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
|
||||
from datetime import timedelta
|
||||
last = datetime.now(timezone.utc) - timedelta(days=2)
|
||||
assert _is_overdue("0 */6 * * *", last) is True
|
||||
|
||||
|
||||
def test_is_overdue_invalid_cron_returns_false():
|
||||
"""Unparseable cron must not raise and should return False (fail-safe)."""
|
||||
assert _is_overdue("not a cron", None) is False
|
||||
|
||||
|
||||
def test_is_overdue_naive_datetime():
|
||||
"""Naive datetime objects are handled without raising."""
|
||||
from datetime import timedelta
|
||||
last = datetime.utcnow() - timedelta(days=1) # naive
|
||||
# Should not raise.
|
||||
result = _is_overdue("0 */6 * * *", last)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_items_from_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_happy_path():
|
||||
"""LLM returns valid JSON array; items with allowed tables are returned."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
|
||||
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content(
|
||||
"Extract tasks and notes.",
|
||||
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
|
||||
["tasks", "notes"],
|
||||
)
|
||||
|
||||
assert len(items) == 2
|
||||
assert items[0]["table"] == "tasks"
|
||||
assert items[0]["data"]["title"] == "Buy milk"
|
||||
assert items[1]["table"] == "notes"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_strips_forbidden_fields():
|
||||
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{
|
||||
"table": "tasks",
|
||||
"data": {
|
||||
"title": "Review PR",
|
||||
"id": "should-be-removed",
|
||||
"createdAt": 99999,
|
||||
"isAiSuggested": 0,
|
||||
"isApproved": 1,
|
||||
},
|
||||
}
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
|
||||
|
||||
assert len(items) == 1
|
||||
data = items[0]["data"]
|
||||
assert "id" not in data
|
||||
assert "createdAt" not in data
|
||||
assert "isAiSuggested" not in data
|
||||
assert "isApproved" not in data
|
||||
assert data["title"] == "Review PR"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_invalid_json_returns_empty():
|
||||
"""LLM returning invalid JSON must return empty list without raising."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "Sorry, I cannot extract anything."
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||
|
||||
assert items == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_disallowed_table_filtered():
|
||||
"""Items whose table is not in data_types are discarded."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{"table": "tasks", "data": {"title": "Valid task"}},
|
||||
{"table": "projects", "data": {"name": "Should be filtered"}},
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
# Only "tasks" is in data_types — "projects" should be filtered.
|
||||
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["table"] == "tasks"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_empty_data_types_returns_empty():
|
||||
"""If no allowed data_types match, skip LLM call and return immediately."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock()
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content("Extract.", "content", [])
|
||||
|
||||
mock_llm.ainvoke.assert_not_called()
|
||||
assert items == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_llm_error_propagates():
|
||||
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
with pytest.raises(RuntimeError, match="API unavailable"):
|
||||
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_insert_to_client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_insert_to_client_happy_path():
|
||||
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
|
||||
mgr = _make_manager()
|
||||
|
||||
sent_payloads: list[dict] = []
|
||||
original_send = mgr.send_frame
|
||||
|
||||
async def _capture_send(uid: str, frame: dict) -> None:
|
||||
sent_payloads.append(frame)
|
||||
# Immediately resolve the pending call with a success result.
|
||||
call_id = frame["id"]
|
||||
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
|
||||
|
||||
mgr.send_frame = _capture_send # type: ignore[method-assign]
|
||||
|
||||
result = await _send_insert_to_client(
|
||||
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
|
||||
)
|
||||
|
||||
assert len(sent_payloads) == 1
|
||||
payload = sent_payloads[0]
|
||||
assert payload["action"] == "insert"
|
||||
assert payload["table"] == "tasks"
|
||||
assert payload["data"]["title"] == "Buy milk"
|
||||
assert payload["data"]["isAiSuggested"] == 1
|
||||
assert payload["data"]["isApproved"] == 0
|
||||
assert result["row"]["title"] == "Buy milk"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_insert_to_client_timeout():
|
||||
"""asyncio.TimeoutError is raised when Electron does not respond."""
|
||||
mgr = _make_manager()
|
||||
|
||||
async def _slow_send(uid: str, frame: dict) -> None:
|
||||
# Never resolve the pending call.
|
||||
pass
|
||||
|
||||
mgr.send_frame = _slow_send # type: ignore[method-assign]
|
||||
|
||||
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_local_agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_device_offline():
|
||||
"""run_local_agent marks run as error when device is offline."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = DeviceConnectionManager() # Empty — no device registered.
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("not connected" in e for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_happy_path():
|
||||
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = _make_manager()
|
||||
|
||||
# Build a fake agent_data frame (will be queued after send).
|
||||
file_frame = {
|
||||
"type": "agent_data",
|
||||
"run_id": run_log.id,
|
||||
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
|
||||
}
|
||||
agent_complete_frame = None # sentinel
|
||||
|
||||
sent_frames: list[dict] = []
|
||||
|
||||
async def _mock_send(uid: str, frame: dict) -> None:
|
||||
sent_frames.append(frame)
|
||||
if frame.get("type") == "agent_run":
|
||||
# Simulate Electron responding with file data then agent_complete.
|
||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||
await q.put(file_frame)
|
||||
await q.put(agent_complete_frame)
|
||||
elif frame.get("type") == "tool_call":
|
||||
# Resolve the pending insert immediately.
|
||||
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
|
||||
|
||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "success"
|
||||
assert kwargs["items_processed"] == 1
|
||||
assert kwargs["items_created"] == 1
|
||||
assert kwargs["errors"] == []
|
||||
assert kwargs["update_config_last_run"] is True
|
||||
|
||||
# Verify agent_run frame was sent.
|
||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||
assert len(agent_run_frames) == 1
|
||||
assert agent_run_frames[0]["agent_id"] == config.id
|
||||
assert "paths" in agent_run_frames[0]["config"]
|
||||
|
||||
# Verify insert frame was sent with AI flags.
|
||||
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
|
||||
assert len(insert_frames) == 1
|
||||
assert insert_frames[0]["data"]["isAiSuggested"] == 1
|
||||
assert insert_frames[0]["data"]["isApproved"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_file_read_timeout():
|
||||
"""run_local_agent marks run as partial/error when device stops sending files."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = _make_manager()
|
||||
|
||||
async def _mock_send(uid: str, frame: dict) -> None:
|
||||
# Don't put anything in the queue — simulate stalled device.
|
||||
pass
|
||||
|
||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||
|
||||
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error" # No items created, so error (not partial).
|
||||
assert any("timed out" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_llm_extraction_error():
|
||||
"""LLM errors per-file are recorded; run continues for remaining files."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = _make_manager()
|
||||
|
||||
file_frame = {
|
||||
"type": "agent_data",
|
||||
"run_id": run_log.id,
|
||||
"files": [
|
||||
{"path": "/file1.eml", "content": "Email one."},
|
||||
{"path": "/file2.eml", "content": "Email two."},
|
||||
],
|
||||
}
|
||||
|
||||
async def _mock_send(uid: str, frame: dict) -> None:
|
||||
if frame.get("type") == "agent_run":
|
||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||
await q.put(file_frame)
|
||||
await q.put(None) # agent_complete sentinel
|
||||
|
||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert kwargs["items_processed"] == 2 # Both files attempted.
|
||||
assert kwargs["items_created"] == 0
|
||||
assert len(kwargs["errors"]) == 2 # One error per file.
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_cloud_agent (stub)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_stub_returns_error():
|
||||
"""Cloud agent stub immediately marks run as error with informative message."""
|
||||
config = _make_cloud_config()
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert len(kwargs["errors"]) == 1
|
||||
assert "gmail" in kwargs["errors"][0].lower()
|
||||
assert "3.6" in kwargs["errors"][0]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# trigger_pending_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_no_overdue():
|
||||
"""If no agents are overdue trigger_pending_runs does nothing."""
|
||||
from datetime import timedelta
|
||||
|
||||
config = _make_local_config()
|
||||
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
||||
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
||||
|
||||
mock_db_result_local = MagicMock()
|
||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||
|
||||
mock_db_result_cloud = MagicMock()
|
||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.execute = AsyncMock(
|
||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||
)
|
||||
mock_session_factory.return_value = mock_ctx
|
||||
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_device_id_filter():
|
||||
"""Local agents are only triggered for the matching device_id."""
|
||||
# The DB query already filters by device_id, so we verify the SELECT
|
||||
# includes the device_id filter by checking that a config bound to a
|
||||
# different device is never dispatched.
|
||||
#
|
||||
# Since trigger_pending_runs queries with device_id == "dev-001",
|
||||
# simulate the DB returning an empty list (as it would for a mismatch).
|
||||
mock_db_result_local = MagicMock()
|
||||
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
||||
|
||||
mock_db_result_cloud = MagicMock()
|
||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||
|
||||
mgr = _make_manager(device_id="dev-001")
|
||||
|
||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
mock_ctx = AsyncMock()
|
||||
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_ctx.execute = AsyncMock(
|
||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||
)
|
||||
mock_session_factory.return_value = mock_ctx
|
||||
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_dispatches_overdue():
|
||||
"""Overdue local agent triggers run_local_agent sequentially."""
|
||||
config = _make_local_config() # last_run_at=None → always overdue
|
||||
|
||||
mock_db_result_local = MagicMock()
|
||||
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||
|
||||
mock_db_result_cloud = MagicMock()
|
||||
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||
|
||||
mgr = _make_manager()
|
||||
|
||||
call_order: list[str] = []
|
||||
|
||||
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
||||
call_order.append("run_local")
|
||||
|
||||
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
||||
# First call: query configs. Subsequent calls: create run_log.
|
||||
mock_query_ctx = AsyncMock()
|
||||
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
||||
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_query_ctx.execute = AsyncMock(
|
||||
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||
)
|
||||
|
||||
run_log_obj = AgentRunLog(
|
||||
id=str(uuid.uuid4()),
|
||||
agent_id=config.id,
|
||||
agent_type="local",
|
||||
user_id=_FREE_UID,
|
||||
status="running",
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
mock_insert_ctx = AsyncMock()
|
||||
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
||||
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_insert_ctx.add = MagicMock()
|
||||
mock_insert_ctx.commit = AsyncMock()
|
||||
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
||||
|
||||
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
||||
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
assert call_order == ["run_local"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: POST /agents/{id}/run
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
"""Route all get_session calls to the test SQLite session."""
|
||||
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_run_unknown_agent(client):
|
||||
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
||||
resp = client.post(
|
||||
f"/api/v1/agents/{uuid.uuid4()}/run",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
||||
# Create the local agent config in the DB.
|
||||
config = LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=TEST_USER_IDS["power"],
|
||||
device_id="dev-001",
|
||||
name="My Agent",
|
||||
directory_paths=["/home/user/docs"],
|
||||
data_types=["tasks"],
|
||||
prompt_template="Extract tasks.",
|
||||
file_extensions=[".txt"],
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
)
|
||||
db_session.add(config)
|
||||
await db_session.commit()
|
||||
|
||||
dispatched: list = []
|
||||
|
||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||
dispatched.append((user_id, cfg.id))
|
||||
|
||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
||||
patch("asyncio.create_task") as mock_create_task:
|
||||
resp = client.post(
|
||||
f"/api/v1/agents/{config.id}/run",
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
assert data["agent_id"] == config.id
|
||||
assert data["status"] == "running"
|
||||
assert data["agent_type"] == "local"
|
||||
|
||||
# Verify create_task was called (dispatching background run).
|
||||
mock_create_task.assert_called_once()
|
||||
Reference in New Issue
Block a user