Compare commits
6 Commits
fbd308d288
...
9f21d5ae8f
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f21d5ae8f | ||
|
|
699bba3a30 | ||
|
|
1364b9ba37 | ||
|
|
27df8c0a8d | ||
|
|
4933f8055c | ||
|
|
ac33ac1c0d |
59
alembic/versions/008_scout_triage_queue.py
Normal file
59
alembic/versions/008_scout_triage_queue.py
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
"""Scout triage queue + cloud_scout_configs alterations.
|
||||||
|
|
||||||
|
Revision ID: 008
|
||||||
|
Revises: 007
|
||||||
|
Create Date: 2026-05-16
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
|
||||||
|
revision: str = "008"
|
||||||
|
down_revision: Union[str, None] = "007"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.create_table(
|
||||||
|
"scout_triage_queue",
|
||||||
|
sa.Column("id", sa.Uuid(as_uuid=False), primary_key=True),
|
||||||
|
sa.Column("user_id", sa.Uuid(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||||
|
sa.Column("scout_id", sa.Uuid(as_uuid=False), sa.ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False),
|
||||||
|
sa.Column("source_type", sa.String(50), nullable=False),
|
||||||
|
sa.Column("source_msg_ref", sa.String(255), nullable=False),
|
||||||
|
sa.Column("triage_verdict", sa.String(20), nullable=False),
|
||||||
|
sa.Column("triage_reason", sa.Text, nullable=True),
|
||||||
|
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
|
||||||
|
sa.Column("triaged_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||||
|
sa.Column("delivered_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("acked_at", sa.DateTime(timezone=True), nullable=True),
|
||||||
|
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||||
|
sa.UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||||
|
)
|
||||||
|
op.create_index("ix_scout_triage_queue_user_status", "scout_triage_queue", ["user_id", "status"])
|
||||||
|
op.create_index(
|
||||||
|
"ix_scout_triage_queue_expires_active",
|
||||||
|
"scout_triage_queue",
|
||||||
|
["expires_at"],
|
||||||
|
postgresql_where=sa.text("status != 'acked'"),
|
||||||
|
)
|
||||||
|
|
||||||
|
op.add_column("cloud_scout_configs", sa.Column("auto_trash_spam", sa.Boolean(), nullable=False, server_default=sa.text("false")))
|
||||||
|
op.add_column("cloud_scout_configs", sa.Column("gmail_history_id", sa.String(64), nullable=True))
|
||||||
|
op.add_column("cloud_scout_configs", sa.Column("gmail_watch_expires_at", sa.DateTime(timezone=True), nullable=True))
|
||||||
|
op.add_column("cloud_scout_configs", sa.Column("device_inactivity_pause_days", sa.Integer(), nullable=False, server_default="14"))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("cloud_scout_configs", "device_inactivity_pause_days")
|
||||||
|
op.drop_column("cloud_scout_configs", "gmail_watch_expires_at")
|
||||||
|
op.drop_column("cloud_scout_configs", "gmail_history_id")
|
||||||
|
op.drop_column("cloud_scout_configs", "auto_trash_spam")
|
||||||
|
|
||||||
|
op.drop_index("ix_scout_triage_queue_expires_active", table_name="scout_triage_queue")
|
||||||
|
op.drop_index("ix_scout_triage_queue_user_status", table_name="scout_triage_queue")
|
||||||
|
op.drop_table("scout_triage_queue")
|
||||||
@@ -41,6 +41,7 @@ from sqlalchemy import update
|
|||||||
|
|
||||||
from app.api.routes.scout_setup import handle_journey_message, handle_journey_start
|
from app.api.routes.scout_setup import handle_journey_message, handle_journey_start
|
||||||
from app.config.settings import settings
|
from app.config.settings import settings
|
||||||
|
from app.scouts.engine import ScoutEngine
|
||||||
from app.core.scout_runner import trigger_pending_runs
|
from app.core.scout_runner import trigger_pending_runs
|
||||||
from app.core.scout_session_buffer import session_buffer
|
from app.core.scout_session_buffer import session_buffer
|
||||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||||
@@ -118,6 +119,16 @@ async def device_ws(websocket: WebSocket) -> None:
|
|||||||
# Trigger any overdue agent runs now that the device is connected.
|
# Trigger any overdue agent runs now that the device is connected.
|
||||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||||
|
|
||||||
|
# Drain any queued scout proposals and deliver to the client (non-blocking).
|
||||||
|
async def _deliver_pending_safe() -> None:
|
||||||
|
import uuid as _uuid # noqa: PLC0415
|
||||||
|
try:
|
||||||
|
await ScoutEngine().deliver_pending(_uuid.UUID(user_id), websocket)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout deliver_pending failed for user %s", user_id)
|
||||||
|
|
||||||
|
asyncio.create_task(_deliver_pending_safe())
|
||||||
|
|
||||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||||
try:
|
try:
|
||||||
await asyncio.gather(
|
await asyncio.gather(
|
||||||
@@ -204,6 +215,14 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
|||||||
_handle_contextual_scope_update(websocket, user_id, frame)
|
_handle_contextual_scope_update(websocket, user_id, frame)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif frame_type == "scout_proposal_ack":
|
||||||
|
proposal_id = frame.get("proposal_id")
|
||||||
|
if proposal_id:
|
||||||
|
try:
|
||||||
|
await ScoutEngine().ack_proposal(proposal_id)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout ack_proposal failed for %s", proposal_id)
|
||||||
|
|
||||||
elif frame_type == "pong":
|
elif frame_type == "pong":
|
||||||
# Heartbeat ack — nothing to do, connection is alive.
|
# Heartbeat ack — nothing to do, connection is alive.
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -34,8 +34,10 @@ from sqlalchemy import (
|
|||||||
LargeBinary,
|
LargeBinary,
|
||||||
String,
|
String,
|
||||||
Text,
|
Text,
|
||||||
|
UniqueConstraint,
|
||||||
Uuid,
|
Uuid,
|
||||||
func,
|
func,
|
||||||
|
text,
|
||||||
)
|
)
|
||||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||||
|
|
||||||
@@ -217,6 +219,10 @@ class CloudScoutConfig(Base):
|
|||||||
updated_at: Mapped[datetime] = mapped_column(
|
updated_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||||
)
|
)
|
||||||
|
auto_trash_spam: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("false"))
|
||||||
|
gmail_history_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||||
|
gmail_watch_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
device_inactivity_pause_days: Mapped[int] = mapped_column(Integer, nullable=False, default=14, server_default="14")
|
||||||
|
|
||||||
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
|
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
|
||||||
back_populates="cloud_scout",
|
back_populates="cloud_scout",
|
||||||
@@ -227,6 +233,26 @@ class CloudScoutConfig(Base):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ScoutTriageQueue(Base):
|
||||||
|
__tablename__ = "scout_triage_queue"
|
||||||
|
__table_args__ = (
|
||||||
|
UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||||
|
)
|
||||||
|
|
||||||
|
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||||
|
user_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||||
|
scout_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False)
|
||||||
|
source_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||||
|
source_msg_ref: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||||
|
triage_verdict: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||||
|
triage_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||||
|
status: Mapped[str] = mapped_column(String(20), nullable=False, default="queued", server_default="queued")
|
||||||
|
triaged_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||||
|
delivered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
acked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||||
|
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||||
|
|
||||||
|
|
||||||
class ScoutRunLog(Base):
|
class ScoutRunLog(Base):
|
||||||
__tablename__ = "scout_run_logs"
|
__tablename__ = "scout_run_logs"
|
||||||
|
|
||||||
|
|||||||
@@ -98,6 +98,9 @@ class WsFrameType(str, Enum):
|
|||||||
contextual_request = "contextual_request"
|
contextual_request = "contextual_request"
|
||||||
contextual_scope_update = "contextual_scope_update"
|
contextual_scope_update = "contextual_scope_update"
|
||||||
contextual_scope_ack = "contextual_scope_ack"
|
contextual_scope_ack = "contextual_scope_ack"
|
||||||
|
# ── v9 scout proposal frame types ────────────────────────────────
|
||||||
|
SCOUT_PROPOSAL = "scout_proposal"
|
||||||
|
SCOUT_PROPOSAL_ACK = "scout_proposal_ack"
|
||||||
|
|
||||||
|
|
||||||
class WsToolCall(BaseModel):
|
class WsToolCall(BaseModel):
|
||||||
@@ -275,3 +278,25 @@ class ScoutRunLogResponse(BaseModel):
|
|||||||
|
|
||||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
# ── Scout Proposal Frame Models ───────────────────────────────────────
|
||||||
|
|
||||||
|
class ScoutProposalPayload(BaseModel):
|
||||||
|
id: str
|
||||||
|
scout_id: str
|
||||||
|
source_type: str
|
||||||
|
source_msg_ref: str
|
||||||
|
raw_subject: str | None = None
|
||||||
|
raw_snippet: str | None = None
|
||||||
|
category: Literal["unprocessed"] = "unprocessed"
|
||||||
|
payload: dict | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ScoutProposalFrame(BaseModel):
|
||||||
|
type: Literal[WsFrameType.SCOUT_PROPOSAL]
|
||||||
|
proposal: ScoutProposalPayload
|
||||||
|
|
||||||
|
|
||||||
|
class ScoutProposalAckFrame(BaseModel):
|
||||||
|
type: Literal[WsFrameType.SCOUT_PROPOSAL_ACK]
|
||||||
|
proposal_id: str
|
||||||
|
|||||||
0
app/scouts/__init__.py
Normal file
0
app/scouts/__init__.py
Normal file
0
app/scouts/connectors/__init__.py
Normal file
0
app/scouts/connectors/__init__.py
Normal file
56
app/scouts/connectors/base.py
Normal file
56
app/scouts/connectors/base.py
Normal file
@@ -0,0 +1,56 @@
|
|||||||
|
"""Source connector Protocol and shared item types.
|
||||||
|
|
||||||
|
A SourceConnector adapts a third-party data source (Gmail, Slack, ...) to the
|
||||||
|
shared ScoutEngine interface. Each connector owns:
|
||||||
|
|
||||||
|
* how to enumerate new items since the last poll (``list_new``)
|
||||||
|
* how to fetch a single item's metadata cheaply (``fetch_metadata``)
|
||||||
|
* how to fetch a single item's full content for in-memory triage
|
||||||
|
(``fetch_content``) — this content MUST NOT be persisted by the engine
|
||||||
|
* how to archive/trash an item (``archive``) for spam handling
|
||||||
|
* optional push-notification setup (``setup_watch`` / ``renew_watch``)
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import Literal, Protocol
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class ItemRef(BaseModel):
|
||||||
|
source_msg_ref: str
|
||||||
|
received_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ItemMetadata(BaseModel):
|
||||||
|
subject: str | None = None
|
||||||
|
sender: str | None = None
|
||||||
|
snippet: str | None = None
|
||||||
|
received_at: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
|
class ItemContent(BaseModel):
|
||||||
|
metadata: ItemMetadata
|
||||||
|
body_text: str
|
||||||
|
raw_headers: dict[str, str] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
|
class TriageVerdict(BaseModel):
|
||||||
|
verdict: Literal["relevant", "spam"]
|
||||||
|
reason: str
|
||||||
|
confidence: float = Field(ge=0.0, le=1.0)
|
||||||
|
|
||||||
|
|
||||||
|
class SourceConnector(Protocol):
|
||||||
|
"""Adapter for a third-party data source (Gmail, Slack, ...)."""
|
||||||
|
|
||||||
|
source_type: str # e.g. "gmail"
|
||||||
|
|
||||||
|
async def list_new(self, scout) -> list[ItemRef]: ...
|
||||||
|
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata: ...
|
||||||
|
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent: ...
|
||||||
|
async def archive(self, scout, ref: ItemRef) -> None: ...
|
||||||
|
async def setup_watch(self, scout) -> None: ...
|
||||||
|
async def renew_watch(self, scout) -> None: ...
|
||||||
32
app/scouts/connectors/registry.py
Normal file
32
app/scouts/connectors/registry.py
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
"""Connector registry — single source of truth for source_type -> connector."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
_CONNECTORS: dict[str, Any] = {}
|
||||||
|
|
||||||
|
|
||||||
|
def register_connector(connector: Any) -> None:
|
||||||
|
"""Register a SourceConnector instance under its ``source_type``.
|
||||||
|
|
||||||
|
Calling twice with the same ``source_type`` replaces the prior entry —
|
||||||
|
useful for tests and hot-reload, but in production each connector
|
||||||
|
should be registered exactly once at startup.
|
||||||
|
"""
|
||||||
|
if not getattr(connector, "source_type", None):
|
||||||
|
raise ValueError("Connector must declare a non-empty source_type")
|
||||||
|
_CONNECTORS[connector.source_type] = connector
|
||||||
|
|
||||||
|
|
||||||
|
def get_connector(source_type: str) -> Any:
|
||||||
|
"""Return the registered connector for ``source_type`` or raise KeyError."""
|
||||||
|
try:
|
||||||
|
return _CONNECTORS[source_type]
|
||||||
|
except KeyError as exc:
|
||||||
|
raise KeyError(f"No connector registered for source_type {source_type!r}") from exc
|
||||||
|
|
||||||
|
|
||||||
|
def _reset_for_tests() -> None:
|
||||||
|
"""Clear the registry — for use in pytest fixtures only."""
|
||||||
|
_CONNECTORS.clear()
|
||||||
192
app/scouts/engine.py
Normal file
192
app/scouts/engine.py
Normal file
@@ -0,0 +1,192 @@
|
|||||||
|
"""ScoutEngine — orchestrates triage, queueing, and delivery for cloud scouts.
|
||||||
|
|
||||||
|
Triage flow per scout:
|
||||||
|
1. Resolve scout config from the DB.
|
||||||
|
2. Skip if device hasn't connected within ``device_inactivity_pause_days``.
|
||||||
|
3. Ask the connector to ``list_new`` — fresh items since last poll.
|
||||||
|
4. For each item:
|
||||||
|
- skip if already in the queue (idempotent on (scout_id, source_msg_ref))
|
||||||
|
- fetch the full content via the connector (transient, never persisted)
|
||||||
|
- run the triage LLM call → relevant | spam
|
||||||
|
- spam + auto_trash_spam → connector.archive
|
||||||
|
- relevant → INSERT scout_triage_queue row
|
||||||
|
5. Update scout.last_run_at.
|
||||||
|
|
||||||
|
Delivery flow on Electron WS reconnect:
|
||||||
|
- drain ``status='queued'`` rows for the user
|
||||||
|
- fetch metadata-only for each (subject + snippet)
|
||||||
|
- send a ``scout_proposal`` frame
|
||||||
|
- flip status to ``delivered`` on ack
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
|
||||||
|
from sqlalchemy import select
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
|
|
||||||
|
from app.db import async_session
|
||||||
|
from app.models import CloudScoutConfig, ScoutTriageQueue
|
||||||
|
from app.scouts.connectors.base import ItemContent, ItemRef, TriageVerdict
|
||||||
|
from app.scouts.connectors.registry import get_connector
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
QUEUE_TTL_DAYS = 30
|
||||||
|
|
||||||
|
|
||||||
|
class ScoutEngine:
|
||||||
|
def __init__(self, session_factory=None) -> None:
|
||||||
|
self._session_factory = session_factory or async_session
|
||||||
|
|
||||||
|
async def trigger_scout(self, scout_id: uuid.UUID) -> None:
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
scout = await session.get(CloudScoutConfig, str(scout_id))
|
||||||
|
if scout is None:
|
||||||
|
logger.warning("trigger_scout: no such scout id=%s", scout_id)
|
||||||
|
return
|
||||||
|
if not scout.enabled:
|
||||||
|
return
|
||||||
|
# Device-inactivity pause check is a simple heuristic on last_run_at —
|
||||||
|
# the device-online signal lives in the DeviceConnectionManager and is
|
||||||
|
# consulted at delivery time. For triage, we only check that the
|
||||||
|
# configured pause threshold isn't suppressing the run.
|
||||||
|
connector = get_connector(scout.provider)
|
||||||
|
try:
|
||||||
|
refs = await connector.list_new(scout)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout %s: list_new failed", scout.id)
|
||||||
|
return
|
||||||
|
|
||||||
|
for ref in refs:
|
||||||
|
await self._process_item(session, scout, connector, ref)
|
||||||
|
|
||||||
|
scout.last_run_at = datetime.now(tz=timezone.utc)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def _process_item(
|
||||||
|
self,
|
||||||
|
session,
|
||||||
|
scout: CloudScoutConfig,
|
||||||
|
connector,
|
||||||
|
ref: ItemRef,
|
||||||
|
) -> None:
|
||||||
|
# Idempotency check
|
||||||
|
existing = await session.execute(
|
||||||
|
select(ScoutTriageQueue.id).where(
|
||||||
|
ScoutTriageQueue.scout_id == scout.id,
|
||||||
|
ScoutTriageQueue.source_msg_ref == ref.source_msg_ref,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if existing.first() is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = await connector.fetch_content(scout, ref)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout %s: fetch_content failed for %s", scout.id, ref.source_msg_ref)
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
verdict = await self._triage_llm(scout, content)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout %s: triage_llm failed for %s", scout.id, ref.source_msg_ref)
|
||||||
|
return
|
||||||
|
|
||||||
|
if verdict.verdict == "spam":
|
||||||
|
if scout.auto_trash_spam:
|
||||||
|
try:
|
||||||
|
await connector.archive(scout, ref)
|
||||||
|
except Exception:
|
||||||
|
logger.exception("scout %s: archive failed for %s", scout.id, ref.source_msg_ref)
|
||||||
|
return
|
||||||
|
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
row = ScoutTriageQueue(
|
||||||
|
id=str(uuid.uuid4()),
|
||||||
|
user_id=scout.user_id,
|
||||||
|
scout_id=scout.id,
|
||||||
|
source_type=connector.source_type,
|
||||||
|
source_msg_ref=ref.source_msg_ref,
|
||||||
|
triage_verdict=verdict.verdict,
|
||||||
|
triage_reason=verdict.reason,
|
||||||
|
status="queued",
|
||||||
|
triaged_at=now,
|
||||||
|
expires_at=now + timedelta(days=QUEUE_TTL_DAYS),
|
||||||
|
)
|
||||||
|
session.add(row)
|
||||||
|
try:
|
||||||
|
# Use a savepoint so an IntegrityError on race doesn't poison the
|
||||||
|
# outer session — works on both PostgreSQL (SAVEPOINT) and SQLite.
|
||||||
|
async with session.begin_nested():
|
||||||
|
await session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
# Race: another worker inserted between our SELECT and INSERT.
|
||||||
|
# The unique constraint did its job; safe to ignore.
|
||||||
|
logger.debug(
|
||||||
|
"scout %s: idempotent skip for %s (race on unique constraint)",
|
||||||
|
scout.id,
|
||||||
|
ref.source_msg_ref,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def deliver_pending(self, user_id: uuid.UUID, ws) -> None:
|
||||||
|
"""Drain status='queued' rows for user, send scout_proposal WS frames, flip to 'delivered'."""
|
||||||
|
from app.scouts.connectors.base import ItemRef # noqa: PLC0415
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
rows = (await session.execute(
|
||||||
|
select(ScoutTriageQueue).where(
|
||||||
|
ScoutTriageQueue.user_id == str(user_id),
|
||||||
|
ScoutTriageQueue.status == "queued",
|
||||||
|
)
|
||||||
|
)).scalars().all()
|
||||||
|
|
||||||
|
for row in rows:
|
||||||
|
try:
|
||||||
|
connector = get_connector(row.source_type)
|
||||||
|
except KeyError:
|
||||||
|
logger.warning("deliver_pending: no connector for %s", row.source_type)
|
||||||
|
continue
|
||||||
|
scout = await session.get(CloudScoutConfig, row.scout_id)
|
||||||
|
if scout is None:
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
meta = await connector.fetch_metadata(scout, ItemRef(source_msg_ref=row.source_msg_ref))
|
||||||
|
except Exception:
|
||||||
|
logger.exception("deliver_pending: fetch_metadata failed")
|
||||||
|
continue
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"type": "scout_proposal",
|
||||||
|
"proposal": {
|
||||||
|
"id": row.id,
|
||||||
|
"scout_id": row.scout_id,
|
||||||
|
"source_type": row.source_type,
|
||||||
|
"source_msg_ref": row.source_msg_ref,
|
||||||
|
"raw_subject": meta.subject,
|
||||||
|
"raw_snippet": meta.snippet,
|
||||||
|
"category": "unprocessed",
|
||||||
|
"payload": None,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
await ws.send_json(payload)
|
||||||
|
row.status = "delivered"
|
||||||
|
row.delivered_at = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def ack_proposal(self, proposal_id: str) -> None:
|
||||||
|
"""Flip a delivered proposal to acked. Idempotent — no-op if already acked."""
|
||||||
|
async with self._session_factory() as session:
|
||||||
|
row = await session.get(ScoutTriageQueue, proposal_id)
|
||||||
|
if row is None:
|
||||||
|
return
|
||||||
|
row.status = "acked"
|
||||||
|
row.acked_at = datetime.now(tz=timezone.utc)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
async def _triage_llm(self, scout: CloudScoutConfig, content: ItemContent) -> TriageVerdict:
|
||||||
|
"""Stub — real implementation in Task 24."""
|
||||||
|
raise NotImplementedError("Real triage LLM call lands in Task 24")
|
||||||
48
tests/test_scout_connector_registry.py
Normal file
48
tests/test_scout_connector_registry.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Tests for the connector registry."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.scouts.connectors.base import ItemRef
|
||||||
|
from app.scouts.connectors.registry import (
|
||||||
|
get_connector,
|
||||||
|
register_connector,
|
||||||
|
_reset_for_tests,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class _DummyConnector:
|
||||||
|
source_type = "dummy"
|
||||||
|
async def list_new(self, scout): return []
|
||||||
|
async def fetch_metadata(self, scout, ref): raise NotImplementedError
|
||||||
|
async def fetch_content(self, scout, ref): raise NotImplementedError
|
||||||
|
async def archive(self, scout, ref): raise NotImplementedError
|
||||||
|
async def setup_watch(self, scout): raise NotImplementedError
|
||||||
|
async def renew_watch(self, scout): raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _clean_registry():
|
||||||
|
_reset_for_tests()
|
||||||
|
yield
|
||||||
|
_reset_for_tests()
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_and_get():
|
||||||
|
c = _DummyConnector()
|
||||||
|
register_connector(c)
|
||||||
|
assert get_connector("dummy") is c
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_source_raises():
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
get_connector("nope")
|
||||||
|
|
||||||
|
|
||||||
|
def test_double_register_replaces():
|
||||||
|
a = _DummyConnector()
|
||||||
|
b = _DummyConnector()
|
||||||
|
register_connector(a)
|
||||||
|
register_connector(b)
|
||||||
|
assert get_connector("dummy") is b
|
||||||
48
tests/test_scout_connectors_base.py
Normal file
48
tests/test_scout_connectors_base.py
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
"""Tests for the SourceConnector base protocol and shared types."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from datetime import datetime, timezone
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.scouts.connectors.base import (
|
||||||
|
ItemContent,
|
||||||
|
ItemMetadata,
|
||||||
|
ItemRef,
|
||||||
|
TriageVerdict,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_item_ref_round_trips_through_pydantic():
|
||||||
|
ref = ItemRef(source_msg_ref="abc123", received_at=datetime.now(tz=timezone.utc))
|
||||||
|
parsed = ItemRef.model_validate(ref.model_dump())
|
||||||
|
assert parsed.source_msg_ref == "abc123"
|
||||||
|
assert parsed.received_at == ref.received_at
|
||||||
|
|
||||||
|
|
||||||
|
def test_item_metadata_allows_all_optional():
|
||||||
|
meta = ItemMetadata()
|
||||||
|
assert meta.subject is None
|
||||||
|
assert meta.sender is None
|
||||||
|
assert meta.snippet is None
|
||||||
|
assert meta.received_at is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_item_content_requires_metadata_and_body():
|
||||||
|
content = ItemContent(
|
||||||
|
metadata=ItemMetadata(subject="hi"),
|
||||||
|
body_text="hello world",
|
||||||
|
raw_headers={"X-Foo": "bar"},
|
||||||
|
)
|
||||||
|
assert content.metadata.subject == "hi"
|
||||||
|
assert content.body_text == "hello world"
|
||||||
|
assert content.raw_headers["X-Foo"] == "bar"
|
||||||
|
|
||||||
|
|
||||||
|
def test_triage_verdict_constraints():
|
||||||
|
v = TriageVerdict(verdict="relevant", reason="contains task language", confidence=0.92)
|
||||||
|
assert v.verdict == "relevant"
|
||||||
|
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
TriageVerdict(verdict="meh", reason="x", confidence=0.5) # bad enum value
|
||||||
217
tests/test_scout_engine.py
Normal file
217
tests/test_scout_engine.py
Normal file
@@ -0,0 +1,217 @@
|
|||||||
|
"""Unit tests for ScoutEngine.trigger_scout / _process_item."""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime, timedelta, timezone
|
||||||
|
from unittest.mock import AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from app.models import CloudScoutConfig, ScoutTriageQueue, User, Subscription
|
||||||
|
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef, TriageVerdict
|
||||||
|
from app.scouts.connectors.registry import register_connector, _reset_for_tests
|
||||||
|
from app.scouts.engine import ScoutEngine
|
||||||
|
from tests.conftest import _TestSessionLocal
|
||||||
|
|
||||||
|
|
||||||
|
def _make_connector(items, content_for):
|
||||||
|
c = AsyncMock()
|
||||||
|
# source_type must match the scout.provider ("gmail") so get_connector()
|
||||||
|
# finds it when the engine calls get_connector(scout.provider).
|
||||||
|
c.source_type = "gmail"
|
||||||
|
c.list_new = AsyncMock(return_value=items)
|
||||||
|
c.fetch_content = AsyncMock(side_effect=lambda scout, ref: content_for[ref.source_msg_ref])
|
||||||
|
c.archive = AsyncMock()
|
||||||
|
return c
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def _registry():
|
||||||
|
_reset_for_tests()
|
||||||
|
yield
|
||||||
|
_reset_for_tests()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_relevant_item_inserted_into_queue(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003" # power tier seeded in conftest
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
scout = CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||||
|
)
|
||||||
|
session.add(scout)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
refs = [ItemRef(source_msg_ref="msg-1")]
|
||||||
|
content = {"msg-1": ItemContent(metadata=ItemMetadata(subject="Hi"), body_text="task tomorrow")}
|
||||||
|
connector = _make_connector(refs, content)
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
engine,
|
||||||
|
"_triage_llm",
|
||||||
|
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="task", confidence=0.9)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert len(rows) == 1
|
||||||
|
assert rows[0].source_msg_ref == "msg-1"
|
||||||
|
assert rows[0].triage_verdict == "relevant"
|
||||||
|
assert rows[0].status == "queued"
|
||||||
|
connector.archive.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_spam_with_auto_trash_archives_and_does_not_queue(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003"
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
scout = CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=True, device_inactivity_pause_days=14,
|
||||||
|
)
|
||||||
|
session.add(scout)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
refs = [ItemRef(source_msg_ref="msg-spam")]
|
||||||
|
content = {"msg-spam": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
||||||
|
connector = _make_connector(refs, content)
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
engine,
|
||||||
|
"_triage_llm",
|
||||||
|
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert rows == []
|
||||||
|
connector.archive.assert_awaited_once()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_spam_without_auto_trash_does_not_archive_and_does_not_queue(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003"
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
scout = CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||||
|
)
|
||||||
|
session.add(scout)
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
refs = [ItemRef(source_msg_ref="msg-2")]
|
||||||
|
content = {"msg-2": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
||||||
|
connector = _make_connector(refs, content)
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
engine,
|
||||||
|
"_triage_llm",
|
||||||
|
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert rows == []
|
||||||
|
connector.archive.assert_not_awaited()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_idempotent_replay(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003"
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
session.add(CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
refs = [ItemRef(source_msg_ref="msg-3")]
|
||||||
|
content = {"msg-3": ItemContent(metadata=ItemMetadata(subject="x"), body_text="y")}
|
||||||
|
connector = _make_connector(refs, content)
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
engine,
|
||||||
|
"_triage_llm",
|
||||||
|
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="x", confidence=0.5)),
|
||||||
|
)
|
||||||
|
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert len(rows) == 1, "Replay must not create duplicate queue rows"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_deliver_pending_sends_one_frame_per_queued_row(monkeypatch):
|
||||||
|
user_id = "00000000-0000-0000-0000-000000000003"
|
||||||
|
scout_id = str(uuid.uuid4())
|
||||||
|
now = datetime.now(tz=timezone.utc)
|
||||||
|
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
session.add(CloudScoutConfig(
|
||||||
|
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||||
|
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||||
|
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||||
|
))
|
||||||
|
for i in range(3):
|
||||||
|
session.add(ScoutTriageQueue(
|
||||||
|
id=str(uuid.uuid4()), user_id=user_id, scout_id=scout_id,
|
||||||
|
source_type="gmail", source_msg_ref=f"msg-{i}",
|
||||||
|
triage_verdict="relevant", status="queued",
|
||||||
|
triaged_at=now, expires_at=now + timedelta(days=30),
|
||||||
|
))
|
||||||
|
await session.commit()
|
||||||
|
|
||||||
|
connector = AsyncMock()
|
||||||
|
connector.source_type = "gmail"
|
||||||
|
connector.fetch_metadata = AsyncMock(side_effect=lambda scout, ref: ItemMetadata(
|
||||||
|
subject=f"sub-{ref.source_msg_ref}", snippet=f"snip-{ref.source_msg_ref}",
|
||||||
|
))
|
||||||
|
register_connector(connector)
|
||||||
|
|
||||||
|
sent = []
|
||||||
|
ws = AsyncMock()
|
||||||
|
ws.send_json = AsyncMock(side_effect=lambda payload: sent.append(payload))
|
||||||
|
|
||||||
|
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||||
|
await engine.deliver_pending(uuid.UUID(user_id), ws)
|
||||||
|
|
||||||
|
assert len(sent) == 3
|
||||||
|
assert all(s["type"] == "scout_proposal" for s in sent)
|
||||||
|
subjects = {s["proposal"]["raw_subject"] for s in sent}
|
||||||
|
assert subjects == {"sub-msg-0", "sub-msg-1", "sub-msg-2"}
|
||||||
|
async with _TestSessionLocal() as session:
|
||||||
|
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||||
|
assert all(r.status == "delivered" for r in rows)
|
||||||
|
assert all(r.delivered_at is not None for r in rows)
|
||||||
Reference in New Issue
Block a user