fix: clean up stale and obsolete tests
- test_deep_agent: update patch target get_llm -> get_agent_llm (8 tests) - test_device_ws: remove 5 tests for deleted agent_data_queue API - test_schemas_v3: remove agent_run/agent_data/agent_complete from v2 compat list - Delete test_agent_runner.py (superseded by test_agent_runner_v2.py) - Delete test_agent_setup.py (superseded by test_journey_v2.py) - Delete test_classify_file.py (_classify_file removed in v2 rewrite)
This commit is contained in:
@@ -1,808 +0,0 @@
|
||||
"""Tests for Step 3.4: agent_runner module.
|
||||
|
||||
Coverage:
|
||||
Unit:
|
||||
- _is_overdue — cron schedule overdue detection
|
||||
- _extract_items_from_content — LLM extraction + JSON parsing + validation
|
||||
- _send_insert_to_client — tool_call frame construction + timeout
|
||||
- run_local_agent — end-to-end local agent happy path
|
||||
- run_local_agent — device offline path
|
||||
- run_local_agent — file-read timeout path
|
||||
- run_local_agent — LLM extraction error path
|
||||
- run_cloud_agent — stub returns error immediately
|
||||
- trigger_pending_runs — skipped when config is client-owned
|
||||
- trigger_pending_runs — non-overdue skipped
|
||||
- trigger_pending_runs — device_id filter for local agents
|
||||
|
||||
Integration:
|
||||
- POST /agents/can-create — billing eligibility check
|
||||
- POST /agents/trigger — creates run log + dispatches background task
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_runner import (
|
||||
_extract_items_from_content,
|
||||
_is_overdue,
|
||||
_send_insert_to_client,
|
||||
run_cloud_agent,
|
||||
run_local_agent,
|
||||
trigger_pending_runs,
|
||||
)
|
||||
from app.core.device_manager import DeviceConnectionManager
|
||||
from app.db import get_session
|
||||
from app.main import app
|
||||
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
|
||||
from tests.conftest import TEST_USER_IDS, auth_header
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_FREE_UID = TEST_USER_IDS["free"]
|
||||
_PRO_UID = TEST_USER_IDS["pro"]
|
||||
|
||||
|
||||
def _make_local_config(user_id: str = _FREE_UID, device_id: str = "dev-001") -> LocalAgentConfig:
|
||||
return LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
device_id=device_id,
|
||||
name="Test Local Agent",
|
||||
directory_paths=["/home/user/emails"],
|
||||
data_types=["tasks", "notes"],
|
||||
prompt_template="Extract tasks and notes from this document.",
|
||||
file_extensions=[".txt", ".eml"],
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
last_run_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_cloud_config(user_id: str = _FREE_UID) -> CloudAgentConfig:
|
||||
return CloudAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
provider="gmail",
|
||||
name="Test Gmail Agent",
|
||||
data_types=["tasks"],
|
||||
prompt_template="Extract tasks from email.",
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
last_run_at=None,
|
||||
)
|
||||
|
||||
|
||||
def _make_run_log(agent_id: str, agent_type: str = "local", user_id: str = _FREE_UID) -> AgentRunLog:
|
||||
return AgentRunLog(
|
||||
id=str(uuid.uuid4()),
|
||||
agent_id=agent_id,
|
||||
agent_type=agent_type,
|
||||
user_id=user_id,
|
||||
status="running",
|
||||
started_at=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
|
||||
def _make_manager(user_id: str = _FREE_UID, device_id: str = "dev-001") -> DeviceConnectionManager:
|
||||
mgr = DeviceConnectionManager()
|
||||
ws = MagicMock()
|
||||
ws.send_text = AsyncMock()
|
||||
mgr.register(user_id, device_id, ws)
|
||||
return mgr
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _is_overdue
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def test_is_overdue_never_run():
|
||||
"""An agent that has never run is always overdue."""
|
||||
assert _is_overdue("0 */6 * * *", None) is True
|
||||
|
||||
|
||||
def test_is_overdue_very_recently_run():
|
||||
"""An agent that just ran is not overdue."""
|
||||
last = datetime.now(timezone.utc)
|
||||
assert _is_overdue("0 */6 * * *", last) is False
|
||||
|
||||
|
||||
def test_is_overdue_long_ago():
|
||||
"""An agent last run 2 days ago with a 6-hour schedule is overdue."""
|
||||
from datetime import timedelta
|
||||
last = datetime.now(timezone.utc) - timedelta(days=2)
|
||||
assert _is_overdue("0 */6 * * *", last) is True
|
||||
|
||||
|
||||
def test_is_overdue_invalid_cron_returns_false():
|
||||
"""Unparseable cron must not raise and should return False (fail-safe)."""
|
||||
assert _is_overdue("not a cron", None) is False
|
||||
|
||||
|
||||
def test_is_overdue_naive_datetime():
|
||||
"""Naive datetime objects are handled without raising."""
|
||||
from datetime import timedelta
|
||||
last = datetime.utcnow() - timedelta(days=1) # naive
|
||||
# Should not raise.
|
||||
result = _is_overdue("0 */6 * * *", last)
|
||||
assert isinstance(result, bool)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _extract_items_from_content
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_happy_path():
|
||||
"""LLM returns valid JSON array; items with allowed tables are returned."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{"table": "tasks", "data": {"title": "Buy milk", "priority": "high"}},
|
||||
{"table": "notes", "data": {"title": "Meeting recap", "content": "Discussed roadmap"}},
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content(
|
||||
"Extract tasks and notes.",
|
||||
"Email body: Buy milk urgently. Notes from meeting: discussed roadmap.",
|
||||
["tasks", "notes"],
|
||||
)
|
||||
|
||||
assert len(items) == 2
|
||||
assert items[0]["table"] == "tasks"
|
||||
assert items[0]["data"]["title"] == "Buy milk"
|
||||
assert items[1]["table"] == "notes"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_strips_forbidden_fields():
|
||||
"""Fields like id, createdAt, isAiSuggested must be stripped from extracted data."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{
|
||||
"table": "tasks",
|
||||
"data": {
|
||||
"title": "Review PR",
|
||||
"id": "should-be-removed",
|
||||
"createdAt": 99999,
|
||||
"isAiSuggested": 0,
|
||||
"isApproved": 1,
|
||||
},
|
||||
}
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content("Extract tasks.", "Review the PR.", ["tasks"])
|
||||
|
||||
assert len(items) == 1
|
||||
data = items[0]["data"]
|
||||
assert "id" not in data
|
||||
assert "createdAt" not in data
|
||||
assert "isAiSuggested" not in data
|
||||
assert "isApproved" not in data
|
||||
assert data["title"] == "Review PR"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_invalid_json_returns_empty():
|
||||
"""LLM returning invalid JSON must return empty list without raising."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = "Sorry, I cannot extract anything."
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||
|
||||
assert items == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_disallowed_table_filtered():
|
||||
"""Items whose table is not in data_types are discarded."""
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{"table": "tasks", "data": {"title": "Valid task"}},
|
||||
{"table": "projects", "data": {"name": "Should be filtered"}},
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
# Only "tasks" is in data_types — "projects" should be filtered.
|
||||
items = await _extract_items_from_content("Extract.", "content", ["tasks"])
|
||||
|
||||
assert len(items) == 1
|
||||
assert items[0]["table"] == "tasks"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_empty_data_types_returns_empty():
|
||||
"""If no allowed data_types match, skip LLM call and return immediately."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock()
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
items = await _extract_items_from_content("Extract.", "content", [])
|
||||
|
||||
mock_llm.ainvoke.assert_not_called()
|
||||
assert items == []
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_extract_items_llm_error_propagates():
|
||||
"""LLM API errors propagate so the caller (run_local_agent) can record them."""
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("API unavailable"))
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm):
|
||||
with pytest.raises(RuntimeError, match="API unavailable"):
|
||||
await _extract_items_from_content("Extract tasks.", "content", ["tasks"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _send_insert_to_client
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_insert_to_client_happy_path():
|
||||
"""Frame is sent with isAiSuggested/isApproved added; result is returned."""
|
||||
mgr = _make_manager()
|
||||
|
||||
sent_payloads: list[dict] = []
|
||||
original_send = mgr.send_frame
|
||||
|
||||
async def _capture_send(uid: str, frame: dict) -> None:
|
||||
sent_payloads.append(frame)
|
||||
# Immediately resolve the pending call with a success result.
|
||||
call_id = frame["id"]
|
||||
mgr.resolve_pending_call(uid, call_id, {"row": {"id": "new-id", "title": "Buy milk"}})
|
||||
|
||||
mgr.send_frame = _capture_send # type: ignore[method-assign]
|
||||
|
||||
result = await _send_insert_to_client(
|
||||
_FREE_UID, "tasks", {"title": "Buy milk", "priority": "high"}, mgr
|
||||
)
|
||||
|
||||
assert len(sent_payloads) == 1
|
||||
payload = sent_payloads[0]
|
||||
assert payload["action"] == "insert"
|
||||
assert payload["table"] == "tasks"
|
||||
assert payload["data"]["title"] == "Buy milk"
|
||||
assert payload["data"]["isAiSuggested"] == 1
|
||||
assert payload["data"]["isApproved"] == 0
|
||||
assert result["row"]["title"] == "Buy milk"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_insert_to_client_timeout():
|
||||
"""asyncio.TimeoutError is raised when Electron does not respond."""
|
||||
mgr = _make_manager()
|
||||
|
||||
async def _slow_send(uid: str, frame: dict) -> None:
|
||||
# Never resolve the pending call.
|
||||
pass
|
||||
|
||||
mgr.send_frame = _slow_send # type: ignore[method-assign]
|
||||
|
||||
with patch("app.core.agent_runner._INSERT_TIMEOUT", 0.05):
|
||||
with pytest.raises(asyncio.TimeoutError):
|
||||
await _send_insert_to_client(_FREE_UID, "tasks", {"title": "X"}, mgr)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_local_agent
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_device_offline():
|
||||
"""run_local_agent marks run as error when device is offline."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = DeviceConnectionManager() # Empty — no device registered.
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("not connected" in e for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_happy_path():
|
||||
"""End-to-end: files received, LLM extracts one task, insert sent + ack'd."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = _make_manager()
|
||||
|
||||
# Build a fake agent_data frame (will be queued after send).
|
||||
file_frame = {
|
||||
"type": "agent_data",
|
||||
"run_id": run_log.id,
|
||||
"files": [{"path": "/email.eml", "content": "Urgent: fix the bug by Friday."}],
|
||||
}
|
||||
agent_complete_frame = None # sentinel
|
||||
|
||||
sent_frames: list[dict] = []
|
||||
|
||||
async def _mock_send(uid: str, frame: dict) -> None:
|
||||
sent_frames.append(frame)
|
||||
if frame.get("type") == "agent_run":
|
||||
# Simulate Electron responding with file data then agent_complete.
|
||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||
await q.put(file_frame)
|
||||
await q.put(agent_complete_frame)
|
||||
elif frame.get("type") == "tool_call":
|
||||
# Resolve the pending insert immediately.
|
||||
mgr.resolve_pending_call(uid, frame["id"], {"row": {"id": "new-task", "title": "Fix the bug"}})
|
||||
|
||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = json.dumps([
|
||||
{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}
|
||||
])
|
||||
mock_llm.ainvoke = AsyncMock(return_value=mock_response)
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "success"
|
||||
assert kwargs["items_processed"] == 1
|
||||
assert kwargs["items_created"] == 1
|
||||
assert kwargs["errors"] == []
|
||||
assert kwargs["update_config_last_run"] is False
|
||||
|
||||
# Verify agent_run frame was sent.
|
||||
agent_run_frames = [f for f in sent_frames if f.get("type") == "agent_run"]
|
||||
assert len(agent_run_frames) == 1
|
||||
assert agent_run_frames[0]["agent_id"] == config.id
|
||||
assert "paths" in agent_run_frames[0]["config"]
|
||||
|
||||
# Verify insert frame was sent with AI flags.
|
||||
insert_frames = [f for f in sent_frames if f.get("type") == "tool_call"]
|
||||
assert len(insert_frames) == 1
|
||||
assert insert_frames[0]["data"]["isAiSuggested"] == 1
|
||||
assert insert_frames[0]["data"]["isApproved"] == 0
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_file_read_timeout():
|
||||
"""run_local_agent marks run as partial/error when device stops sending files."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = _make_manager()
|
||||
|
||||
async def _mock_send(uid: str, frame: dict) -> None:
|
||||
# Don't put anything in the queue — simulate stalled device.
|
||||
pass
|
||||
|
||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||
|
||||
with patch("app.core.agent_runner._FILE_READ_TIMEOUT", 0.1), \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error" # No items created, so error (not partial).
|
||||
assert any("timed out" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_local_agent_llm_extraction_error():
|
||||
"""LLM errors per-file are recorded; run continues for remaining files."""
|
||||
config = _make_local_config()
|
||||
run_log = _make_run_log(config.id)
|
||||
mgr = _make_manager()
|
||||
|
||||
file_frame = {
|
||||
"type": "agent_data",
|
||||
"run_id": run_log.id,
|
||||
"files": [
|
||||
{"path": "/file1.eml", "content": "Email one."},
|
||||
{"path": "/file2.eml", "content": "Email two."},
|
||||
],
|
||||
}
|
||||
|
||||
async def _mock_send(uid: str, frame: dict) -> None:
|
||||
if frame.get("type") == "agent_run":
|
||||
q = mgr.get_agent_data_queue(uid, frame["run_id"])
|
||||
await q.put(file_frame)
|
||||
await q.put(None) # agent_complete sentinel
|
||||
|
||||
mgr.send_frame = _mock_send # type: ignore[method-assign]
|
||||
|
||||
mock_llm = MagicMock()
|
||||
mock_llm.ainvoke = AsyncMock(side_effect=RuntimeError("LLM boom"))
|
||||
|
||||
with patch("app.core.agent_runner.get_llm", return_value=mock_llm), \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_local_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_args, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert kwargs["items_processed"] == 2 # Both files attempted.
|
||||
assert kwargs["items_created"] == 0
|
||||
assert len(kwargs["errors"]) == 2 # One error per file.
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# run_cloud_agent (stub)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_device_offline():
|
||||
"""Cloud agent aborts immediately when no device is connected."""
|
||||
config = _make_cloud_config()
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = DeviceConnectionManager() # empty — no devices registered
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_finalize.assert_called_once()
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("device" in e.lower() or "connected" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_no_oauth_token():
|
||||
"""Cloud agent errors when no OAuth token is stored."""
|
||||
config = _make_cloud_config()
|
||||
config.oauth_token_encrypted = None
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize:
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("oauth" in e.lower() or "token" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_token_decrypt_failure():
|
||||
"""Cloud agent errors gracefully when the stored token cannot be decrypted."""
|
||||
config = _make_cloud_config()
|
||||
config.oauth_token_encrypted = "this-is-not-valid-fernet-ciphertext"
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
valid_key = _Fernet.generate_key().decode()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||
patch("app.integrations.settings") as mock_settings:
|
||||
mock_settings.OAUTH_ENCRYPTION_KEY = valid_key
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("decrypt" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_happy_path_gmail():
|
||||
"""Cloud agent happy path: Gmail fetch → LLM extraction → inserts → success."""
|
||||
from app.integrations import EmailMessage, encrypt_token
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
|
||||
fernet_key = _Fernet.generate_key().decode()
|
||||
credentials = {
|
||||
"token": "access_abc",
|
||||
"refresh_token": "refresh_xyz",
|
||||
"token_uri": "https://oauth2.googleapis.com/token",
|
||||
"client_id": "cid",
|
||||
"client_secret": "csec",
|
||||
}
|
||||
|
||||
config = _make_cloud_config()
|
||||
config.provider = "gmail"
|
||||
config.prompt_template = "Extract tasks from this email."
|
||||
config.data_types = ["tasks"]
|
||||
|
||||
with patch("app.integrations.settings") as ms:
|
||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
sample_email = EmailMessage(
|
||||
id="msg001",
|
||||
subject="Action required",
|
||||
sender="boss@company.com",
|
||||
body_text="Please fix the bug by Friday.",
|
||||
date=datetime(2025, 6, 1, 10, 0, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
extracted_items = [{"table": "tasks", "data": {"title": "Fix the bug", "priority": "high"}}]
|
||||
|
||||
with patch("app.integrations.settings") as mock_int_settings, \
|
||||
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||
patch("app.core.agent_runner._extract_items_from_content", new_callable=AsyncMock, return_value=extracted_items) as mock_extract, \
|
||||
patch("app.core.agent_runner._send_insert_to_client", new_callable=AsyncMock, return_value={"ok": True}) as mock_insert, \
|
||||
patch("app.core.agent_runner.async_session"):
|
||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
|
||||
mock_gmail = AsyncMock()
|
||||
mock_gmail.fetch_messages = AsyncMock(return_value=[sample_email])
|
||||
mock_gmail.refreshed_credentials = None
|
||||
|
||||
with patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||
patch("app.integrations.get_provider", return_value=mock_gmail):
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
mock_extract.assert_called_once()
|
||||
mock_insert.assert_called_once()
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "success"
|
||||
assert kwargs["items_processed"] == 1
|
||||
assert kwargs["items_created"] == 1
|
||||
assert kwargs["config_type"] == "cloud"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_provider_fetch_error():
|
||||
"""Cloud agent records error status when provider fetch raises RuntimeError."""
|
||||
credentials = {"token": "abc"}
|
||||
config = _make_cloud_config()
|
||||
config.oauth_token_encrypted = "some_encrypted_value" # non-empty so decrypt step is reached
|
||||
config.prompt_template = "Extract tasks."
|
||||
config.data_types = ["tasks"]
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.fetch_messages = AsyncMock(side_effect=RuntimeError("API quota exceeded"))
|
||||
mock_provider.refreshed_credentials = None
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_finalize, \
|
||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||
patch("app.core.agent_runner.async_session"):
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
_, kwargs = mock_finalize.call_args
|
||||
assert kwargs["status"] == "error"
|
||||
assert any("quota" in e.lower() or "fetch" in e.lower() for e in kwargs["errors"])
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_run_cloud_agent_refreshed_token_persisted():
|
||||
"""When the provider refreshes its token, the new ciphertext is written to DB."""
|
||||
from app.integrations import encrypt_token
|
||||
from cryptography.fernet import Fernet as _Fernet
|
||||
|
||||
fernet_key = _Fernet.generate_key().decode()
|
||||
credentials = {"token": "old_token", "refresh_token": "rt_old"}
|
||||
fresh_credentials = {"token": "new_token", "refresh_token": "rt_new"}
|
||||
|
||||
config = _make_cloud_config()
|
||||
config.prompt_template = "Extract tasks."
|
||||
config.data_types = ["tasks"]
|
||||
|
||||
with patch("app.integrations.settings") as ms:
|
||||
ms.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
config.oauth_token_encrypted = encrypt_token(credentials)
|
||||
|
||||
run_log = _make_run_log(config.id, agent_type="cloud")
|
||||
mgr = _make_manager()
|
||||
|
||||
mock_provider = AsyncMock()
|
||||
mock_provider.fetch_messages = AsyncMock(return_value=[])
|
||||
mock_provider.refreshed_credentials = fresh_credentials # token was refreshed
|
||||
|
||||
# Track DB writes via mock async_session.
|
||||
mock_cfg_row = MagicMock()
|
||||
mock_cfg_row.oauth_token_encrypted = None
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_db.scalar_one_or_none = AsyncMock(return_value=mock_cfg_row)
|
||||
cfg_result = MagicMock()
|
||||
cfg_result.scalar_one_or_none.return_value = mock_cfg_row
|
||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock), \
|
||||
patch("app.integrations.decrypt_token", return_value=credentials), \
|
||||
patch("app.integrations.get_provider", return_value=mock_provider), \
|
||||
patch("app.integrations.encrypt_token", return_value="new_encrypted") as mock_encrypt, \
|
||||
patch("app.core.agent_runner.async_session", return_value=mock_db), \
|
||||
patch("app.integrations.settings") as mock_int_settings:
|
||||
mock_int_settings.OAUTH_ENCRYPTION_KEY = fernet_key
|
||||
await run_cloud_agent(_FREE_UID, config, run_log, mgr)
|
||||
|
||||
# The new encrypted token should have been written to the config row.
|
||||
mock_encrypt.assert_called_once_with(fresh_credentials)
|
||||
assert mock_cfg_row.oauth_token_encrypted == "new_encrypted"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_finalize_run_updates_cloud_config_last_run_at():
|
||||
"""_finalize_run with config_type='cloud' updates CloudAgentConfig.last_run_at."""
|
||||
from app.core.agent_runner import _finalize_run
|
||||
|
||||
run_log = _make_run_log(str(uuid.uuid4()), agent_type="cloud")
|
||||
run_log.id = str(uuid.uuid4())
|
||||
|
||||
mock_cfg = MagicMock()
|
||||
mock_cfg.last_run_at = None
|
||||
|
||||
cfg_result = MagicMock()
|
||||
cfg_result.scalar_one_or_none.return_value = mock_cfg
|
||||
|
||||
mock_db = AsyncMock()
|
||||
mock_db.__aenter__ = AsyncMock(return_value=mock_db)
|
||||
mock_db.__aexit__ = AsyncMock(return_value=False)
|
||||
mock_db.merge = AsyncMock(return_value=run_log)
|
||||
mock_db.execute = AsyncMock(return_value=cfg_result)
|
||||
mock_db.commit = AsyncMock()
|
||||
|
||||
config_id = str(uuid.uuid4())
|
||||
|
||||
with patch("app.core.agent_runner.async_session", return_value=mock_db):
|
||||
await _finalize_run(
|
||||
run_log,
|
||||
status="success",
|
||||
update_config_last_run=True,
|
||||
config_id=config_id,
|
||||
config_type="cloud",
|
||||
)
|
||||
|
||||
# CloudAgentConfig.last_run_at should have been set.
|
||||
assert mock_cfg.last_run_at is not None
|
||||
mock_db.commit.assert_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# trigger_pending_runs
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_no_overdue():
|
||||
"""Pending-run scan is skipped because agent config is client-owned."""
|
||||
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_device_id_filter():
|
||||
"""Device filtering is no longer backend-managed in pending runs."""
|
||||
|
||||
mgr = _make_manager(device_id="dev-001")
|
||||
|
||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_pending_runs_dispatches_overdue():
|
||||
"""No pending runs are dispatched by backend after config deprecation."""
|
||||
|
||||
mgr = _make_manager()
|
||||
|
||||
with patch("app.core.agent_runner.run_local_agent", new_callable=AsyncMock) as mock_run:
|
||||
await trigger_pending_runs(_FREE_UID, "dev-001", mgr)
|
||||
|
||||
mock_run.assert_not_called()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration: POST /agents/can-create and /agents/trigger
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _override_db(db_session):
|
||||
"""Route all get_session calls to the test SQLite session."""
|
||||
|
||||
async def _gen():
|
||||
yield db_session
|
||||
|
||||
app.dependency_overrides[get_session] = _gen
|
||||
yield
|
||||
app.dependency_overrides.pop(get_session, None)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_create_agent_allows_when_under_limit(client):
|
||||
"""POST /agents/can-create returns allowed=True when under tier limit."""
|
||||
resp = client.post(
|
||||
"/api/v1/agents/can-create",
|
||||
json={"active_agents": 0},
|
||||
headers=auth_header("free"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["allowed"] is True
|
||||
assert body["tier"] == "free"
|
||||
assert body["active_agents"] == 0
|
||||
assert body["limit"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_can_create_agent_denies_when_at_limit(client):
|
||||
"""POST /agents/can-create returns allowed=False at free-tier limit."""
|
||||
resp = client.post(
|
||||
"/api/v1/agents/can-create",
|
||||
json={"active_agents": 2},
|
||||
headers=auth_header("free"),
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["allowed"] is False
|
||||
assert body["limit"] == 2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_trigger_run_local_agent_creates_run_log(client, db_session):
|
||||
"""POST /agents/trigger creates a local run log and dispatches background task."""
|
||||
dispatched: list[tuple[str, str]] = []
|
||||
|
||||
async def _fake_run(user_id, cfg, run_log, device_mgr):
|
||||
dispatched.append((user_id, cfg.id))
|
||||
|
||||
def _fake_create_task(coro):
|
||||
coro.close()
|
||||
return MagicMock()
|
||||
|
||||
with patch("app.api.routes.agents.run_local_agent", new_callable=AsyncMock, side_effect=_fake_run), \
|
||||
patch("asyncio.create_task") as mock_create_task:
|
||||
mock_create_task.side_effect = _fake_create_task
|
||||
resp = client.post(
|
||||
"/api/v1/agents/trigger",
|
||||
json={
|
||||
"directory": "/home/user/docs",
|
||||
"what_to_extract": ["task", "note"],
|
||||
"batch_interval": "0 */6 * * *",
|
||||
"custom_agent_prompt": "Extract tasks and notes.",
|
||||
"active_agents": 0,
|
||||
},
|
||||
headers=auth_header("power"),
|
||||
)
|
||||
|
||||
assert resp.status_code == 202
|
||||
data = resp.json()
|
||||
assert isinstance(data["agent_id"], str)
|
||||
assert data["agent_id"]
|
||||
assert data["status"] == "running"
|
||||
assert data["agent_type"] == "local"
|
||||
|
||||
# Verify create_task was called (dispatching background run).
|
||||
mock_create_task.assert_called_once()
|
||||
@@ -1,242 +0,0 @@
|
||||
"""Tests for the Chatbot Journey endpoints.
|
||||
|
||||
Covers:
|
||||
1. Start journey for local agent → session_id + first question, done=False
|
||||
2. Start journey for cloud agent → contextual email-focused question
|
||||
3. Start journey with existing agent_id → session seeded, first question returned
|
||||
4. Start journey with non-existent agent_id → still succeeds (graceful fallback)
|
||||
5. Message: continue conversation → done=False, follow-up question returned
|
||||
6. Message: LLM wraps up → done=True + prompt_template extracted correctly
|
||||
7. Message with max-turns nudge → no crash, returns response
|
||||
8. Invalid session_id → 404
|
||||
9. Expired session → 404
|
||||
10. Session ownership: user B cannot access user A's session
|
||||
11. No JWT on /start → 401
|
||||
12. No JWT on /message → 401
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from fastapi.testclient import TestClient
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from app.api.routes.agent_setup import (
|
||||
_SESSION_TTL_SECONDS,
|
||||
_TEMPLATE_END,
|
||||
_TEMPLATE_START,
|
||||
_extract_template,
|
||||
_sessions,
|
||||
)
|
||||
from app.models import LocalAgentConfig
|
||||
from tests.conftest import TEST_USER_IDS, auth_header
|
||||
|
||||
# ── Helpers ──────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def _start(client: TestClient, agent_type: str = "local", agent_id: str | None = None, tier: str = "power") -> dict:
|
||||
body: dict = {"agent_type": agent_type}
|
||||
if agent_id:
|
||||
body["agent_id"] = agent_id
|
||||
resp = client.post("/api/v1/agents/journey/start", json=body, headers=auth_header(tier))
|
||||
return resp
|
||||
|
||||
|
||||
def _message(client: TestClient, session_id: str, message: str, tier: str = "power") -> dict:
|
||||
return client.post(
|
||||
"/api/v1/agents/journey/message",
|
||||
json={"session_id": session_id, "message": message},
|
||||
headers=auth_header(tier),
|
||||
)
|
||||
|
||||
|
||||
# ── Unit: _extract_template ───────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_extract_template_present():
|
||||
text = f"Some preamble.\n{_TEMPLATE_START}\nExtract tasks from emails.\n{_TEMPLATE_END}\nTrailing text."
|
||||
result = _extract_template(text)
|
||||
assert result == "Extract tasks from emails."
|
||||
|
||||
|
||||
def test_extract_template_absent():
|
||||
assert _extract_template("No markers here.") is None
|
||||
|
||||
|
||||
def test_extract_template_empty_content():
|
||||
text = f"{_TEMPLATE_START}\n{_TEMPLATE_END}"
|
||||
assert _extract_template(text) is None
|
||||
|
||||
|
||||
# ── Start journey ─────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_start_journey_local(client: TestClient):
|
||||
resp = _start(client, agent_type="local")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert "session_id" in body
|
||||
assert body["done"] is False
|
||||
assert body["prompt_template"] is None
|
||||
assert len(body["message"]) > 0
|
||||
# Local question should be about files/directories
|
||||
assert any(w in body["message"].lower() for w in ("file", "director", "document", "monitor"))
|
||||
|
||||
|
||||
def test_start_journey_cloud(client: TestClient):
|
||||
resp = _start(client, agent_type="cloud")
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["done"] is False
|
||||
# Cloud question should mention emails or messages
|
||||
assert any(w in body["message"].lower() for w in ("email", "message", "communication"))
|
||||
|
||||
|
||||
def test_start_journey_with_agent_id(client: TestClient, db_session: AsyncSession):
|
||||
"""When agent_id is provided, session should be created even if agent doesn't exist."""
|
||||
fake_agent_id = str(uuid.uuid4())
|
||||
resp = _start(client, agent_type="local", agent_id=fake_agent_id)
|
||||
# Should succeed gracefully even if the agent_id doesn't exist
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["done"] is False
|
||||
|
||||
|
||||
def test_start_journey_with_existing_agent(client: TestClient, db_session: AsyncSession):
|
||||
"""When a real local agent is provided, session is seeded with its prompt_template."""
|
||||
import asyncio
|
||||
|
||||
user_id = TEST_USER_IDS["power"]
|
||||
agent = LocalAgentConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=user_id,
|
||||
name="Test Agent",
|
||||
device_id="device-1",
|
||||
directory_paths=["/home/user/emails"],
|
||||
data_types=["tasks"],
|
||||
prompt_template="Extract tasks from .eml files.",
|
||||
file_extensions=[".eml"],
|
||||
schedule_cron="0 */6 * * *",
|
||||
enabled=True,
|
||||
)
|
||||
|
||||
async def _seed():
|
||||
db_session.add(agent)
|
||||
await db_session.commit()
|
||||
|
||||
asyncio.get_event_loop().run_until_complete(_seed())
|
||||
|
||||
resp = _start(client, agent_type="local", agent_id=agent.id)
|
||||
assert resp.status_code == 200
|
||||
body = resp.json()
|
||||
assert body["done"] is False
|
||||
# The session should be stored
|
||||
assert body["session_id"] in _sessions
|
||||
|
||||
|
||||
def test_start_journey_requires_auth(client: TestClient):
|
||||
resp = client.post("/api/v1/agents/journey/start", json={"agent_type": "local"})
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
# ── Message ───────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_message_continues_conversation(client: TestClient):
|
||||
"""A mid-journey reply (no template markers) returns done=False."""
|
||||
follow_up = "That looks good. Can you tell me more about priority rules?"
|
||||
|
||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
||||
start_resp = _start(client, agent_type="local")
|
||||
assert start_resp.status_code == 200
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
msg_resp = _message(client, session_id, "I have .eml and .txt files")
|
||||
assert msg_resp.status_code == 200
|
||||
body = msg_resp.json()
|
||||
assert body["done"] is False
|
||||
assert body["prompt_template"] is None
|
||||
assert body["message"] == follow_up
|
||||
assert body["session_id"] == session_id
|
||||
|
||||
|
||||
def test_message_produces_template(client: TestClient):
|
||||
"""When the LLM includes PROMPT_TEMPLATE markers, done=True and prompt_template is set."""
|
||||
final_template = "Extract tasks from email. Subject → title. 'urgent' → high priority."
|
||||
llm_response = (
|
||||
"Great, I have all the information I need.\n"
|
||||
f"{_TEMPLATE_START}\n{final_template}\n{_TEMPLATE_END}\n"
|
||||
)
|
||||
|
||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=llm_response)):
|
||||
start_resp = _start(client, agent_type="cloud")
|
||||
assert start_resp.status_code == 200
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
msg_resp = _message(client, session_id, "Only invoices from clients")
|
||||
assert msg_resp.status_code == 200
|
||||
body = msg_resp.json()
|
||||
assert body["done"] is True
|
||||
assert body["prompt_template"] == final_template
|
||||
# Session should be cleaned up
|
||||
assert session_id not in _sessions
|
||||
|
||||
|
||||
def test_message_invalid_session(client: TestClient):
|
||||
resp = _message(client, "nonexistent-session-id", "hello")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_message_wrong_owner(client: TestClient):
|
||||
"""User B cannot access user A's session."""
|
||||
start_resp = _start(client, agent_type="local", tier="power")
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
# user with "pro" tier (different user_id) tries to send a message
|
||||
resp = client.post(
|
||||
"/api/v1/agents/journey/message",
|
||||
json={"session_id": session_id, "message": "hello"},
|
||||
headers=auth_header("pro"), # different user
|
||||
)
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_message_expired_session(client: TestClient):
|
||||
"""Expired sessions return 404."""
|
||||
start_resp = _start(client, agent_type="local")
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
# Manually expire the session
|
||||
_sessions[session_id].created_at = time.monotonic() - _SESSION_TTL_SECONDS - 1
|
||||
|
||||
resp = _message(client, session_id, "hello")
|
||||
assert resp.status_code == 404
|
||||
|
||||
|
||||
def test_message_requires_auth(client: TestClient):
|
||||
resp = client.post(
|
||||
"/api/v1/agents/journey/message",
|
||||
json={"session_id": "any", "message": "hello"},
|
||||
)
|
||||
assert resp.status_code == 401
|
||||
|
||||
|
||||
def test_message_max_turns_nudge(client: TestClient):
|
||||
"""After _MAX_TURNS user messages, a system nudge is appended but no crash occurs."""
|
||||
from app.api.routes.agent_setup import _MAX_TURNS
|
||||
|
||||
follow_up = "Tell me more about priority rules."
|
||||
|
||||
with patch("app.api.routes.agent_setup._call_llm", new=AsyncMock(return_value=follow_up)):
|
||||
start_resp = _start(client, agent_type="local")
|
||||
session_id = start_resp.json()["session_id"]
|
||||
|
||||
for i in range(_MAX_TURNS):
|
||||
resp = _message(client, session_id, f"Answer {i + 1}")
|
||||
assert resp.status_code == 200
|
||||
# While no template produced, session must still exist
|
||||
if resp.json()["done"]:
|
||||
break # LLM decided to wrap up early — also fine
|
||||
@@ -1,184 +0,0 @@
|
||||
"""Unit tests for Step 1 file classification (_classify_file).
|
||||
|
||||
These tests call the real LLM so they require OPENAI_API_KEY / LLM env vars.
|
||||
Run with: pytest tests/test_classify_file.py -v
|
||||
|
||||
To run a quick manual check against a real file without the full UI:
|
||||
python -m tests.test_classify_file <path/to/file.txt> [project_name...]
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import sys
|
||||
|
||||
import pytest
|
||||
|
||||
from app.core.agent_runner import _classify_file
|
||||
|
||||
|
||||
# ── Fixtures ──────────────────────────────────────────────────────────────
|
||||
|
||||
PROJECTS_SAMPLE = [
|
||||
{
|
||||
"id": "aaaa-0001-0000-0000-000000000001",
|
||||
"name": "ARPA Sicilia POC",
|
||||
"status": "active",
|
||||
"aiSummary": "Proof of concept for AI features targeting ARPA Sicilia agency.",
|
||||
},
|
||||
{
|
||||
"id": "bbbb-0002-0000-0000-000000000002",
|
||||
"name": "SNAM AI Meeting Prep",
|
||||
"status": "active",
|
||||
"aiSummary": "AI-assisted preparation of meeting materials for SNAM.",
|
||||
},
|
||||
{
|
||||
"id": "cccc-0003-0000-0000-000000000003",
|
||||
"name": "SFERA+ Wave 2",
|
||||
"status": "active",
|
||||
"aiSummary": "Second wave of the SFERA+ whitelist project.",
|
||||
},
|
||||
]
|
||||
|
||||
ARPA_EMAIL = """\
|
||||
to: roberto.musso@hpe.com; luca.tondin@hpecds.com
|
||||
isImportance: normal
|
||||
hasAttachment: True
|
||||
---
|
||||
## Body
|
||||
Buongiorno,
|
||||
|
||||
In riferimento alla riunione di ieri sul POC ARPA Sicilia, vi invio il riassunto
|
||||
dei deliverable concordati:
|
||||
- Preparare demo entro il 30 marzo
|
||||
- Condividere documentazione tecnica con il team ARPA
|
||||
- Fissare call di follow-up la prossima settimana
|
||||
|
||||
Cordiali saluti
|
||||
Roberto Marchetti
|
||||
"""
|
||||
|
||||
SNAM_EMAIL = """\
|
||||
to: roberto.musso@hpe.com
|
||||
isImportance: high
|
||||
hasAttachment: False
|
||||
---
|
||||
## Body
|
||||
Ciao,
|
||||
ti invio l'agenda per la riunione SNAM di domani.
|
||||
Per favore conferma la tua presenza.
|
||||
"""
|
||||
|
||||
UNRELATED_EMAIL = """\
|
||||
to: roberto.musso@hpe.com
|
||||
isImportance: normal
|
||||
---
|
||||
## Body
|
||||
Benvenuto nel programma HPE Employee Learning Series.
|
||||
Completa la formazione richiesta entro la fine del trimestre.
|
||||
"""
|
||||
|
||||
|
||||
# ── Tests ─────────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_arpa_matches_existing():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="arpa_email.txt",
|
||||
file_content=ARPA_EMAIL,
|
||||
projects=PROJECTS_SAMPLE,
|
||||
config_data_types=["tasks", "notes", "timelines"],
|
||||
)
|
||||
assert project_id == "aaaa-0001-0000-0000-000000000001", (
|
||||
f"Expected ARPA project, got project_id={project_id!r} new_name={new_name!r}"
|
||||
)
|
||||
assert new_name is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_snam_matches_existing():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="snam_email.txt",
|
||||
file_content=SNAM_EMAIL,
|
||||
projects=PROJECTS_SAMPLE,
|
||||
config_data_types=["tasks", "notes"],
|
||||
)
|
||||
assert project_id == "bbbb-0002-0000-0000-000000000002", (
|
||||
f"Expected SNAM project, got project_id={project_id!r} new_name={new_name!r}"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_unrelated_returns_new():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="learning_email.txt",
|
||||
file_content=UNRELATED_EMAIL,
|
||||
projects=PROJECTS_SAMPLE,
|
||||
config_data_types=["tasks", "notes"],
|
||||
)
|
||||
assert project_id == "new"
|
||||
assert new_name is not None # LLM should suggest a name
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_empty_file_returns_new():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="empty.txt",
|
||||
file_content=" ",
|
||||
projects=PROJECTS_SAMPLE,
|
||||
config_data_types=["tasks"],
|
||||
)
|
||||
assert project_id == "new"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_classify_no_projects_returns_new():
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path="arpa_email.txt",
|
||||
file_content=ARPA_EMAIL,
|
||||
projects=[],
|
||||
config_data_types=["tasks", "notes"],
|
||||
)
|
||||
assert project_id == "new"
|
||||
assert new_name is not None
|
||||
|
||||
|
||||
# ── CLI quick-test runner ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
async def _cli_test(file_path: str, project_names: list[str]) -> None:
|
||||
"""Run Step 1 classification against a real file from the CLI."""
|
||||
import json
|
||||
from pathlib import Path
|
||||
|
||||
content = Path(file_path).read_text(encoding="utf-8", errors="replace")
|
||||
projects = [
|
||||
{"id": f"test-id-{i:04d}", "name": name, "status": "active", "aiSummary": ""}
|
||||
for i, name in enumerate(project_names)
|
||||
]
|
||||
|
||||
print(f"\nClassifying: {file_path}")
|
||||
print(f"Projects in context: {[p['name'] for p in projects]}\n")
|
||||
|
||||
project_id, domains, new_name = await _classify_file(
|
||||
file_path=file_path,
|
||||
file_content=content,
|
||||
projects=projects,
|
||||
config_data_types=["tasks", "notes", "timelines"],
|
||||
)
|
||||
|
||||
result = {
|
||||
"project_id": project_id,
|
||||
"matched_name": next((p["name"] for p in projects if p["id"] == project_id), None),
|
||||
"new_project_name": new_name,
|
||||
"domains": domains,
|
||||
}
|
||||
print(json.dumps(result, indent=2, ensure_ascii=False))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
if len(sys.argv) < 2:
|
||||
print("Usage: python -m tests.test_classify_file <file_path> [project_name ...]")
|
||||
sys.exit(1)
|
||||
asyncio.run(_cli_test(sys.argv[1], sys.argv[2:]))
|
||||
@@ -63,7 +63,7 @@ class _FakeLLM:
|
||||
async def test_run_home_uses_mocked_tool_result():
|
||||
fake_llm = _FakeLLM()
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||
):
|
||||
out = await run_home("user-1", "list my tasks", {})
|
||||
@@ -76,7 +76,7 @@ async def test_run_home_uses_mocked_tool_result():
|
||||
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
|
||||
fake_llm = _FakeLLM()
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||
):
|
||||
events = []
|
||||
@@ -103,7 +103,7 @@ async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
|
||||
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
|
||||
)
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=_ClassifierOnlyLLM()):
|
||||
with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()):
|
||||
domain = await _infer_floating_domain(
|
||||
"Quali sono i miei task per il progetto X",
|
||||
{
|
||||
@@ -165,7 +165,7 @@ async def test_run_floating_strips_xml_like_tags_from_final_text():
|
||||
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||
)
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||
):
|
||||
text, _domain = await run_floating(
|
||||
@@ -187,7 +187,7 @@ async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
|
||||
yield "token", "Hai 1 task:\\n"
|
||||
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||
):
|
||||
events = []
|
||||
@@ -233,7 +233,7 @@ async def test_run_floating_stream_falls_back_to_final_response_content_when_ast
|
||||
if False:
|
||||
yield None
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=_NoChunkLLM()), patch(
|
||||
with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch(
|
||||
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
|
||||
):
|
||||
events = []
|
||||
@@ -255,7 +255,7 @@ async def test_run_floating_returns_fallback_when_sanitization_would_empty_text(
|
||||
async def _fake_run_single_agent(**_kwargs):
|
||||
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
|
||||
):
|
||||
text, _domain = await run_floating(
|
||||
@@ -274,7 +274,7 @@ async def test_run_floating_stream_returns_fallback_when_sanitization_would_empt
|
||||
async def _fake_stream(**_kwargs):
|
||||
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
|
||||
|
||||
with patch("app.core.deep_agent.get_llm", return_value=fake_llm), patch(
|
||||
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
|
||||
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
|
||||
):
|
||||
events = []
|
||||
|
||||
@@ -156,40 +156,6 @@ async def test_manager_unregister_cancels_pending_calls(manager, mock_ws):
|
||||
assert fut.cancelled()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_agent_data_queue(manager, mock_ws):
|
||||
manager.register("user1", "dev-A", mock_ws)
|
||||
q = manager.get_agent_data_queue("user1", "run-xyz")
|
||||
# Put a frame and get it back.
|
||||
frame = {"type": "agent_data", "run_id": "run-xyz", "files": []}
|
||||
await q.put(frame)
|
||||
assert await q.get() == frame
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_agent_data_queue_creates_once(manager, mock_ws):
|
||||
manager.register("user1", "dev-A", mock_ws)
|
||||
q1 = manager.get_agent_data_queue("user1", "run-1")
|
||||
q2 = manager.get_agent_data_queue("user1", "run-1")
|
||||
assert q1 is q2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_agent_data_queue_raises_when_offline(manager):
|
||||
with pytest.raises(RuntimeError, match="not connected"):
|
||||
manager.get_agent_data_queue("ghost", "run-1")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_manager_cleanup_agent_data_queue(manager, mock_ws):
|
||||
manager.register("user1", "dev-A", mock_ws)
|
||||
manager.get_agent_data_queue("user1", "run-1")
|
||||
manager.cleanup_agent_data_queue("user1", "run-1")
|
||||
# After cleanup a new queue is created (not the same object).
|
||||
q_new = manager.get_agent_data_queue("user1", "run-1")
|
||||
assert q_new is not None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Integration tests — /api/v1/ws/device endpoint
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -266,43 +232,6 @@ def test_ws_device_tool_result_dispatched(client):
|
||||
assert any(c["call_id"] == "call-123" for c in captured)
|
||||
|
||||
|
||||
def test_ws_device_agent_data_enqueued(client):
|
||||
"""agent_data frame is placed in the per-run queue by the message loop."""
|
||||
from app.core.device_manager import device_manager as dm
|
||||
|
||||
token = make_jwt(tier="free")
|
||||
user_id = TEST_USER_IDS["free"]
|
||||
|
||||
# Capture the queue object the message loop accesses.
|
||||
captured_queue: list[asyncio.Queue] = []
|
||||
original_get_queue = dm.get_agent_data_queue
|
||||
|
||||
def _spy_get_queue(uid, run_id):
|
||||
q = original_get_queue(uid, run_id)
|
||||
if not captured_queue:
|
||||
captured_queue.append(q)
|
||||
return q
|
||||
|
||||
with patch.object(dm, "get_agent_data_queue", side_effect=_spy_get_queue):
|
||||
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 9999):
|
||||
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
|
||||
ws.send_text(_device_hello("dev-001"))
|
||||
ws.send_text(
|
||||
json.dumps(
|
||||
{
|
||||
"type": "agent_data",
|
||||
"run_id": "run-XYZ",
|
||||
"files": [{"path": "/tmp/file.txt", "content": "hello"}],
|
||||
}
|
||||
)
|
||||
)
|
||||
ws.close()
|
||||
|
||||
# The queue should have received exactly one frame.
|
||||
assert captured_queue, "queue was never accessed"
|
||||
assert not captured_queue[0].empty()
|
||||
|
||||
|
||||
def test_ws_device_disconnect_marks_run_logs_as_error(client, db_session):
|
||||
"""On disconnect, _mark_runs_disconnected is called with the correct user_id."""
|
||||
from app.api.routes import device_ws as _dws
|
||||
|
||||
@@ -45,9 +45,6 @@ def test_v2_frame_types_still_exist():
|
||||
"tool_result",
|
||||
"final",
|
||||
"ping",
|
||||
"agent_run",
|
||||
"agent_data",
|
||||
"agent_complete",
|
||||
"device_hello",
|
||||
]
|
||||
for name in v2_types:
|
||||
|
||||
Reference in New Issue
Block a user