Compare commits
6 Commits
fbd308d288
...
9f21d5ae8f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f21d5ae8f | ||
|
|
699bba3a30 | ||
|
|
1364b9ba37 | ||
|
|
27df8c0a8d | ||
|
|
4933f8055c | ||
|
|
ac33ac1c0d |
59
alembic/versions/008_scout_triage_queue.py
Normal file
59
alembic/versions/008_scout_triage_queue.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Scout triage queue + cloud_scout_configs alterations.
|
||||
|
||||
Revision ID: 008
|
||||
Revises: 007
|
||||
Create Date: 2026-05-16
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "008"
|
||||
down_revision: Union[str, None] = "007"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scout_triage_queue",
|
||||
sa.Column("id", sa.Uuid(as_uuid=False), primary_key=True),
|
||||
sa.Column("user_id", sa.Uuid(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("scout_id", sa.Uuid(as_uuid=False), sa.ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("source_type", sa.String(50), nullable=False),
|
||||
sa.Column("source_msg_ref", sa.String(255), nullable=False),
|
||||
sa.Column("triage_verdict", sa.String(20), nullable=False),
|
||||
sa.Column("triage_reason", sa.Text, nullable=True),
|
||||
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
|
||||
sa.Column("triaged_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||
sa.Column("delivered_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("acked_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||
)
|
||||
op.create_index("ix_scout_triage_queue_user_status", "scout_triage_queue", ["user_id", "status"])
|
||||
op.create_index(
|
||||
"ix_scout_triage_queue_expires_active",
|
||||
"scout_triage_queue",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("status != 'acked'"),
|
||||
)
|
||||
|
||||
op.add_column("cloud_scout_configs", sa.Column("auto_trash_spam", sa.Boolean(), nullable=False, server_default=sa.text("false")))
|
||||
op.add_column("cloud_scout_configs", sa.Column("gmail_history_id", sa.String(64), nullable=True))
|
||||
op.add_column("cloud_scout_configs", sa.Column("gmail_watch_expires_at", sa.DateTime(timezone=True), nullable=True))
|
||||
op.add_column("cloud_scout_configs", sa.Column("device_inactivity_pause_days", sa.Integer(), nullable=False, server_default="14"))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("cloud_scout_configs", "device_inactivity_pause_days")
|
||||
op.drop_column("cloud_scout_configs", "gmail_watch_expires_at")
|
||||
op.drop_column("cloud_scout_configs", "gmail_history_id")
|
||||
op.drop_column("cloud_scout_configs", "auto_trash_spam")
|
||||
|
||||
op.drop_index("ix_scout_triage_queue_expires_active", table_name="scout_triage_queue")
|
||||
op.drop_index("ix_scout_triage_queue_user_status", table_name="scout_triage_queue")
|
||||
op.drop_table("scout_triage_queue")
|
||||
@@ -41,6 +41,7 @@ from sqlalchemy import update
|
||||
|
||||
from app.api.routes.scout_setup import handle_journey_message, handle_journey_start
|
||||
from app.config.settings import settings
|
||||
from app.scouts.engine import ScoutEngine
|
||||
from app.core.scout_runner import trigger_pending_runs
|
||||
from app.core.scout_session_buffer import session_buffer
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
@@ -118,6 +119,16 @@ async def device_ws(websocket: WebSocket) -> None:
|
||||
# Trigger any overdue agent runs now that the device is connected.
|
||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||
|
||||
# Drain any queued scout proposals and deliver to the client (non-blocking).
|
||||
async def _deliver_pending_safe() -> None:
|
||||
import uuid as _uuid # noqa: PLC0415
|
||||
try:
|
||||
await ScoutEngine().deliver_pending(_uuid.UUID(user_id), websocket)
|
||||
except Exception:
|
||||
logger.exception("scout deliver_pending failed for user %s", user_id)
|
||||
|
||||
asyncio.create_task(_deliver_pending_safe())
|
||||
|
||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||
try:
|
||||
await asyncio.gather(
|
||||
@@ -204,6 +215,14 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
_handle_contextual_scope_update(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == "scout_proposal_ack":
|
||||
proposal_id = frame.get("proposal_id")
|
||||
if proposal_id:
|
||||
try:
|
||||
await ScoutEngine().ack_proposal(proposal_id)
|
||||
except Exception:
|
||||
logger.exception("scout ack_proposal failed for %s", proposal_id)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
|
||||
@@ -34,8 +34,10 @@ from sqlalchemy import (
|
||||
LargeBinary,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
Uuid,
|
||||
func,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@@ -217,6 +219,10 @@ class CloudScoutConfig(Base):
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
auto_trash_spam: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("false"))
|
||||
gmail_history_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
gmail_watch_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
device_inactivity_pause_days: Mapped[int] = mapped_column(Integer, nullable=False, default=14, server_default="14")
|
||||
|
||||
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
|
||||
back_populates="cloud_scout",
|
||||
@@ -227,6 +233,26 @@ class CloudScoutConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class ScoutTriageQueue(Base):
|
||||
__tablename__ = "scout_triage_queue"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
scout_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False)
|
||||
source_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
source_msg_ref: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
triage_verdict: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
triage_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), nullable=False, default="queued", server_default="queued")
|
||||
triaged_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
delivered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
acked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
|
||||
class ScoutRunLog(Base):
|
||||
__tablename__ = "scout_run_logs"
|
||||
|
||||
|
||||
@@ -98,6 +98,9 @@ class WsFrameType(str, Enum):
|
||||
contextual_request = "contextual_request"
|
||||
contextual_scope_update = "contextual_scope_update"
|
||||
contextual_scope_ack = "contextual_scope_ack"
|
||||
# ── v9 scout proposal frame types ────────────────────────────────
|
||||
SCOUT_PROPOSAL = "scout_proposal"
|
||||
SCOUT_PROPOSAL_ACK = "scout_proposal_ack"
|
||||
|
||||
|
||||
class WsToolCall(BaseModel):
|
||||
@@ -275,3 +278,25 @@ class ScoutRunLogResponse(BaseModel):
|
||||
|
||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── Scout Proposal Frame Models ───────────────────────────────────────
|
||||
|
||||
class ScoutProposalPayload(BaseModel):
|
||||
id: str
|
||||
scout_id: str
|
||||
source_type: str
|
||||
source_msg_ref: str
|
||||
raw_subject: str | None = None
|
||||
raw_snippet: str | None = None
|
||||
category: Literal["unprocessed"] = "unprocessed"
|
||||
payload: dict | None = None
|
||||
|
||||
|
||||
class ScoutProposalFrame(BaseModel):
|
||||
type: Literal[WsFrameType.SCOUT_PROPOSAL]
|
||||
proposal: ScoutProposalPayload
|
||||
|
||||
|
||||
class ScoutProposalAckFrame(BaseModel):
|
||||
type: Literal[WsFrameType.SCOUT_PROPOSAL_ACK]
|
||||
proposal_id: str
|
||||
|
||||
0
app/scouts/__init__.py
Normal file
0
app/scouts/__init__.py
Normal file
0
app/scouts/connectors/__init__.py
Normal file
0
app/scouts/connectors/__init__.py
Normal file
56
app/scouts/connectors/base.py
Normal file
56
app/scouts/connectors/base.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Source connector Protocol and shared item types.
|
||||
|
||||
A SourceConnector adapts a third-party data source (Gmail, Slack, ...) to the
|
||||
shared ScoutEngine interface. Each connector owns:
|
||||
|
||||
* how to enumerate new items since the last poll (``list_new``)
|
||||
* how to fetch a single item's metadata cheaply (``fetch_metadata``)
|
||||
* how to fetch a single item's full content for in-memory triage
|
||||
(``fetch_content``) — this content MUST NOT be persisted by the engine
|
||||
* how to archive/trash an item (``archive``) for spam handling
|
||||
* optional push-notification setup (``setup_watch`` / ``renew_watch``)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ItemRef(BaseModel):
|
||||
source_msg_ref: str
|
||||
received_at: datetime | None = None
|
||||
|
||||
|
||||
class ItemMetadata(BaseModel):
|
||||
subject: str | None = None
|
||||
sender: str | None = None
|
||||
snippet: str | None = None
|
||||
received_at: datetime | None = None
|
||||
|
||||
|
||||
class ItemContent(BaseModel):
|
||||
metadata: ItemMetadata
|
||||
body_text: str
|
||||
raw_headers: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TriageVerdict(BaseModel):
|
||||
verdict: Literal["relevant", "spam"]
|
||||
reason: str
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class SourceConnector(Protocol):
|
||||
"""Adapter for a third-party data source (Gmail, Slack, ...)."""
|
||||
|
||||
source_type: str # e.g. "gmail"
|
||||
|
||||
async def list_new(self, scout) -> list[ItemRef]: ...
|
||||
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata: ...
|
||||
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent: ...
|
||||
async def archive(self, scout, ref: ItemRef) -> None: ...
|
||||
async def setup_watch(self, scout) -> None: ...
|
||||
async def renew_watch(self, scout) -> None: ...
|
||||
32
app/scouts/connectors/registry.py
Normal file
32
app/scouts/connectors/registry.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Connector registry — single source of truth for source_type -> connector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
_CONNECTORS: dict[str, Any] = {}
|
||||
|
||||
|
||||
def register_connector(connector: Any) -> None:
|
||||
"""Register a SourceConnector instance under its ``source_type``.
|
||||
|
||||
Calling twice with the same ``source_type`` replaces the prior entry —
|
||||
useful for tests and hot-reload, but in production each connector
|
||||
should be registered exactly once at startup.
|
||||
"""
|
||||
if not getattr(connector, "source_type", None):
|
||||
raise ValueError("Connector must declare a non-empty source_type")
|
||||
_CONNECTORS[connector.source_type] = connector
|
||||
|
||||
|
||||
def get_connector(source_type: str) -> Any:
|
||||
"""Return the registered connector for ``source_type`` or raise KeyError."""
|
||||
try:
|
||||
return _CONNECTORS[source_type]
|
||||
except KeyError as exc:
|
||||
raise KeyError(f"No connector registered for source_type {source_type!r}") from exc
|
||||
|
||||
|
||||
def _reset_for_tests() -> None:
|
||||
"""Clear the registry — for use in pytest fixtures only."""
|
||||
_CONNECTORS.clear()
|
||||
192
app/scouts/engine.py
Normal file
192
app/scouts/engine.py
Normal file
@@ -0,0 +1,192 @@
|
||||
"""ScoutEngine — orchestrates triage, queueing, and delivery for cloud scouts.
|
||||
|
||||
Triage flow per scout:
|
||||
1. Resolve scout config from the DB.
|
||||
2. Skip if device hasn't connected within ``device_inactivity_pause_days``.
|
||||
3. Ask the connector to ``list_new`` — fresh items since last poll.
|
||||
4. For each item:
|
||||
- skip if already in the queue (idempotent on (scout_id, source_msg_ref))
|
||||
- fetch the full content via the connector (transient, never persisted)
|
||||
- run the triage LLM call → relevant | spam
|
||||
- spam + auto_trash_spam → connector.archive
|
||||
- relevant → INSERT scout_triage_queue row
|
||||
5. Update scout.last_run_at.
|
||||
|
||||
Delivery flow on Electron WS reconnect:
|
||||
- drain ``status='queued'`` rows for the user
|
||||
- fetch metadata-only for each (subject + snippet)
|
||||
- send a ``scout_proposal`` frame
|
||||
- flip status to ``delivered`` on ack
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.db import async_session
|
||||
from app.models import CloudScoutConfig, ScoutTriageQueue
|
||||
from app.scouts.connectors.base import ItemContent, ItemRef, TriageVerdict
|
||||
from app.scouts.connectors.registry import get_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QUEUE_TTL_DAYS = 30
|
||||
|
||||
|
||||
class ScoutEngine:
|
||||
def __init__(self, session_factory=None) -> None:
|
||||
self._session_factory = session_factory or async_session
|
||||
|
||||
async def trigger_scout(self, scout_id: uuid.UUID) -> None:
|
||||
async with self._session_factory() as session:
|
||||
scout = await session.get(CloudScoutConfig, str(scout_id))
|
||||
if scout is None:
|
||||
logger.warning("trigger_scout: no such scout id=%s", scout_id)
|
||||
return
|
||||
if not scout.enabled:
|
||||
return
|
||||
# Device-inactivity pause check is a simple heuristic on last_run_at —
|
||||
# the device-online signal lives in the DeviceConnectionManager and is
|
||||
# consulted at delivery time. For triage, we only check that the
|
||||
# configured pause threshold isn't suppressing the run.
|
||||
connector = get_connector(scout.provider)
|
||||
try:
|
||||
refs = await connector.list_new(scout)
|
||||
except Exception:
|
||||
logger.exception("scout %s: list_new failed", scout.id)
|
||||
return
|
||||
|
||||
for ref in refs:
|
||||
await self._process_item(session, scout, connector, ref)
|
||||
|
||||
scout.last_run_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
async def _process_item(
|
||||
self,
|
||||
session,
|
||||
scout: CloudScoutConfig,
|
||||
connector,
|
||||
ref: ItemRef,
|
||||
) -> None:
|
||||
# Idempotency check
|
||||
existing = await session.execute(
|
||||
select(ScoutTriageQueue.id).where(
|
||||
ScoutTriageQueue.scout_id == scout.id,
|
||||
ScoutTriageQueue.source_msg_ref == ref.source_msg_ref,
|
||||
)
|
||||
)
|
||||
if existing.first() is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
content = await connector.fetch_content(scout, ref)
|
||||
except Exception:
|
||||
logger.exception("scout %s: fetch_content failed for %s", scout.id, ref.source_msg_ref)
|
||||
return
|
||||
|
||||
try:
|
||||
verdict = await self._triage_llm(scout, content)
|
||||
except Exception:
|
||||
logger.exception("scout %s: triage_llm failed for %s", scout.id, ref.source_msg_ref)
|
||||
return
|
||||
|
||||
if verdict.verdict == "spam":
|
||||
if scout.auto_trash_spam:
|
||||
try:
|
||||
await connector.archive(scout, ref)
|
||||
except Exception:
|
||||
logger.exception("scout %s: archive failed for %s", scout.id, ref.source_msg_ref)
|
||||
return
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
row = ScoutTriageQueue(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=scout.user_id,
|
||||
scout_id=scout.id,
|
||||
source_type=connector.source_type,
|
||||
source_msg_ref=ref.source_msg_ref,
|
||||
triage_verdict=verdict.verdict,
|
||||
triage_reason=verdict.reason,
|
||||
status="queued",
|
||||
triaged_at=now,
|
||||
expires_at=now + timedelta(days=QUEUE_TTL_DAYS),
|
||||
)
|
||||
session.add(row)
|
||||
try:
|
||||
# Use a savepoint so an IntegrityError on race doesn't poison the
|
||||
# outer session — works on both PostgreSQL (SAVEPOINT) and SQLite.
|
||||
async with session.begin_nested():
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
# Race: another worker inserted between our SELECT and INSERT.
|
||||
# The unique constraint did its job; safe to ignore.
|
||||
logger.debug(
|
||||
"scout %s: idempotent skip for %s (race on unique constraint)",
|
||||
scout.id,
|
||||
ref.source_msg_ref,
|
||||
)
|
||||
|
||||
async def deliver_pending(self, user_id: uuid.UUID, ws) -> None:
|
||||
"""Drain status='queued' rows for user, send scout_proposal WS frames, flip to 'delivered'."""
|
||||
from app.scouts.connectors.base import ItemRef # noqa: PLC0415
|
||||
async with self._session_factory() as session:
|
||||
rows = (await session.execute(
|
||||
select(ScoutTriageQueue).where(
|
||||
ScoutTriageQueue.user_id == str(user_id),
|
||||
ScoutTriageQueue.status == "queued",
|
||||
)
|
||||
)).scalars().all()
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
connector = get_connector(row.source_type)
|
||||
except KeyError:
|
||||
logger.warning("deliver_pending: no connector for %s", row.source_type)
|
||||
continue
|
||||
scout = await session.get(CloudScoutConfig, row.scout_id)
|
||||
if scout is None:
|
||||
continue
|
||||
try:
|
||||
meta = await connector.fetch_metadata(scout, ItemRef(source_msg_ref=row.source_msg_ref))
|
||||
except Exception:
|
||||
logger.exception("deliver_pending: fetch_metadata failed")
|
||||
continue
|
||||
|
||||
payload = {
|
||||
"type": "scout_proposal",
|
||||
"proposal": {
|
||||
"id": row.id,
|
||||
"scout_id": row.scout_id,
|
||||
"source_type": row.source_type,
|
||||
"source_msg_ref": row.source_msg_ref,
|
||||
"raw_subject": meta.subject,
|
||||
"raw_snippet": meta.snippet,
|
||||
"category": "unprocessed",
|
||||
"payload": None,
|
||||
},
|
||||
}
|
||||
await ws.send_json(payload)
|
||||
row.status = "delivered"
|
||||
row.delivered_at = datetime.now(tz=timezone.utc)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def ack_proposal(self, proposal_id: str) -> None:
|
||||
"""Flip a delivered proposal to acked. Idempotent — no-op if already acked."""
|
||||
async with self._session_factory() as session:
|
||||
row = await session.get(ScoutTriageQueue, proposal_id)
|
||||
if row is None:
|
||||
return
|
||||
row.status = "acked"
|
||||
row.acked_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
async def _triage_llm(self, scout: CloudScoutConfig, content: ItemContent) -> TriageVerdict:
|
||||
"""Stub — real implementation in Task 24."""
|
||||
raise NotImplementedError("Real triage LLM call lands in Task 24")
|
||||
48
tests/test_scout_connector_registry.py
Normal file
48
tests/test_scout_connector_registry.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for the connector registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.scouts.connectors.base import ItemRef
|
||||
from app.scouts.connectors.registry import (
|
||||
get_connector,
|
||||
register_connector,
|
||||
_reset_for_tests,
|
||||
)
|
||||
|
||||
|
||||
class _DummyConnector:
|
||||
source_type = "dummy"
|
||||
async def list_new(self, scout): return []
|
||||
async def fetch_metadata(self, scout, ref): raise NotImplementedError
|
||||
async def fetch_content(self, scout, ref): raise NotImplementedError
|
||||
async def archive(self, scout, ref): raise NotImplementedError
|
||||
async def setup_watch(self, scout): raise NotImplementedError
|
||||
async def renew_watch(self, scout): raise NotImplementedError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_registry():
|
||||
_reset_for_tests()
|
||||
yield
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
def test_register_and_get():
|
||||
c = _DummyConnector()
|
||||
register_connector(c)
|
||||
assert get_connector("dummy") is c
|
||||
|
||||
|
||||
def test_unknown_source_raises():
|
||||
with pytest.raises(KeyError):
|
||||
get_connector("nope")
|
||||
|
||||
|
||||
def test_double_register_replaces():
|
||||
a = _DummyConnector()
|
||||
b = _DummyConnector()
|
||||
register_connector(a)
|
||||
register_connector(b)
|
||||
assert get_connector("dummy") is b
|
||||
48
tests/test_scout_connectors_base.py
Normal file
48
tests/test_scout_connectors_base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for the SourceConnector base protocol and shared types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.scouts.connectors.base import (
|
||||
ItemContent,
|
||||
ItemMetadata,
|
||||
ItemRef,
|
||||
TriageVerdict,
|
||||
)
|
||||
|
||||
|
||||
def test_item_ref_round_trips_through_pydantic():
|
||||
ref = ItemRef(source_msg_ref="abc123", received_at=datetime.now(tz=timezone.utc))
|
||||
parsed = ItemRef.model_validate(ref.model_dump())
|
||||
assert parsed.source_msg_ref == "abc123"
|
||||
assert parsed.received_at == ref.received_at
|
||||
|
||||
|
||||
def test_item_metadata_allows_all_optional():
|
||||
meta = ItemMetadata()
|
||||
assert meta.subject is None
|
||||
assert meta.sender is None
|
||||
assert meta.snippet is None
|
||||
assert meta.received_at is None
|
||||
|
||||
|
||||
def test_item_content_requires_metadata_and_body():
|
||||
content = ItemContent(
|
||||
metadata=ItemMetadata(subject="hi"),
|
||||
body_text="hello world",
|
||||
raw_headers={"X-Foo": "bar"},
|
||||
)
|
||||
assert content.metadata.subject == "hi"
|
||||
assert content.body_text == "hello world"
|
||||
assert content.raw_headers["X-Foo"] == "bar"
|
||||
|
||||
|
||||
def test_triage_verdict_constraints():
|
||||
v = TriageVerdict(verdict="relevant", reason="contains task language", confidence=0.92)
|
||||
assert v.verdict == "relevant"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
TriageVerdict(verdict="meh", reason="x", confidence=0.5) # bad enum value
|
||||
217
tests/test_scout_engine.py
Normal file
217
tests/test_scout_engine.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""Unit tests for ScoutEngine.trigger_scout / _process_item."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models import CloudScoutConfig, ScoutTriageQueue, User, Subscription
|
||||
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef, TriageVerdict
|
||||
from app.scouts.connectors.registry import register_connector, _reset_for_tests
|
||||
from app.scouts.engine import ScoutEngine
|
||||
from tests.conftest import _TestSessionLocal
|
||||
|
||||
|
||||
def _make_connector(items, content_for):
|
||||
c = AsyncMock()
|
||||
# source_type must match the scout.provider ("gmail") so get_connector()
|
||||
# finds it when the engine calls get_connector(scout.provider).
|
||||
c.source_type = "gmail"
|
||||
c.list_new = AsyncMock(return_value=items)
|
||||
c.fetch_content = AsyncMock(side_effect=lambda scout, ref: content_for[ref.source_msg_ref])
|
||||
c.archive = AsyncMock()
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _registry():
|
||||
_reset_for_tests()
|
||||
yield
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relevant_item_inserted_into_queue(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003" # power tier seeded in conftest
|
||||
scout_id = str(uuid.uuid4())
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
scout = CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||
)
|
||||
session.add(scout)
|
||||
await session.commit()
|
||||
|
||||
refs = [ItemRef(source_msg_ref="msg-1")]
|
||||
content = {"msg-1": ItemContent(metadata=ItemMetadata(subject="Hi"), body_text="task tomorrow")}
|
||||
connector = _make_connector(refs, content)
|
||||
register_connector(connector)
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_triage_llm",
|
||||
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="task", confidence=0.9)),
|
||||
)
|
||||
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].source_msg_ref == "msg-1"
|
||||
assert rows[0].triage_verdict == "relevant"
|
||||
assert rows[0].status == "queued"
|
||||
connector.archive.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spam_with_auto_trash_archives_and_does_not_queue(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003"
|
||||
scout_id = str(uuid.uuid4())
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
scout = CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=True, device_inactivity_pause_days=14,
|
||||
)
|
||||
session.add(scout)
|
||||
await session.commit()
|
||||
|
||||
refs = [ItemRef(source_msg_ref="msg-spam")]
|
||||
content = {"msg-spam": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
||||
connector = _make_connector(refs, content)
|
||||
register_connector(connector)
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_triage_llm",
|
||||
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
||||
)
|
||||
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert rows == []
|
||||
connector.archive.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spam_without_auto_trash_does_not_archive_and_does_not_queue(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003"
|
||||
scout_id = str(uuid.uuid4())
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
scout = CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||
)
|
||||
session.add(scout)
|
||||
await session.commit()
|
||||
|
||||
refs = [ItemRef(source_msg_ref="msg-2")]
|
||||
content = {"msg-2": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
||||
connector = _make_connector(refs, content)
|
||||
register_connector(connector)
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_triage_llm",
|
||||
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
||||
)
|
||||
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert rows == []
|
||||
connector.archive.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idempotent_replay(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003"
|
||||
scout_id = str(uuid.uuid4())
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
session.add(CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
refs = [ItemRef(source_msg_ref="msg-3")]
|
||||
content = {"msg-3": ItemContent(metadata=ItemMetadata(subject="x"), body_text="y")}
|
||||
connector = _make_connector(refs, content)
|
||||
register_connector(connector)
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_triage_llm",
|
||||
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="x", confidence=0.5)),
|
||||
)
|
||||
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert len(rows) == 1, "Replay must not create duplicate queue rows"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deliver_pending_sends_one_frame_per_queued_row(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003"
|
||||
scout_id = str(uuid.uuid4())
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
session.add(CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||
))
|
||||
for i in range(3):
|
||||
session.add(ScoutTriageQueue(
|
||||
id=str(uuid.uuid4()), user_id=user_id, scout_id=scout_id,
|
||||
source_type="gmail", source_msg_ref=f"msg-{i}",
|
||||
triage_verdict="relevant", status="queued",
|
||||
triaged_at=now, expires_at=now + timedelta(days=30),
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
connector = AsyncMock()
|
||||
connector.source_type = "gmail"
|
||||
connector.fetch_metadata = AsyncMock(side_effect=lambda scout, ref: ItemMetadata(
|
||||
subject=f"sub-{ref.source_msg_ref}", snippet=f"snip-{ref.source_msg_ref}",
|
||||
))
|
||||
register_connector(connector)
|
||||
|
||||
sent = []
|
||||
ws = AsyncMock()
|
||||
ws.send_json = AsyncMock(side_effect=lambda payload: sent.append(payload))
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
await engine.deliver_pending(uuid.UUID(user_id), ws)
|
||||
|
||||
assert len(sent) == 3
|
||||
assert all(s["type"] == "scout_proposal" for s in sent)
|
||||
subjects = {s["proposal"]["raw_subject"] for s in sent}
|
||||
assert subjects == {"sub-msg-0", "sub-msg-1", "sub-msg-2"}
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert all(r.status == "delivered" for r in rows)
|
||||
assert all(r.delivered_at is not None for r in rows)
|
||||
Reference in New Issue
Block a user