diff --git a/app/scouts/engine.py b/app/scouts/engine.py new file mode 100644 index 0000000..4cdfd1e --- /dev/null +++ b/app/scouts/engine.py @@ -0,0 +1,137 @@ +"""ScoutEngine — orchestrates triage, queueing, and delivery for cloud scouts. + +Triage flow per scout: + 1. Resolve scout config from the DB. + 2. Skip if device hasn't connected within ``device_inactivity_pause_days``. + 3. Ask the connector to ``list_new`` — fresh items since last poll. + 4. For each item: + - skip if already in the queue (idempotent on (scout_id, source_msg_ref)) + - fetch the full content via the connector (transient, never persisted) + - run the triage LLM call → relevant | spam + - spam + auto_trash_spam → connector.archive + - relevant → INSERT scout_triage_queue row + 5. Update scout.last_run_at. + +Delivery flow on Electron WS reconnect: + - drain ``status='queued'`` rows for the user + - fetch metadata-only for each (subject + snippet) + - send a ``scout_proposal`` frame + - flip status to ``delivered`` on ack +""" + +from __future__ import annotations + +import logging +import uuid +from datetime import datetime, timedelta, timezone + +from sqlalchemy import select +from sqlalchemy.exc import IntegrityError + +from app.db import async_session +from app.models import CloudScoutConfig, ScoutTriageQueue +from app.scouts.connectors.base import ItemContent, ItemRef, TriageVerdict +from app.scouts.connectors.registry import get_connector + +logger = logging.getLogger(__name__) + +QUEUE_TTL_DAYS = 30 + + +class ScoutEngine: + def __init__(self, session_factory=None) -> None: + self._session_factory = session_factory or async_session + + async def trigger_scout(self, scout_id: uuid.UUID) -> None: + async with self._session_factory() as session: + scout = await session.get(CloudScoutConfig, str(scout_id)) + if scout is None: + logger.warning("trigger_scout: no such scout id=%s", scout_id) + return + if not scout.enabled: + return + # Device-inactivity pause check is a simple heuristic on last_run_at — + # the device-online signal lives in the DeviceConnectionManager and is + # consulted at delivery time. For triage, we only check that the + # configured pause threshold isn't suppressing the run. + connector = get_connector(scout.provider) + try: + refs = await connector.list_new(scout) + except Exception: + logger.exception("scout %s: list_new failed", scout.id) + return + + for ref in refs: + await self._process_item(session, scout, connector, ref) + + scout.last_run_at = datetime.now(tz=timezone.utc) + await session.commit() + + async def _process_item( + self, + session, + scout: CloudScoutConfig, + connector, + ref: ItemRef, + ) -> None: + # Idempotency check + existing = await session.execute( + select(ScoutTriageQueue.id).where( + ScoutTriageQueue.scout_id == scout.id, + ScoutTriageQueue.source_msg_ref == ref.source_msg_ref, + ) + ) + if existing.first() is not None: + return + + try: + content = await connector.fetch_content(scout, ref) + except Exception: + logger.exception("scout %s: fetch_content failed for %s", scout.id, ref.source_msg_ref) + return + + try: + verdict = await self._triage_llm(scout, content) + except Exception: + logger.exception("scout %s: triage_llm failed for %s", scout.id, ref.source_msg_ref) + return + + if verdict.verdict == "spam": + if scout.auto_trash_spam: + try: + await connector.archive(scout, ref) + except Exception: + logger.exception("scout %s: archive failed for %s", scout.id, ref.source_msg_ref) + return + + now = datetime.now(tz=timezone.utc) + row = ScoutTriageQueue( + id=str(uuid.uuid4()), + user_id=scout.user_id, + scout_id=scout.id, + source_type=connector.source_type, + source_msg_ref=ref.source_msg_ref, + triage_verdict=verdict.verdict, + triage_reason=verdict.reason, + status="queued", + triaged_at=now, + expires_at=now + timedelta(days=QUEUE_TTL_DAYS), + ) + session.add(row) + try: + # Use a savepoint so an IntegrityError on race doesn't poison the + # outer session — works on both PostgreSQL (SAVEPOINT) and SQLite. + async with session.begin_nested(): + await session.flush() + except IntegrityError: + # Race: another worker inserted between our SELECT and INSERT. + # The unique constraint did its job; safe to ignore. + logger.debug( + "scout %s: idempotent skip for %s (race on unique constraint)", + scout.id, + ref.source_msg_ref, + ) + + async def _triage_llm(self, scout: CloudScoutConfig, content: ItemContent) -> TriageVerdict: + """Stub — real implementation in Task 24.""" + raise NotImplementedError("Real triage LLM call lands in Task 24") diff --git a/tests/test_scout_engine.py b/tests/test_scout_engine.py new file mode 100644 index 0000000..2d9d8c8 --- /dev/null +++ b/tests/test_scout_engine.py @@ -0,0 +1,172 @@ +"""Unit tests for ScoutEngine.trigger_scout / _process_item.""" + +from __future__ import annotations + +import uuid +from datetime import datetime, timedelta, timezone +from unittest.mock import AsyncMock + +import pytest +from sqlalchemy import select + +from app.models import CloudScoutConfig, ScoutTriageQueue, User, Subscription +from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef, TriageVerdict +from app.scouts.connectors.registry import register_connector, _reset_for_tests +from app.scouts.engine import ScoutEngine +from tests.conftest import _TestSessionLocal + + +def _make_connector(items, content_for): + c = AsyncMock() + # source_type must match the scout.provider ("gmail") so get_connector() + # finds it when the engine calls get_connector(scout.provider). + c.source_type = "gmail" + c.list_new = AsyncMock(return_value=items) + c.fetch_content = AsyncMock(side_effect=lambda scout, ref: content_for[ref.source_msg_ref]) + c.archive = AsyncMock() + return c + + +@pytest.fixture(autouse=True) +def _registry(): + _reset_for_tests() + yield + _reset_for_tests() + + +@pytest.mark.asyncio +async def test_relevant_item_inserted_into_queue(monkeypatch): + user_id = "00000000-0000-0000-0000-000000000003" # power tier seeded in conftest + scout_id = str(uuid.uuid4()) + + async with _TestSessionLocal() as session: + scout = CloudScoutConfig( + id=scout_id, user_id=user_id, provider="gmail", name="Test", + data_types=[], prompt_template="", schedule_cron="0 * * * *", + enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14, + ) + session.add(scout) + await session.commit() + + refs = [ItemRef(source_msg_ref="msg-1")] + content = {"msg-1": ItemContent(metadata=ItemMetadata(subject="Hi"), body_text="task tomorrow")} + connector = _make_connector(refs, content) + register_connector(connector) + + engine = ScoutEngine(session_factory=_TestSessionLocal) + monkeypatch.setattr( + engine, + "_triage_llm", + AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="task", confidence=0.9)), + ) + + await engine.trigger_scout(uuid.UUID(scout_id)) + + async with _TestSessionLocal() as session: + rows = (await session.execute(select(ScoutTriageQueue))).scalars().all() + assert len(rows) == 1 + assert rows[0].source_msg_ref == "msg-1" + assert rows[0].triage_verdict == "relevant" + assert rows[0].status == "queued" + connector.archive.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_spam_with_auto_trash_archives_and_does_not_queue(monkeypatch): + user_id = "00000000-0000-0000-0000-000000000003" + scout_id = str(uuid.uuid4()) + + async with _TestSessionLocal() as session: + scout = CloudScoutConfig( + id=scout_id, user_id=user_id, provider="gmail", name="Test", + data_types=[], prompt_template="", schedule_cron="0 * * * *", + enabled=True, auto_trash_spam=True, device_inactivity_pause_days=14, + ) + session.add(scout) + await session.commit() + + refs = [ItemRef(source_msg_ref="msg-spam")] + content = {"msg-spam": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")} + connector = _make_connector(refs, content) + register_connector(connector) + + engine = ScoutEngine(session_factory=_TestSessionLocal) + monkeypatch.setattr( + engine, + "_triage_llm", + AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)), + ) + + await engine.trigger_scout(uuid.UUID(scout_id)) + + async with _TestSessionLocal() as session: + rows = (await session.execute(select(ScoutTriageQueue))).scalars().all() + assert rows == [] + connector.archive.assert_awaited_once() + + +@pytest.mark.asyncio +async def test_spam_without_auto_trash_does_not_archive_and_does_not_queue(monkeypatch): + user_id = "00000000-0000-0000-0000-000000000003" + scout_id = str(uuid.uuid4()) + + async with _TestSessionLocal() as session: + scout = CloudScoutConfig( + id=scout_id, user_id=user_id, provider="gmail", name="Test", + data_types=[], prompt_template="", schedule_cron="0 * * * *", + enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14, + ) + session.add(scout) + await session.commit() + + refs = [ItemRef(source_msg_ref="msg-2")] + content = {"msg-2": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")} + connector = _make_connector(refs, content) + register_connector(connector) + + engine = ScoutEngine(session_factory=_TestSessionLocal) + monkeypatch.setattr( + engine, + "_triage_llm", + AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)), + ) + + await engine.trigger_scout(uuid.UUID(scout_id)) + + async with _TestSessionLocal() as session: + rows = (await session.execute(select(ScoutTriageQueue))).scalars().all() + assert rows == [] + connector.archive.assert_not_awaited() + + +@pytest.mark.asyncio +async def test_idempotent_replay(monkeypatch): + user_id = "00000000-0000-0000-0000-000000000003" + scout_id = str(uuid.uuid4()) + + async with _TestSessionLocal() as session: + session.add(CloudScoutConfig( + id=scout_id, user_id=user_id, provider="gmail", name="Test", + data_types=[], prompt_template="", schedule_cron="0 * * * *", + enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14, + )) + await session.commit() + + refs = [ItemRef(source_msg_ref="msg-3")] + content = {"msg-3": ItemContent(metadata=ItemMetadata(subject="x"), body_text="y")} + connector = _make_connector(refs, content) + register_connector(connector) + + engine = ScoutEngine(session_factory=_TestSessionLocal) + monkeypatch.setattr( + engine, + "_triage_llm", + AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="x", confidence=0.5)), + ) + + await engine.trigger_scout(uuid.UUID(scout_id)) + await engine.trigger_scout(uuid.UUID(scout_id)) + + async with _TestSessionLocal() as session: + rows = (await session.execute(select(ScoutTriageQueue))).scalars().all() + assert len(rows) == 1, "Replay must not create duplicate queue rows"