step 3.4 complete: agent run orchestrator — local/cloud runner + trigger_pending_runs + 23 tests
This commit is contained in:
@@ -375,7 +375,7 @@ Cloud Agent:
|
|||||||
- **Outcome:** Backend maintains persistent WS connections to Electron devices for agent triggers.
|
- **Outcome:** Backend maintains persistent WS connections to Electron devices for agent triggers.
|
||||||
|
|
||||||
### Step 3.4 — Agent run orchestrator
|
### Step 3.4 — Agent run orchestrator
|
||||||
- [ ] Create `app/core/agent_runner.py`:
|
- [x] Create `app/core/agent_runner.py`:
|
||||||
- `async run_local_agent(user_id, config: LocalAgentConfig, device_mgr: DeviceConnectionManager)`:
|
- `async run_local_agent(user_id, config: LocalAgentConfig, device_mgr: DeviceConnectionManager)`:
|
||||||
1. Check device is online with matching `device_id` → abort if offline
|
1. Check device is online with matching `device_id` → abort if offline
|
||||||
2. Create `AgentRunLog` with `status=running`
|
2. Create `AgentRunLog` with `status=running`
|
||||||
@@ -404,8 +404,12 @@ Cloud Agent:
|
|||||||
- For cloud agents: triggers regardless of device (any connected device can receive results)
|
- For cloud agents: triggers regardless of device (any connected device can receive results)
|
||||||
- Executes runs sequentially (one at a time to avoid overwhelming the WS)
|
- Executes runs sequentially (one at a time to avoid overwhelming the WS)
|
||||||
- Error handling: on any failure, update `AgentRunLog` with `status=error` + error details
|
- Error handling: on any failure, update `AgentRunLog` with `status=error` + error details
|
||||||
- **Files:** `app/core/agent_runner.py`
|
- [x] Wire `POST /agents/{id}/run` endpoint to dispatch background task via `asyncio.create_task()`
|
||||||
- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls).
|
- [x] Replace `_trigger_pending_runs_stub` in `device_ws.py` with real `trigger_pending_runs` call
|
||||||
|
- [x] Add `croniter>=3.0.0` to `requirements.txt`
|
||||||
|
- [x] 23 unit + integration tests covering all code paths
|
||||||
|
- **Files:** `app/core/agent_runner.py`, `app/api/routes/agents.py`, `app/api/routes/device_ws.py`, `requirements.txt`, `tests/test_agent_runner.py`
|
||||||
|
- **Outcome:** Backend drives all agent execution — both local (via WS file request) and cloud (direct API calls — stub until Step 3.6).
|
||||||
|
|
||||||
### Step 3.5 — Chatbot Journey endpoint
|
### Step 3.5 — Chatbot Journey endpoint
|
||||||
- [ ] Create `app/api/routes/agent_setup.py`:
|
- [ ] Create `app/api/routes/agent_setup.py`:
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ Endpoints:
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
@@ -26,6 +27,8 @@ from sqlalchemy.ext.asyncio import AsyncSession
|
|||||||
|
|
||||||
from app.api.deps import get_current_user
|
from app.api.deps import get_current_user
|
||||||
from app.billing.tier_manager import FEATURES
|
from app.billing.tier_manager import FEATURES
|
||||||
|
from app.core.agent_runner import run_cloud_agent, run_local_agent
|
||||||
|
from app.core.device_manager import device_manager
|
||||||
from app.db import get_session
|
from app.db import get_session
|
||||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
from app.schemas import (
|
from app.schemas import (
|
||||||
@@ -399,14 +402,19 @@ async def trigger_agent_run(
|
|||||||
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
``DeviceConnectionManager`` and ``agent_runner`` are available.
|
||||||
"""
|
"""
|
||||||
# Determine agent type by trying local first, then cloud.
|
# Determine agent type by trying local first, then cloud.
|
||||||
agent_type: str
|
# Keep the full config object so we can pass it to the agent runner.
|
||||||
|
local_config: LocalAgentConfig | None = None
|
||||||
|
cloud_config: CloudAgentConfig | None = None
|
||||||
|
|
||||||
local_result = await db.execute(
|
local_result = await db.execute(
|
||||||
select(LocalAgentConfig).where(
|
select(LocalAgentConfig).where(
|
||||||
LocalAgentConfig.id == agent_id,
|
LocalAgentConfig.id == agent_id,
|
||||||
LocalAgentConfig.user_id == current_user.id,
|
LocalAgentConfig.user_id == current_user.id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if local_result.scalar_one_or_none() is not None:
|
local_config = local_result.scalar_one_or_none()
|
||||||
|
|
||||||
|
if local_config is not None:
|
||||||
agent_type = "local"
|
agent_type = "local"
|
||||||
else:
|
else:
|
||||||
cloud_result = await db.execute(
|
cloud_result = await db.execute(
|
||||||
@@ -415,7 +423,8 @@ async def trigger_agent_run(
|
|||||||
CloudAgentConfig.user_id == current_user.id,
|
CloudAgentConfig.user_id == current_user.id,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if cloud_result.scalar_one_or_none() is not None:
|
cloud_config = cloud_result.scalar_one_or_none()
|
||||||
|
if cloud_config is not None:
|
||||||
agent_type = "cloud"
|
agent_type = "cloud"
|
||||||
else:
|
else:
|
||||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Agent not found")
|
||||||
@@ -429,4 +438,15 @@ async def trigger_agent_run(
|
|||||||
db.add(run_log)
|
db.add(run_log)
|
||||||
await db.commit()
|
await db.commit()
|
||||||
await db.refresh(run_log)
|
await db.refresh(run_log)
|
||||||
|
|
||||||
|
# Dispatch the run as a background task — returns 202 immediately.
|
||||||
|
if agent_type == "local" and local_config is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
run_local_agent(current_user.id, local_config, run_log, device_manager)
|
||||||
|
)
|
||||||
|
elif agent_type == "cloud" and cloud_config is not None:
|
||||||
|
asyncio.create_task(
|
||||||
|
run_cloud_agent(current_user.id, cloud_config, run_log, device_manager)
|
||||||
|
)
|
||||||
|
|
||||||
return _to_run_log_response(run_log)
|
return _to_run_log_response(run_log)
|
||||||
|
|||||||
@@ -39,6 +39,7 @@ from jose import JWTError, jwt
|
|||||||
from sqlalchemy import select, update
|
from sqlalchemy import select, update
|
||||||
|
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.core.agent_runner import trigger_pending_runs
|
||||||
from app.core.device_manager import device_manager
|
from app.core.device_manager import device_manager
|
||||||
from app.db import async_session
|
from app.db import async_session
|
||||||
from app.models import AgentRunLog
|
from app.models import AgentRunLog
|
||||||
@@ -100,8 +101,8 @@ async def device_ws(websocket: WebSocket) -> None:
|
|||||||
agent_ids,
|
agent_ids,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Step 3.4 will replace this stub with a real call to agent_runner.
|
# Trigger any overdue agent runs now that the device is connected.
|
||||||
asyncio.create_task(_trigger_pending_runs_stub(user_id, device_id))
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||||
|
|
||||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||||
try:
|
try:
|
||||||
@@ -217,10 +218,4 @@ async def _mark_runs_disconnected(user_id: str) -> None:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# ── Pending-run trigger stub (Step 3.4 will replace) ─────────────────
|
|
||||||
|
|
||||||
async def _trigger_pending_runs_stub(user_id: str, device_id: str) -> None:
|
|
||||||
"""No-op stub. Step 3.4 wires this to agent_runner.trigger_pending_runs."""
|
|
||||||
logger.debug(
|
|
||||||
"device_ws: _trigger_pending_runs stub user=%s device=%s", user_id, device_id
|
|
||||||
)
|
|
||||||
|
|||||||
534
app/core/agent_runner.py
Normal file
534
app/core/agent_runner.py
Normal file
@@ -0,0 +1,534 @@
|
|||||||
|
"""Agent run orchestrator.
|
||||||
|
|
||||||
|
Drives two agent types:
|
||||||
|
|
||||||
|
* **Local directory agent** — sends an ``agent_run`` frame to the connected
|
||||||
|
Electron device, waits for the device to stream back file contents via
|
||||||
|
``agent_data`` frames, then calls the LLM to extract structured items from
|
||||||
|
each file and pushes inserts to Electron via tool-call round-trips.
|
||||||
|
|
||||||
|
* **Cloud connector agent** — fetches data from third-party APIs (Gmail,
|
||||||
|
Teams, Outlook) and pushes extracted items to Electron. **This path is
|
||||||
|
a stub** — provider integrations are implemented in Step 3.6.
|
||||||
|
|
||||||
|
Usage
|
||||||
|
-----
|
||||||
|
Background tasks are spawned with ``asyncio.create_task()``::
|
||||||
|
|
||||||
|
asyncio.create_task(run_local_agent(user_id, config, run_log, device_manager))
|
||||||
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||||
|
|
||||||
|
The ``trigger_pending_runs`` function is called by the device WS endpoint
|
||||||
|
when Electron sends ``device_hello``, so any overdue runs fire immediately
|
||||||
|
when the device reconnects.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from croniter import croniter
|
||||||
|
from langchain_core.messages import HumanMessage, SystemMessage
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
|
from app.core.llm import get_llm
|
||||||
|
from app.db import async_session
|
||||||
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ── Timeouts ───────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
# Max seconds to wait for Electron to finish streaming file data.
|
||||||
|
_FILE_READ_TIMEOUT: int = 120
|
||||||
|
# Max seconds to wait for Electron to acknowledge a single tool-call insert.
|
||||||
|
_INSERT_TIMEOUT: int = 30
|
||||||
|
|
||||||
|
# ── Allowed tables & extraction schema hints ───────────────────────────────
|
||||||
|
|
||||||
|
_ALLOWED_TABLES: frozenset[str] = frozenset(
|
||||||
|
{"tasks", "notes", "checkpoints", "projects", "taskComments"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Field descriptions fed to the extraction LLM as concise schema references.
|
||||||
|
_TABLE_SCHEMAS: dict[str, str] = {
|
||||||
|
"tasks": (
|
||||||
|
"title (str, required), description (str), "
|
||||||
|
"status (todo|in_progress|done, default todo), "
|
||||||
|
"priority (high|medium|low, default medium), "
|
||||||
|
"assignee (JSON array string), dueDate (ms timestamp int), projectId (str)"
|
||||||
|
),
|
||||||
|
"notes": "title (str, required), content (str, markdown), projectId (str)",
|
||||||
|
"checkpoints": (
|
||||||
|
"title (str, required), projectId (str, required), date (ms timestamp int)"
|
||||||
|
),
|
||||||
|
"projects": "name (str, required), clientId (str)",
|
||||||
|
"taskComments": "taskId (str, required), author (str), content (str, required)",
|
||||||
|
}
|
||||||
|
|
||||||
|
_EXTRACTION_SYSTEM_PROMPT = """\
|
||||||
|
You are a data extraction assistant for a freelance project management tool.
|
||||||
|
Given a document, extract structured records matching the user's instructions.
|
||||||
|
|
||||||
|
Output a JSON array (no markdown fences, no explanation) of objects shaped:
|
||||||
|
[{{"table": "<table_name>", "data": {{...fields}}}}, ...]
|
||||||
|
|
||||||
|
Allowed table names and their fields:
|
||||||
|
{table_schemas}
|
||||||
|
|
||||||
|
Rules:
|
||||||
|
- Only extract tables listed in the "data_types" instructions.
|
||||||
|
- Use camelCase field names exactly as shown above.
|
||||||
|
- Omit optional fields you cannot determine; do not invent data.
|
||||||
|
- Never include id, createdAt, updatedAt, isAiSuggested, or isApproved.
|
||||||
|
- If nothing relevant is found, return an empty JSON array: []
|
||||||
|
- Return ONLY the JSON array.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cron helper ────────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
|
||||||
|
"""Return ``True`` if the next scheduled run time has already passed.
|
||||||
|
|
||||||
|
Always validates the cron expression first — an invalid expression returns
|
||||||
|
``False`` (fail-safe: never trigger an unparseable schedule).
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
if last_run_at is None:
|
||||||
|
# Validate the expression before deciding this is overdue.
|
||||||
|
croniter(schedule_cron, now)
|
||||||
|
return True
|
||||||
|
ts = last_run_at
|
||||||
|
if ts.tzinfo is None:
|
||||||
|
ts = ts.replace(tzinfo=timezone.utc)
|
||||||
|
cron = croniter(schedule_cron, ts)
|
||||||
|
next_run: datetime = cron.get_next(datetime)
|
||||||
|
return now >= next_run
|
||||||
|
except Exception as exc:
|
||||||
|
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc)
|
||||||
|
return False # Fail-safe: don't trigger if expression is invalid.
|
||||||
|
|
||||||
|
|
||||||
|
# ── LLM extraction ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _extract_items_from_content(
|
||||||
|
prompt_template: str,
|
||||||
|
file_content: str,
|
||||||
|
data_types: list[str],
|
||||||
|
) -> list[dict[str, Any]]:
|
||||||
|
"""Call the LLM to extract structured records from *file_content*.
|
||||||
|
|
||||||
|
Returns a validated list of ``{table: str, data: dict}`` objects.
|
||||||
|
Items referencing tables not in *data_types* are discarded.
|
||||||
|
"""
|
||||||
|
allowed = [t for t in data_types if t in _ALLOWED_TABLES]
|
||||||
|
if not allowed:
|
||||||
|
return []
|
||||||
|
|
||||||
|
schema_text = "\n".join(
|
||||||
|
f" {table}: {_TABLE_SCHEMAS.get(table, '(unknown)')}" for table in allowed
|
||||||
|
)
|
||||||
|
system_prompt = _EXTRACTION_SYSTEM_PROMPT.format(table_schemas=schema_text)
|
||||||
|
user_prompt = (
|
||||||
|
f"User instructions: {prompt_template}\n\n"
|
||||||
|
f"Extract these record types: {', '.join(allowed)}\n\n"
|
||||||
|
f"Document:\n{file_content[:8000]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
llm = get_llm()
|
||||||
|
raw = ""
|
||||||
|
try:
|
||||||
|
response = await llm.ainvoke(
|
||||||
|
[SystemMessage(content=system_prompt), HumanMessage(content=user_prompt)]
|
||||||
|
)
|
||||||
|
raw = str(response.content).strip()
|
||||||
|
items: list[dict] = json.loads(raw)
|
||||||
|
if not isinstance(items, list):
|
||||||
|
raise ValueError("LLM response is not a JSON array")
|
||||||
|
except json.JSONDecodeError as exc:
|
||||||
|
logger.warning(
|
||||||
|
"agent_runner: LLM extraction returned invalid JSON: %s — snippet: %.200r",
|
||||||
|
exc,
|
||||||
|
raw,
|
||||||
|
)
|
||||||
|
return []
|
||||||
|
# Other exceptions (LLM API errors, network errors) propagate to the
|
||||||
|
# caller (run_local_agent) which records them per-file in the run log.
|
||||||
|
|
||||||
|
validated: list[dict[str, Any]] = []
|
||||||
|
for item in items:
|
||||||
|
table = item.get("table")
|
||||||
|
data = item.get("data")
|
||||||
|
if not isinstance(table, str) or table not in allowed:
|
||||||
|
continue
|
||||||
|
if not isinstance(data, dict) or not data:
|
||||||
|
continue
|
||||||
|
# Strip any server-generated or forbidden fields.
|
||||||
|
for _field in ("id", "createdAt", "updatedAt", "isAiSuggested", "isApproved"):
|
||||||
|
data.pop(_field, None)
|
||||||
|
validated.append({"table": table, "data": data})
|
||||||
|
return validated
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tool-call insert helper ─────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _send_insert_to_client(
|
||||||
|
user_id: str,
|
||||||
|
table: str,
|
||||||
|
data: dict[str, Any],
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> dict[str, Any]:
|
||||||
|
"""Send an ``insert`` tool_call frame to Electron and await the tool_result.
|
||||||
|
|
||||||
|
All inserts include ``isAiSuggested=1, isApproved=0`` so the user can
|
||||||
|
review AI-produced records before they are treated as confirmed.
|
||||||
|
|
||||||
|
Raises ``asyncio.TimeoutError`` if Electron does not respond within
|
||||||
|
``_INSERT_TIMEOUT`` seconds. Raises ``RuntimeError`` if the device
|
||||||
|
disconnects before the frame can be sent.
|
||||||
|
"""
|
||||||
|
call_id = str(uuid.uuid4())
|
||||||
|
payload: dict[str, Any] = {
|
||||||
|
"type": "tool_call",
|
||||||
|
"id": call_id,
|
||||||
|
"action": "insert",
|
||||||
|
"table": table,
|
||||||
|
"data": {**data, "isAiSuggested": 1, "isApproved": 0},
|
||||||
|
}
|
||||||
|
fut = device_mgr.create_pending_call(user_id, call_id)
|
||||||
|
await device_mgr.send_frame(user_id, payload)
|
||||||
|
return await asyncio.wait_for(fut, timeout=_INSERT_TIMEOUT)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Local agent runner ──────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_local_agent(
|
||||||
|
user_id: str,
|
||||||
|
config: LocalAgentConfig,
|
||||||
|
run_log: AgentRunLog,
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a local directory agent run end-to-end.
|
||||||
|
|
||||||
|
Steps:
|
||||||
|
|
||||||
|
1. Verify the device identified by ``config.device_id`` is currently online.
|
||||||
|
2. Pre-create the agent_data queue so no incoming frames are lost.
|
||||||
|
3. Send ``agent_run`` frame to Electron (paths, extensions, prompt, data_types).
|
||||||
|
4. Consume ``agent_data`` frames until the ``None`` sentinel from
|
||||||
|
``agent_complete``.
|
||||||
|
5. For each received file call the LLM to extract ``{table, data}`` items.
|
||||||
|
6. Push each item to Electron as an ``insert`` tool-call; include
|
||||||
|
``isAiSuggested=1, isApproved=0`` so users can review AI suggestions.
|
||||||
|
7. Persist the run outcome (status, counts, errors) and update
|
||||||
|
``config.last_run_at``.
|
||||||
|
"""
|
||||||
|
run_id = run_log.id
|
||||||
|
|
||||||
|
# ── 1. Device online check ─────────────────────────────────────────
|
||||||
|
if not device_mgr.is_online(user_id, config.device_id):
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: skip run=%s — device %r offline for user=%s",
|
||||||
|
run_id,
|
||||||
|
config.device_id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Device {config.device_id!r} is not connected"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 2. Pre-create agent_data queue ────────────────────────────────
|
||||||
|
try:
|
||||||
|
device_mgr.get_agent_data_queue(user_id, run_id)
|
||||||
|
except RuntimeError:
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=["Device disconnected before agent run could start"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
# ── 3. Send agent_run frame ────────────────────────────────────────
|
||||||
|
frame: dict[str, Any] = {
|
||||||
|
"type": "agent_run",
|
||||||
|
"run_id": run_id,
|
||||||
|
"agent_id": config.id,
|
||||||
|
"config": {
|
||||||
|
"paths": config.directory_paths,
|
||||||
|
"file_extensions": config.file_extensions,
|
||||||
|
"prompt_template": config.prompt_template,
|
||||||
|
"data_types": config.data_types,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
await device_mgr.send_frame(user_id, frame)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[f"Failed to send agent_run frame: {exc}"],
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: sent agent_run run=%s agent=%s user=%s",
|
||||||
|
run_id,
|
||||||
|
config.id,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── 4. Consume agent_data frames ──────────────────────────────────
|
||||||
|
files: list[dict[str, Any]] = []
|
||||||
|
errors: list[str] = []
|
||||||
|
|
||||||
|
try:
|
||||||
|
queue = device_mgr.get_agent_data_queue(user_id, run_id)
|
||||||
|
deadline = asyncio.get_event_loop().time() + _FILE_READ_TIMEOUT
|
||||||
|
while True:
|
||||||
|
remaining = deadline - asyncio.get_event_loop().time()
|
||||||
|
if remaining <= 0:
|
||||||
|
errors.append("Timed out waiting for file data from device")
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
frame_data = await asyncio.wait_for(queue.get(), timeout=remaining)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append("Timed out waiting for file data from device")
|
||||||
|
break
|
||||||
|
if frame_data is None:
|
||||||
|
# Sentinel from agent_complete — stream is done.
|
||||||
|
break
|
||||||
|
files.extend(frame_data.get("files", []))
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Queue error reading agent data: {exc}")
|
||||||
|
|
||||||
|
# ── 5–6. Extract + insert ─────────────────────────────────────────
|
||||||
|
items_processed = 0
|
||||||
|
items_created = 0
|
||||||
|
|
||||||
|
for file_info in files:
|
||||||
|
file_path: str = file_info.get("path", "<unknown>")
|
||||||
|
content: str = file_info.get("content", "")
|
||||||
|
if not content:
|
||||||
|
continue
|
||||||
|
items_processed += 1
|
||||||
|
try:
|
||||||
|
extracted = await _extract_items_from_content(
|
||||||
|
config.prompt_template, content, config.data_types
|
||||||
|
)
|
||||||
|
except Exception as exc:
|
||||||
|
errors.append(f"LLM extraction error for {file_path!r}: {exc}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
for item in extracted:
|
||||||
|
try:
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
user_id, item["table"], item["data"], device_mgr
|
||||||
|
)
|
||||||
|
if result.get("error"):
|
||||||
|
errors.append(
|
||||||
|
f"Insert failed ({item['table']}, {file_path!r}): {result['error']}"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
items_created += 1
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
errors.append(
|
||||||
|
f"Timed out awaiting insert ack ({item['table']}, {file_path!r})"
|
||||||
|
)
|
||||||
|
except RuntimeError as exc:
|
||||||
|
errors.append(f"Insert error ({item['table']}, {file_path!r}): {exc}")
|
||||||
|
|
||||||
|
# ── 7. Finalise ────────────────────────────────────────────────────
|
||||||
|
device_mgr.cleanup_agent_data_queue(user_id, run_id)
|
||||||
|
|
||||||
|
if errors and items_created == 0:
|
||||||
|
final_status = "error"
|
||||||
|
elif errors:
|
||||||
|
final_status = "partial"
|
||||||
|
else:
|
||||||
|
final_status = "success"
|
||||||
|
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status=final_status,
|
||||||
|
items_processed=items_processed,
|
||||||
|
items_created=items_created,
|
||||||
|
errors=errors,
|
||||||
|
update_config_last_run=True,
|
||||||
|
config_id=config.id,
|
||||||
|
config_type="local",
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
|
||||||
|
run_id,
|
||||||
|
final_status,
|
||||||
|
items_processed,
|
||||||
|
items_created,
|
||||||
|
len(errors),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Cloud agent runner (stub) ───────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def run_cloud_agent(
|
||||||
|
user_id: str,
|
||||||
|
config: CloudAgentConfig,
|
||||||
|
run_log: AgentRunLog,
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Execute a cloud connector agent run.
|
||||||
|
|
||||||
|
.. note::
|
||||||
|
This is a **stub** — provider integrations (Gmail, Teams, Outlook)
|
||||||
|
are implemented in Step 3.6. The run is immediately marked as an
|
||||||
|
error with an informative message.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: cloud agent %s (provider=%s) for user=%s — pending Step 3.6",
|
||||||
|
config.id,
|
||||||
|
config.provider,
|
||||||
|
user_id,
|
||||||
|
)
|
||||||
|
await _finalize_run(
|
||||||
|
run_log,
|
||||||
|
status="error",
|
||||||
|
errors=[
|
||||||
|
f"Cloud provider integrations for '{config.provider}' are not yet "
|
||||||
|
"implemented. This feature arrives in Step 3.6."
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Pending-run trigger ─────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def trigger_pending_runs(
|
||||||
|
user_id: str,
|
||||||
|
device_id: str,
|
||||||
|
device_mgr: DeviceConnectionManager,
|
||||||
|
) -> None:
|
||||||
|
"""Dispatch any overdue agent runs after an Electron device connects.
|
||||||
|
|
||||||
|
Called as a background task from the device WS endpoint on ``device_hello``.
|
||||||
|
|
||||||
|
Scheduling rules:
|
||||||
|
|
||||||
|
* **Local agents**: only triggered when ``config.device_id == device_id``.
|
||||||
|
* **Cloud agents**: triggered on any connected device (no device binding).
|
||||||
|
* Runs execute **sequentially** to avoid flooding the WS connection.
|
||||||
|
"""
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: scanning overdue runs for user=%s device=%s", user_id, device_id
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
local_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(
|
||||||
|
LocalAgentConfig.user_id == user_id,
|
||||||
|
LocalAgentConfig.enabled == True, # noqa: E712
|
||||||
|
LocalAgentConfig.device_id == device_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
local_configs: list[LocalAgentConfig] = list(local_result.scalars().all())
|
||||||
|
|
||||||
|
cloud_result = await db.execute(
|
||||||
|
select(CloudAgentConfig).where(
|
||||||
|
CloudAgentConfig.user_id == user_id,
|
||||||
|
CloudAgentConfig.enabled == True, # noqa: E712
|
||||||
|
)
|
||||||
|
)
|
||||||
|
cloud_configs: list[CloudAgentConfig] = list(cloud_result.scalars().all())
|
||||||
|
|
||||||
|
# Build ordered list of overdue (type, config) pairs.
|
||||||
|
pending: list[tuple[str, Any]] = []
|
||||||
|
for cfg in local_configs:
|
||||||
|
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||||
|
pending.append(("local", cfg))
|
||||||
|
for cfg in cloud_configs:
|
||||||
|
if _is_overdue(cfg.schedule_cron, cfg.last_run_at):
|
||||||
|
pending.append(("cloud", cfg))
|
||||||
|
|
||||||
|
if not pending:
|
||||||
|
logger.debug("agent_runner: no overdue runs for user=%s", user_id)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
"agent_runner: %d overdue run(s) to dispatch for user=%s", len(pending), user_id
|
||||||
|
)
|
||||||
|
|
||||||
|
for agent_type, cfg in pending:
|
||||||
|
# Create a fresh run log for this scheduled dispatch.
|
||||||
|
run_log = AgentRunLog(
|
||||||
|
agent_id=cfg.id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
)
|
||||||
|
async with async_session() as db:
|
||||||
|
db.add(run_log)
|
||||||
|
await db.commit()
|
||||||
|
await db.refresh(run_log)
|
||||||
|
|
||||||
|
if agent_type == "local":
|
||||||
|
await run_local_agent(user_id, cfg, run_log, device_mgr)
|
||||||
|
else:
|
||||||
|
await run_cloud_agent(user_id, cfg, run_log, device_mgr)
|
||||||
|
|
||||||
|
|
||||||
|
# ── Internal helper ─────────────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
async def _finalize_run(
|
||||||
|
run_log: AgentRunLog,
|
||||||
|
*,
|
||||||
|
status: str,
|
||||||
|
items_processed: int = 0,
|
||||||
|
items_created: int = 0,
|
||||||
|
errors: list[str] | None = None,
|
||||||
|
update_config_last_run: bool = False,
|
||||||
|
config_id: str | None = None,
|
||||||
|
config_type: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Persist the run outcome and optionally update ``LocalAgentConfig.last_run_at``.
|
||||||
|
|
||||||
|
Uses a fresh DB session so this is safe to call from background tasks
|
||||||
|
after the original request session has closed.
|
||||||
|
"""
|
||||||
|
now = datetime.now(timezone.utc)
|
||||||
|
try:
|
||||||
|
async with async_session() as db:
|
||||||
|
managed = await db.merge(run_log)
|
||||||
|
managed.status = status
|
||||||
|
managed.items_processed = items_processed
|
||||||
|
managed.items_created = items_created
|
||||||
|
managed.errors = errors or []
|
||||||
|
managed.completed_at = now
|
||||||
|
|
||||||
|
if update_config_last_run and config_id and config_type == "local":
|
||||||
|
cfg_result = await db.execute(
|
||||||
|
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
|
||||||
|
)
|
||||||
|
cfg = cfg_result.scalar_one_or_none()
|
||||||
|
if cfg:
|
||||||
|
cfg.last_run_at = now
|
||||||
|
|
||||||
|
await db.commit()
|
||||||
|
except Exception as exc:
|
||||||
|
logger.error(
|
||||||
|
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc
|
||||||
|
)
|
||||||
@@ -24,4 +24,5 @@ aiosqlite>=0.20.0
|
|||||||
moto[s3]>=5.0.0
|
moto[s3]>=5.0.0
|
||||||
pinecone>=5.0.0
|
pinecone>=5.0.0
|
||||||
qdrant-client>=1.7.0
|
qdrant-client>=1.7.0
|
||||||
|
croniter>=3.0.0
|
||||||
ruff>=0.8.0
|
ruff>=0.8.0
|
||||||
|
|||||||
660
tests/test_agent_runner.py
Normal file
660
tests/test_agent_runner.py
Normal file
@@ -0,0 +1,660 @@
|
|||||||
|
"""Tests for Step 3.4: agent_runner module.
|
||||||
|
|
||||||
|
Coverage:
|
||||||
|
Unit:
|
||||||
|
- _is_overdue — cron schedule overdue detection
|
||||||
|
- _extract_items_from_content — LLM extraction + JSON parsing + validation
|
||||||
|
- _send_insert_to_client — tool_call frame construction + timeout
|
||||||
|
- run_local_agent — end-to-end local agent happy path
|
||||||
|
- run_local_agent — device offline path
|
||||||
|
- run_local_agent — file-read timeout path
|
||||||
|
- run_local_agent — LLM extraction error path
|
||||||
|
- run_cloud_agent — stub returns error immediately
|
||||||
|
- trigger_pending_runs — overdue local + cloud dispatched
|
||||||
|
- trigger_pending_runs — non-overdue skipped
|
||||||
|
- trigger_pending_runs — device_id filter for local agents
|
||||||
|
|
||||||
|
Integration:
|
||||||
|
- POST /agents/{id}/run — 404 on unknown agent
|
||||||
|
- POST /agents/{id}/run — creates run log + dispatches background task
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from app.core.agent_runner import (
|
||||||
|
_extract_items_from_content,
|
||||||
|
_is_overdue,
|
||||||
|
_send_insert_to_client,
|
||||||
|
run_cloud_agent,
|
||||||
|
run_local_agent,
|
||||||
|
trigger_pending_runs,
|
||||||
|
)
|
||||||
|
from app.core.device_manager import DeviceConnectionManager
|
||||||
|
from app.db import get_session
|
||||||
|
from app.main import app
|
||||||
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||||
|
from tests.conftest import TEST_USER_IDS, auth_header
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Helpers
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
_FREE_UID = TEST_USER_IDS["free"]
|
||||||
|
_PRO_UID = TEST_USER_IDS["pro"]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
|
||||||
|
return LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
device_id=device_id,
|
||||||
|
name="Test Local Agent",
|
||||||
|
directory_paths=["/home/user/emails"],
|
||||||
|
data_types=["tasks", "notes"],
|
||||||
|
prompt_template="Extract tasks and notes from this document.",
|
||||||
|
file_extensions=[".txt", ".eml"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
|
||||||
|
return CloudAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=user_id,
|
||||||
|
provider="gmail",
|
||||||
|
name="Test Gmail Agent",
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks from email.",
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
last_run_at=None,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
|
||||||
|
return AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=agent_id,
|
||||||
|
agent_type=agent_type,
|
||||||
|
user_id=user_id,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
|
||||||
|
mgr = DeviceConnectionManager()
|
||||||
|
ws = MagicMock()
|
||||||
|
ws.send_text = AsyncMock()
|
||||||
|
mgr.register(user_id, device_id, ws)
|
||||||
|
return mgr
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _is_overdue
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
def test_is_overdue_never_run():
|
||||||
|
"""An agent that has never run is always overdue."""
|
||||||
|
assert _is_overdue("0 */6 * * *", None) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_very_recently_run():
|
||||||
|
"""An agent that just ran is not overdue."""
|
||||||
|
last = datetime.now(timezone.utc)
|
||||||
|
assert _is_overdue("0 */6 * * *", last) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_long_ago():
|
||||||
|
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
|
||||||
|
from datetime import timedelta
|
||||||
|
last = datetime.now(timezone.utc) - timedelta(days=2)
|
||||||
|
assert _is_overdue("0 */6 * * *", last) is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_invalid_cron_returns_false():
|
||||||
|
"""Unparseable cron must not raise and should return False (fail-safe)."""
|
||||||
|
assert _is_overdue("not a cron", None) is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_overdue_naive_datetime():
|
||||||
|
"""Naive datetime objects are handled without raising."""
|
||||||
|
from datetime import timedelta
|
||||||
|
last = datetime.utcnow() - timedelta(days=1) # naive
|
||||||
|
# Should not raise.
|
||||||
|
result = _is_overdue("0 */6 * * *", last)
|
||||||
|
assert isinstance(result, bool)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _extract_items_from_content
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_happy_path():
|
||||||
|
"""LLM returns valid JSON array; items with allowed tables are returned."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
|
||||||
|
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content(
|
||||||
|
"Extract tasks and notes.",
|
||||||
|
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
|
||||||
|
["tasks", "notes"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(items) == 2
|
||||||
|
assert items[0]["table"] == "tasks"
|
||||||
|
assert items[0]["data"]["title"] == "Buy milk"
|
||||||
|
assert items[1]["table"] == "notes"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_strips_forbidden_fields():
|
||||||
|
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{
|
||||||
|
"table": "tasks",
|
||||||
|
"data": {
|
||||||
|
"title": "Review PR",
|
||||||
|
"id": "should-be-removed",
|
||||||
|
"createdAt": 99999,
|
||||||
|
"isAiSuggested": 0,
|
||||||
|
"isApproved": 1,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
data = items[0]["data"]
|
||||||
|
assert "id" not in data
|
||||||
|
assert "createdAt" not in data
|
||||||
|
assert "isAiSuggested" not in data
|
||||||
|
assert "isApproved" not in data
|
||||||
|
assert data["title"] == "Review PR"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_invalid_json_returns_empty():
|
||||||
|
"""LLM returning invalid JSON must return empty list without raising."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = "Sorry, I cannot extract anything."
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||||
|
|
||||||
|
assert items == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_disallowed_table_filtered():
|
||||||
|
"""Items whose table is not in data_types are discarded."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{"table": "tasks", "data": {"title": "Valid task"}},
|
||||||
|
{"table": "projects", "data": {"name": "Should be filtered"}},
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
# Only "tasks" is in data_types — "projects" should be filtered.
|
||||||
|
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
|
||||||
|
|
||||||
|
assert len(items) == 1
|
||||||
|
assert items[0]["table"] == "tasks"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_empty_data_types_returns_empty():
|
||||||
|
"""If no allowed data_types match, skip LLM call and return immediately."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
items = await _extract_items_from_content("Extract.", "content", [])
|
||||||
|
|
||||||
|
mock_llm.ainvoke.assert_not_called()
|
||||||
|
assert items == []
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_extract_items_llm_error_propagates():
|
||||||
|
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||||
|
with pytest.raises(RuntimeError, match="API unavailable"):
|
||||||
|
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _send_insert_to_client
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_insert_to_client_happy_path():
|
||||||
|
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
sent_payloads: list[dict] = []
|
||||||
|
original_send = mgr.send_frame
|
||||||
|
|
||||||
|
async def _capture_send(uid: str, frame: dict) -> None:
|
||||||
|
sent_payloads.append(frame)
|
||||||
|
# Immediately resolve the pending call with a success result.
|
||||||
|
call_id = frame["id"]
|
||||||
|
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
|
||||||
|
|
||||||
|
mgr.send_frame = _capture_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
result = await _send_insert_to_client(
|
||||||
|
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
|
||||||
|
)
|
||||||
|
|
||||||
|
assert len(sent_payloads) == 1
|
||||||
|
payload = sent_payloads[0]
|
||||||
|
assert payload["action"] == "insert"
|
||||||
|
assert payload["table"] == "tasks"
|
||||||
|
assert payload["data"]["title"] == "Buy milk"
|
||||||
|
assert payload["data"]["isAiSuggested"] == 1
|
||||||
|
assert payload["data"]["isApproved"] == 0
|
||||||
|
assert result["row"]["title"] == "Buy milk"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_send_insert_to_client_timeout():
|
||||||
|
"""asyncio.TimeoutError is raised when Electron does not respond."""
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
async def _slow_send(uid: str, frame: dict) -> None:
|
||||||
|
# Never resolve the pending call.
|
||||||
|
pass
|
||||||
|
|
||||||
|
mgr.send_frame = _slow_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
|
||||||
|
with pytest.raises(asyncio.TimeoutError):
|
||||||
|
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_local_agent
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_device_offline():
|
||||||
|
"""run_local_agent marks run as error when device is offline."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = DeviceConnectionManager() # Empty — no device registered.
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert any("not connected" in e for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_happy_path():
|
||||||
|
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
# Build a fake agent_data frame (will be queued after send).
|
||||||
|
file_frame = {
|
||||||
|
"type": "agent_data",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
|
||||||
|
}
|
||||||
|
agent_complete_frame = None # sentinel
|
||||||
|
|
||||||
|
sent_frames: list[dict] = []
|
||||||
|
|
||||||
|
async def _mock_send(uid: str, frame: dict) -> None:
|
||||||
|
sent_frames.append(frame)
|
||||||
|
if frame.get("type") == "agent_run":
|
||||||
|
# Simulate Electron responding with file data then agent_complete.
|
||||||
|
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||||
|
await q.put(file_frame)
|
||||||
|
await q.put(agent_complete_frame)
|
||||||
|
elif frame.get("type") == "tool_call":
|
||||||
|
# Resolve the pending insert immediately.
|
||||||
|
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
|
||||||
|
|
||||||
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_response = MagicMock()
|
||||||
|
mock_response.content = json.dumps([
|
||||||
|
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
|
||||||
|
])
|
||||||
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "success"
|
||||||
|
assert kwargs["items_processed"] == 1
|
||||||
|
assert kwargs["items_created"] == 1
|
||||||
|
assert kwargs["errors"] == []
|
||||||
|
assert kwargs["update_config_last_run"] is True
|
||||||
|
|
||||||
|
# Verify agent_run frame was sent.
|
||||||
|
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||||
|
assert len(agent_run_frames) == 1
|
||||||
|
assert agent_run_frames[0]["agent_id"] == config.id
|
||||||
|
assert "paths" in agent_run_frames[0]["config"]
|
||||||
|
|
||||||
|
# Verify insert frame was sent with AI flags.
|
||||||
|
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
|
||||||
|
assert len(insert_frames) == 1
|
||||||
|
assert insert_frames[0]["data"]["isAiSuggested"] == 1
|
||||||
|
assert insert_frames[0]["data"]["isApproved"] == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_file_read_timeout():
|
||||||
|
"""run_local_agent marks run as partial/error when device stops sending files."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
async def _mock_send(uid: str, frame: dict) -> None:
|
||||||
|
# Don't put anything in the queue — simulate stalled device.
|
||||||
|
pass
|
||||||
|
|
||||||
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error" # No items created, so error (not partial).
|
||||||
|
assert any("timed out" in e.lower() for e in kwargs["errors"])
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_local_agent_llm_extraction_error():
|
||||||
|
"""LLM errors per-file are recorded; run continues for remaining files."""
|
||||||
|
config = _make_local_config()
|
||||||
|
run_log = _make_run_log(config.id)
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
file_frame = {
|
||||||
|
"type": "agent_data",
|
||||||
|
"run_id": run_log.id,
|
||||||
|
"files": [
|
||||||
|
{"path": "/file1.eml", "content": "Email one."},
|
||||||
|
{"path": "/file2.eml", "content": "Email two."},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
async def _mock_send(uid: str, frame: dict) -> None:
|
||||||
|
if frame.get("type") == "agent_run":
|
||||||
|
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||||
|
await q.put(file_frame)
|
||||||
|
await q.put(None) # agent_complete sentinel
|
||||||
|
|
||||||
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||||
|
|
||||||
|
mock_llm = MagicMock()
|
||||||
|
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||||
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert kwargs["items_processed"] == 2 # Both files attempted.
|
||||||
|
assert kwargs["items_created"] == 0
|
||||||
|
assert len(kwargs["errors"]) == 2 # One error per file.
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# run_cloud_agent (stub)
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_run_cloud_agent_stub_returns_error():
|
||||||
|
"""Cloud agent stub immediately marks run as error with informative message."""
|
||||||
|
config = _make_cloud_config()
|
||||||
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||||
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||||
|
|
||||||
|
mock_finalize.assert_called_once()
|
||||||
|
_args, kwargs = mock_finalize.call_args
|
||||||
|
assert kwargs["status"] == "error"
|
||||||
|
assert len(kwargs["errors"]) == 1
|
||||||
|
assert "gmail" in kwargs["errors"][0].lower()
|
||||||
|
assert "3.6" in kwargs["errors"][0]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# trigger_pending_runs
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_pending_runs_no_overdue():
|
||||||
|
"""If no agents are overdue trigger_pending_runs does nothing."""
|
||||||
|
from datetime import timedelta
|
||||||
|
|
||||||
|
config = _make_local_config()
|
||||||
|
config.last_run_at = datetime.now(timezone.utc) - timedelta(minutes=30) # ran 30m ago
|
||||||
|
config.schedule_cron = "0 */6 * * *" # every 6h — not due yet
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_pending_runs_device_id_filter():
|
||||||
|
"""Local agents are only triggered for the matching device_id."""
|
||||||
|
# The DB query already filters by device_id, so we verify the SELECT
|
||||||
|
# includes the device_id filter by checking that a config bound to a
|
||||||
|
# different device is never dispatched.
|
||||||
|
#
|
||||||
|
# Since trigger_pending_runs queries with device_id == "dev-001",
|
||||||
|
# simulate the DB returning an empty list (as it would for a mismatch).
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [] # no match
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
|
mgr = _make_manager(device_id="dev-001")
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||||
|
mock_ctx = AsyncMock()
|
||||||
|
mock_ctx.__aenter__ = AsyncMock(return_value=mock_ctx)
|
||||||
|
mock_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
mock_session_factory.return_value = mock_ctx
|
||||||
|
|
||||||
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
|
mock_run.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_pending_runs_dispatches_overdue():
|
||||||
|
"""Overdue local agent triggers run_local_agent sequentially."""
|
||||||
|
config = _make_local_config() # last_run_at=None → always overdue
|
||||||
|
|
||||||
|
mock_db_result_local = MagicMock()
|
||||||
|
mock_db_result_local.scalars.return_value.all.return_value = [config]
|
||||||
|
|
||||||
|
mock_db_result_cloud = MagicMock()
|
||||||
|
mock_db_result_cloud.scalars.return_value.all.return_value = []
|
||||||
|
|
||||||
|
mgr = _make_manager()
|
||||||
|
|
||||||
|
call_order: list[str] = []
|
||||||
|
|
||||||
|
async def _mock_run_local(user_id, cfg, run_log, device_mgr):
|
||||||
|
call_order.append("run_local")
|
||||||
|
|
||||||
|
with patch("app.core.agent_runner.async_session") as mock_session_factory, \
|
||||||
|
patch("app.core.agent_runner.run_local_agent", side_effect=_mock_run_local):
|
||||||
|
# First call: query configs. Subsequent calls: create run_log.
|
||||||
|
mock_query_ctx = AsyncMock()
|
||||||
|
mock_query_ctx.__aenter__ = AsyncMock(return_value=mock_query_ctx)
|
||||||
|
mock_query_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_query_ctx.execute = AsyncMock(
|
||||||
|
side_effect=[mock_db_result_local, mock_db_result_cloud]
|
||||||
|
)
|
||||||
|
|
||||||
|
run_log_obj = AgentRunLog(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
agent_id=config.id,
|
||||||
|
agent_type="local",
|
||||||
|
user_id=_FREE_UID,
|
||||||
|
status="running",
|
||||||
|
started_at=datetime.now(timezone.utc),
|
||||||
|
)
|
||||||
|
mock_insert_ctx = AsyncMock()
|
||||||
|
mock_insert_ctx.__aenter__ = AsyncMock(return_value=mock_insert_ctx)
|
||||||
|
mock_insert_ctx.__aexit__ = AsyncMock(return_value=False)
|
||||||
|
mock_insert_ctx.add = MagicMock()
|
||||||
|
mock_insert_ctx.commit = AsyncMock()
|
||||||
|
mock_insert_ctx.refresh = AsyncMock(side_effect=lambda obj: None)
|
||||||
|
|
||||||
|
mock_session_factory.side_effect = [mock_query_ctx, mock_insert_ctx]
|
||||||
|
|
||||||
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||||
|
|
||||||
|
assert call_order == ["run_local"]
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Integration: POST /agents/{id}/run
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _override_db(db_session):
|
||||||
|
"""Route all get_session calls to the test SQLite session."""
|
||||||
|
|
||||||
|
async def _gen():
|
||||||
|
yield db_session
|
||||||
|
|
||||||
|
app.dependency_overrides[get_session] = _gen
|
||||||
|
yield
|
||||||
|
app.dependency_overrides.pop(get_session, None)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_run_unknown_agent(client):
|
||||||
|
"""POST /agents/{id}/run returns 404 for unknown agent id."""
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/v1/agents/{uuid.uuid4()}/run",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
assert resp.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||||
|
"""POST /agents/{id}/run creates a run log and dispatches a background task."""
|
||||||
|
# Create the local agent config in the DB.
|
||||||
|
config = LocalAgentConfig(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=TEST_USER_IDS["power"],
|
||||||
|
device_id="dev-001",
|
||||||
|
name="My Agent",
|
||||||
|
directory_paths=["/home/user/docs"],
|
||||||
|
data_types=["tasks"],
|
||||||
|
prompt_template="Extract tasks.",
|
||||||
|
file_extensions=[".txt"],
|
||||||
|
schedule_cron="0 */6 * * *",
|
||||||
|
enabled=True,
|
||||||
|
)
|
||||||
|
db_session.add(config)
|
||||||
|
await db_session.commit()
|
||||||
|
|
||||||
|
dispatched: list = []
|
||||||
|
|
||||||
|
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||||
|
dispatched.append((user_id, cfg.id))
|
||||||
|
|
||||||
|
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||||
|
patch("app.api.routes.agents.run_cloud_agent", new_callable=AsyncMock), \
|
||||||
|
patch("asyncio.create_task") as mock_create_task:
|
||||||
|
resp = client.post(
|
||||||
|
f"/api/v1/agents/{config.id}/run",
|
||||||
|
headers=auth_header("power"),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert resp.status_code == 202
|
||||||
|
data = resp.json()
|
||||||
|
assert data["agent_id"] == config.id
|
||||||
|
assert data["status"] == "running"
|
||||||
|
assert data["agent_type"] == "local"
|
||||||
|
|
||||||
|
# Verify create_task was called (dispatching background run).
|
||||||
|
mock_create_task.assert_called_once()
|
||||||
Reference in New Issue
Block a user