- Add _language_instruction() to deep_agent.py, reads language from core memory - Append language directive to all 4 run_* functions (task/project/checkpoint/note) - Minor fixes: alembic env, route imports, test cleanup
809 lines
29 KiB
Python
809 lines
29 KiB
Python
"""Tests for Step 3.4: agent_runner module.
|
|
|
|
Coverage:
|
|
Unit:
|
|
- _is_overdue — cron schedule overdue detection
|
|
- _extract_items_from_content — LLM extraction + JSON parsing + validation
|
|
- _send_insert_to_client — tool_call frame construction + timeout
|
|
- run_local_agent — end-to-end local agent happy path
|
|
- run_local_agent — device offline path
|
|
- run_local_agent — file-read timeout path
|
|
- run_local_agent — LLM extraction error path
|
|
- run_cloud_agent — stub returns error immediately
|
|
- trigger_pending_runs — skipped when config is client-owned
|
|
- trigger_pending_runs — non-overdue skipped
|
|
- trigger_pending_runs — device_id filter for local agents
|
|
|
|
Integration:
|
|
- POST /agents/can-create — billing eligibility check
|
|
- POST /agents/trigger — creates run log + dispatches background task
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import uuid
|
|
from datetime import datetime, timezone
|
|
from unittest.mock import AsyncMock, MagicMock, patch
|
|
|
|
import pytest
|
|
|
|
from app.core.agent_runner import (
|
|
_extract_items_from_content,
|
|
_is_overdue,
|
|
_send_insert_to_client,
|
|
run_cloud_agent,
|
|
run_local_agent,
|
|
trigger_pending_runs,
|
|
)
|
|
from app.core.device_manager import DeviceConnectionManager
|
|
from app.db import get_session
|
|
from app.main import app
|
|
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
|
from tests.conftest import TEST_USER_IDS, auth_header
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
_FREE_UID = TEST_USER_IDS["free"]
|
|
_PRO_UID = TEST_USER_IDS["pro"]
|
|
|
|
|
|
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
|
|
return LocalAgentConfig(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
device_id=device_id,
|
|
name="Test Local Agent",
|
|
directory_paths=["/home/user/emails"],
|
|
data_types=["tasks", "notes"],
|
|
prompt_template="Extract tasks and notes from this document.",
|
|
file_extensions=[".txt", ".eml"],
|
|
schedule_cron="0 */6 * * *",
|
|
enabled=True,
|
|
last_run_at=None,
|
|
)
|
|
|
|
|
|
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
|
|
return CloudAgentConfig(
|
|
id=str(uuid.uuid4()),
|
|
user_id=user_id,
|
|
provider="gmail",
|
|
name="Test Gmail Agent",
|
|
data_types=["tasks"],
|
|
prompt_template="Extract tasks from email.",
|
|
schedule_cron="0 */6 * * *",
|
|
enabled=True,
|
|
last_run_at=None,
|
|
)
|
|
|
|
|
|
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
|
|
return AgentRunLog(
|
|
id=str(uuid.uuid4()),
|
|
agent_id=agent_id,
|
|
agent_type=agent_type,
|
|
user_id=user_id,
|
|
status="running",
|
|
started_at=datetime.now(timezone.utc),
|
|
)
|
|
|
|
|
|
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
|
|
mgr = DeviceConnectionManager()
|
|
ws = MagicMock()
|
|
ws.send_text = AsyncMock()
|
|
mgr.register(user_id, device_id, ws)
|
|
return mgr
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _is_overdue
|
|
# ---------------------------------------------------------------------------
|
|
|
|
def test_is_overdue_never_run():
|
|
"""An agent that has never run is always overdue."""
|
|
assert _is_overdue("0 */6 * * *", None) is True
|
|
|
|
|
|
def test_is_overdue_very_recently_run():
|
|
"""An agent that just ran is not overdue."""
|
|
last = datetime.now(timezone.utc)
|
|
assert _is_overdue("0 */6 * * *", last) is False
|
|
|
|
|
|
def test_is_overdue_long_ago():
|
|
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
|
|
from datetime import timedelta
|
|
last = datetime.now(timezone.utc) - timedelta(days=2)
|
|
assert _is_overdue("0 */6 * * *", last) is True
|
|
|
|
|
|
def test_is_overdue_invalid_cron_returns_false():
|
|
"""Unparseable cron must not raise and should return False (fail-safe)."""
|
|
assert _is_overdue("not a cron", None) is False
|
|
|
|
|
|
def test_is_overdue_naive_datetime():
|
|
"""Naive datetime objects are handled without raising."""
|
|
from datetime import timedelta
|
|
last = datetime.utcnow() - timedelta(days=1) # naive
|
|
# Should not raise.
|
|
result = _is_overdue("0 */6 * * *", last)
|
|
assert isinstance(result, bool)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _extract_items_from_content
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_items_happy_path():
|
|
"""LLM returns valid JSON array; items with allowed tables are returned."""
|
|
mock_llm = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = json.dumps([
|
|
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
|
|
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
|
|
])
|
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
items = await _extract_items_from_content(
|
|
"Extract tasks and notes.",
|
|
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
|
|
["tasks", "notes"],
|
|
)
|
|
|
|
assert len(items) == 2
|
|
assert items[0]["table"] == "tasks"
|
|
assert items[0]["data"]["title"] == "Buy milk"
|
|
assert items[1]["table"] == "notes"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_items_strips_forbidden_fields():
|
|
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
|
|
mock_llm = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = json.dumps([
|
|
{
|
|
"table": "tasks",
|
|
"data": {
|
|
"title": "Review PR",
|
|
"id": "should-be-removed",
|
|
"createdAt": 99999,
|
|
"isAiSuggested": 0,
|
|
"isApproved": 1,
|
|
},
|
|
}
|
|
])
|
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
|
|
|
|
assert len(items) == 1
|
|
data = items[0]["data"]
|
|
assert "id" not in data
|
|
assert "createdAt" not in data
|
|
assert "isAiSuggested" not in data
|
|
assert "isApproved" not in data
|
|
assert data["title"] == "Review PR"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_items_invalid_json_returns_empty():
|
|
"""LLM returning invalid JSON must return empty list without raising."""
|
|
mock_llm = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = "Sorry, I cannot extract anything."
|
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
|
|
|
assert items == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_items_disallowed_table_filtered():
|
|
"""Items whose table is not in data_types are discarded."""
|
|
mock_llm = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = json.dumps([
|
|
{"table": "tasks", "data": {"title": "Valid task"}},
|
|
{"table": "projects", "data": {"name": "Should be filtered"}},
|
|
])
|
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
# Only "tasks" is in data_types — "projects" should be filtered.
|
|
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
|
|
|
|
assert len(items) == 1
|
|
assert items[0]["table"] == "tasks"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_items_empty_data_types_returns_empty():
|
|
"""If no allowed data_types match, skip LLM call and return immediately."""
|
|
mock_llm = MagicMock()
|
|
mock_llm.ainvoke = AsyncMock()
|
|
|
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
items = await _extract_items_from_content("Extract.", "content", [])
|
|
|
|
mock_llm.ainvoke.assert_not_called()
|
|
assert items == []
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_extract_items_llm_error_propagates():
|
|
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
|
|
mock_llm = MagicMock()
|
|
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
|
|
|
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
|
with pytest.raises(RuntimeError, match="API unavailable"):
|
|
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# _send_insert_to_client
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_insert_to_client_happy_path():
|
|
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
|
|
mgr = _make_manager()
|
|
|
|
sent_payloads: list[dict] = []
|
|
original_send = mgr.send_frame
|
|
|
|
async def _capture_send(uid: str, frame: dict) -> None:
|
|
sent_payloads.append(frame)
|
|
# Immediately resolve the pending call with a success result.
|
|
call_id = frame["id"]
|
|
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
|
|
|
|
mgr.send_frame = _capture_send # type: ignore[method-assign]
|
|
|
|
result = await _send_insert_to_client(
|
|
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
|
|
)
|
|
|
|
assert len(sent_payloads) == 1
|
|
payload = sent_payloads[0]
|
|
assert payload["action"] == "insert"
|
|
assert payload["table"] == "tasks"
|
|
assert payload["data"]["title"] == "Buy milk"
|
|
assert payload["data"]["isAiSuggested"] == 1
|
|
assert payload["data"]["isApproved"] == 0
|
|
assert result["row"]["title"] == "Buy milk"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_send_insert_to_client_timeout():
|
|
"""asyncio.TimeoutError is raised when Electron does not respond."""
|
|
mgr = _make_manager()
|
|
|
|
async def _slow_send(uid: str, frame: dict) -> None:
|
|
# Never resolve the pending call.
|
|
pass
|
|
|
|
mgr.send_frame = _slow_send # type: ignore[method-assign]
|
|
|
|
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
|
|
with pytest.raises(asyncio.TimeoutError):
|
|
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# run_local_agent
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_local_agent_device_offline():
|
|
"""run_local_agent marks run as error when device is offline."""
|
|
config = _make_local_config()
|
|
run_log = _make_run_log(config.id)
|
|
mgr = DeviceConnectionManager() # Empty — no device registered.
|
|
|
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
mock_finalize.assert_called_once()
|
|
_args, kwargs = mock_finalize.call_args
|
|
assert kwargs["status"] == "error"
|
|
assert any("not connected" in e for e in kwargs["errors"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_local_agent_happy_path():
|
|
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
|
|
config = _make_local_config()
|
|
run_log = _make_run_log(config.id)
|
|
mgr = _make_manager()
|
|
|
|
# Build a fake agent_data frame (will be queued after send).
|
|
file_frame = {
|
|
"type": "agent_data",
|
|
"run_id": run_log.id,
|
|
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
|
|
}
|
|
agent_complete_frame = None # sentinel
|
|
|
|
sent_frames: list[dict] = []
|
|
|
|
async def _mock_send(uid: str, frame: dict) -> None:
|
|
sent_frames.append(frame)
|
|
if frame.get("type") == "agent_run":
|
|
# Simulate Electron responding with file data then agent_complete.
|
|
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
|
await q.put(file_frame)
|
|
await q.put(agent_complete_frame)
|
|
elif frame.get("type") == "tool_call":
|
|
# Resolve the pending insert immediately.
|
|
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
|
|
|
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
|
|
|
mock_llm = MagicMock()
|
|
mock_response = MagicMock()
|
|
mock_response.content = json.dumps([
|
|
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
|
|
])
|
|
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
|
|
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
mock_finalize.assert_called_once()
|
|
_args, kwargs = mock_finalize.call_args
|
|
assert kwargs["status"] == "success"
|
|
assert kwargs["items_processed"] == 1
|
|
assert kwargs["items_created"] == 1
|
|
assert kwargs["errors"] == []
|
|
assert kwargs["update_config_last_run"] is False
|
|
|
|
# Verify agent_run frame was sent.
|
|
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
|
assert len(agent_run_frames) == 1
|
|
assert agent_run_frames[0]["agent_id"] == config.id
|
|
assert "paths" in agent_run_frames[0]["config"]
|
|
|
|
# Verify insert frame was sent with AI flags.
|
|
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
|
|
assert len(insert_frames) == 1
|
|
assert insert_frames[0]["data"]["isAiSuggested"] == 1
|
|
assert insert_frames[0]["data"]["isApproved"] == 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_local_agent_file_read_timeout():
|
|
"""run_local_agent marks run as partial/error when device stops sending files."""
|
|
config = _make_local_config()
|
|
run_log = _make_run_log(config.id)
|
|
mgr = _make_manager()
|
|
|
|
async def _mock_send(uid: str, frame: dict) -> None:
|
|
# Don't put anything in the queue — simulate stalled device.
|
|
pass
|
|
|
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
|
|
|
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
|
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
mock_finalize.assert_called_once()
|
|
_args, kwargs = mock_finalize.call_args
|
|
assert kwargs["status"] == "error" # No items created, so error (not partial).
|
|
assert any("timed out" in e.lower() for e in kwargs["errors"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_local_agent_llm_extraction_error():
|
|
"""LLM errors per-file are recorded; run continues for remaining files."""
|
|
config = _make_local_config()
|
|
run_log = _make_run_log(config.id)
|
|
mgr = _make_manager()
|
|
|
|
file_frame = {
|
|
"type": "agent_data",
|
|
"run_id": run_log.id,
|
|
"files": [
|
|
{"path": "/file1.eml", "content": "Email one."},
|
|
{"path": "/file2.eml", "content": "Email two."},
|
|
],
|
|
}
|
|
|
|
async def _mock_send(uid: str, frame: dict) -> None:
|
|
if frame.get("type") == "agent_run":
|
|
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
|
await q.put(file_frame)
|
|
await q.put(None) # agent_complete sentinel
|
|
|
|
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
|
|
|
mock_llm = MagicMock()
|
|
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
|
|
|
|
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
_args, kwargs = mock_finalize.call_args
|
|
assert kwargs["status"] == "error"
|
|
assert kwargs["items_processed"] == 2 # Both files attempted.
|
|
assert kwargs["items_created"] == 0
|
|
assert len(kwargs["errors"]) == 2 # One error per file.
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# run_cloud_agent (stub)
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_cloud_agent_device_offline():
|
|
"""Cloud agent aborts immediately when no device is connected."""
|
|
config = _make_cloud_config()
|
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
mgr = DeviceConnectionManager() # empty — no devices registered
|
|
|
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
mock_finalize.assert_called_once()
|
|
_, kwargs = mock_finalize.call_args
|
|
assert kwargs["status"] == "error"
|
|
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_cloud_agent_no_oauth_token():
|
|
"""Cloud agent errors when no OAuth token is stored."""
|
|
config = _make_cloud_config()
|
|
config.oauth_token_encrypted = None
|
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
mgr = _make_manager()
|
|
|
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
_, kwargs = mock_finalize.call_args
|
|
assert kwargs["status"] == "error"
|
|
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_cloud_agent_token_decrypt_failure():
|
|
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
|
config = _make_cloud_config()
|
|
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
mgr = _make_manager()
|
|
|
|
from cryptography.fernet import Fernet as _Fernet
|
|
valid_key = _Fernet.generate_key().decode()
|
|
|
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
|
patch("app.integrations.settings") as mock_settings:
|
|
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
_, kwargs = mock_finalize.call_args
|
|
assert kwargs["status"] == "error"
|
|
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_cloud_agent_happy_path_gmail():
|
|
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
|
from app.integrations import EmailMessage, encrypt_token
|
|
from cryptography.fernet import Fernet as _Fernet
|
|
|
|
fernet_key = _Fernet.generate_key().decode()
|
|
credentials = {
|
|
"token": "access_abc",
|
|
"refresh_token": "refresh_xyz",
|
|
"token_uri": "https://oauth2.googleapis.com/token",
|
|
"client_id": "cid",
|
|
"client_secret": "csec",
|
|
}
|
|
|
|
config = _make_cloud_config()
|
|
config.provider = "gmail"
|
|
config.prompt_template = "Extract tasks from this email."
|
|
config.data_types = ["tasks"]
|
|
|
|
with patch("app.integrations.settings") as ms:
|
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
|
|
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
mgr = _make_manager()
|
|
|
|
sample_email = EmailMessage(
|
|
id="msg001",
|
|
subject="Action required",
|
|
sender="boss@company.com",
|
|
body_text="Please fix the bug by Friday.",
|
|
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
|
)
|
|
|
|
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
|
|
|
with patch("app.integrations.settings") as mock_int_settings, \
|
|
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
|
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
|
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
|
patch("app.core.agent_runner.async_session"):
|
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
|
|
mock_gmail = AsyncMock()
|
|
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
|
mock_gmail.refreshed_credentials = None
|
|
|
|
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
|
patch("app.integrations.get_provider", return_value=mock_gmail):
|
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
mock_extract.assert_called_once()
|
|
mock_insert.assert_called_once()
|
|
_, kwargs = mock_finalize.call_args
|
|
assert kwargs["status"] == "success"
|
|
assert kwargs["items_processed"] == 1
|
|
assert kwargs["items_created"] == 1
|
|
assert kwargs["config_type"] == "cloud"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_cloud_agent_provider_fetch_error():
|
|
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
|
credentials = {"token": "abc"}
|
|
config = _make_cloud_config()
|
|
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
|
config.prompt_template = "Extract tasks."
|
|
config.data_types = ["tasks"]
|
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
mgr = _make_manager()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
|
mock_provider.refreshed_credentials = None
|
|
|
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
|
patch("app.core.agent_runner.async_session"):
|
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
_, kwargs = mock_finalize.call_args
|
|
assert kwargs["status"] == "error"
|
|
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_run_cloud_agent_refreshed_token_persisted():
|
|
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
|
from app.integrations import encrypt_token
|
|
from cryptography.fernet import Fernet as _Fernet
|
|
|
|
fernet_key = _Fernet.generate_key().decode()
|
|
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
|
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
|
|
|
config = _make_cloud_config()
|
|
config.prompt_template = "Extract tasks."
|
|
config.data_types = ["tasks"]
|
|
|
|
with patch("app.integrations.settings") as ms:
|
|
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
config.oauth_token_encrypted = encrypt_token(credentials)
|
|
|
|
run_log = _make_run_log(config.id, agent_type="cloud")
|
|
mgr = _make_manager()
|
|
|
|
mock_provider = AsyncMock()
|
|
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
|
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
|
|
|
# Track DB writes via mock async_session.
|
|
mock_cfg_row = MagicMock()
|
|
mock_cfg_row.oauth_token_encrypted = None
|
|
|
|
mock_db = AsyncMock()
|
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
|
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
|
cfg_result = MagicMock()
|
|
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
|
mock_db.commit = AsyncMock()
|
|
|
|
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
|
patch("app.integrations.decrypt_token", return_value=credentials), \
|
|
patch("app.integrations.get_provider", return_value=mock_provider), \
|
|
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
|
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
|
patch("app.integrations.settings") as mock_int_settings:
|
|
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
|
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
|
|
|
# The new encrypted token should have been written to the config row.
|
|
mock_encrypt.assert_called_once_with(fresh_credentials)
|
|
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_finalize_run_updates_cloud_config_last_run_at():
|
|
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
|
from app.core.agent_runner import _finalize_run
|
|
|
|
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
|
run_log.id = str(uuid.uuid4())
|
|
|
|
mock_cfg = MagicMock()
|
|
mock_cfg.last_run_at = None
|
|
|
|
cfg_result = MagicMock()
|
|
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
|
|
|
mock_db = AsyncMock()
|
|
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
|
mock_db.__aexit__ = AsyncMock(return_value=False)
|
|
mock_db.merge = AsyncMock(return_value=run_log)
|
|
mock_db.execute = AsyncMock(return_value=cfg_result)
|
|
mock_db.commit = AsyncMock()
|
|
|
|
config_id = str(uuid.uuid4())
|
|
|
|
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
|
await _finalize_run(
|
|
run_log,
|
|
status="success",
|
|
update_config_last_run=True,
|
|
config_id=config_id,
|
|
config_type="cloud",
|
|
)
|
|
|
|
# CloudAgentConfig.last_run_at should have been set.
|
|
assert mock_cfg.last_run_at is not None
|
|
mock_db.commit.assert_called()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# trigger_pending_runs
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_trigger_pending_runs_no_overdue():
|
|
"""Pending-run scan is skipped because agent config is client-owned."""
|
|
|
|
mgr = _make_manager()
|
|
|
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
|
|
|
mock_run.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_trigger_pending_runs_device_id_filter():
|
|
"""Device filtering is no longer backend-managed in pending runs."""
|
|
|
|
mgr = _make_manager(device_id="dev-001")
|
|
|
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
|
|
|
mock_run.assert_not_called()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_trigger_pending_runs_dispatches_overdue():
|
|
"""No pending runs are dispatched by backend after config deprecation."""
|
|
|
|
mgr = _make_manager()
|
|
|
|
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
|
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
|
|
|
mock_run.assert_not_called()
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Integration: POST /agents/can-create and /agents/trigger
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _override_db(db_session):
|
|
"""Route all get_session calls to the test SQLite session."""
|
|
|
|
async def _gen():
|
|
yield db_session
|
|
|
|
app.dependency_overrides[get_session] = _gen
|
|
yield
|
|
app.dependency_overrides.pop(get_session, None)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_can_create_agent_allows_when_under_limit(client):
|
|
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
|
resp = client.post(
|
|
"/api/v1/agents/can-create",
|
|
json={"active_agents": 0},
|
|
headers=auth_header("free"),
|
|
)
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["allowed"] is True
|
|
assert body["tier"] == "free"
|
|
assert body["active_agents"] == 0
|
|
assert body["limit"] == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_can_create_agent_denies_when_at_limit(client):
|
|
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
|
resp = client.post(
|
|
"/api/v1/agents/can-create",
|
|
json={"active_agents": 2},
|
|
headers=auth_header("free"),
|
|
)
|
|
assert resp.status_code == 200
|
|
body = resp.json()
|
|
assert body["allowed"] is False
|
|
assert body["limit"] == 2
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
|
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
|
dispatched: list[tuple[str, str]] = []
|
|
|
|
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
|
dispatched.append((user_id, cfg.id))
|
|
|
|
def _fake_create_task(coro):
|
|
coro.close()
|
|
return MagicMock()
|
|
|
|
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
|
patch("asyncio.create_task") as mock_create_task:
|
|
mock_create_task.side_effect = _fake_create_task
|
|
resp = client.post(
|
|
"/api/v1/agents/trigger",
|
|
json={
|
|
"directory": "/home/user/docs",
|
|
"what_to_extract": ["task", "note"],
|
|
"batch_interval": "0 */6 * * *",
|
|
"custom_agent_prompt": "Extract tasks and notes.",
|
|
"active_agents": 0,
|
|
},
|
|
headers=auth_header("power"),
|
|
)
|
|
|
|
assert resp.status_code == 202
|
|
data = resp.json()
|
|
assert isinstance(data["agent_id"], str)
|
|
assert data["agent_id"]
|
|
assert data["status"] == "running"
|
|
assert data["agent_type"] == "local"
|
|
|
|
# Verify create_task was called (dispatching background run).
|
|
mock_create_task.assert_called_once()
|