Compare commits
13 Commits
fbd308d288
...
develop
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0833db239c | ||
|
|
11b31e5814 | ||
|
|
cb274c9728 | ||
|
|
d3497a1908 | ||
|
|
0c0299808c | ||
|
|
d1016fd65a | ||
|
|
c559754532 | ||
|
|
9f21d5ae8f | ||
|
|
699bba3a30 | ||
|
|
1364b9ba37 | ||
|
|
27df8c0a8d | ||
|
|
4933f8055c | ||
|
|
ac33ac1c0d |
59
alembic/versions/008_scout_triage_queue.py
Normal file
59
alembic/versions/008_scout_triage_queue.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""Scout triage queue + cloud_scout_configs alterations.
|
||||
|
||||
Revision ID: 008
|
||||
Revises: 007
|
||||
Create Date: 2026-05-16
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
|
||||
revision: str = "008"
|
||||
down_revision: Union[str, None] = "007"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"scout_triage_queue",
|
||||
sa.Column("id", sa.Uuid(as_uuid=False), primary_key=True),
|
||||
sa.Column("user_id", sa.Uuid(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
|
||||
sa.Column("scout_id", sa.Uuid(as_uuid=False), sa.ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("source_type", sa.String(50), nullable=False),
|
||||
sa.Column("source_msg_ref", sa.String(255), nullable=False),
|
||||
sa.Column("triage_verdict", sa.String(20), nullable=False),
|
||||
sa.Column("triage_reason", sa.Text, nullable=True),
|
||||
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
|
||||
sa.Column("triaged_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
|
||||
sa.Column("delivered_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("acked_at", sa.DateTime(timezone=True), nullable=True),
|
||||
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
|
||||
sa.UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||
)
|
||||
op.create_index("ix_scout_triage_queue_user_status", "scout_triage_queue", ["user_id", "status"])
|
||||
op.create_index(
|
||||
"ix_scout_triage_queue_expires_active",
|
||||
"scout_triage_queue",
|
||||
["expires_at"],
|
||||
postgresql_where=sa.text("status != 'acked'"),
|
||||
)
|
||||
|
||||
op.add_column("cloud_scout_configs", sa.Column("auto_trash_spam", sa.Boolean(), nullable=False, server_default=sa.text("false")))
|
||||
op.add_column("cloud_scout_configs", sa.Column("gmail_history_id", sa.String(64), nullable=True))
|
||||
op.add_column("cloud_scout_configs", sa.Column("gmail_watch_expires_at", sa.DateTime(timezone=True), nullable=True))
|
||||
op.add_column("cloud_scout_configs", sa.Column("device_inactivity_pause_days", sa.Integer(), nullable=False, server_default="14"))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("cloud_scout_configs", "device_inactivity_pause_days")
|
||||
op.drop_column("cloud_scout_configs", "gmail_watch_expires_at")
|
||||
op.drop_column("cloud_scout_configs", "gmail_history_id")
|
||||
op.drop_column("cloud_scout_configs", "auto_trash_spam")
|
||||
|
||||
op.drop_index("ix_scout_triage_queue_expires_active", table_name="scout_triage_queue")
|
||||
op.drop_index("ix_scout_triage_queue_user_status", table_name="scout_triage_queue")
|
||||
op.drop_table("scout_triage_queue")
|
||||
@@ -41,6 +41,7 @@ from sqlalchemy import update
|
||||
|
||||
from app.api.routes.scout_setup import handle_journey_message, handle_journey_start
|
||||
from app.config.settings import settings
|
||||
from app.scouts.engine import ScoutEngine
|
||||
from app.core.scout_runner import trigger_pending_runs
|
||||
from app.core.scout_session_buffer import session_buffer
|
||||
from app.core.brief_agent import run_home_brief, run_project_brief
|
||||
@@ -118,6 +119,16 @@ async def device_ws(websocket: WebSocket) -> None:
|
||||
# Trigger any overdue agent runs now that the device is connected.
|
||||
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
|
||||
|
||||
# Drain any queued scout proposals and deliver to the client (non-blocking).
|
||||
async def _deliver_pending_safe() -> None:
|
||||
import uuid as _uuid # noqa: PLC0415
|
||||
try:
|
||||
await ScoutEngine().deliver_pending(_uuid.UUID(user_id), websocket)
|
||||
except Exception:
|
||||
logger.exception("scout deliver_pending failed for user %s", user_id)
|
||||
|
||||
asyncio.create_task(_deliver_pending_safe())
|
||||
|
||||
# ── 4. Concurrent message loop + heartbeat ────────────────────────
|
||||
try:
|
||||
await asyncio.gather(
|
||||
@@ -204,6 +215,14 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
|
||||
_handle_contextual_scope_update(websocket, user_id, frame)
|
||||
)
|
||||
|
||||
elif frame_type == "scout_proposal_ack":
|
||||
proposal_id = frame.get("proposal_id")
|
||||
if proposal_id:
|
||||
try:
|
||||
await ScoutEngine().ack_proposal(proposal_id)
|
||||
except Exception:
|
||||
logger.exception("scout ack_proposal failed for %s", proposal_id)
|
||||
|
||||
elif frame_type == "pong":
|
||||
# Heartbeat ack — nothing to do, connection is alive.
|
||||
pass
|
||||
|
||||
120
app/api/routes/scout_webhooks.py
Normal file
120
app/api/routes/scout_webhooks.py
Normal file
@@ -0,0 +1,120 @@
|
||||
"""Gmail Pub/Sub push receiver.
|
||||
|
||||
Google Pub/Sub push subscriptions deliver Gmail watch notifications as POST
|
||||
requests with a JSON envelope. The body payload contains a base64-encoded
|
||||
JSON blob with ``emailAddress`` + ``historyId``. We resolve the user by
|
||||
email, look up their cloud_scout_configs row for provider='gmail', and
|
||||
hand off to ScoutEngine.trigger_scout.
|
||||
|
||||
Authentication: Pub/Sub push includes an OIDC JWT in the Authorization
|
||||
header. We verify it against Google's public keys with the audience
|
||||
configured in our Pub/Sub subscription.
|
||||
|
||||
Dev mode: when ``GMAIL_PUBSUB_AUDIENCE`` is empty, JWT verification is
|
||||
skipped and a warning is logged. Production must set this env var.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
|
||||
from fastapi import APIRouter, Header, HTTPException, Request, status
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.db import async_session
|
||||
from app.models import CloudScoutConfig, User
|
||||
from app.scouts.engine import ScoutEngine
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix="/scouts/webhooks", tags=["scout-webhooks"])
|
||||
|
||||
|
||||
def _verify_pubsub_jwt(token: str) -> bool:
|
||||
"""Verify the Google Pub/Sub OIDC JWT.
|
||||
|
||||
Returns True when valid, False on any verification failure.
|
||||
|
||||
Dev skip: if ``settings.GMAIL_PUBSUB_AUDIENCE`` is empty, logs a
|
||||
warning and returns True so local development works without a real
|
||||
Pub/Sub subscription. Production must configure the audience.
|
||||
"""
|
||||
if not token:
|
||||
return False
|
||||
|
||||
if not settings.GMAIL_PUBSUB_AUDIENCE:
|
||||
logger.warning(
|
||||
"GMAIL_PUBSUB_AUDIENCE not set — skipping Pub/Sub JWT verification (dev mode only)"
|
||||
)
|
||||
return True
|
||||
|
||||
try:
|
||||
from google.auth.transport import requests as g_requests # noqa: PLC0415
|
||||
from google.oauth2 import id_token # noqa: PLC0415
|
||||
|
||||
id_token.verify_oauth2_token(
|
||||
token,
|
||||
g_requests.Request(),
|
||||
audience=settings.GMAIL_PUBSUB_AUDIENCE,
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.warning("pubsub jwt verification failed", exc_info=True)
|
||||
return False
|
||||
|
||||
|
||||
@router.post("/gmail", status_code=status.HTTP_204_NO_CONTENT)
|
||||
async def gmail_pubsub(
|
||||
request: Request,
|
||||
authorization: str = Header(default=""),
|
||||
) -> None:
|
||||
"""Receive a Gmail Pub/Sub push notification.
|
||||
|
||||
Verifies the OIDC JWT, decodes the Pub/Sub envelope, resolves the user
|
||||
by email, and triggers ScoutEngine.trigger_scout for each enabled Gmail
|
||||
scout belonging to that user.
|
||||
|
||||
Returns 204 No Content on success (including benign no-ops like unknown
|
||||
email or empty message data). Returns 401 on JWT verification failure.
|
||||
"""
|
||||
token = authorization.removeprefix("Bearer ").strip()
|
||||
if not _verify_pubsub_jwt(token):
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid Pub/Sub JWT")
|
||||
|
||||
body = await request.json()
|
||||
msg = body.get("message") or {}
|
||||
raw = msg.get("data")
|
||||
if not raw:
|
||||
return # ack without action — empty message data
|
||||
|
||||
try:
|
||||
decoded = json.loads(base64.b64decode(raw).decode())
|
||||
except Exception:
|
||||
logger.warning("pubsub payload decode failed")
|
||||
return
|
||||
|
||||
email = decoded.get("emailAddress")
|
||||
if not email:
|
||||
return
|
||||
|
||||
async with async_session() as session:
|
||||
user_q = await session.execute(select(User).where(User.email == email))
|
||||
user = user_q.scalar_one_or_none()
|
||||
if user is None:
|
||||
logger.info("pubsub: no user for %s — ignoring", email)
|
||||
return
|
||||
scouts_q = await session.execute(
|
||||
select(CloudScoutConfig).where(
|
||||
CloudScoutConfig.user_id == user.id,
|
||||
CloudScoutConfig.provider == "gmail",
|
||||
CloudScoutConfig.enabled == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
scouts = scouts_q.scalars().all()
|
||||
|
||||
engine = ScoutEngine()
|
||||
for scout in scouts:
|
||||
await engine.trigger_scout(uuid.UUID(str(scout.id)))
|
||||
@@ -7,28 +7,40 @@ Backend responsibilities are intentionally minimal:
|
||||
|
||||
Scout configuration is owned by the Electron app and is not persisted
|
||||
in backend scout-config tables.
|
||||
|
||||
Gmail OAuth setup (scout-specific consent):
|
||||
GET /scouts/oauth/gmail/authorize — returns consent-screen URL
|
||||
GET /scouts/oauth/gmail/web-callback — bounces to deep link (excluded from schema)
|
||||
POST /scouts/oauth/gmail/callback — exchanges code, stores encrypted token
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import secrets
|
||||
import time
|
||||
import urllib.parse
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from fastapi.responses import RedirectResponse
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.api.deps import get_current_user
|
||||
from app.auth.oauth_providers import generate_pkce_pair
|
||||
from app.billing.tier_manager import FEATURES
|
||||
from app.config.settings import settings
|
||||
from app.core.scout_runner import is_agent_running, run_local_agent
|
||||
from app.core.device_manager import device_manager
|
||||
from app.core.note_summarizer import generate_note_summary
|
||||
from app.db import get_session
|
||||
from app.models import ScoutRunLog, LocalScoutConfig
|
||||
from app.integrations import encrypt_token
|
||||
from app.models import CloudScoutConfig, ScoutRunLog, LocalScoutConfig
|
||||
from app.schemas import (
|
||||
ScoutCatalogItem,
|
||||
ScoutCreationCheckRequest,
|
||||
@@ -255,3 +267,174 @@ async def summarize_note(
|
||||
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
|
||||
summary = await generate_note_summary(body.title, body.content)
|
||||
return NoteSummarizeResponse(summary=summary)
|
||||
|
||||
|
||||
# ── Gmail OAuth setup (scout-specific) ───────────────────────────────────────
|
||||
|
||||
# Scopes required for Gmail scout connectivity.
|
||||
_GMAIL_SCOUT_SCOPES = [
|
||||
"openid",
|
||||
"email",
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
]
|
||||
|
||||
# Google OAuth endpoints.
|
||||
_GOOGLE_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
|
||||
_GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
|
||||
# In-memory pending OAuth states for scout Gmail consent:
|
||||
# state → (code_verifier, scout_id, user_id, expires_at_epoch_s)
|
||||
# Production note: replace with Redis for multi-process deployments.
|
||||
_pending_scout_oauth_states: dict[str, tuple[str, str, str, float]] = {}
|
||||
_SCOUT_OAUTH_TTL_SECONDS = 600 # 10 minutes
|
||||
|
||||
|
||||
def _scout_gmail_redirect_uri() -> str:
|
||||
"""Derive the scout Gmail web-callback URI from the configured base OAUTH_REDIRECT_URI.
|
||||
|
||||
``OAUTH_REDIRECT_URI`` is the full path used for login OAuth
|
||||
(e.g. http://localhost:8000/api/v1/auth/oauth/google/web-callback).
|
||||
We strip the path to get the scheme+host base, then append the scout path.
|
||||
"""
|
||||
parsed = urllib.parse.urlparse(settings.OAUTH_REDIRECT_URI)
|
||||
base = f"{parsed.scheme}://{parsed.netloc}"
|
||||
return f"{base}/api/v1/scouts/oauth/gmail/web-callback"
|
||||
|
||||
|
||||
class _ScoutGmailAuthorizeResponse(BaseModel):
|
||||
authorize_url: str
|
||||
|
||||
|
||||
class _ScoutGmailCallbackBody(BaseModel):
|
||||
code: str
|
||||
state: str
|
||||
|
||||
|
||||
@router.get("/oauth/gmail/authorize", response_model=_ScoutGmailAuthorizeResponse)
|
||||
async def scout_gmail_oauth_authorize(
|
||||
scout_id: str,
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> _ScoutGmailAuthorizeResponse:
|
||||
"""Start the Gmail OAuth flow for a specific cloud scout.
|
||||
|
||||
Returns the Google consent-screen URL. The client opens this URL in the
|
||||
system browser; after consent Google redirects to web-callback which bounces
|
||||
to the ``adiuvai://scout/oauth/gmail/callback`` deep link.
|
||||
"""
|
||||
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
|
||||
raise HTTPException(
|
||||
status.HTTP_503_SERVICE_UNAVAILABLE,
|
||||
"Google OAuth is not configured on this server",
|
||||
)
|
||||
|
||||
code_verifier, code_challenge = generate_pkce_pair()
|
||||
state = secrets.token_urlsafe(32)
|
||||
|
||||
# Purge expired states to prevent unbounded growth.
|
||||
now = time.time()
|
||||
expired = [s for s, (_, _, _, exp) in _pending_scout_oauth_states.items() if exp < now]
|
||||
for s in expired:
|
||||
del _pending_scout_oauth_states[s]
|
||||
|
||||
_pending_scout_oauth_states[state] = (code_verifier, scout_id, current_user.id, now + _SCOUT_OAUTH_TTL_SECONDS)
|
||||
|
||||
redirect_uri = _scout_gmail_redirect_uri()
|
||||
params = {
|
||||
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
|
||||
"redirect_uri": redirect_uri,
|
||||
"response_type": "code",
|
||||
"scope": " ".join(_GMAIL_SCOUT_SCOPES),
|
||||
"state": state,
|
||||
"code_challenge": code_challenge,
|
||||
"code_challenge_method": "S256",
|
||||
"access_type": "offline",
|
||||
"prompt": "consent",
|
||||
}
|
||||
authorize_url = f"{_GOOGLE_AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
return _ScoutGmailAuthorizeResponse(authorize_url=authorize_url)
|
||||
|
||||
|
||||
@router.get("/oauth/gmail/web-callback", include_in_schema=False)
|
||||
async def scout_gmail_oauth_web_callback(code: str, state: str) -> RedirectResponse:
|
||||
"""Google redirects here after Gmail consent.
|
||||
|
||||
Immediately bounces to the Electron deep link so the desktop app
|
||||
receives the authorization code.
|
||||
"""
|
||||
params = urllib.parse.urlencode({"code": code, "state": state})
|
||||
deep_link = f"adiuvai://scout/oauth/gmail/callback?{params}"
|
||||
return RedirectResponse(url=deep_link, status_code=302)
|
||||
|
||||
|
||||
@router.post("/oauth/gmail/callback")
|
||||
async def scout_gmail_oauth_callback(
|
||||
body: _ScoutGmailCallbackBody,
|
||||
db: AsyncSession = Depends(get_session),
|
||||
current_user: UserProfile = Depends(get_current_user),
|
||||
) -> dict:
|
||||
"""Exchange the Gmail authorization code and store the encrypted token on the scout.
|
||||
|
||||
Called by the Electron app after it receives the deep-link callback with
|
||||
the ``code`` and ``state`` params.
|
||||
"""
|
||||
entry = _pending_scout_oauth_states.pop(body.state, None)
|
||||
if entry is None or entry[3] < time.time() or entry[2] != current_user.id:
|
||||
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
|
||||
|
||||
code_verifier, scout_id, _, _ = entry
|
||||
|
||||
redirect_uri = _scout_gmail_redirect_uri()
|
||||
|
||||
import httpx
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
_GOOGLE_TOKEN_URL,
|
||||
data={
|
||||
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
|
||||
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||
"code": body.code,
|
||||
"code_verifier": code_verifier,
|
||||
"grant_type": "authorization_code",
|
||||
"redirect_uri": redirect_uri,
|
||||
},
|
||||
)
|
||||
try:
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
logger.error("Gmail token exchange failed: %s", exc.response.text)
|
||||
raise HTTPException(status.HTTP_502_BAD_GATEWAY, "Failed to exchange Gmail authorization code")
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
creds_dict: dict = {
|
||||
"token": token_data["access_token"],
|
||||
"refresh_token": token_data.get("refresh_token"),
|
||||
"token_uri": _GOOGLE_TOKEN_URL,
|
||||
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
|
||||
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
|
||||
"scopes": [
|
||||
"https://www.googleapis.com/auth/gmail.readonly",
|
||||
"https://www.googleapis.com/auth/gmail.modify",
|
||||
],
|
||||
}
|
||||
encrypted = encrypt_token(creds_dict)
|
||||
|
||||
scout = await db.get(CloudScoutConfig, scout_id)
|
||||
if scout is None or scout.user_id != current_user.id:
|
||||
raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found")
|
||||
scout.oauth_token_encrypted = encrypted
|
||||
await db.commit()
|
||||
|
||||
# Attempt to set up Gmail push watch so we start receiving Pub/Sub notifications.
|
||||
from app.scouts.connectors.registry import get_connector
|
||||
try:
|
||||
connector = get_connector("gmail")
|
||||
await connector.setup_watch(scout)
|
||||
await db.commit()
|
||||
except KeyError:
|
||||
logger.warning("gmail connector not registered — skipping setup_watch for scout %s", scout_id)
|
||||
except Exception:
|
||||
logger.exception("setup_watch failed for scout %s", scout_id)
|
||||
|
||||
return {"ok": True}
|
||||
|
||||
@@ -58,6 +58,16 @@ class Settings(BaseSettings):
|
||||
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
|
||||
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
|
||||
|
||||
# Gmail Pub/Sub topic for push notifications.
|
||||
# Full resource name, e.g. "projects/my-project/topics/gmail-push".
|
||||
# Leave empty in dev — setup_watch will skip registration gracefully.
|
||||
GMAIL_PUBSUB_TOPIC: str = ""
|
||||
# OIDC token audience for Pub/Sub push subscription JWT verification.
|
||||
# Set to the service account email or audience string configured in the
|
||||
# Pub/Sub push subscription. Leave empty in dev to skip verification
|
||||
# (a warning is logged — never silent in production).
|
||||
GMAIL_PUBSUB_AUDIENCE: str = ""
|
||||
|
||||
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
|
||||
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
|
||||
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()
|
||||
|
||||
101
app/main.py
101
app/main.py
@@ -77,8 +77,98 @@ async def _memory_cron_tick() -> None:
|
||||
_log.warning("memory cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
async def _scout_cron_tick() -> None:
|
||||
"""Every-15-min cron: poll enabled cloud scouts (cron-fallback; push is primary).
|
||||
|
||||
Skips any scout whose ``last_run_at`` is within the last 5 minutes so
|
||||
a push notification and the fallback cron don't double-fire within the
|
||||
same window.
|
||||
"""
|
||||
import logging # noqa: PLC0415
|
||||
import uuid # noqa: PLC0415
|
||||
from datetime import datetime, timezone # noqa: PLC0415
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("scout cron tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.models import CloudScoutConfig # noqa: PLC0415
|
||||
from app.scouts.engine import ScoutEngine # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
async with async_session() as session:
|
||||
scouts = (await session.execute(
|
||||
select(CloudScoutConfig).where(CloudScoutConfig.enabled == True) # noqa: E712
|
||||
)).scalars().all()
|
||||
|
||||
engine = ScoutEngine()
|
||||
triggered = 0
|
||||
for scout in scouts:
|
||||
# Rate-limit guard: push is primary; skip if ran within 5 minutes.
|
||||
if scout.last_run_at:
|
||||
elapsed = (datetime.now(tz=timezone.utc) - scout.last_run_at).total_seconds()
|
||||
if elapsed < 300:
|
||||
continue
|
||||
try:
|
||||
await engine.trigger_scout(uuid.UUID(str(scout.id)))
|
||||
triggered += 1
|
||||
except Exception as exc:
|
||||
_log.warning("scout cron tick: trigger failed scout=%s: %s", scout.id, exc)
|
||||
|
||||
_log.info("scout cron tick: done triggered=%d total=%d", triggered, len(scouts))
|
||||
except Exception as exc:
|
||||
_log.warning("scout cron tick: failed: %s", exc)
|
||||
|
||||
|
||||
async def _scout_watch_renewal_tick() -> None:
|
||||
"""Every-24-hour cron: re-issue Gmail users.watch for scouts expiring within 24h.
|
||||
|
||||
Handles missing or misconfigured connectors gracefully — logs and continues.
|
||||
"""
|
||||
import logging # noqa: PLC0415
|
||||
from datetime import datetime, timedelta, timezone # noqa: PLC0415
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
_log.info("scout watch renewal tick: starting")
|
||||
try:
|
||||
from app.db import async_session # noqa: PLC0415
|
||||
from app.models import CloudScoutConfig # noqa: PLC0415
|
||||
from app.scouts.connectors.registry import get_connector # noqa: PLC0415
|
||||
from sqlalchemy import select # noqa: PLC0415
|
||||
|
||||
threshold = datetime.now(tz=timezone.utc) + timedelta(hours=24)
|
||||
renewed = 0
|
||||
async with async_session() as session:
|
||||
scouts = (await session.execute(
|
||||
select(CloudScoutConfig).where(
|
||||
CloudScoutConfig.enabled == True, # noqa: E712
|
||||
CloudScoutConfig.provider == "gmail",
|
||||
CloudScoutConfig.gmail_watch_expires_at <= threshold,
|
||||
)
|
||||
)).scalars().all()
|
||||
|
||||
for scout in scouts:
|
||||
try:
|
||||
connector = get_connector("gmail")
|
||||
await connector.renew_watch(scout)
|
||||
renewed += 1
|
||||
except Exception:
|
||||
_log.exception("scout watch renewal tick: renew failed scout=%s", scout.id)
|
||||
|
||||
await session.commit()
|
||||
|
||||
_log.info("scout watch renewal tick: done renewed=%d", renewed)
|
||||
except Exception as exc:
|
||||
_log.warning("scout watch renewal tick: failed: %s", exc)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
# Startup: register source connectors.
|
||||
from app.scouts.connectors.gmail import GmailConnector # noqa: PLC0415
|
||||
from app.scouts.connectors.registry import register_connector # noqa: PLC0415
|
||||
register_connector(GmailConnector())
|
||||
|
||||
# Startup: ensure agent tool modules are loaded.
|
||||
import app.agents # noqa: F401
|
||||
|
||||
@@ -89,6 +179,14 @@ async def lifespan(app: FastAPI):
|
||||
scheduler = AsyncIOScheduler()
|
||||
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
|
||||
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
|
||||
scheduler.add_job(
|
||||
_scout_cron_tick, "interval", minutes=15,
|
||||
id="scout_cron_tick", replace_existing=True,
|
||||
)
|
||||
scheduler.add_job(
|
||||
_scout_watch_renewal_tick, "interval", hours=24,
|
||||
id="scout_watch_renewal_tick", replace_existing=True,
|
||||
)
|
||||
scheduler.start()
|
||||
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
|
||||
|
||||
@@ -124,12 +222,13 @@ def create_app() -> FastAPI:
|
||||
app.add_middleware(SanitizerMiddleware)
|
||||
app.add_middleware(TierRateLimitMiddleware)
|
||||
|
||||
from app.api.routes import scouts, auth, billing, chat, device_ws, memory
|
||||
from app.api.routes import scouts, auth, billing, chat, device_ws, memory, scout_webhooks
|
||||
|
||||
app.include_router(auth.router, prefix="/api/v1")
|
||||
app.include_router(chat.router, prefix="/api/v1")
|
||||
app.include_router(billing.router, prefix="/api/v1")
|
||||
app.include_router(scouts.router, prefix="/api/v1")
|
||||
app.include_router(scout_webhooks.router, prefix="/api/v1")
|
||||
app.include_router(device_ws.router, prefix="/api/v1")
|
||||
app.include_router(memory.router, prefix="/api/v1")
|
||||
|
||||
|
||||
@@ -34,8 +34,10 @@ from sqlalchemy import (
|
||||
LargeBinary,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
Uuid,
|
||||
func,
|
||||
text,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
@@ -217,6 +219,10 @@ class CloudScoutConfig(Base):
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
|
||||
)
|
||||
auto_trash_spam: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("false"))
|
||||
gmail_history_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
|
||||
gmail_watch_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
device_inactivity_pause_days: Mapped[int] = mapped_column(Integer, nullable=False, default=14, server_default="14")
|
||||
|
||||
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
|
||||
back_populates="cloud_scout",
|
||||
@@ -227,6 +233,26 @@ class CloudScoutConfig(Base):
|
||||
)
|
||||
|
||||
|
||||
class ScoutTriageQueue(Base):
|
||||
__tablename__ = "scout_triage_queue"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
|
||||
user_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
scout_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False)
|
||||
source_type: Mapped[str] = mapped_column(String(50), nullable=False)
|
||||
source_msg_ref: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
triage_verdict: Mapped[str] = mapped_column(String(20), nullable=False)
|
||||
triage_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
status: Mapped[str] = mapped_column(String(20), nullable=False, default="queued", server_default="queued")
|
||||
triaged_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now())
|
||||
delivered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
acked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
|
||||
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
|
||||
|
||||
class ScoutRunLog(Base):
|
||||
__tablename__ = "scout_run_logs"
|
||||
|
||||
|
||||
@@ -98,6 +98,9 @@ class WsFrameType(str, Enum):
|
||||
contextual_request = "contextual_request"
|
||||
contextual_scope_update = "contextual_scope_update"
|
||||
contextual_scope_ack = "contextual_scope_ack"
|
||||
# ── v9 scout proposal frame types ────────────────────────────────
|
||||
SCOUT_PROPOSAL = "scout_proposal"
|
||||
SCOUT_PROPOSAL_ACK = "scout_proposal_ack"
|
||||
|
||||
|
||||
class WsToolCall(BaseModel):
|
||||
@@ -275,3 +278,25 @@ class ScoutRunLogResponse(BaseModel):
|
||||
|
||||
# ── Chatbot Journey ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
# ── Scout Proposal Frame Models ───────────────────────────────────────
|
||||
|
||||
class ScoutProposalPayload(BaseModel):
|
||||
id: str
|
||||
scout_id: str
|
||||
source_type: str
|
||||
source_msg_ref: str
|
||||
raw_subject: str | None = None
|
||||
raw_snippet: str | None = None
|
||||
category: Literal["unprocessed"] = "unprocessed"
|
||||
payload: dict | None = None
|
||||
|
||||
|
||||
class ScoutProposalFrame(BaseModel):
|
||||
type: Literal[WsFrameType.SCOUT_PROPOSAL]
|
||||
proposal: ScoutProposalPayload
|
||||
|
||||
|
||||
class ScoutProposalAckFrame(BaseModel):
|
||||
type: Literal[WsFrameType.SCOUT_PROPOSAL_ACK]
|
||||
proposal_id: str
|
||||
|
||||
0
app/scouts/__init__.py
Normal file
0
app/scouts/__init__.py
Normal file
0
app/scouts/connectors/__init__.py
Normal file
0
app/scouts/connectors/__init__.py
Normal file
56
app/scouts/connectors/base.py
Normal file
56
app/scouts/connectors/base.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Source connector Protocol and shared item types.
|
||||
|
||||
A SourceConnector adapts a third-party data source (Gmail, Slack, ...) to the
|
||||
shared ScoutEngine interface. Each connector owns:
|
||||
|
||||
* how to enumerate new items since the last poll (``list_new``)
|
||||
* how to fetch a single item's metadata cheaply (``fetch_metadata``)
|
||||
* how to fetch a single item's full content for in-memory triage
|
||||
(``fetch_content``) — this content MUST NOT be persisted by the engine
|
||||
* how to archive/trash an item (``archive``) for spam handling
|
||||
* optional push-notification setup (``setup_watch`` / ``renew_watch``)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Literal, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ItemRef(BaseModel):
|
||||
source_msg_ref: str
|
||||
received_at: datetime | None = None
|
||||
|
||||
|
||||
class ItemMetadata(BaseModel):
|
||||
subject: str | None = None
|
||||
sender: str | None = None
|
||||
snippet: str | None = None
|
||||
received_at: datetime | None = None
|
||||
|
||||
|
||||
class ItemContent(BaseModel):
|
||||
metadata: ItemMetadata
|
||||
body_text: str
|
||||
raw_headers: dict[str, str] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TriageVerdict(BaseModel):
|
||||
verdict: Literal["relevant", "spam"]
|
||||
reason: str
|
||||
confidence: float = Field(ge=0.0, le=1.0)
|
||||
|
||||
|
||||
class SourceConnector(Protocol):
|
||||
"""Adapter for a third-party data source (Gmail, Slack, ...)."""
|
||||
|
||||
source_type: str # e.g. "gmail"
|
||||
|
||||
async def list_new(self, scout) -> list[ItemRef]: ...
|
||||
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata: ...
|
||||
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent: ...
|
||||
async def archive(self, scout, ref: ItemRef) -> None: ...
|
||||
async def setup_watch(self, scout) -> None: ...
|
||||
async def renew_watch(self, scout) -> None: ...
|
||||
213
app/scouts/connectors/gmail.py
Normal file
213
app/scouts/connectors/gmail.py
Normal file
@@ -0,0 +1,213 @@
|
||||
"""Gmail SourceConnector — wraps the existing GmailClient.
|
||||
|
||||
Responsibilities:
|
||||
* list_new: incremental fetch since the scout's stored gmail_history_id
|
||||
* fetch_metadata: subject + sender + snippet only (Gmail metadata format)
|
||||
* fetch_content: full body text — transient, never persisted by engine
|
||||
* archive: move a message to Gmail Trash (recoverable for 30 days)
|
||||
* setup_watch / renew_watch: Gmail push notifications via Pub/Sub
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
|
||||
from app.config.settings import settings
|
||||
from app.integrations import decrypt_token
|
||||
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _extract_plain_text_body(payload: dict) -> str:
|
||||
"""Recursively walk a Gmail message payload to find text/plain content."""
|
||||
import base64
|
||||
mime_type = payload.get("mimeType", "")
|
||||
if mime_type == "text/plain":
|
||||
data = payload.get("body", {}).get("data", "")
|
||||
if data:
|
||||
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return ""
|
||||
if mime_type.startswith("multipart/"):
|
||||
for part in payload.get("parts", []):
|
||||
text = _extract_plain_text_body(part)
|
||||
if text:
|
||||
return text
|
||||
# text/html fallback: strip tags rudimentarily if no text/plain part
|
||||
if mime_type == "text/html":
|
||||
data = payload.get("body", {}).get("data", "")
|
||||
if data:
|
||||
import re
|
||||
html = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
|
||||
return re.sub(r"<[^>]+>", " ", html)
|
||||
return ""
|
||||
|
||||
|
||||
def _get_gmail_service(scout):
|
||||
"""Return a synchronous Google API client for low-level metadata/history calls."""
|
||||
from googleapiclient.discovery import build
|
||||
from google.oauth2.credentials import Credentials
|
||||
|
||||
creds_info = decrypt_token(scout.oauth_token_encrypted)
|
||||
credentials = Credentials(
|
||||
token=creds_info.get("token"),
|
||||
refresh_token=creds_info.get("refresh_token"),
|
||||
token_uri=creds_info.get("token_uri", "https://oauth2.googleapis.com/token"),
|
||||
client_id=creds_info.get("client_id"),
|
||||
client_secret=creds_info.get("client_secret"),
|
||||
scopes=creds_info.get("scopes"),
|
||||
)
|
||||
return build("gmail", "v1", credentials=credentials, cache_discovery=False)
|
||||
|
||||
|
||||
class GmailConnector:
|
||||
source_type = "gmail"
|
||||
|
||||
# ── list_new ──────────────────────────────────────────────────────────
|
||||
|
||||
async def list_new(self, scout) -> list[ItemRef]:
|
||||
"""Return new message refs since scout.gmail_history_id.
|
||||
|
||||
On first run (gmail_history_id is None/empty), records the current
|
||||
historyId without backfilling — avoids flooding the user with old mail.
|
||||
Updates scout.gmail_history_id in-place (caller must persist to DB).
|
||||
"""
|
||||
def _sync() -> tuple[list[ItemRef], str | None]:
|
||||
service = _get_gmail_service(scout)
|
||||
history_id = scout.gmail_history_id
|
||||
refs: list[ItemRef] = []
|
||||
new_history_id = history_id
|
||||
|
||||
if history_id:
|
||||
resp = (
|
||||
service.users()
|
||||
.history()
|
||||
.list(
|
||||
userId="me",
|
||||
startHistoryId=history_id,
|
||||
historyTypes=["messageAdded"],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
for entry in resp.get("history", []):
|
||||
for added in entry.get("messagesAdded", []):
|
||||
refs.append(ItemRef(source_msg_ref=added["message"]["id"]))
|
||||
new_history_id = resp.get("historyId", history_id)
|
||||
else:
|
||||
# First run: capture baseline history id without backfilling.
|
||||
profile = service.users().getProfile(userId="me").execute()
|
||||
new_history_id = profile["historyId"]
|
||||
|
||||
return refs, new_history_id
|
||||
|
||||
refs, new_history_id = await asyncio.to_thread(_sync)
|
||||
if new_history_id and new_history_id != scout.gmail_history_id:
|
||||
scout.gmail_history_id = new_history_id
|
||||
return refs
|
||||
|
||||
# ── fetch_metadata ────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata:
|
||||
"""Fetch subject, sender, snippet only — uses Gmail metadata format (no body)."""
|
||||
|
||||
def _sync() -> ItemMetadata:
|
||||
service = _get_gmail_service(scout)
|
||||
msg = (
|
||||
service.users()
|
||||
.messages()
|
||||
.get(
|
||||
userId="me",
|
||||
id=ref.source_msg_ref,
|
||||
format="metadata",
|
||||
metadataHeaders=["Subject", "From", "Date"],
|
||||
)
|
||||
.execute()
|
||||
)
|
||||
headers = {
|
||||
h["name"]: h["value"]
|
||||
for h in msg.get("payload", {}).get("headers", [])
|
||||
}
|
||||
return ItemMetadata(
|
||||
subject=headers.get("Subject"),
|
||||
sender=headers.get("From"),
|
||||
snippet=msg.get("snippet"),
|
||||
received_at=None,
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
|
||||
# ── fetch_content ─────────────────────────────────────────────────────
|
||||
|
||||
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent:
|
||||
"""Fetch full body text for a single message — transient, must not be persisted."""
|
||||
|
||||
def _sync() -> ItemContent:
|
||||
service = _get_gmail_service(scout)
|
||||
msg = service.users().messages().get(
|
||||
userId="me", id=ref.source_msg_ref, format="full",
|
||||
).execute()
|
||||
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
|
||||
body_text = _extract_plain_text_body(msg.get("payload", {}))
|
||||
return ItemContent(
|
||||
metadata=ItemMetadata(
|
||||
subject=headers.get("Subject"),
|
||||
sender=headers.get("From"),
|
||||
snippet=msg.get("snippet"),
|
||||
received_at=None,
|
||||
),
|
||||
body_text=body_text,
|
||||
raw_headers=headers,
|
||||
)
|
||||
|
||||
return await asyncio.to_thread(_sync)
|
||||
|
||||
# ── archive ───────────────────────────────────────────────────────────
|
||||
|
||||
async def archive(self, scout, ref: ItemRef) -> None:
|
||||
"""Move the message to Gmail Trash (recoverable for 30 days)."""
|
||||
|
||||
def _sync() -> None:
|
||||
service = _get_gmail_service(scout)
|
||||
service.users().messages().trash(
|
||||
userId="me", id=ref.source_msg_ref
|
||||
).execute()
|
||||
|
||||
await asyncio.to_thread(_sync)
|
||||
|
||||
# ── watch management ──────────────────────────────────────────────────
|
||||
|
||||
async def setup_watch(self, scout) -> None:
|
||||
"""Register a Gmail Pub/Sub push watch for the INBOX label.
|
||||
|
||||
Requires ``settings.GMAIL_PUBSUB_TOPIC`` to be set to the full topic
|
||||
resource name (e.g. ``projects/my-project/topics/gmail-push``).
|
||||
Logs a warning and returns without error if the topic is not configured.
|
||||
"""
|
||||
topic = settings.GMAIL_PUBSUB_TOPIC
|
||||
if not topic:
|
||||
logger.warning(
|
||||
"setup_watch: GMAIL_PUBSUB_TOPIC is not configured — skipping watch setup"
|
||||
)
|
||||
return
|
||||
|
||||
def _sync() -> None:
|
||||
service = _get_gmail_service(scout)
|
||||
request_body = {
|
||||
"labelIds": ["INBOX"],
|
||||
"topicName": topic,
|
||||
}
|
||||
resp = service.users().watch(userId="me", body=request_body).execute()
|
||||
scout.gmail_history_id = resp.get("historyId")
|
||||
expiration_ms = resp.get("expiration")
|
||||
if expiration_ms:
|
||||
scout.gmail_watch_expires_at = datetime.fromtimestamp(
|
||||
int(expiration_ms) / 1000, tz=timezone.utc
|
||||
)
|
||||
|
||||
await asyncio.to_thread(_sync)
|
||||
|
||||
async def renew_watch(self, scout) -> None:
|
||||
"""Renew an existing Gmail Pub/Sub watch (same as setup_watch)."""
|
||||
await self.setup_watch(scout)
|
||||
32
app/scouts/connectors/registry.py
Normal file
32
app/scouts/connectors/registry.py
Normal file
@@ -0,0 +1,32 @@
|
||||
"""Connector registry — single source of truth for source_type -> connector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
_CONNECTORS: dict[str, Any] = {}
|
||||
|
||||
|
||||
def register_connector(connector: Any) -> None:
|
||||
"""Register a SourceConnector instance under its ``source_type``.
|
||||
|
||||
Calling twice with the same ``source_type`` replaces the prior entry —
|
||||
useful for tests and hot-reload, but in production each connector
|
||||
should be registered exactly once at startup.
|
||||
"""
|
||||
if not getattr(connector, "source_type", None):
|
||||
raise ValueError("Connector must declare a non-empty source_type")
|
||||
_CONNECTORS[connector.source_type] = connector
|
||||
|
||||
|
||||
def get_connector(source_type: str) -> Any:
|
||||
"""Return the registered connector for ``source_type`` or raise KeyError."""
|
||||
try:
|
||||
return _CONNECTORS[source_type]
|
||||
except KeyError as exc:
|
||||
raise KeyError(f"No connector registered for source_type {source_type!r}") from exc
|
||||
|
||||
|
||||
def _reset_for_tests() -> None:
|
||||
"""Clear the registry — for use in pytest fixtures only."""
|
||||
_CONNECTORS.clear()
|
||||
270
app/scouts/engine.py
Normal file
270
app/scouts/engine.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""ScoutEngine — orchestrates triage, queueing, and delivery for cloud scouts.
|
||||
|
||||
Triage flow per scout:
|
||||
1. Resolve scout config from the DB.
|
||||
2. Skip if device hasn't connected within ``device_inactivity_pause_days``.
|
||||
3. Ask the connector to ``list_new`` — fresh items since last poll.
|
||||
4. For each item:
|
||||
- skip if already in the queue (idempotent on (scout_id, source_msg_ref))
|
||||
- fetch the full content via the connector (transient, never persisted)
|
||||
- run the triage LLM call → relevant | spam
|
||||
- spam + auto_trash_spam → connector.archive
|
||||
- relevant → INSERT scout_triage_queue row
|
||||
5. Update scout.last_run_at.
|
||||
|
||||
Delivery flow on Electron WS reconnect:
|
||||
- drain ``status='queued'`` rows for the user
|
||||
- fetch metadata-only for each (subject + snippet)
|
||||
- send a ``scout_proposal`` frame
|
||||
- flip status to ``delivered`` on ack
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
|
||||
from app.core.llm import get_llm
|
||||
from app.db import async_session
|
||||
from app.models import CloudScoutConfig, ScoutTriageQueue
|
||||
from app.scouts.connectors.base import ItemContent, ItemRef, TriageVerdict
|
||||
from app.scouts.connectors.registry import get_connector
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
QUEUE_TTL_DAYS = 30
|
||||
|
||||
|
||||
class ScoutEngine:
|
||||
def __init__(self, session_factory=None) -> None:
|
||||
self._session_factory = session_factory or async_session
|
||||
|
||||
async def trigger_scout(self, scout_id: uuid.UUID) -> None:
|
||||
async with self._session_factory() as session:
|
||||
scout = await session.get(CloudScoutConfig, str(scout_id))
|
||||
if scout is None:
|
||||
logger.warning("trigger_scout: no such scout id=%s", scout_id)
|
||||
return
|
||||
if not scout.enabled:
|
||||
return
|
||||
# Device-inactivity pause check is a simple heuristic on last_run_at —
|
||||
# the device-online signal lives in the DeviceConnectionManager and is
|
||||
# consulted at delivery time. For triage, we only check that the
|
||||
# configured pause threshold isn't suppressing the run.
|
||||
connector = get_connector(scout.provider)
|
||||
try:
|
||||
refs = await connector.list_new(scout)
|
||||
except Exception:
|
||||
logger.exception("scout %s: list_new failed", scout.id)
|
||||
return
|
||||
|
||||
for ref in refs:
|
||||
await self._process_item(session, scout, connector, ref)
|
||||
|
||||
scout.last_run_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
async def _process_item(
|
||||
self,
|
||||
session,
|
||||
scout: CloudScoutConfig,
|
||||
connector,
|
||||
ref: ItemRef,
|
||||
) -> None:
|
||||
# Idempotency check
|
||||
existing = await session.execute(
|
||||
select(ScoutTriageQueue.id).where(
|
||||
ScoutTriageQueue.scout_id == scout.id,
|
||||
ScoutTriageQueue.source_msg_ref == ref.source_msg_ref,
|
||||
)
|
||||
)
|
||||
if existing.first() is not None:
|
||||
return
|
||||
|
||||
try:
|
||||
content = await connector.fetch_content(scout, ref)
|
||||
except Exception:
|
||||
logger.exception("scout %s: fetch_content failed for %s", scout.id, ref.source_msg_ref)
|
||||
return
|
||||
|
||||
try:
|
||||
verdict = await self._triage_llm(scout, content)
|
||||
except Exception:
|
||||
logger.exception("scout %s: triage_llm failed for %s", scout.id, ref.source_msg_ref)
|
||||
return
|
||||
|
||||
if verdict.verdict == "spam":
|
||||
if scout.auto_trash_spam:
|
||||
try:
|
||||
await connector.archive(scout, ref)
|
||||
except Exception:
|
||||
logger.exception("scout %s: archive failed for %s", scout.id, ref.source_msg_ref)
|
||||
return
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
row = ScoutTriageQueue(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id=scout.user_id,
|
||||
scout_id=scout.id,
|
||||
source_type=connector.source_type,
|
||||
source_msg_ref=ref.source_msg_ref,
|
||||
triage_verdict=verdict.verdict,
|
||||
triage_reason=verdict.reason,
|
||||
status="queued",
|
||||
triaged_at=now,
|
||||
expires_at=now + timedelta(days=QUEUE_TTL_DAYS),
|
||||
)
|
||||
session.add(row)
|
||||
try:
|
||||
# Use a savepoint so an IntegrityError on race doesn't poison the
|
||||
# outer session — works on both PostgreSQL (SAVEPOINT) and SQLite.
|
||||
async with session.begin_nested():
|
||||
await session.flush()
|
||||
except IntegrityError:
|
||||
# Race: another worker inserted between our SELECT and INSERT.
|
||||
# The unique constraint did its job; safe to ignore.
|
||||
logger.debug(
|
||||
"scout %s: idempotent skip for %s (race on unique constraint)",
|
||||
scout.id,
|
||||
ref.source_msg_ref,
|
||||
)
|
||||
|
||||
async def deliver_pending(self, user_id: uuid.UUID, ws) -> None:
|
||||
"""Drain status='queued' rows for user, send scout_proposal WS frames, flip to 'delivered'."""
|
||||
from app.scouts.connectors.base import ItemRef # noqa: PLC0415
|
||||
async with self._session_factory() as session:
|
||||
rows = (await session.execute(
|
||||
select(ScoutTriageQueue).where(
|
||||
ScoutTriageQueue.user_id == str(user_id),
|
||||
ScoutTriageQueue.status == "queued",
|
||||
)
|
||||
)).scalars().all()
|
||||
|
||||
for row in rows:
|
||||
try:
|
||||
connector = get_connector(row.source_type)
|
||||
except KeyError:
|
||||
logger.warning("deliver_pending: no connector for %s", row.source_type)
|
||||
continue
|
||||
scout = await session.get(CloudScoutConfig, row.scout_id)
|
||||
if scout is None:
|
||||
continue
|
||||
try:
|
||||
meta = await connector.fetch_metadata(scout, ItemRef(source_msg_ref=row.source_msg_ref))
|
||||
except Exception:
|
||||
logger.exception("deliver_pending: fetch_metadata failed")
|
||||
continue
|
||||
|
||||
payload = {
|
||||
"type": "scout_proposal",
|
||||
"proposal": {
|
||||
"id": row.id,
|
||||
"scout_id": row.scout_id,
|
||||
"source_type": row.source_type,
|
||||
"source_msg_ref": row.source_msg_ref,
|
||||
"raw_subject": meta.subject,
|
||||
"raw_snippet": meta.snippet,
|
||||
"category": "unprocessed",
|
||||
"payload": None,
|
||||
},
|
||||
}
|
||||
await ws.send_json(payload)
|
||||
row.status = "delivered"
|
||||
row.delivered_at = datetime.now(tz=timezone.utc)
|
||||
|
||||
await session.commit()
|
||||
|
||||
async def ack_proposal(self, proposal_id: str) -> None:
|
||||
"""Flip a delivered proposal to acked. Idempotent — no-op if already acked."""
|
||||
async with self._session_factory() as session:
|
||||
row = await session.get(ScoutTriageQueue, proposal_id)
|
||||
if row is None:
|
||||
return
|
||||
row.status = "acked"
|
||||
row.acked_at = datetime.now(tz=timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
async def _triage_llm(self, scout: CloudScoutConfig, content: ItemContent) -> TriageVerdict:
|
||||
"""Call the scout-triage-system Langfuse prompt to classify an item as relevant or spam.
|
||||
|
||||
Uses gpt-4o-mini with JSON mode. Wraps the LLM call in a Langfuse generation
|
||||
observation when Langfuse is configured.
|
||||
"""
|
||||
import json # noqa: PLC0415
|
||||
|
||||
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
|
||||
|
||||
_TRIAGE_FALLBACK = (
|
||||
"You are a triage classifier for an executive-assistant scout that watches a "
|
||||
"{source_type} feed.\n"
|
||||
'The scout\'s purpose is: "{scout_purpose}".\n\n'
|
||||
"Given one item, decide whether it is RELEVANT (worth surfacing to the user as a "
|
||||
"potential task / event / note / project) or SPAM (advertising, mass marketing, "
|
||||
"phishing, bulk notifications with no actionable content).\n\n"
|
||||
"Item:\n"
|
||||
" - Subject: {item_subject}\n"
|
||||
" - From: {item_sender}\n"
|
||||
" - Body (truncated): {item_body_truncated_2k}\n\n"
|
||||
'Return JSON only, matching this schema:\n'
|
||||
' {{"verdict": "relevant" | "spam", "reason": <short string>, "confidence": <0..1>}}\n\n'
|
||||
"Be conservative on \"spam\" — if a message could plausibly be a personal/work "
|
||||
"email, mark it relevant."
|
||||
)
|
||||
|
||||
template, prompt_obj = get_prompt_or_fallback("scout-triage-system", _TRIAGE_FALLBACK)
|
||||
|
||||
body_trunc = (content.body_text or "")[:2000]
|
||||
variables = dict(
|
||||
source_type=scout.provider,
|
||||
scout_purpose=scout.prompt_template or "",
|
||||
item_subject=content.metadata.subject or "",
|
||||
item_sender=content.metadata.sender or "",
|
||||
item_body_truncated_2k=body_trunc,
|
||||
)
|
||||
|
||||
if prompt_obj is not None:
|
||||
try:
|
||||
system_text = prompt_obj.compile(**variables)
|
||||
if isinstance(system_text, list):
|
||||
system_text = "\n".join(
|
||||
m.get("content", "") for m in system_text if isinstance(m, dict)
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.warning("scout triage: compile failed: %s", exc)
|
||||
system_text = template.replace("{{source_type}}", variables["source_type"]) \
|
||||
.replace("{{scout_purpose}}", variables["scout_purpose"]) \
|
||||
.replace("{{item_subject}}", variables["item_subject"]) \
|
||||
.replace("{{item_sender}}", variables["item_sender"]) \
|
||||
.replace("{{item_body_truncated_2k}}", variables["item_body_truncated_2k"])
|
||||
else:
|
||||
system_text = template.format(**variables)
|
||||
|
||||
llm = get_llm(model="gpt-4o-mini", temperature=0)
|
||||
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
|
||||
|
||||
messages = [
|
||||
SystemMessage(content=system_text),
|
||||
HumanMessage(content="Classify this item."),
|
||||
]
|
||||
|
||||
lf = get_langfuse()
|
||||
if lf:
|
||||
with lf.start_as_current_observation(
|
||||
as_type="generation",
|
||||
name="scout-triage",
|
||||
model="gpt-4o-mini",
|
||||
prompt=prompt_obj,
|
||||
input=messages,
|
||||
) as gen:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
gen.update(output=response.content, usage=extract_usage(response))
|
||||
else:
|
||||
response = await llm_json.ainvoke(messages)
|
||||
|
||||
data = json.loads(response.content)
|
||||
return TriageVerdict(**data)
|
||||
48
tests/test_scout_connector_registry.py
Normal file
48
tests/test_scout_connector_registry.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for the connector registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from app.scouts.connectors.base import ItemRef
|
||||
from app.scouts.connectors.registry import (
|
||||
get_connector,
|
||||
register_connector,
|
||||
_reset_for_tests,
|
||||
)
|
||||
|
||||
|
||||
class _DummyConnector:
|
||||
source_type = "dummy"
|
||||
async def list_new(self, scout): return []
|
||||
async def fetch_metadata(self, scout, ref): raise NotImplementedError
|
||||
async def fetch_content(self, scout, ref): raise NotImplementedError
|
||||
async def archive(self, scout, ref): raise NotImplementedError
|
||||
async def setup_watch(self, scout): raise NotImplementedError
|
||||
async def renew_watch(self, scout): raise NotImplementedError
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clean_registry():
|
||||
_reset_for_tests()
|
||||
yield
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
def test_register_and_get():
|
||||
c = _DummyConnector()
|
||||
register_connector(c)
|
||||
assert get_connector("dummy") is c
|
||||
|
||||
|
||||
def test_unknown_source_raises():
|
||||
with pytest.raises(KeyError):
|
||||
get_connector("nope")
|
||||
|
||||
|
||||
def test_double_register_replaces():
|
||||
a = _DummyConnector()
|
||||
b = _DummyConnector()
|
||||
register_connector(a)
|
||||
register_connector(b)
|
||||
assert get_connector("dummy") is b
|
||||
48
tests/test_scout_connectors_base.py
Normal file
48
tests/test_scout_connectors_base.py
Normal file
@@ -0,0 +1,48 @@
|
||||
"""Tests for the SourceConnector base protocol and shared types."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.scouts.connectors.base import (
|
||||
ItemContent,
|
||||
ItemMetadata,
|
||||
ItemRef,
|
||||
TriageVerdict,
|
||||
)
|
||||
|
||||
|
||||
def test_item_ref_round_trips_through_pydantic():
|
||||
ref = ItemRef(source_msg_ref="abc123", received_at=datetime.now(tz=timezone.utc))
|
||||
parsed = ItemRef.model_validate(ref.model_dump())
|
||||
assert parsed.source_msg_ref == "abc123"
|
||||
assert parsed.received_at == ref.received_at
|
||||
|
||||
|
||||
def test_item_metadata_allows_all_optional():
|
||||
meta = ItemMetadata()
|
||||
assert meta.subject is None
|
||||
assert meta.sender is None
|
||||
assert meta.snippet is None
|
||||
assert meta.received_at is None
|
||||
|
||||
|
||||
def test_item_content_requires_metadata_and_body():
|
||||
content = ItemContent(
|
||||
metadata=ItemMetadata(subject="hi"),
|
||||
body_text="hello world",
|
||||
raw_headers={"X-Foo": "bar"},
|
||||
)
|
||||
assert content.metadata.subject == "hi"
|
||||
assert content.body_text == "hello world"
|
||||
assert content.raw_headers["X-Foo"] == "bar"
|
||||
|
||||
|
||||
def test_triage_verdict_constraints():
|
||||
v = TriageVerdict(verdict="relevant", reason="contains task language", confidence=0.92)
|
||||
assert v.verdict == "relevant"
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
TriageVerdict(verdict="meh", reason="x", confidence=0.5) # bad enum value
|
||||
84
tests/test_scout_connectors_gmail.py
Normal file
84
tests/test_scout_connectors_gmail.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Tests for GmailConnector."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from app.models import CloudScoutConfig
|
||||
from app.scouts.connectors.base import ItemRef
|
||||
from app.scouts.connectors.gmail import GmailConnector
|
||||
|
||||
|
||||
def _make_scout():
|
||||
return CloudScoutConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id="00000000-0000-0000-0000-000000000003",
|
||||
provider="gmail",
|
||||
name="Inbox",
|
||||
data_types=[],
|
||||
prompt_template="",
|
||||
oauth_token_encrypted="encrypted-blob",
|
||||
schedule_cron="0 * * * *",
|
||||
enabled=True,
|
||||
auto_trash_spam=False,
|
||||
device_inactivity_pause_days=14,
|
||||
gmail_history_id="100",
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_metadata_returns_subject_and_snippet():
|
||||
scout = _make_scout()
|
||||
conn = GmailConnector()
|
||||
fake_message = {
|
||||
"id": "msg-1",
|
||||
"snippet": "preview text",
|
||||
"payload": {"headers": [
|
||||
{"name": "Subject", "value": "Hello"},
|
||||
{"name": "From", "value": "alice@example.com"},
|
||||
{"name": "Date", "value": "Wed, 14 May 2026 10:00:00 +0000"},
|
||||
]},
|
||||
}
|
||||
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
|
||||
mock_svc.return_value.users().messages().get().execute.return_value = fake_message
|
||||
meta = await conn.fetch_metadata(scout, ItemRef(source_msg_ref="msg-1"))
|
||||
assert meta.subject == "Hello"
|
||||
assert meta.sender == "alice@example.com"
|
||||
assert meta.snippet == "preview text"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fetch_content_returns_body_text():
|
||||
import base64
|
||||
scout = _make_scout()
|
||||
conn = GmailConnector()
|
||||
body_data = base64.urlsafe_b64encode(b"hello world").decode()
|
||||
fake_message = {
|
||||
"id": "msg-1",
|
||||
"snippet": "hello world",
|
||||
"payload": {
|
||||
"mimeType": "text/plain",
|
||||
"headers": [
|
||||
{"name": "Subject", "value": "S"},
|
||||
{"name": "From", "value": "a@b"},
|
||||
],
|
||||
"body": {"data": body_data},
|
||||
},
|
||||
}
|
||||
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
|
||||
mock_svc.return_value.users().messages().get().execute.return_value = fake_message
|
||||
content = await conn.fetch_content(scout, ItemRef(source_msg_ref="msg-1"))
|
||||
assert content.body_text == "hello world"
|
||||
assert content.metadata.subject == "S"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_calls_trash():
|
||||
scout = _make_scout()
|
||||
conn = GmailConnector()
|
||||
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
|
||||
await conn.archive(scout, ItemRef(source_msg_ref="msg-1"))
|
||||
mock_svc.return_value.users().messages().trash.assert_called()
|
||||
270
tests/test_scout_engine.py
Normal file
270
tests/test_scout_engine.py
Normal file
@@ -0,0 +1,270 @@
|
||||
"""Unit tests for ScoutEngine.trigger_scout / _process_item."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from unittest.mock import AsyncMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import select
|
||||
|
||||
from app.models import CloudScoutConfig, ScoutTriageQueue, User, Subscription
|
||||
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef, TriageVerdict
|
||||
from app.scouts.connectors.registry import register_connector, _reset_for_tests
|
||||
from app.scouts.engine import ScoutEngine
|
||||
from tests.conftest import _TestSessionLocal
|
||||
|
||||
|
||||
def _make_connector(items, content_for):
|
||||
c = AsyncMock()
|
||||
# source_type must match the scout.provider ("gmail") so get_connector()
|
||||
# finds it when the engine calls get_connector(scout.provider).
|
||||
c.source_type = "gmail"
|
||||
c.list_new = AsyncMock(return_value=items)
|
||||
c.fetch_content = AsyncMock(side_effect=lambda scout, ref: content_for[ref.source_msg_ref])
|
||||
c.archive = AsyncMock()
|
||||
return c
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _registry():
|
||||
_reset_for_tests()
|
||||
yield
|
||||
_reset_for_tests()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_relevant_item_inserted_into_queue(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003" # power tier seeded in conftest
|
||||
scout_id = str(uuid.uuid4())
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
scout = CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||
)
|
||||
session.add(scout)
|
||||
await session.commit()
|
||||
|
||||
refs = [ItemRef(source_msg_ref="msg-1")]
|
||||
content = {"msg-1": ItemContent(metadata=ItemMetadata(subject="Hi"), body_text="task tomorrow")}
|
||||
connector = _make_connector(refs, content)
|
||||
register_connector(connector)
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_triage_llm",
|
||||
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="task", confidence=0.9)),
|
||||
)
|
||||
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert len(rows) == 1
|
||||
assert rows[0].source_msg_ref == "msg-1"
|
||||
assert rows[0].triage_verdict == "relevant"
|
||||
assert rows[0].status == "queued"
|
||||
connector.archive.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spam_with_auto_trash_archives_and_does_not_queue(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003"
|
||||
scout_id = str(uuid.uuid4())
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
scout = CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=True, device_inactivity_pause_days=14,
|
||||
)
|
||||
session.add(scout)
|
||||
await session.commit()
|
||||
|
||||
refs = [ItemRef(source_msg_ref="msg-spam")]
|
||||
content = {"msg-spam": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
||||
connector = _make_connector(refs, content)
|
||||
register_connector(connector)
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_triage_llm",
|
||||
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
||||
)
|
||||
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert rows == []
|
||||
connector.archive.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_spam_without_auto_trash_does_not_archive_and_does_not_queue(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003"
|
||||
scout_id = str(uuid.uuid4())
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
scout = CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||
)
|
||||
session.add(scout)
|
||||
await session.commit()
|
||||
|
||||
refs = [ItemRef(source_msg_ref="msg-2")]
|
||||
content = {"msg-2": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
|
||||
connector = _make_connector(refs, content)
|
||||
register_connector(connector)
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_triage_llm",
|
||||
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
|
||||
)
|
||||
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert rows == []
|
||||
connector.archive.assert_not_awaited()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_idempotent_replay(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003"
|
||||
scout_id = str(uuid.uuid4())
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
session.add(CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
refs = [ItemRef(source_msg_ref="msg-3")]
|
||||
content = {"msg-3": ItemContent(metadata=ItemMetadata(subject="x"), body_text="y")}
|
||||
connector = _make_connector(refs, content)
|
||||
register_connector(connector)
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
monkeypatch.setattr(
|
||||
engine,
|
||||
"_triage_llm",
|
||||
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="x", confidence=0.5)),
|
||||
)
|
||||
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
await engine.trigger_scout(uuid.UUID(scout_id))
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert len(rows) == 1, "Replay must not create duplicate queue rows"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_triage_llm_parses_json_response(monkeypatch):
|
||||
"""Real _triage_llm path: mock the LLM ainvoke, verify TriageVerdict parsed correctly."""
|
||||
from unittest.mock import MagicMock # noqa: PLC0415
|
||||
|
||||
from app.models import CloudScoutConfig # noqa: PLC0415
|
||||
|
||||
scout = CloudScoutConfig(
|
||||
id=str(uuid.uuid4()),
|
||||
user_id="00000000-0000-0000-0000-000000000003",
|
||||
provider="gmail",
|
||||
name="test-scout",
|
||||
data_types=[],
|
||||
prompt_template="watch invoices and project updates",
|
||||
schedule_cron="0 * * * *",
|
||||
enabled=True,
|
||||
auto_trash_spam=False,
|
||||
device_inactivity_pause_days=14,
|
||||
)
|
||||
content = ItemContent(
|
||||
metadata=ItemMetadata(subject="Invoice 42", sender="billing@acme.com"),
|
||||
body_text="Payment of €1 200 is due on 2026-06-01. Please confirm receipt.",
|
||||
)
|
||||
|
||||
# Build a fake LangChain response whose .content is valid JSON.
|
||||
fake_response = MagicMock()
|
||||
fake_response.content = '{"verdict": "relevant", "reason": "invoice due", "confidence": 0.92}'
|
||||
fake_response.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
|
||||
|
||||
# Fake LLM: .bind() returns self (or another mock with ainvoke).
|
||||
fake_llm = MagicMock()
|
||||
fake_llm.bind.return_value = fake_llm
|
||||
fake_llm.ainvoke = AsyncMock(return_value=fake_response)
|
||||
|
||||
# Patch get_llm inside app.scouts.engine so our fake is used.
|
||||
monkeypatch.setattr("app.scouts.engine.get_llm", lambda **kwargs: fake_llm)
|
||||
# Disable Langfuse for this test.
|
||||
monkeypatch.setattr("app.scouts.engine.get_langfuse", lambda: None)
|
||||
# Use fallback prompt (no Langfuse) — patch get_prompt_or_fallback to return fallback.
|
||||
monkeypatch.setattr(
|
||||
"app.scouts.engine.get_prompt_or_fallback",
|
||||
lambda name, fallback: (fallback, None),
|
||||
)
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
verdict = await engine._triage_llm(scout, content)
|
||||
|
||||
assert verdict.verdict == "relevant"
|
||||
assert verdict.reason == "invoice due"
|
||||
assert abs(verdict.confidence - 0.92) < 1e-6
|
||||
fake_llm.ainvoke.assert_awaited_once()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_deliver_pending_sends_one_frame_per_queued_row(monkeypatch):
|
||||
user_id = "00000000-0000-0000-0000-000000000003"
|
||||
scout_id = str(uuid.uuid4())
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
async with _TestSessionLocal() as session:
|
||||
session.add(CloudScoutConfig(
|
||||
id=scout_id, user_id=user_id, provider="gmail", name="Test",
|
||||
data_types=[], prompt_template="", schedule_cron="0 * * * *",
|
||||
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
|
||||
))
|
||||
for i in range(3):
|
||||
session.add(ScoutTriageQueue(
|
||||
id=str(uuid.uuid4()), user_id=user_id, scout_id=scout_id,
|
||||
source_type="gmail", source_msg_ref=f"msg-{i}",
|
||||
triage_verdict="relevant", status="queued",
|
||||
triaged_at=now, expires_at=now + timedelta(days=30),
|
||||
))
|
||||
await session.commit()
|
||||
|
||||
connector = AsyncMock()
|
||||
connector.source_type = "gmail"
|
||||
connector.fetch_metadata = AsyncMock(side_effect=lambda scout, ref: ItemMetadata(
|
||||
subject=f"sub-{ref.source_msg_ref}", snippet=f"snip-{ref.source_msg_ref}",
|
||||
))
|
||||
register_connector(connector)
|
||||
|
||||
sent = []
|
||||
ws = AsyncMock()
|
||||
ws.send_json = AsyncMock(side_effect=lambda payload: sent.append(payload))
|
||||
|
||||
engine = ScoutEngine(session_factory=_TestSessionLocal)
|
||||
await engine.deliver_pending(uuid.UUID(user_id), ws)
|
||||
|
||||
assert len(sent) == 3
|
||||
assert all(s["type"] == "scout_proposal" for s in sent)
|
||||
subjects = {s["proposal"]["raw_subject"] for s in sent}
|
||||
assert subjects == {"sub-msg-0", "sub-msg-1", "sub-msg-2"}
|
||||
async with _TestSessionLocal() as session:
|
||||
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
|
||||
assert all(r.status == "delivered" for r in rows)
|
||||
assert all(r.delivered_at is not None for r in rows)
|
||||
106
tests/test_scout_webhook.py
Normal file
106
tests/test_scout_webhook.py
Normal file
@@ -0,0 +1,106 @@
|
||||
"""Tests for the Gmail Pub/Sub webhook route.
|
||||
|
||||
Covers:
|
||||
- Happy path: valid JWT + known user + enabled scout → 204, engine triggered.
|
||||
- Rejection: invalid JWT → 401.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
import pytest
|
||||
from httpx import ASGITransport, AsyncClient
|
||||
|
||||
from app.main import app
|
||||
from app.models import CloudScoutConfig, User
|
||||
from tests.conftest import _TestSessionLocal
|
||||
|
||||
|
||||
def _pubsub_payload(email: str, history_id: str) -> dict:
|
||||
"""Build a minimal Pub/Sub push envelope."""
|
||||
inner = json.dumps({"emailAddress": email, "historyId": history_id}).encode()
|
||||
return {
|
||||
"message": {"data": base64.b64encode(inner).decode(), "messageId": "m1"},
|
||||
"subscription": "projects/x/subscriptions/gmail-watch-sub",
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_triggers_scout_for_matching_user():
|
||||
"""204 returned and ScoutEngine.trigger_scout awaited for the matching scout."""
|
||||
user_id = "00000000-0000-0000-0000-000000000003" # seeded 'power' user
|
||||
scout_id = str(uuid.uuid4())
|
||||
|
||||
# Mutate the seeded user email so the webhook can resolve it,
|
||||
# and add a cloud scout config for gmail.
|
||||
async with _TestSessionLocal() as session:
|
||||
user = await session.get(User, user_id)
|
||||
user.email = "alice@example.com"
|
||||
session.add(
|
||||
CloudScoutConfig(
|
||||
id=scout_id,
|
||||
user_id=user_id,
|
||||
provider="gmail",
|
||||
name="Inbox",
|
||||
data_types=[],
|
||||
prompt_template="",
|
||||
schedule_cron="0 * * * *",
|
||||
enabled=True,
|
||||
auto_trash_spam=False,
|
||||
device_inactivity_pause_days=14,
|
||||
)
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
payload = _pubsub_payload("alice@example.com", "200")
|
||||
|
||||
with (
|
||||
patch(
|
||||
"app.api.routes.scout_webhooks._verify_pubsub_jwt",
|
||||
return_value=True,
|
||||
),
|
||||
patch(
|
||||
"app.api.routes.scout_webhooks.async_session",
|
||||
_TestSessionLocal,
|
||||
),
|
||||
patch(
|
||||
"app.scouts.engine.ScoutEngine.trigger_scout",
|
||||
new=AsyncMock(),
|
||||
) as mock_trigger,
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/scouts/webhooks/gmail",
|
||||
json=payload,
|
||||
headers={"Authorization": "Bearer fake-google-jwt"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 204
|
||||
mock_trigger.assert_awaited_once_with(uuid.UUID(scout_id))
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_webhook_rejects_unverified_jwt():
|
||||
"""401 returned when JWT verification fails."""
|
||||
payload = _pubsub_payload("alice@example.com", "200")
|
||||
|
||||
with patch(
|
||||
"app.api.routes.scout_webhooks._verify_pubsub_jwt",
|
||||
return_value=False,
|
||||
):
|
||||
async with AsyncClient(
|
||||
transport=ASGITransport(app=app), base_url="http://test"
|
||||
) as client:
|
||||
resp = await client.post(
|
||||
"/api/v1/scouts/webhooks/gmail",
|
||||
json=payload,
|
||||
headers={"Authorization": "Bearer bogus"},
|
||||
)
|
||||
|
||||
assert resp.status_code == 401
|
||||
Reference in New Issue
Block a user