173 lines
6.2 KiB
Python
173 lines
6.2 KiB
Python
"""Unit tests for ScoutEngine.trigger_scout / _process_item."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from app.models import CloudScoutConfig, ScoutTriageQueue, User, Subscription
|
|
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef, TriageVerdict
|
|
from app.scouts.connectors.registry import register_connector, _reset_for_tests
|
|
from app.scouts.engine import ScoutEngine
|
|
from tests.conftest import _TestSessionLocal
|
|
|
|
|
|
def _make_connector(items, content_for):
|
|
c = AsyncMock()
|
|
# source_type must match the scout.provider ("gmail") so get_connector()
|
|
# finds it when the engine calls get_connector(scout.provider).
|
|
c.source_type = "gmail"
|
|
c.list_new = AsyncMock(return_value=items)
|
|
c.fetch_content = AsyncMock(side_effect=lambda scout, ref: content_for[ref.source_msg_ref])
|
|
c.archive = AsyncMock()
|
|
return c
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _registry():
|
|
_reset_for_tests()
|
|
yield
|
|
_reset_for_tests()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_relevant_item_inserted_into_queue(monkeypatch):
|
|
user_id = "00000000-0000-0000-0000-000000000003" # power tier seeded in conftest
|
|
scout_id = str(uuid.uuid4())
|
|
|
|
async with _TestSessionLocal() as session:
|
|
scout = CloudScoutConfig(
|
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
|
)
|
|
session.add(scout)
|
|
await session.commit()
|
|
|
|
refs = [ItemRef(source_msg_ref="msg-1")]
|
|
content = {"msg-1": ItemContent(metadata=ItemMetadata(subject="Hi"), body_text="task tomorrow")}
|
|
connector = _make_connector(refs, content)
|
|
register_connector(connector)
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
monkeypatch.setattr(
|
|
engine,
|
|
"_triage_llm",
|
|
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="task", confidence=0.9)),
|
|
)
|
|
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
|
|
async with _TestSessionLocal() as session:
|
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
|
assert len(rows) == 1
|
|
assert rows[0].source_msg_ref == "msg-1"
|
|
assert rows[0].triage_verdict == "relevant"
|
|
assert rows[0].status == "queued"
|
|
connector.archive.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_spam_with_auto_trash_archives_and_does_not_queue(monkeypatch):
|
|
user_id = "00000000-0000-0000-0000-000000000003"
|
|
scout_id = str(uuid.uuid4())
|
|
|
|
async with _TestSessionLocal() as session:
|
|
scout = CloudScoutConfig(
|
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
|
enabled=True, auto_trash_spam=True, device_inactivity_pause_days=14,
|
|
)
|
|
session.add(scout)
|
|
await session.commit()
|
|
|
|
refs = [ItemRef(source_msg_ref="msg-spam")]
|
|
content = {"msg-spam": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
|
connector = _make_connector(refs, content)
|
|
register_connector(connector)
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
monkeypatch.setattr(
|
|
engine,
|
|
"_triage_llm",
|
|
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
|
)
|
|
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
|
|
async with _TestSessionLocal() as session:
|
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
|
assert rows == []
|
|
connector.archive.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_spam_without_auto_trash_does_not_archive_and_does_not_queue(monkeypatch):
|
|
user_id = "00000000-0000-0000-0000-000000000003"
|
|
scout_id = str(uuid.uuid4())
|
|
|
|
async with _TestSessionLocal() as session:
|
|
scout = CloudScoutConfig(
|
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
|
)
|
|
session.add(scout)
|
|
await session.commit()
|
|
|
|
refs = [ItemRef(source_msg_ref="msg-2")]
|
|
content = {"msg-2": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
|
connector = _make_connector(refs, content)
|
|
register_connector(connector)
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
monkeypatch.setattr(
|
|
engine,
|
|
"_triage_llm",
|
|
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
|
)
|
|
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
|
|
async with _TestSessionLocal() as session:
|
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
|
assert rows == []
|
|
connector.archive.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_idempotent_replay(monkeypatch):
|
|
user_id = "00000000-0000-0000-0000-000000000003"
|
|
scout_id = str(uuid.uuid4())
|
|
|
|
async with _TestSessionLocal() as session:
|
|
session.add(CloudScoutConfig(
|
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
|
))
|
|
await session.commit()
|
|
|
|
refs = [ItemRef(source_msg_ref="msg-3")]
|
|
content = {"msg-3": ItemContent(metadata=ItemMetadata(subject="x"), body_text="y")}
|
|
connector = _make_connector(refs, content)
|
|
register_connector(connector)
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
monkeypatch.setattr(
|
|
engine,
|
|
"_triage_llm",
|
|
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="x", confidence=0.5)),
|
|
)
|
|
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
|
|
async with _TestSessionLocal() as session:
|
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
|
assert len(rows) == 1, "Replay must not create duplicate queue rows"
|