271 lines
10 KiB
Python
271 lines
10 KiB
Python
"""Unit tests for ScoutEngine.trigger_scout / _process_item."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from unittest.mock import AsyncMock
|
|
|
|
import pytest
|
|
from sqlalchemy import select
|
|
|
|
from app.models import CloudScoutConfig, ScoutTriageQueue, User, Subscription
|
|
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef, TriageVerdict
|
|
from app.scouts.connectors.registry import register_connector, _reset_for_tests
|
|
from app.scouts.engine import ScoutEngine
|
|
from tests.conftest import _TestSessionLocal
|
|
|
|
|
|
def _make_connector(items, content_for):
|
|
c = AsyncMock()
|
|
# source_type must match the scout.provider ("gmail") so get_connector()
|
|
# finds it when the engine calls get_connector(scout.provider).
|
|
c.source_type = "gmail"
|
|
c.list_new = AsyncMock(return_value=items)
|
|
c.fetch_content = AsyncMock(side_effect=lambda scout, ref: content_for[ref.source_msg_ref])
|
|
c.archive = AsyncMock()
|
|
return c
|
|
|
|
|
|
@pytest.fixture(autouse=True)
|
|
def _registry():
|
|
_reset_for_tests()
|
|
yield
|
|
_reset_for_tests()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_relevant_item_inserted_into_queue(monkeypatch):
|
|
user_id = "00000000-0000-0000-0000-000000000003" # power tier seeded in conftest
|
|
scout_id = str(uuid.uuid4())
|
|
|
|
async with _TestSessionLocal() as session:
|
|
scout = CloudScoutConfig(
|
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
|
)
|
|
session.add(scout)
|
|
await session.commit()
|
|
|
|
refs = [ItemRef(source_msg_ref="msg-1")]
|
|
content = {"msg-1": ItemContent(metadata=ItemMetadata(subject="Hi"), body_text="task tomorrow")}
|
|
connector = _make_connector(refs, content)
|
|
register_connector(connector)
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
monkeypatch.setattr(
|
|
engine,
|
|
"_triage_llm",
|
|
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="task", confidence=0.9)),
|
|
)
|
|
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
|
|
async with _TestSessionLocal() as session:
|
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
|
assert len(rows) == 1
|
|
assert rows[0].source_msg_ref == "msg-1"
|
|
assert rows[0].triage_verdict == "relevant"
|
|
assert rows[0].status == "queued"
|
|
connector.archive.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_spam_with_auto_trash_archives_and_does_not_queue(monkeypatch):
|
|
user_id = "00000000-0000-0000-0000-000000000003"
|
|
scout_id = str(uuid.uuid4())
|
|
|
|
async with _TestSessionLocal() as session:
|
|
scout = CloudScoutConfig(
|
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
|
enabled=True, auto_trash_spam=True, device_inactivity_pause_days=14,
|
|
)
|
|
session.add(scout)
|
|
await session.commit()
|
|
|
|
refs = [ItemRef(source_msg_ref="msg-spam")]
|
|
content = {"msg-spam": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
|
connector = _make_connector(refs, content)
|
|
register_connector(connector)
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
monkeypatch.setattr(
|
|
engine,
|
|
"_triage_llm",
|
|
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
|
)
|
|
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
|
|
async with _TestSessionLocal() as session:
|
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
|
assert rows == []
|
|
connector.archive.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_spam_without_auto_trash_does_not_archive_and_does_not_queue(monkeypatch):
|
|
user_id = "00000000-0000-0000-0000-000000000003"
|
|
scout_id = str(uuid.uuid4())
|
|
|
|
async with _TestSessionLocal() as session:
|
|
scout = CloudScoutConfig(
|
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
|
)
|
|
session.add(scout)
|
|
await session.commit()
|
|
|
|
refs = [ItemRef(source_msg_ref="msg-2")]
|
|
content = {"msg-2": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
|
connector = _make_connector(refs, content)
|
|
register_connector(connector)
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
monkeypatch.setattr(
|
|
engine,
|
|
"_triage_llm",
|
|
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
|
)
|
|
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
|
|
async with _TestSessionLocal() as session:
|
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
|
assert rows == []
|
|
connector.archive.assert_not_awaited()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_idempotent_replay(monkeypatch):
|
|
user_id = "00000000-0000-0000-0000-000000000003"
|
|
scout_id = str(uuid.uuid4())
|
|
|
|
async with _TestSessionLocal() as session:
|
|
session.add(CloudScoutConfig(
|
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
|
))
|
|
await session.commit()
|
|
|
|
refs = [ItemRef(source_msg_ref="msg-3")]
|
|
content = {"msg-3": ItemContent(metadata=ItemMetadata(subject="x"), body_text="y")}
|
|
connector = _make_connector(refs, content)
|
|
register_connector(connector)
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
monkeypatch.setattr(
|
|
engine,
|
|
"_triage_llm",
|
|
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="x", confidence=0.5)),
|
|
)
|
|
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
|
|
|
async with _TestSessionLocal() as session:
|
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
|
assert len(rows) == 1, "Replay must not create duplicate queue rows"
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_triage_llm_parses_json_response(monkeypatch):
|
|
"""Real _triage_llm path: mock the LLM ainvoke, verify TriageVerdict parsed correctly."""
|
|
from unittest.mock import MagicMock # noqa: PLC0415
|
|
|
|
from app.models import CloudScoutConfig # noqa: PLC0415
|
|
|
|
scout = CloudScoutConfig(
|
|
id=str(uuid.uuid4()),
|
|
user_id="00000000-0000-0000-0000-000000000003",
|
|
provider="gmail",
|
|
name="test-scout",
|
|
data_types=[],
|
|
prompt_template="watch invoices and project updates",
|
|
schedule_cron="0 * * * *",
|
|
enabled=True,
|
|
auto_trash_spam=False,
|
|
device_inactivity_pause_days=14,
|
|
)
|
|
content = ItemContent(
|
|
metadata=ItemMetadata(subject="Invoice 42", sender="billing@acme.com"),
|
|
body_text="Payment of €1 200 is due on 2026-06-01. Please confirm receipt.",
|
|
)
|
|
|
|
# Build a fake LangChain response whose .content is valid JSON.
|
|
fake_response = MagicMock()
|
|
fake_response.content = '{"verdict": "relevant", "reason": "invoice due", "confidence": 0.92}'
|
|
fake_response.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
|
|
|
# Fake LLM: .bind() returns self (or another mock with ainvoke).
|
|
fake_llm = MagicMock()
|
|
fake_llm.bind.return_value = fake_llm
|
|
fake_llm.ainvoke = AsyncMock(return_value=fake_response)
|
|
|
|
# Patch get_llm inside app.scouts.engine so our fake is used.
|
|
monkeypatch.setattr("app.scouts.engine.get_llm", lambda **kwargs: fake_llm)
|
|
# Disable Langfuse for this test.
|
|
monkeypatch.setattr("app.scouts.engine.get_langfuse", lambda: None)
|
|
# Use fallback prompt (no Langfuse) — patch get_prompt_or_fallback to return fallback.
|
|
monkeypatch.setattr(
|
|
"app.scouts.engine.get_prompt_or_fallback",
|
|
lambda name, fallback: (fallback, None),
|
|
)
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
verdict = await engine._triage_llm(scout, content)
|
|
|
|
assert verdict.verdict == "relevant"
|
|
assert verdict.reason == "invoice due"
|
|
assert abs(verdict.confidence - 0.92) < 1e-6
|
|
fake_llm.ainvoke.assert_awaited_once()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_deliver_pending_sends_one_frame_per_queued_row(monkeypatch):
|
|
user_id = "00000000-0000-0000-0000-000000000003"
|
|
scout_id = str(uuid.uuid4())
|
|
now = datetime.now(tz=timezone.utc)
|
|
|
|
async with _TestSessionLocal() as session:
|
|
session.add(CloudScoutConfig(
|
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
|
))
|
|
for i in range(3):
|
|
session.add(ScoutTriageQueue(
|
|
id=str(uuid.uuid4()), user_id=user_id, scout_id=scout_id,
|
|
source_type="gmail", source_msg_ref=f"msg-{i}",
|
|
triage_verdict="relevant", status="queued",
|
|
triaged_at=now, expires_at=now + timedelta(days=30),
|
|
))
|
|
await session.commit()
|
|
|
|
connector = AsyncMock()
|
|
connector.source_type = "gmail"
|
|
connector.fetch_metadata = AsyncMock(side_effect=lambda scout, ref: ItemMetadata(
|
|
subject=f"sub-{ref.source_msg_ref}", snippet=f"snip-{ref.source_msg_ref}",
|
|
))
|
|
register_connector(connector)
|
|
|
|
sent = []
|
|
ws = AsyncMock()
|
|
ws.send_json = AsyncMock(side_effect=lambda payload: sent.append(payload))
|
|
|
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
|
await engine.deliver_pending(uuid.UUID(user_id), ws)
|
|
|
|
assert len(sent) == 3
|
|
assert all(s["type"] == "scout_proposal" for s in sent)
|
|
subjects = {s["proposal"]["raw_subject"] for s in sent}
|
|
assert subjects == {"sub-msg-0", "sub-msg-1", "sub-msg-2"}
|
|
async with _TestSessionLocal() as session:
|
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
|
assert all(r.status == "delivered" for r in rows)
|
|
assert all(r.delivered_at is not None for r in rows)
|