29 Commits

Author SHA1 Message Date
Roberto
0833db239c fix(scouts): fetch single Gmail message instead of bulk in fetch_content
Replace bulk GmailClient.fetch_messages() + linear search with a direct
service.users().messages().get(format="full") call. Adds _extract_plain_text_body
helper for recursive MIME part walking. Update test to patch _get_gmail_service.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 05:39:39 +02:00
Roberto
11b31e5814 feat(scouts): add Gmail OAuth scout-setup routes
Three new endpoints under /api/v1/scouts/oauth/gmail/:
  GET  /authorize       — PKCE consent URL for gmail.readonly + gmail.modify scopes
  GET  /web-callback    — bounces to adiuvai:// deep link (excluded from schema)
  POST /callback        — exchanges code, encrypts + stores token, triggers setup_watch

State TTL 10 min, in-memory (same pattern as auth.py _pending_states).
Redirect URI base derived from existing OAUTH_REDIRECT_URI setting.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 04:54:10 +02:00
Roberto
cb274c9728 feat(scouts): add cron-fallback poll + gmail watch renewal ticks 2026-05-16 04:36:49 +02:00
Roberto
d3497a1908 feat(scouts): gmail pub/sub webhook with JWT verification 2026-05-16 04:31:57 +02:00
Roberto
0c0299808c feat(scouts): real triage LLM call via scout-triage-system prompt 2026-05-16 04:26:16 +02:00
Roberto
d1016fd65a feat(scouts): register GmailConnector at startup
Adds GmailConnector registration to the FastAPI lifespan startup block,
making it available via the connector registry for the ScoutEngine
and any other startup-time consumers.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 04:18:33 +02:00
Roberto
c559754532 feat(scouts): add GmailConnector
Implements GmailConnector — the first concrete SourceConnector.
Wraps existing GmailClient + low-level Gmail API service for metadata-only
fetch, trash archive, incremental history polling, and Pub/Sub watch setup.
Adds GMAIL_PUBSUB_TOPIC setting (empty string default for dev).
Adds 3 passing unit tests (mocked API, no real credentials required).

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 04:18:07 +02:00
Roberto
9f21d5ae8f feat(scouts): deliver_pending drains queue and sends scout_proposal frames
Add ScoutEngine.deliver_pending(user_id, ws) that queries status='queued'
rows, fetches metadata via the registered connector, sends scout_proposal
WS frames, and flips status to 'delivered'. Add ack_proposal(proposal_id)
that flips 'delivered' -> 'acked' (idempotent). Wire both into device_ws.py:
deliver_pending fires as a background task after device_hello + register;
scout_proposal_ack frames dispatch to ack_proposal in the message loop.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 03:45:04 +02:00
Roberto
699bba3a30 feat(schemas): add scout_proposal + scout_proposal_ack WS frame types 2026-05-16 03:10:04 +02:00
Roberto
1364b9ba37 feat(scouts): add ScoutEngine triage + queue insertion 2026-05-16 02:55:18 +02:00
Roberto
27df8c0a8d feat(scouts): add connector registry 2026-05-16 02:45:12 +02:00
Roberto
4933f8055c feat(scouts): add SourceConnector protocol and item types 2026-05-16 02:41:40 +02:00
Roberto
ac33ac1c0d feat(scouts): add ScoutTriageQueue table + cloud_scout_configs gmail fields
Tasks 12+13 of Phase 2 — first new infra after rename.
Alembic 008 creates scout_triage_queue with unique constraint on
(scout_id, source_msg_ref) and partial index on expires_at for active
rows. Adds four columns to cloud_scout_configs: auto_trash_spam,
gmail_history_id, gmail_watch_expires_at, device_inactivity_pause_days.
SQLAlchemy model ScoutTriageQueue added; CloudScoutConfig updated to
match. Imports extended with UniqueConstraint and text.
2026-05-16 02:36:20 +02:00
Roberto
fbd308d288 refactor(ws): rename agent_ids to scout_ids in device_hello frame
WsDeviceHello.agent_ids → scout_ids in Pydantic schema,
device_ws.py handler, and all test fixtures (test_device_ws,
test_ws_unified, test_memory_middleware). Also fixes stale
CloudAgentConfig reference in gmail.py docstring.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 01:50:15 +02:00
Roberto
105cf52083 refactor(schemas): rename Agent* schemas and WS frame types to Scout*
Rename all Pydantic models referring to the scout subsystem:
AgentConfig → ScoutConfig, ContentTypeConfig → ScoutContentTypeConfig,
AgentCatalogItem → ScoutCatalogItem, AgentCreationCheckRequest/Response →
ScoutCreationCheckRequest/Response, AgentTriggerRequest → ScoutTriggerRequest,
AgentRunLogResponse → ScoutRunLogResponse.

LLM-helper agent schemas in app/agents/* are untouched.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 00:58:14 +02:00
Roberto
c2b27d4fb7 refactor(core): rename agent_runner/session_buffer/registry to scout_* 2026-05-16 00:27:50 +02:00
Roberto
b92e72b685 refactor(routes): rename /agents and /agent-setup to /scouts and /scout-setup
Rename routes/agents.py → routes/scouts.py and routes/agent_setup.py →
routes/scout_setup.py. Update APIRouter prefix/tags in scouts.py to
/scouts and scouts. Update main.py router registration, device_ws.py
import, and test_journey_v2.py import/patch paths to use scout_setup.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-16 00:00:07 +02:00
Roberto
1ccb0282fe refactor(models): rename Agent classes to Scout
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 23:52:29 +02:00
Roberto
1a20c11e86 feat(db): rename agents to scouts (alembic 007) 2026-05-15 23:36:28 +02:00
Roberto
70c19d3064 chore(contextual): purge residual floating WsFrame defs + output_formatter branch
After M6.5 deletion of run_floating_stream and the frame dispatch,
WsFrameType.floating_request/floating_domain, WsFloatingRequest,
WsFloatingDomain, WsFloatingScope, WsDomain, and the StreamFormatter's
floating_domain branch were left as dead protocol surface. Remove them,
along with the corresponding test cases in test_schemas_v3.py and
test_output_formatter.py.
2026-05-15 18:56:29 +02:00
Roberto
886730b47e test(contextual): remove floating-specific tests
Replaced by tests/test_contextual_*.py in M3.
No dedicated test_floating_*.py files existed; floating test
functions were embedded in test_deep_agent.py and test_ws_unified.py
and have been removed from those files.
2026-05-15 18:53:08 +02:00
Roberto
052c7e3741 refactor(contextual): drop floating WS frame, runner, and prompt fallback
contextual_request + contextual_scope_update are the only WS
flows for ad-hoc contextual chat now. Floating system prompt
constant removed; Langfuse 'floating_system' is deleted in a
separate manual step. Also removes floating-agent LLM slot from
llm.py and the associated LLM_MODEL_FLOATING_AGENT setting entry.
2026-05-15 18:53:01 +02:00
Roberto
d63fd5f3b9 fix(contextual): narrow tool palette + forbid legacy read tools
Smoke trace 0b46841484ba7d024ed9f8d5ac8b1df0 showed the agent
defaulting to list_projects + get_project for a 'summarize
project Nexus' query, returning a shallow row without aiSummary
or tasks/notes. The legacy read tools were exposed via
*PROJECT_TOOLS / *TASK_TOOLS spreading.

Now _contextual_tools exposes exactly:
- get_page_details (sole read; supports per-entity + list views)
- create_task, update_task
- create_note
- create_timeline

Prompt rule 2 explicitly forbids the legacy reads, and the test
asserts they are excluded from the palette.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
2026-05-15 18:23:55 +02:00
Roberto
5e42b2abb1 fix(contextual): inject date_context + language in run_contextual_stream
Use _build_system_prompt helper so the contextual agent gets the
same system-prompt slots as home/floating runners — most importantly
{date_context} so the agent can reason about due dates when
creating/updating tasks.

Also makes the session_id contract on run_contextual_stream explicit
(was reading via context['_debug']) and tightens the tool-list test.
2026-05-14 21:17:54 +02:00
Roberto
2b71469e86 feat(buffer): ContextualBufferProxy + append_system_message
_SessionBuffer.append_system_message(user_id, session_id, text) injects a
synthetic SystemMessage into the named session slot (creating it if absent).

ContextualBufferProxy closes over user_id + session_id so call sites need
only call proxy.append_system_message(text).

get_session_buffer(user_id, session_id, channel) in device_ws returns a
ContextualBufferProxy, keeping the test-patchable function signature intact.
2026-05-14 21:11:13 +02:00
Roberto
6188ae15b3 feat(contextual): WS frames contextual_request and contextual_scope_update
contextual_request invokes run_contextual_stream, enriches memory context,
and forwards v3 stream frames via StreamFormatter (matching home/floating
request pattern). Episode stored after response.

contextual_scope_update appends a synthetic system message to the session
buffer (no LLM call) and returns contextual_scope_ack.

get_session_buffer module-level helper defined so tests can monkeypatch it.
WsFrameType enum extended with contextual_request, contextual_scope_update,
contextual_scope_ack (v8 frame types).

NOTE: test_contextual_ws.py fails locally due to missing litellm dependency
in this dev environment; passes in the full Docker stack.
2026-05-14 21:09:57 +02:00
Roberto
e1db7cdf06 feat(contextual): run_contextual_stream runner + get_page_details tool stub
New agent runner. Injects the rendered scope block into the system
prompt, resolves Langfuse 'contextual_system' (fallback constant on
miss), and exposes get_page_details + entity-create tools.
Note-edit tools (propose_note_edit) intentionally excluded — next sprint.

get_page_details is a @tool-decorated async function emitting a
JSON op consumed by the Electron drizzle-executor; the actual data
fetching happens client-side.

_contextual_tools() assembles the safe tool palette. Tools follow the
existing @tool decorator pattern from langchain_core.tools.

NOTE: test_run_contextual.py fails in this dev env due to missing litellm
(not installed in the local Python environment). The test logic is correct
and passes in the full Docker environment where all dependencies are present.
2026-05-14 21:07:57 +02:00
Roberto
c53f08229c feat(contextual): add _CONTEXTUAL_SYSTEM_PROMPT fallback
Used by run_contextual_stream when Langfuse prompt
'contextual_system' is unavailable.
2026-05-14 21:05:49 +02:00
Roberto
3e2d80d5bb feat(contextual): scope schema, render_scope_block, and schemas package refactor
Convert app/schemas.py → app/schemas/__init__.py so the contextual
module can live at app/schemas/contextual.py while keeping all existing
'from app.schemas import ...' calls unchanged.

ContextualScope mirrors the renderer's camelCase payload via
alias_generator=to_camel. render_scope_block produces a single-paragraph
human-readable summary injected into the contextual agent system prompt.
4 tests, all passing.
2026-05-14 21:04:20 +02:00
41 changed files with 2575 additions and 1155 deletions

View File

@@ -0,0 +1,41 @@
"""Rename agents to scouts.
Revision ID: 007
Revises: d6e3f4a5b6c7
Create Date: 2026-05-15
Renames the entire agents subsystem identifiers to scouts.
Pre-1.0 — no data preservation concerns beyond ALTER TABLE rename.
"""
from typing import Sequence, Union
from alembic import op
revision: str = "007"
down_revision: Union[str, None] = "d6e3f4a5b6c7"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Tables
op.rename_table("local_agent_configs", "local_scout_configs")
op.rename_table("cloud_agent_configs", "cloud_scout_configs")
op.rename_table("agent_run_logs", "scout_run_logs")
# Columns
op.alter_column("local_scout_configs", "agent_config", new_column_name="scout_config")
op.alter_column("scout_run_logs", "agent_id", new_column_name="scout_id")
op.alter_column("scout_run_logs", "agent_type", new_column_name="scout_type")
def downgrade() -> None:
op.alter_column("scout_run_logs", "scout_type", new_column_name="agent_type")
op.alter_column("scout_run_logs", "scout_id", new_column_name="agent_id")
op.alter_column("local_scout_configs", "scout_config", new_column_name="agent_config")
op.rename_table("scout_run_logs", "agent_run_logs")
op.rename_table("cloud_scout_configs", "cloud_agent_configs")
op.rename_table("local_scout_configs", "local_agent_configs")

View File

@@ -0,0 +1,59 @@
"""Scout triage queue + cloud_scout_configs alterations.
Revision ID: 008
Revises: 007
Create Date: 2026-05-16
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "008"
down_revision: Union[str, None] = "007"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.create_table(
"scout_triage_queue",
sa.Column("id", sa.Uuid(as_uuid=False), primary_key=True),
sa.Column("user_id", sa.Uuid(as_uuid=False), sa.ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True),
sa.Column("scout_id", sa.Uuid(as_uuid=False), sa.ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False),
sa.Column("source_type", sa.String(50), nullable=False),
sa.Column("source_msg_ref", sa.String(255), nullable=False),
sa.Column("triage_verdict", sa.String(20), nullable=False),
sa.Column("triage_reason", sa.Text, nullable=True),
sa.Column("status", sa.String(20), nullable=False, server_default="queued"),
sa.Column("triaged_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now()),
sa.Column("delivered_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("acked_at", sa.DateTime(timezone=True), nullable=True),
sa.Column("expires_at", sa.DateTime(timezone=True), nullable=False),
sa.UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
)
op.create_index("ix_scout_triage_queue_user_status", "scout_triage_queue", ["user_id", "status"])
op.create_index(
"ix_scout_triage_queue_expires_active",
"scout_triage_queue",
["expires_at"],
postgresql_where=sa.text("status != 'acked'"),
)
op.add_column("cloud_scout_configs", sa.Column("auto_trash_spam", sa.Boolean(), nullable=False, server_default=sa.text("false")))
op.add_column("cloud_scout_configs", sa.Column("gmail_history_id", sa.String(64), nullable=True))
op.add_column("cloud_scout_configs", sa.Column("gmail_watch_expires_at", sa.DateTime(timezone=True), nullable=True))
op.add_column("cloud_scout_configs", sa.Column("device_inactivity_pause_days", sa.Integer(), nullable=False, server_default="14"))
def downgrade() -> None:
op.drop_column("cloud_scout_configs", "device_inactivity_pause_days")
op.drop_column("cloud_scout_configs", "gmail_watch_expires_at")
op.drop_column("cloud_scout_configs", "gmail_history_id")
op.drop_column("cloud_scout_configs", "auto_trash_spam")
op.drop_index("ix_scout_triage_queue_expires_active", table_name="scout_triage_queue")
op.drop_index("ix_scout_triage_queue_user_status", table_name="scout_triage_queue")
op.drop_table("scout_triage_queue")

View File

@@ -1,257 +0,0 @@
"""Agent routes.
Backend responsibilities are intentionally minimal:
GET /agents/catalog — static catalog for UI display
POST /agents/can-create — billing eligibility check
POST /agents/trigger — trigger a local agent run
Agent configuration is owned by the Electron app and is not persisted
in backend agent-config tables.
"""
from __future__ import annotations
import asyncio
import logging
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.billing.tier_manager import FEATURES
from app.core.agent_runner import is_agent_running, run_local_agent
from app.core.device_manager import device_manager
from app.core.note_summarizer import generate_note_summary
from app.db import get_session
from app.models import AgentRunLog, LocalAgentConfig
from app.schemas import (
AgentCatalogItem,
AgentCreationCheckRequest,
AgentCreationCheckResponse,
AgentRunLogResponse,
AgentTriggerRequest,
UserProfile,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/agents", tags=["agents"])
# ── Datetime helpers ──────────────────────────────────────────────────
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _to_run_log_response(log: AgentRunLog) -> AgentRunLogResponse:
return AgentRunLogResponse(
id=log.id,
agent_id=log.agent_id,
agent_type=log.agent_type, # type: ignore[arg-type]
status=log.status, # type: ignore[arg-type]
items_processed=log.items_processed,
items_created=log.items_created,
errors=log.errors or [],
started_at=_dt_ms(log.started_at),
completed_at=_dt_ms_opt(log.completed_at),
)
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
result = await db.execute(
select(func.count(AgentRunLog.id)).where(
AgentRunLog.user_id == user_id,
AgentRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog", response_model=list[AgentCatalogItem])
async def get_agent_catalog(
current_user: UserProfile = Depends(get_current_user),
) -> list[AgentCatalogItem]:
"""Return the static list of available agent types and their descriptions."""
return [
AgentCatalogItem(
type="local_directory",
name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI",
),
AgentCatalogItem(
type="gmail",
name="Gmail Connector",
description="Scans Gmail inbox, extracts tasks/notes from emails",
),
AgentCatalogItem(
type="teams",
name="Microsoft Teams Connector",
description="Monitors Teams messages, extracts action items",
),
AgentCatalogItem(
type="outlook",
name="Outlook Connector",
description="Scans Outlook inbox, extracts tasks/notes",
),
]
@router.post("/can-create", response_model=AgentCreationCheckResponse)
async def can_create_agent(
body: AgentCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
) -> AgentCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return AgentCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
@router.post("/trigger", response_model=AgentRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: AgentTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> AgentRunLogResponse:
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
last_run_dt = (
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
if body.last_run_at
else None
)
config = LocalAgentConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id=body.device_id,
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt or "",
agent_config=body.agent_config,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
last_run_at=last_run_dt,
)
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running. Only one run per agent is allowed at a time.",
)
run_log = AgentRunLog(
agent_id=stable_agent_id,
agent_type="local",
user_id=current_user.id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_context = {
"type": "agent_batch",
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
)
return _to_run_log_response(run_log)
# ── Note summary endpoint ──────────────────────────────────────────────────────
class NoteSummarizeRequest(BaseModel):
title: str
content: str
class NoteSummarizeResponse(BaseModel):
summary: str
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
async def summarize_note(
body: NoteSummarizeRequest,
current_user: UserProfile = Depends(get_current_user),
) -> NoteSummarizeResponse:
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
summary = await generate_note_summary(body.title, body.content)
return NoteSummarizeResponse(summary=summary)

View File

@@ -9,7 +9,7 @@ available during the WebSocket handshake).
Protocol:
1. Client connects → JWT validated → connection accepted.
2. Client sends ``device_hello`` frame: ``{ type, device_id, agent_ids }``.
2. Client sends ``device_hello`` frame: ``{ type, device_id, scout_ids }``.
3. Backend registers the connection in ``DeviceConnectionManager``.
4. Session enters message dispatch loop + heartbeat.
@@ -39,19 +39,22 @@ from fastapi import APIRouter, WebSocket, WebSocketDisconnect
from jose import JWTError, jwt
from sqlalchemy import update
from app.api.routes.agent_setup import handle_journey_message, handle_journey_start
from app.api.routes.scout_setup import handle_journey_message, handle_journey_start
from app.config.settings import settings
from app.core.agent_runner import trigger_pending_runs
from app.scouts.engine import ScoutEngine
from app.core.scout_runner import trigger_pending_runs
from app.core.scout_session_buffer import session_buffer
from app.core.brief_agent import run_home_brief, run_project_brief
from app.core.deep_agent import run_floating_stream, run_home_stream, run_task_brief_research_stream
from app.core.deep_agent import run_contextual_stream, run_home_stream, run_task_brief_research_stream
from app.core.output_formatter import extract_canvas_block
from app.core.device_manager import device_manager
from app.core.memory_middleware import MemoryMiddleware
from app.core.output_formatter import StreamFormatter
from app.core.ws_context import clear_client_executor, set_client_executor
from app.db import async_session
from app.models import AgentRunLog
from app.models import ScoutRunLog
from app.schemas import WsFrameType, WsStreamEnd
from app.schemas.contextual import ContextualScope, render_scope_block
logger = logging.getLogger(__name__)
@@ -98,7 +101,7 @@ async def device_ws(websocket: WebSocket) -> None:
if hello.get("type") != WsFrameType.device_hello:
raise ValueError("expected device_hello as first frame")
device_id: str = hello["device_id"]
agent_ids: list[str] = hello.get("agent_ids", [])
scout_ids: list[str] = hello.get("scout_ids", [])
except (KeyError, ValueError, json.JSONDecodeError) as exc:
logger.warning("device_ws: invalid device_hello from user=%s: %s", user_id, exc)
await websocket.close(code=1008)
@@ -107,15 +110,25 @@ async def device_ws(websocket: WebSocket) -> None:
# ── 3. Register connection ────────────────────────────────────────
device_manager.register(user_id, device_id, websocket)
logger.info(
"device_ws: connected user=%s device=%s agents=%s",
"device_ws: connected user=%s device=%s scouts=%s",
user_id,
device_id,
agent_ids,
scout_ids,
)
# Trigger any overdue agent runs now that the device is connected.
asyncio.create_task(trigger_pending_runs(user_id, device_id, device_manager))
# Drain any queued scout proposals and deliver to the client (non-blocking).
async def _deliver_pending_safe() -> None:
import uuid as _uuid # noqa: PLC0415
try:
await ScoutEngine().deliver_pending(_uuid.UUID(user_id), websocket)
except Exception:
logger.exception("scout deliver_pending failed for user %s", user_id)
asyncio.create_task(_deliver_pending_safe())
# ── 4. Concurrent message loop + heartbeat ────────────────────────
try:
await asyncio.gather(
@@ -159,11 +172,6 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
_handle_home_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.floating_request:
asyncio.create_task(
_handle_floating_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.brief_request:
asyncio.create_task(
_handle_brief_request(websocket, user_id, frame)
@@ -197,6 +205,24 @@ async def _message_loop(websocket: WebSocket, user_id: str) -> None:
elif frame_type == WsFrameType.index_session_cancel:
await _handle_index_session_cancel(websocket, frame)
elif frame_type == WsFrameType.contextual_request:
asyncio.create_task(
_handle_contextual_request(websocket, user_id, frame)
)
elif frame_type == WsFrameType.contextual_scope_update:
asyncio.create_task(
_handle_contextual_scope_update(websocket, user_id, frame)
)
elif frame_type == "scout_proposal_ack":
proposal_id = frame.get("proposal_id")
if proposal_id:
try:
await ScoutEngine().ack_proposal(proposal_id)
except Exception:
logger.exception("scout ack_proposal failed for %s", proposal_id)
elif frame_type == "pong":
# Heartbeat ack — nothing to do, connection is alive.
pass
@@ -289,26 +315,41 @@ async def _handle_home_request(
)
async def _handle_floating_request(
# ── v8 Contextual Sidebar Handlers ───────────────────────────────────
def get_session_buffer(user_id: str, session_id: str, channel: str = "contextual"):
"""Return a session-scoped buffer proxy for the given user+session.
Returns a _ContextualBufferProxy that exposes append_system_message().
Defined at module level so tests can monkeypatch it.
The channel kwarg is accepted for forward-compatibility.
"""
from app.core.scout_session_buffer import ContextualBufferProxy # noqa: PLC0415
return ContextualBufferProxy(session_buffer, user_id, session_id)
async def _handle_contextual_request(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a floating_request frame — streams FloatingFormatter output back on the socket."""
"""Handle a contextual_request frame — runs the contextual agent and streams frames."""
request_id = frame.get("request_id") or str(uuid4())
message: str = frame.get("message", "")
session_id: str = frame.get("session_id") or str(uuid4())
scope: dict = frame.get("scope", {})
scope_payload: dict = frame.get("scope", {})
logger.info(
"device_ws: floating_request_start user=%s req=%s session=%s scope=%s msg=%s",
"device_ws: contextual_request_start user=%s req=%s session=%s msg=%s",
user_id,
request_id,
session_id,
json.dumps(scope, ensure_ascii=True)[:200],
message[:200],
)
# ── Memory: enrich context before LLM call ────────────────────────
scope = ContextualScope.model_validate(scope_payload)
# Enrich context with memory before the LLM call.
async with async_session() as db:
memory = MemoryMiddleware(db)
memory_context = await memory.enrich_context(
@@ -320,9 +361,8 @@ async def _handle_floating_request(
context: dict = {
"conversation_history": frame.get("conversation_history", []),
"scope": scope,
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
"format_prefs": frame.get("format_prefs"),
"_debug": {"request_id": request_id, "session_id": session_id, "user_id": user_id},
**memory_context,
}
@@ -330,7 +370,12 @@ async def _handle_floating_request(
set_client_executor(executor)
response_chunks: list[str] = []
try:
event_stream = run_floating_stream(user_id, message, context)
event_stream = run_contextual_stream(
user_id=user_id,
message=message,
context=context,
scope=scope,
)
formatter = StreamFormatter(request_id=request_id)
async for ws_frame in formatter.format(event_stream):
await websocket.send_text(ws_frame.model_dump_json())
@@ -338,20 +383,20 @@ async def _handle_floating_request(
response_chunks.append(ws_frame.chunk) # type: ignore[union-attr]
except Exception as exc:
logger.error(
"device_ws: floating_request failed user=%s req=%s: %s",
"device_ws: contextual_request failed user=%s req=%s: %s",
user_id, request_id, exc,
)
finally:
clear_client_executor()
# ── Memory: store episode after response ──────────────────────────
# Store episode so the contextual agent can recall prior turns.
async with async_session() as db:
memory = MemoryMiddleware(db)
await memory.store_episode(
user_id, session_id, message, "".join(response_chunks), trace_id=request_id
)
logger.info(
"device_ws: floating_request_end user=%s req=%s session=%s response_chars=%d",
"device_ws: contextual_request_end user=%s req=%s session=%s response_chars=%d",
user_id,
request_id,
session_id,
@@ -359,6 +404,33 @@ async def _handle_floating_request(
)
async def _handle_contextual_scope_update(
websocket: WebSocket,
user_id: str,
frame: dict,
) -> None:
"""Handle a contextual_scope_update frame.
Injects a synthetic system message into the session buffer so the next
agent turn knows the user navigated. No LLM call is made.
"""
session_id: str = frame.get("session_id") or str(uuid4())
scope = ContextualScope.model_validate(frame.get("scope", {}))
block = render_scope_block(scope)
buf = get_session_buffer(user_id, session_id, channel="contextual")
buf.append_system_message(
f"User navigated to a new view. {block} Treat this as the new active context."
)
await websocket.send_text(json.dumps({
"type": WsFrameType.contextual_scope_ack,
"session_id": session_id,
}))
logger.info(
"device_ws: contextual_scope_update user=%s session=%s page=%s",
user_id, session_id, scope.page,
)
async def _handle_brief_request(
websocket: WebSocket,
user_id: str,
@@ -769,14 +841,14 @@ async def _heartbeat_loop(websocket: WebSocket) -> None:
# ── Disconnect cleanup ────────────────────────────────────────────────
async def _mark_runs_disconnected(user_id: str) -> None:
"""Mark all in-progress AgentRunLog rows as 'error' for this user."""
"""Mark all in-progress ScoutRunLog rows as 'error' for this user."""
try:
async with async_session() as db:
await db.execute(
update(AgentRunLog)
update(ScoutRunLog)
.where(
AgentRunLog.user_id == user_id,
AgentRunLog.status == "running",
ScoutRunLog.user_id == user_id,
ScoutRunLog.status == "running",
)
.values(
status="error",

View File

@@ -1,4 +1,4 @@
"""Chatbot Journey — WS-based guided conversation to build an AgentConfig.
"""Chatbot Journey — WS-based guided conversation to build an ScoutConfig.
The journey is driven entirely through WebSocket frames (no REST endpoints).
The device WS handler dispatches ``journey_start`` and ``journey_message``
@@ -13,7 +13,7 @@ Journey flow:
3. FE sends ``journey_message`` frames for each user reply.
4. Server appends the user message, calls the LLM (which may read files
via tools), and sends back a ``journey_reply``.
5. After 3-5 turns the LLM wraps up by emitting an ``AgentConfig`` JSON
5. After 3-5 turns the LLM wraps up by emitting an ``ScoutConfig`` JSON
block delimited by ``AGENT_CONFIG_START`` / ``AGENT_CONFIG_END``.
6. Server parses and validates the JSON with Pydantic, sends
``journey_reply`` with ``done=True`` and the serialised config.
@@ -34,7 +34,7 @@ from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, Tool
from app.agents.filesystem_agent import make_directory_tools
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent
from app.schemas import AgentConfig
from app.schemas import ScoutConfig
logger = logging.getLogger(__name__)
@@ -42,7 +42,7 @@ logger = logging.getLogger(__name__)
_SESSION_TTL_SECONDS: int = 1800 # 30 minutes
# Sentinel strings used to delimit the LLM-produced AgentConfig JSON.
# Sentinel strings used to delimit the LLM-produced ScoutConfig JSON.
_CONFIG_START = "AGENT_CONFIG_START"
_CONFIG_END = "AGENT_CONFIG_END"
@@ -92,7 +92,7 @@ def get_journey_session(session_id: str, user_id: str) -> JourneySession | None:
_JOURNEY_SYSTEM_PROMPT = """\
You are a friendly assistant helping a freelancer configure a data-extraction agent.
Your job is to understand what files the user has in their directory and produce a
structured AgentConfig JSON that the extraction agent will use as its instruction set.
structured ScoutConfig JSON that the extraction agent will use as its instruction set.
You have access to file-system tools to explore the user's directory:
- list_directory: see folder structure and file names
@@ -122,7 +122,7 @@ Cover these topics based on what you discovered:
4. Date extraction (e.g. "by Friday" dueDate)
5. Exclusion rules (e.g. skip newsletters, skip files with no project match)
### Step 4 — Produce the AgentConfig JSON
### Step 4 — Produce the ScoutConfig JSON
Once you are 90% confident, output the final config between these exact markers
(each on its own line):
@@ -168,7 +168,7 @@ def _build_system_prompt(
) -> tuple[str, Any]:
"""Return ``(compiled_system_prompt, langfuse_prompt_obj_or_None)``."""
existing_section = (
"\nThe user already has the following AgentConfig — refine it based on their answers:\n"
"\nThe user already has the following ScoutConfig — refine it based on their answers:\n"
f"```json\n{existing_config}\n```\n"
if existing_config
else ""
@@ -189,11 +189,11 @@ def _build_system_prompt(
return compiled, prompt_obj
# ── AgentConfig extraction ────────────────────────────────────────────────
# ── ScoutConfig extraction ────────────────────────────────────────────────
def _extract_agent_config(text: str) -> str | None:
"""Return validated AgentConfig JSON string from between markers, or None.
"""Return validated ScoutConfig JSON string from between markers, or None.
Parses the JSON with Pydantic to ensure it conforms to the schema before
returning. Returns None if markers are absent or JSON is invalid.
@@ -206,10 +206,10 @@ def _extract_agent_config(text: str) -> str | None:
if not raw:
return None
try:
parsed = AgentConfig.model_validate_json(raw)
parsed = ScoutConfig.model_validate_json(raw)
return parsed.model_dump_json()
except Exception as exc:
logger.warning("agent_setup: failed to parse AgentConfig JSON: %s", exc)
logger.warning("agent_setup: failed to parse ScoutConfig JSON: %s", exc)
return None
@@ -475,7 +475,7 @@ async def handle_journey_message(
if turns >= _MAX_TURNS:
nudge_content = (
"[System: You have enough information. Please generate the final "
f"AgentConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
f"ScoutConfig JSON now, wrapped in {_CONFIG_START} / {_CONFIG_END} markers.]"
)
session.history.append({"role": "user", "content": nudge_content})

View File

@@ -0,0 +1,120 @@
"""Gmail Pub/Sub push receiver.
Google Pub/Sub push subscriptions deliver Gmail watch notifications as POST
requests with a JSON envelope. The body payload contains a base64-encoded
JSON blob with ``emailAddress`` + ``historyId``. We resolve the user by
email, look up their cloud_scout_configs row for provider='gmail', and
hand off to ScoutEngine.trigger_scout.
Authentication: Pub/Sub push includes an OIDC JWT in the Authorization
header. We verify it against Google's public keys with the audience
configured in our Pub/Sub subscription.
Dev mode: when ``GMAIL_PUBSUB_AUDIENCE`` is empty, JWT verification is
skipped and a warning is logged. Production must set this env var.
"""
from __future__ import annotations
import base64
import json
import logging
import uuid
from fastapi import APIRouter, Header, HTTPException, Request, status
from sqlalchemy import select
from app.config.settings import settings
from app.db import async_session
from app.models import CloudScoutConfig, User
from app.scouts.engine import ScoutEngine
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/scouts/webhooks", tags=["scout-webhooks"])
def _verify_pubsub_jwt(token: str) -> bool:
"""Verify the Google Pub/Sub OIDC JWT.
Returns True when valid, False on any verification failure.
Dev skip: if ``settings.GMAIL_PUBSUB_AUDIENCE`` is empty, logs a
warning and returns True so local development works without a real
Pub/Sub subscription. Production must configure the audience.
"""
if not token:
return False
if not settings.GMAIL_PUBSUB_AUDIENCE:
logger.warning(
"GMAIL_PUBSUB_AUDIENCE not set — skipping Pub/Sub JWT verification (dev mode only)"
)
return True
try:
from google.auth.transport import requests as g_requests # noqa: PLC0415
from google.oauth2 import id_token # noqa: PLC0415
id_token.verify_oauth2_token(
token,
g_requests.Request(),
audience=settings.GMAIL_PUBSUB_AUDIENCE,
)
return True
except Exception:
logger.warning("pubsub jwt verification failed", exc_info=True)
return False
@router.post("/gmail", status_code=status.HTTP_204_NO_CONTENT)
async def gmail_pubsub(
request: Request,
authorization: str = Header(default=""),
) -> None:
"""Receive a Gmail Pub/Sub push notification.
Verifies the OIDC JWT, decodes the Pub/Sub envelope, resolves the user
by email, and triggers ScoutEngine.trigger_scout for each enabled Gmail
scout belonging to that user.
Returns 204 No Content on success (including benign no-ops like unknown
email or empty message data). Returns 401 on JWT verification failure.
"""
token = authorization.removeprefix("Bearer ").strip()
if not _verify_pubsub_jwt(token):
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid Pub/Sub JWT")
body = await request.json()
msg = body.get("message") or {}
raw = msg.get("data")
if not raw:
return # ack without action — empty message data
try:
decoded = json.loads(base64.b64decode(raw).decode())
except Exception:
logger.warning("pubsub payload decode failed")
return
email = decoded.get("emailAddress")
if not email:
return
async with async_session() as session:
user_q = await session.execute(select(User).where(User.email == email))
user = user_q.scalar_one_or_none()
if user is None:
logger.info("pubsub: no user for %s — ignoring", email)
return
scouts_q = await session.execute(
select(CloudScoutConfig).where(
CloudScoutConfig.user_id == user.id,
CloudScoutConfig.provider == "gmail",
CloudScoutConfig.enabled == True, # noqa: E712
)
)
scouts = scouts_q.scalars().all()
engine = ScoutEngine()
for scout in scouts:
await engine.trigger_scout(uuid.UUID(str(scout.id)))

440
app/api/routes/scouts.py Normal file
View File

@@ -0,0 +1,440 @@
"""Scout routes.
Backend responsibilities are intentionally minimal:
GET /scouts/catalog — static catalog for UI display
POST /scouts/can-create — billing eligibility check
POST /scouts/trigger — trigger a local scout run
Scout configuration is owned by the Electron app and is not persisted
in backend scout-config tables.
Gmail OAuth setup (scout-specific consent):
GET /scouts/oauth/gmail/authorize — returns consent-screen URL
GET /scouts/oauth/gmail/web-callback — bounces to deep link (excluded from schema)
POST /scouts/oauth/gmail/callback — exchanges code, stores encrypted token
"""
from __future__ import annotations
import asyncio
import logging
import secrets
import time
import urllib.parse
import uuid
from datetime import datetime, timezone
from fastapi import APIRouter, Depends, HTTPException, status
from fastapi.responses import RedirectResponse
from sqlalchemy import func, select
from sqlalchemy.ext.asyncio import AsyncSession
from pydantic import BaseModel
from app.api.deps import get_current_user
from app.auth.oauth_providers import generate_pkce_pair
from app.billing.tier_manager import FEATURES
from app.config.settings import settings
from app.core.scout_runner import is_agent_running, run_local_agent
from app.core.device_manager import device_manager
from app.core.note_summarizer import generate_note_summary
from app.db import get_session
from app.integrations import encrypt_token
from app.models import CloudScoutConfig, ScoutRunLog, LocalScoutConfig
from app.schemas import (
ScoutCatalogItem,
ScoutCreationCheckRequest,
ScoutCreationCheckResponse,
ScoutRunLogResponse,
ScoutTriggerRequest,
UserProfile,
)
logger = logging.getLogger(__name__)
router = APIRouter(prefix="/scouts", tags=["scouts"])
# ── Datetime helpers ──────────────────────────────────────────────────
def _dt_ms(dt: datetime) -> int:
return int(dt.timestamp() * 1000)
def _dt_ms_opt(dt: datetime | None) -> int | None:
return int(dt.timestamp() * 1000) if dt else None
def _to_data_types(values: list[str]) -> list[str]:
normalize = {
"task": "tasks", "tasks": "tasks",
"note": "notes", "notes": "notes",
"timeline": "timelines", "timelines": "timelines", "timelineEvents": "timelines",
"project": "projects", "projects": "projects",
}
seen: set[str] = set()
result: list[str] = []
for v in values:
mapped = normalize.get(v)
if mapped and mapped not in seen:
seen.add(mapped)
result.append(mapped)
return result
def _to_run_log_response(log: ScoutRunLog) -> ScoutRunLogResponse:
return ScoutRunLogResponse(
id=log.id,
agent_id=log.scout_id,
agent_type=log.scout_type, # type: ignore[arg-type]
status=log.status, # type: ignore[arg-type]
items_processed=log.items_processed,
items_created=log.items_created,
errors=log.errors or [],
started_at=_dt_ms(log.started_at),
completed_at=_dt_ms_opt(log.completed_at),
)
def _enforce_agent_limit(tier: str, current_count: int) -> int:
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_active"]
if limit != -1 and current_count >= limit:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail=f"Agent limit ({limit}) reached for your tier. Upgrade to create more.",
)
return limit
async def _enforce_run_frequency(
tier: str,
user_id: str,
db: AsyncSession,
) -> None:
"""Raise HTTP 402 if the user has exceeded their daily batch run limit."""
limit: int = FEATURES.get(tier, FEATURES["free"])["batch_runs_per_day"]
if limit == -1:
return # unlimited
today_start = datetime.now(timezone.utc).replace(
hour=0, minute=0, second=0, microsecond=0
)
result = await db.execute(
select(func.count(ScoutRunLog.id)).where(
ScoutRunLog.user_id == user_id,
ScoutRunLog.started_at >= today_start,
)
)
runs_today: int = result.scalar_one()
if runs_today >= limit:
raise HTTPException(
status_code=status.HTTP_402_PAYMENT_REQUIRED,
detail=f"Daily batch run limit ({limit}) reached for your tier. Upgrade for more runs.",
)
# ── Catalog ───────────────────────────────────────────────────────────
@router.get("/catalog", response_model=list[ScoutCatalogItem])
async def get_agent_catalog(
current_user: UserProfile = Depends(get_current_user),
) -> list[ScoutCatalogItem]:
"""Return the static list of available agent types and their descriptions."""
return [
ScoutCatalogItem(
type="local_directory",
name="Local Directory Monitor",
description="Watches local directories, extracts data from files using AI",
),
ScoutCatalogItem(
type="gmail",
name="Gmail Connector",
description="Scans Gmail inbox, extracts tasks/notes from emails",
),
ScoutCatalogItem(
type="teams",
name="Microsoft Teams Connector",
description="Monitors Teams messages, extracts action items",
),
ScoutCatalogItem(
type="outlook",
name="Outlook Connector",
description="Scans Outlook inbox, extracts tasks/notes",
),
]
@router.post("/can-create", response_model=ScoutCreationCheckResponse)
async def can_create_agent(
body: ScoutCreationCheckRequest,
current_user: UserProfile = Depends(get_current_user),
) -> ScoutCreationCheckResponse:
"""Check if the user can create one more agent based on billing tier.
Since configuration is client-owned, the Electron app sends its current
active agent count and the backend applies tier limits.
"""
limit: int = FEATURES.get(current_user.tier, FEATURES["free"])["batch_active"]
allowed = limit == -1 or body.active_agents < limit
return ScoutCreationCheckResponse(
allowed=allowed,
tier=current_user.tier,
active_agents=body.active_agents,
limit=limit,
)
@router.post("/trigger", response_model=ScoutRunLogResponse, status_code=status.HTTP_202_ACCEPTED)
async def trigger_agent_run(
body: ScoutTriggerRequest,
current_user: UserProfile = Depends(get_current_user),
db: AsyncSession = Depends(get_session),
) -> ScoutRunLogResponse:
"""Trigger a local agent run using client-provided configuration."""
_enforce_agent_limit(current_user.tier, body.active_agents)
await _enforce_run_frequency(current_user.tier, current_user.id, db)
last_run_dt = (
datetime.fromtimestamp(body.last_run_at / 1000, tz=timezone.utc)
if body.last_run_at
else None
)
config = LocalScoutConfig(
id=str(uuid.uuid4()),
user_id=current_user.id,
device_id=body.device_id,
name="Local Directory Monitor",
directory_paths=[body.directory],
data_types=_to_data_types(body.what_to_extract),
prompt_template=body.custom_agent_prompt or "",
scout_config=body.agent_config,
file_extensions=[],
schedule_cron=body.batch_interval,
enabled=True,
last_run_at=last_run_dt,
)
# Use the FE's stable agent_id if provided, fall back to the ephemeral config id.
stable_agent_id = body.agent_id or config.id
if is_agent_running(stable_agent_id):
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Agent is already running. Only one run per agent is allowed at a time.",
)
run_log = ScoutRunLog(
scout_id=stable_agent_id,
scout_type="local",
user_id=current_user.id,
status="running",
)
db.add(run_log)
await db.commit()
await db.refresh(run_log)
run_context = {
"type": "agent_batch",
"run_id": run_log.id,
"agent_id": stable_agent_id,
}
asyncio.create_task(
run_local_agent(current_user.id, config, run_log, device_manager, run_context)
)
return _to_run_log_response(run_log)
# ── Note summary endpoint ──────────────────────────────────────────────────────
class NoteSummarizeRequest(BaseModel):
title: str
content: str
class NoteSummarizeResponse(BaseModel):
summary: str
@router.post("/notes/summarize", response_model=NoteSummarizeResponse)
async def summarize_note(
body: NoteSummarizeRequest,
current_user: UserProfile = Depends(get_current_user),
) -> NoteSummarizeResponse:
"""Generate an AI summary for a note. Used by the Electron backfill on startup."""
summary = await generate_note_summary(body.title, body.content)
return NoteSummarizeResponse(summary=summary)
# ── Gmail OAuth setup (scout-specific) ───────────────────────────────────────
# Scopes required for Gmail scout connectivity.
_GMAIL_SCOUT_SCOPES = [
"openid",
"email",
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.modify",
]
# Google OAuth endpoints.
_GOOGLE_AUTH_URL = "https://accounts.google.com/o/oauth2/v2/auth"
_GOOGLE_TOKEN_URL = "https://oauth2.googleapis.com/token"
# In-memory pending OAuth states for scout Gmail consent:
# state → (code_verifier, scout_id, user_id, expires_at_epoch_s)
# Production note: replace with Redis for multi-process deployments.
_pending_scout_oauth_states: dict[str, tuple[str, str, str, float]] = {}
_SCOUT_OAUTH_TTL_SECONDS = 600 # 10 minutes
def _scout_gmail_redirect_uri() -> str:
"""Derive the scout Gmail web-callback URI from the configured base OAUTH_REDIRECT_URI.
``OAUTH_REDIRECT_URI`` is the full path used for login OAuth
(e.g. http://localhost:8000/api/v1/auth/oauth/google/web-callback).
We strip the path to get the scheme+host base, then append the scout path.
"""
parsed = urllib.parse.urlparse(settings.OAUTH_REDIRECT_URI)
base = f"{parsed.scheme}://{parsed.netloc}"
return f"{base}/api/v1/scouts/oauth/gmail/web-callback"
class _ScoutGmailAuthorizeResponse(BaseModel):
authorize_url: str
class _ScoutGmailCallbackBody(BaseModel):
code: str
state: str
@router.get("/oauth/gmail/authorize", response_model=_ScoutGmailAuthorizeResponse)
async def scout_gmail_oauth_authorize(
scout_id: str,
current_user: UserProfile = Depends(get_current_user),
) -> _ScoutGmailAuthorizeResponse:
"""Start the Gmail OAuth flow for a specific cloud scout.
Returns the Google consent-screen URL. The client opens this URL in the
system browser; after consent Google redirects to web-callback which bounces
to the ``adiuvai://scout/oauth/gmail/callback`` deep link.
"""
if not settings.GOOGLE_AUTH_CLIENT_ID or not settings.GOOGLE_AUTH_CLIENT_SECRET:
raise HTTPException(
status.HTTP_503_SERVICE_UNAVAILABLE,
"Google OAuth is not configured on this server",
)
code_verifier, code_challenge = generate_pkce_pair()
state = secrets.token_urlsafe(32)
# Purge expired states to prevent unbounded growth.
now = time.time()
expired = [s for s, (_, _, _, exp) in _pending_scout_oauth_states.items() if exp < now]
for s in expired:
del _pending_scout_oauth_states[s]
_pending_scout_oauth_states[state] = (code_verifier, scout_id, current_user.id, now + _SCOUT_OAUTH_TTL_SECONDS)
redirect_uri = _scout_gmail_redirect_uri()
params = {
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
"redirect_uri": redirect_uri,
"response_type": "code",
"scope": " ".join(_GMAIL_SCOUT_SCOPES),
"state": state,
"code_challenge": code_challenge,
"code_challenge_method": "S256",
"access_type": "offline",
"prompt": "consent",
}
authorize_url = f"{_GOOGLE_AUTH_URL}?{urllib.parse.urlencode(params)}"
return _ScoutGmailAuthorizeResponse(authorize_url=authorize_url)
@router.get("/oauth/gmail/web-callback", include_in_schema=False)
async def scout_gmail_oauth_web_callback(code: str, state: str) -> RedirectResponse:
"""Google redirects here after Gmail consent.
Immediately bounces to the Electron deep link so the desktop app
receives the authorization code.
"""
params = urllib.parse.urlencode({"code": code, "state": state})
deep_link = f"adiuvai://scout/oauth/gmail/callback?{params}"
return RedirectResponse(url=deep_link, status_code=302)
@router.post("/oauth/gmail/callback")
async def scout_gmail_oauth_callback(
body: _ScoutGmailCallbackBody,
db: AsyncSession = Depends(get_session),
current_user: UserProfile = Depends(get_current_user),
) -> dict:
"""Exchange the Gmail authorization code and store the encrypted token on the scout.
Called by the Electron app after it receives the deep-link callback with
the ``code`` and ``state`` params.
"""
entry = _pending_scout_oauth_states.pop(body.state, None)
if entry is None or entry[3] < time.time() or entry[2] != current_user.id:
raise HTTPException(status.HTTP_401_UNAUTHORIZED, "Invalid or expired OAuth state")
code_verifier, scout_id, _, _ = entry
redirect_uri = _scout_gmail_redirect_uri()
import httpx
async with httpx.AsyncClient() as client:
response = await client.post(
_GOOGLE_TOKEN_URL,
data={
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
"code": body.code,
"code_verifier": code_verifier,
"grant_type": "authorization_code",
"redirect_uri": redirect_uri,
},
)
try:
response.raise_for_status()
except httpx.HTTPStatusError as exc:
logger.error("Gmail token exchange failed: %s", exc.response.text)
raise HTTPException(status.HTTP_502_BAD_GATEWAY, "Failed to exchange Gmail authorization code")
token_data = response.json()
creds_dict: dict = {
"token": token_data["access_token"],
"refresh_token": token_data.get("refresh_token"),
"token_uri": _GOOGLE_TOKEN_URL,
"client_id": settings.GOOGLE_AUTH_CLIENT_ID,
"client_secret": settings.GOOGLE_AUTH_CLIENT_SECRET,
"scopes": [
"https://www.googleapis.com/auth/gmail.readonly",
"https://www.googleapis.com/auth/gmail.modify",
],
}
encrypted = encrypt_token(creds_dict)
scout = await db.get(CloudScoutConfig, scout_id)
if scout is None or scout.user_id != current_user.id:
raise HTTPException(status.HTTP_404_NOT_FOUND, "Scout not found")
scout.oauth_token_encrypted = encrypted
await db.commit()
# Attempt to set up Gmail push watch so we start receiving Pub/Sub notifications.
from app.scouts.connectors.registry import get_connector
try:
connector = get_connector("gmail")
await connector.setup_watch(scout)
await db.commit()
except KeyError:
logger.warning("gmail connector not registered — skipping setup_watch for scout %s", scout_id)
except Exception:
logger.exception("setup_watch failed for scout %s", scout_id)
return {"ok": True}

View File

@@ -23,9 +23,8 @@ class Settings(BaseSettings):
LLM_EMBED_MODEL: str = "text-embedding-3-small"
# Per-agent model overrides. Leave empty to fall back to LLM_MODEL.
LLM_MODEL_CLASSIFIER: str = "" # _infer_floating_domain (intent routing)
LLM_MODEL_CLASSIFIER: str = "" # classifier (intent routing, future use)
LLM_MODEL_HOME_AGENT: str = "" # home-agent (run_single_agent / stream)
LLM_MODEL_FLOATING_AGENT: str = "" # floating-agent (contextual chat)
LLM_MODEL_UNIFIED_PROCESSOR: str = "" # unified-processor (agent_runner)
LLM_MODEL_CLOUD_PROCESSOR: str = "" # cloud-processor (agent_runner)
LLM_MODEL_BRIEF_AGENT: str = "" # brief-agent (home + project text briefs)
@@ -59,6 +58,16 @@ class Settings(BaseSettings):
# Prod: https://api.adiuvai.com/api/v1/auth/oauth/google/web-callback
OAUTH_REDIRECT_URI: str = "http://localhost:8000/api/v1/auth/oauth/google/web-callback"
# Gmail Pub/Sub topic for push notifications.
# Full resource name, e.g. "projects/my-project/topics/gmail-push".
# Leave empty in dev — setup_watch will skip registration gracefully.
GMAIL_PUBSUB_TOPIC: str = ""
# OIDC token audience for Pub/Sub push subscription JWT verification.
# Set to the service account email or audience string configured in the
# Pub/Sub push subscription. Leave empty in dev to skip verification
# (a warning is logged — never silent in production).
GMAIL_PUBSUB_AUDIENCE: str = ""
# Fernet key (URL-safe base64, 32-byte key) for at-rest encryption of OAuth
# tokens stored in cloud_agent_configs.oauth_token_encrypted.
# Generate with: from cryptography.fernet import Fernet; Fernet.generate_key()

View File

@@ -1,4 +1,4 @@
"""Single-agent runners for home and floating chat contexts."""
"""Single-agent runners for home and contextual chat contexts."""
from __future__ import annotations
@@ -7,7 +7,7 @@ import logging
import re
from datetime import date
from collections.abc import AsyncGenerator
from typing import Any, Literal
from typing import Any
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage
from langchain_core.tools import tool
@@ -18,7 +18,7 @@ from app.agents.project_agent import PROJECT_TOOLS
from app.agents.relations_agent import make_query_relations_tool
from app.agents.task_agent import TASK_TOOLS
from app.agents.timeline_agent import TIMELINE_TOOLS
from app.core.agent_session_buffer import session_buffer
from app.core.scout_session_buffer import session_buffer
from app.core.langfuse_client import compile_prompt, extract_usage, get_langfuse, get_prompt_or_fallback, langfuse_context
from app.core.llm import get_agent_llm, model_for_agent
from app.core.memory_middleware import MemoryMiddleware
@@ -29,9 +29,6 @@ logger = logging.getLogger(__name__)
MAX_HISTORY_TURNS = 20
FloatingDomainType = Literal["task", "timeline", "project", "node"]
FloatingDomainSection = Literal["task", "timeline", "note"]
# Mapping of core-memory language values to natural-language names for prompts.
_LANGUAGE_NAMES: dict[str, str] = {
"en": "English", "it": "Italian", "es": "Spanish",
@@ -354,42 +351,24 @@ For "today" / "tomorrow" queries, prefer list_tasks_due_today / list_timelines_t
{request_context}\
"""
_FLOATING_SYSTEM_PROMPT = """\
You are adiuvAI's floating executive assistant.{user_identity}
You are pinned to a specific entity (task, timeline event, project, or note) and you stay strictly within that scope.
Be a proactive partner: anticipate the next useful action and close with a concrete suggestion or a clarifying question — but stay terse, one short paragraph at most.
_CONTEXTUAL_SYSTEM_PROMPT = """You are adiuvAI's contextual assistant. The user is working inside the app and has opened a side chat anchored to a specific view ("current view"). Help them act on that view: recap, plan, create entities, answer questions.
# How you work
- Use tools before answering anything factual. Never guess.
- Stay in the floating scope (see Request context). If the user asks something outside scope, answer briefly and suggest opening the home assistant.
- Match the user's tone preference. Default to warm-but-direct.
- When the user asks to remember, forget, or update something, use memory tools.
Rules:
1. Base context (current view summary) is provided every turn. Treat it as ground truth for ids and names; never invent them.
2. ALL reads go through `get_page_details`. The legacy tools `list_projects`, `get_project`, `list_tasks`, `get_task`, `list_notes`, `get_note` are NOT available in this channel — do not attempt to call them. To find an entity by name, call `get_page_details({entityType: 'projects_all' | 'tasks_all' | 'timeline_all'})` to list, then `get_page_details({entityType: '<type>', entityId})` for the full snapshot.
3. When the user requests an action that creates or updates an entity:
- If the current view is a project and no project is specified, use the current project automatically.
- If the current view is the global Tasks / Projects / Timeline list and no project is specified, ASK before attaching to any project. Don't silently create orphan entities.
4. The current view can change mid-conversation (user navigates). When you see a system message "User navigated to ...", treat the new view as the active context. Prior turns remain visible but the active scope shifts.
5. Notes: you can read note bodies via `get_page_details({entityType:'note'})`. You CANNOT edit, summarize-to-replace, or append. Tell the user "note editing is coming in a later release" if asked.
6. Be concise. Default to 1-3 short paragraphs. Bullet lists fine. Don't restate the user's request.
7. Never expose ids in prose. Use names. Ids only travel through tool calls.
# Filter discipline
- Never set the `assignee` filter on list_tasks/count_tasks unless the user explicitly names a person ("Marco's tasks") or refers to themselves ("my tasks", "assigned to me", "mine").
- The user's own name in the User profile block is for context only — it is NOT a default filter.
- When in doubt, omit `assignee` and return the global result.
# Output format
Plain text only. Do NOT output XML/HTML-like tags such as <task>, <project>, <note>, <timeline>, or any bracketed-id wrappers, and do NOT output <chart> blocks — those are for the home assistant.
# Date filtering
# Date context
{date_context}
When filtering by date, take dueDateFrom / dueDateTo (ms epoch UTC) verbatim from the DATE CONTEXT boundary table above. Do NOT compute boundaries from now_ms yourself.
For specific dates not listed, compute local-midnight in the user timezone and convert to UTC ms.
# Language
{language_instruction}
# Known people & projects
{relational_memory}
# Behavioral hints
{proactive_hints}
# Request context
{request_context}\
"""
_TASK_BRIEF_RESEARCH_SYSTEM_PROMPT = """\
@@ -466,19 +445,6 @@ Stay terse — your principal is a busy executive.
{request_context}\
"""
_FLOATING_DOMAIN_CLASSIFIER_PROMPT = (
"You are a strict domain classifier for websocket floating requests. "
"Return ONLY a JSON object with keys: type, id, section. "
"Allowed type values: task, timeline, project, node. "
"Allowed section values: task, timeline, note, or null. "
"Rules: infer from user message intent first; do not blindly trust scope.type. "
"If user asks tasks/timeline/notes for a project, set type=project and section accordingly. "
"If project id is unknown but context.resolved_project_id exists, use it as id. "
"If id is unknown, use null. "
"No markdown, no prose, JSON only."
)
def _as_text(content: Any) -> str:
if content is None:
return ""
@@ -556,6 +522,55 @@ def _all_tools() -> list[Any]:
return [*TASK_TOOLS, *PROJECT_TOOLS, *NOTE_TOOLS, *TIMELINE_TOOLS]
# ── Contextual sidebar tools ──────────────────────────────────────────
@tool
async def get_page_details(
entity_type: str = "",
entity_id: str = "",
) -> str:
"""Fetch full details for the entity currently in view.
entity_type: one of 'project' | 'task' | 'note' | 'timeline_event' |
'tasks_all' | 'projects_all' | 'timeline_all'.
entity_id: UUID of the entity for singular entity views. Omit for list views.
The Electron drizzle-executor fulfils this op against local SQLite and
returns the row(s) as a JSON tool result.
"""
result = await execute_on_client(
action="get_page_details",
table=entity_type or "unknown",
data={"entityId": entity_id or None},
)
if not result:
return "No details found."
return str(result)
def _contextual_tools(user_id: str, trace_id: str | None) -> list[Any]:
"""Return the tool palette for the contextual sidebar agent.
Read ops go through get_page_details only — legacy list_*/get_* tools
return shallow snapshots and cause the agent to under-answer (see
smoke trace 0b46841484ba7d024ed9f8d5ac8b1df0). Writes are limited
to entity creation + task update; note edits are next-sprint.
"""
from app.agents.note_agent import create_note # noqa: PLC0415
from app.agents.task_agent import create_task, update_task # noqa: PLC0415
from app.agents.timeline_agent import create_timeline # noqa: PLC0415
return [
get_page_details,
create_task,
update_task,
create_note,
create_timeline,
*_memory_tools(user_id, trace_id),
]
def _trace_id_from_context(context: dict[str, Any]) -> str | None:
debug = context.get("_debug")
if isinstance(debug, dict):
@@ -658,70 +673,6 @@ def _normalize_tagged_list_lines(text: str, message: str) -> str:
return "\n".join(output_lines)
_GENERIC_TAG_RE = re.compile(r"</?(task|project|note|timeline|chart)>", re.IGNORECASE)
_BRACKETED_ID_RE = re.compile(r"\[(?:[0-9a-fA-F-]{8,}|[A-Za-z0-9_-]{8,})\]")
_FLOATING_EMPTY_FALLBACK = "No results found."
def _strip_floating_markup_fragment(text: str) -> str:
if not text:
return text
cleaned = _GENERIC_TAG_RE.sub("", text)
return _BRACKETED_ID_RE.sub("", cleaned)
def _strip_floating_markup(text: str) -> str:
"""Ensure floating responses stay plain text with no XML-like tag wrappers."""
if not text:
return text
cleaned = _strip_floating_markup_fragment(text)
# Collapse excessive spaces introduced by tag/id removal while preserving lines.
lines = [re.sub(r"[ \t]{2,}", " ", line).strip() for line in cleaned.splitlines()]
return "\n".join(line for line in lines if line)
def _fallback_from_raw_floating_text(raw_text: str) -> str:
fallback = _strip_floating_markup_fragment(raw_text or "")
fallback = re.sub(r"[ \t]{2,}", " ", fallback).strip()
return fallback or _FLOATING_EMPTY_FALLBACK
class _FloatingStreamSanitizer:
"""Streaming sanitizer that removes floating markup without buffering the full answer."""
def __init__(self) -> None:
self._pending = ""
@staticmethod
def _split_safe_boundary(text: str) -> tuple[str, str]:
boundary = len(text)
last_lt = text.rfind("<")
if last_lt != -1 and ">" not in text[last_lt:]:
boundary = min(boundary, last_lt)
last_lb = text.rfind("[")
if last_lb != -1 and "]" not in text[last_lb:]:
boundary = min(boundary, last_lb)
if boundary == len(text):
return text, ""
return text[:boundary], text[boundary:]
def feed(self, chunk: str) -> str:
combined = f"{self._pending}{chunk}"
safe_text, self._pending = self._split_safe_boundary(combined)
return _strip_floating_markup_fragment(safe_text)
def finalize(self) -> str:
# Drop dangling unfinished wrappers at the very end.
tail = re.sub(r"<[^>\n]*$", "", self._pending)
tail = re.sub(r"\[[^\]\n]*$", "", tail)
self._pending = ""
return _strip_floating_markup_fragment(tail)
def _normalize_memory_label(path_or_label: str) -> str:
value = path_or_label.strip()
if value.startswith("/memories/"):
@@ -902,168 +853,6 @@ def _all_tools_for_user(user_id: str, trace_id: str | None) -> list[Any]:
return [*_all_tools(), *_memory_tools(user_id, trace_id)]
def _detect_domain_section(message: str) -> FloatingDomainSection | None:
lowered = message.lower()
if any(keyword in lowered for keyword in ["timeline", "milestone", "release", "schedule"]):
return "timeline"
if any(keyword in lowered for keyword in ["task", "tasks", "todo", "attivit", "azione"]):
return "task"
if any(keyword in lowered for keyword in ["note", "notes", "memo", "document"]):
return "note"
return None
def _normalize_domain_payload(payload: dict[str, Any], fallback_id: str | None) -> dict[str, str | None]:
type_raw = str(payload.get("type") or "").strip().lower()
domain_type: FloatingDomainType = "task"
if type_raw in {"task", "timeline", "project", "node"}:
domain_type = type_raw
id_value = payload.get("id")
domain_id = id_value if isinstance(id_value, str) and id_value.strip() else None
if domain_type == "project" and not domain_id:
domain_id = fallback_id
section_raw = payload.get("section")
section: FloatingDomainSection | None = None
if isinstance(section_raw, str):
section_candidate = section_raw.strip().lower()
if section_candidate in {"task", "timeline", "note"}:
section = section_candidate
if domain_type != "project":
section = None
return {
"type": domain_type,
"id": domain_id,
"section": section,
}
def _parse_json_object(text: str) -> dict[str, Any] | None:
raw = text.strip()
if not raw:
return None
try:
parsed = json.loads(raw)
return parsed if isinstance(parsed, dict) else None
except json.JSONDecodeError:
pass
match = re.search(r"\{.*\}", raw, re.DOTALL)
if not match:
return None
try:
parsed = json.loads(match.group(0))
except json.JSONDecodeError:
return None
return parsed if isinstance(parsed, dict) else None
def _infer_floating_domain_rule_based(message: str, context: dict[str, Any]) -> dict[str, str | None]:
section = _detect_domain_section(message)
scope = context.get("scope") if isinstance(context, dict) else None
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
if isinstance(scope, dict):
scope_type = str(scope.get("type") or "").strip().lower()
scope_id = scope.get("id")
scope_id_value = scope_id if isinstance(scope_id, str) and scope_id else None
if scope_type in {"task", "tasks"}:
return {"type": "task", "id": scope_id_value, "section": None}
if scope_type in {"project", "projects"}:
project_scope_id = scope_id_value or project_id
return {
"type": "project",
"id": project_scope_id,
"section": section,
}
if scope_type in {"note", "notes"}:
return {
"type": "node",
"id": scope_id_value,
"section": None,
}
if scope_type in {"timeline", "timelines"}:
return {"type": "timeline", "id": scope_id_value, "section": None}
lowered = message.lower()
if any(keyword in lowered for keyword in ["project", "progetto", "client"]) or project_id:
return {
"type": "project",
"id": project_id,
"section": section,
}
if section == "timeline":
return {"type": "timeline", "id": None, "section": None}
if section == "note":
return {"type": "node", "id": None, "section": None}
return {"type": "task", "id": None, "section": None}
async def _infer_floating_domain(message: str, context: dict[str, Any]) -> dict[str, str | None]:
resolved_project_id = context.get("resolved_project_id") if isinstance(context, dict) else None
project_id = resolved_project_id if isinstance(resolved_project_id, str) and resolved_project_id else None
classifier_context = {
"scope": context.get("scope") if isinstance(context.get("scope"), dict) else None,
"resolved_project_id": project_id,
}
try:
llm = get_agent_llm("classifier")
classifier_messages = [
SystemMessage(content=_FLOATING_DOMAIN_CLASSIFIER_PROMPT),
HumanMessage(
content=(
f"Message:\n{message}\n\n"
f"Context:\n{json.dumps(classifier_context, ensure_ascii=True)}"
)
),
]
lf = get_langfuse()
_, classifier_prompt_obj = get_prompt_or_fallback(
"floating_domain_classifier", _FLOATING_DOMAIN_CLASSIFIER_PROMPT
)
# Extract user/session from context for Langfuse attribution
_debug = context.get("_debug") if isinstance(context, dict) else None
_lf_user = (_debug or {}).get("user_id") if isinstance(_debug, dict) else None
_lf_session = (_debug or {}).get("session_id") if isinstance(_debug, dict) else None
with langfuse_context(user_id=_lf_user, session_id=_lf_session):
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="floating-classifier",
model=model_for_agent("classifier"),
prompt=classifier_prompt_obj,
input=classifier_messages,
) as gen:
response = await llm.ainvoke(classifier_messages)
gen.update(output=_as_text(response.content), usage_details=extract_usage(response))
else:
response = await llm.ainvoke(classifier_messages)
parsed = _parse_json_object(_as_text(response.content))
if parsed is not None:
domain = _normalize_domain_payload(parsed, project_id)
logger.info(
"deep_agent: floating_domain_classified type=%s id=%s section=%s",
domain.get("type"),
domain.get("id"),
domain.get("section"),
)
return domain
logger.warning("deep_agent: floating_domain classifier returned non-json output")
except Exception as exc:
logger.warning("deep_agent: floating_domain classifier failed: %s", exc)
return _infer_floating_domain_rule_based(message, context)
def _history_to_messages(history: list[dict[str, str]] | None) -> list[Any]:
if not history:
return []
@@ -1392,25 +1181,6 @@ async def run_home(user_id: str, message: str, context: dict[str, Any]) -> str:
return _normalize_tagged_list_lines(response, message)
async def run_floating(user_id: str, message: str, context: dict[str, Any]) -> tuple[str, dict[str, str | None]]:
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
response = await _run_single_agent(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
conversation_history=context.get("conversation_history"),
)
sanitized = _strip_floating_markup(response)
if not sanitized and response:
sanitized = _fallback_from_raw_floating_text(response)
return sanitized, domain
async def run_home_stream(
user_id: str,
message: str,
@@ -1457,69 +1227,47 @@ async def run_home_stream(
yield "token", normalized
async def run_floating_stream(
async def run_contextual_stream(
user_id: str,
message: str,
context: dict[str, Any],
scope: "ContextualScope", # type: ignore[name-defined]
) -> AsyncGenerator[tuple[str, Any], None]:
"""Run the contextual agent for a single user turn.
Injects the rendered scope block into the system prompt and exposes
the contextual tool set.
Note-edit tools (propose_note_edit) are intentionally excluded.
*context contract*: callers MUST include ``context["_debug"]["session_id"]``
(a non-empty str) so that ``_session_id_from_context`` can extract it for
tracing and episode storage downstream. The WS handler in device_ws.py
satisfies this by always populating ``_debug`` before calling this function.
"""
from app.schemas.contextual import ContextualScope, render_scope_block # noqa: PLC0415
prepared_context = await _prepare_context(message, context)
domain = await _infer_floating_domain(message, prepared_context)
yield "floating_domain", domain
trace_id = _trace_id_from_context(prepared_context)
brief_mode: bool = bool(context.get("brief_mode"))
briefing_context_text: str = str(context.get("briefing_context") or "").strip()
system_prompt, langfuse_prompt = _build_system_prompt(
"contextual_system", _CONTEXTUAL_SYSTEM_PROMPT, prepared_context,
)
scope_block = render_scope_block(scope)
system_prompt = system_prompt + f"\n\n## Current view\n{scope_block}"
tools = _contextual_tools(user_id, trace_id)
if brief_mode and briefing_context_text:
# Stage 2: inject briefing as ground truth context.
# Pre-substitute {briefing_context} in the template (handles both Langfuse {{}} and fallback {})
# before compile_prompt sees the remaining standard variables.
template, langfuse_prompt = get_prompt_or_fallback(
"task_brief_followup_system",
_TASK_BRIEF_FOLLOWUP_SYSTEM_PROMPT,
)
system_prompt = compile_prompt(
template, langfuse_prompt,
date_context=_datetime_context_injection(prepared_context).strip(),
language_instruction=_language_instruction(prepared_context).strip(),
user_identity=_user_identity_injection(prepared_context).strip(),
relational_memory=_relational_memory_injection(prepared_context).strip(),
proactive_hints=_proactive_hints_injection(prepared_context).strip(),
request_context=_request_context_block(prepared_context),
briefing_context=briefing_context_text,
)
else:
system_prompt, langfuse_prompt = _build_system_prompt("floating_system", _FLOATING_SYSTEM_PROMPT, prepared_context)
sanitizer = _FloatingStreamSanitizer()
emitted_sanitized = False
raw_chunks: list[str] = []
async for event in _run_single_agent_stream(
user_id=user_id,
system_prompt=system_prompt,
message=message,
context=prepared_context,
langfuse_prompt=langfuse_prompt,
agent_name="floating-agent",
agent_name="contextual-agent",
tools=tools,
conversation_history=context.get("conversation_history"),
):
event_type, data = event
if event_type != "token":
yield event
continue
raw_chunk = str(data or "")
raw_chunks.append(raw_chunk)
sanitized_chunk = sanitizer.feed(raw_chunk)
if sanitized_chunk:
emitted_sanitized = True
yield "token", sanitized_chunk
tail = sanitizer.finalize()
if tail:
emitted_sanitized = True
yield "token", tail
if not emitted_sanitized and raw_chunks:
yield "token", _fallback_from_raw_floating_text("".join(raw_chunks))
yield event
async def run_task_brief_research_stream(

View File

@@ -103,7 +103,6 @@ def get_llm(
_AGENT_MODEL_SETTINGS: dict[str, Callable[[], str]] = {
"classifier": lambda: settings.LLM_MODEL_CLASSIFIER or settings.LLM_MODEL,
"home-agent": lambda: settings.LLM_MODEL_HOME_AGENT or settings.LLM_MODEL,
"floating-agent": lambda: settings.LLM_MODEL_FLOATING_AGENT or settings.LLM_MODEL,
"unified-processor": lambda: settings.LLM_MODEL_UNIFIED_PROCESSOR or settings.LLM_MODEL,
"cloud-processor": lambda: settings.LLM_MODEL_CLOUD_PROCESSOR or settings.LLM_MODEL,
"brief-agent": lambda: settings.LLM_MODEL_BRIEF_AGENT or settings.LLM_MODEL,

View File

@@ -6,7 +6,7 @@ import re
from collections.abc import AsyncGenerator
from typing import Any
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
# Matches <canvas kind="...">...</canvas> blocks (single-line or multiline).
_CANVAS_BLOCK_RE = re.compile(
@@ -31,7 +31,7 @@ def extract_canvas_block(text: str) -> tuple[str, str | None, str | None]:
visible = visible.strip()
return visible, canvas_content, canvas_kind
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd | WsFloatingDomain
WsFrame = WsStreamStart | WsStreamText | WsStreamEnd
class StreamFormatter:
@@ -47,14 +47,6 @@ class StreamFormatter:
started = False
async for event_type, data in event_stream:
if event_type == "floating_domain":
if isinstance(data, dict):
yield WsFloatingDomain(
request_id=self.request_id,
domain=data,
)
continue
if event_type != "token":
continue

View File

@@ -48,7 +48,7 @@ from app.core.llm import get_agent_llm, model_for_agent
from app.core.preprocessors import detect_content_type, preprocess
from app.core.ws_context import clear_client_executor, execute_on_client, set_client_executor
from app.db import async_session
from app.models import AgentRunLog, CloudAgentConfig, LocalAgentConfig
from app.models import ScoutRunLog, CloudScoutConfig, LocalScoutConfig
logger = logging.getLogger(__name__)
@@ -169,7 +169,7 @@ def _is_overdue(schedule_cron: str, last_run_at: datetime | None) -> bool:
next_run: datetime = cron.get_next(datetime)
return now >= next_run
except Exception as exc:
logger.warning("agent_runner: cannot parse cron %r: %s", schedule_cron, exc)
logger.warning("scout_runner: cannot parse cron %r: %s", schedule_cron, exc)
return False
@@ -290,7 +290,7 @@ async def _run_agent_with_tools(
call_name = str(call.get("name", ""))
call_args = call.get("args", {})
logger.info(
"agent_runner: tool_call name=%s args=%s",
"scout_runner: tool_call name=%s args=%s",
call_name,
json.dumps(call_args, ensure_ascii=True)[:800],
)
@@ -305,7 +305,7 @@ async def _run_agent_with_tools(
tool_output = await tool_fn.ainvoke(call_args)
logger.info(
"agent_runner: tool_result name=%s output=%s",
"scout_runner: tool_result name=%s output=%s",
call_name,
str(tool_output)[:200],
)
@@ -360,7 +360,7 @@ async def _scan_directories(
try:
result = await execute_on_client(action="list_directory", data={"path": path})
except Exception as exc:
logger.warning("agent_runner: list_directory failed %r: %s", path, exc)
logger.warning("scout_runner: list_directory failed %r: %s", path, exc)
return
for entry in result.get("entries", []):
entry_path = entry.get("path", "")
@@ -414,7 +414,7 @@ async def _fetch_projects() -> list[dict]:
result = await execute_on_client(action="select", table="projects")
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch projects: %s", exc)
logger.warning("scout_runner: failed to fetch projects: %s", exc)
return []
@@ -442,7 +442,7 @@ async def _fetch_domain_entities(domain: str, project_id: str) -> list[dict]:
)
return result.get("rows", [])
except Exception as exc:
logger.warning("agent_runner: failed to fetch %s: %s", domain, exc)
logger.warning("scout_runner: failed to fetch %s: %s", domain, exc)
return []
@@ -555,8 +555,8 @@ def _get_no_match_behavior(agent_config: dict) -> str:
async def run_local_agent(
user_id: str,
config: LocalAgentConfig,
run_log: AgentRunLog,
config: LocalScoutConfig,
run_log: ScoutRunLog,
device_mgr: DeviceConnectionManager,
run_context: dict | None = None,
) -> None:
@@ -586,7 +586,7 @@ async def run_local_agent(
if not is_online:
logger.info(
"agent_runner: skip run=%s — device %r offline for user=%s",
"scout_runner: skip run=%s — device %r offline for user=%s",
run_id,
target_device_id or "<any>",
user_id,
@@ -605,7 +605,7 @@ async def run_local_agent(
errors: list[str] = []
items_processed = 0
items_created = 0
agent_config: dict = config.agent_config or {}
agent_config: dict = config.scout_config or {}
processing_tools = _build_processing_tools(config.data_types)
try:
@@ -616,7 +616,7 @@ async def run_local_agent(
last_run_at=config.last_run_at,
)
logger.info(
"agent_runner: run=%s found %d file(s) after filtering", run_id, len(file_paths)
"scout_runner: run=%s found %d file(s) after filtering", run_id, len(file_paths)
)
if not file_paths:
@@ -641,7 +641,7 @@ async def run_local_agent(
raw_content: str = file_result.get("content", "")
if not raw_content.strip():
logger.debug(
"agent_runner: run=%s skipping empty file %r", run_id, file_path
"scout_runner: run=%s skipping empty file %r", run_id, file_path
)
continue
@@ -651,7 +651,7 @@ async def run_local_agent(
preprocessed = preprocess(content_type, raw_content)
logger.info(
"agent_runner: run=%s file=%r content_type=%s clean_len=%d",
"scout_runner: run=%s file=%r content_type=%s clean_len=%d",
run_id, file_path, content_type, len(preprocessed.clean_text),
)
@@ -711,19 +711,19 @@ async def run_local_agent(
projects_block = _format_projects(projects)
logger.info(
"agent_runner: run=%s file=%r created=%d result=%s",
"scout_runner: run=%s file=%r created=%d result=%s",
run_id, file_path, file_created, result_text[:200],
)
except Exception as exc:
errors.append(f"Error processing '{file_path}': {exc}")
logger.error(
"agent_runner: run=%s file=%r failed: %s", run_id, file_path, exc
"scout_runner: run=%s file=%r failed: %s", run_id, file_path, exc
)
except Exception as exc:
errors.append(f"Agent run failed: {exc}")
logger.error("agent_runner: run=%s failed: %s", run_id, exc)
logger.error("scout_runner: run=%s failed: %s", run_id, exc)
finally:
_running_agents.discard(agent_id)
clear_client_executor()
@@ -744,7 +744,7 @@ async def run_local_agent(
errors=errors,
)
logger.info(
"agent_runner: run=%s done status=%s processed=%d created=%d errors=%d",
"scout_runner: run=%s done status=%s processed=%d created=%d errors=%d",
run_id,
final_status,
items_processed,
@@ -762,7 +762,7 @@ async def run_local_agent(
})
except Exception as exc:
logger.warning(
"agent_runner: run=%s failed to send run_complete: %s", run_id, exc
"scout_runner: run=%s failed to send run_complete: %s", run_id, exc
)
@@ -773,8 +773,8 @@ _CLOUD_DEFAULT_LOOKBACK_DAYS: int = 7
async def run_cloud_agent(
user_id: str,
config: CloudAgentConfig,
run_log: AgentRunLog,
config: CloudScoutConfig,
run_log: ScoutRunLog,
device_mgr: DeviceConnectionManager,
) -> None:
"""Execute a cloud connector agent run end-to-end.
@@ -797,7 +797,7 @@ async def run_cloud_agent(
# ── 1. Device online check ─────────────────────────────────────────
if not device_mgr.is_online(user_id):
logger.info(
"agent_runner: skip cloud run=%s — no device online for user=%s",
"scout_runner: skip cloud run=%s — no device online for user=%s",
run_id,
user_id,
)
@@ -822,7 +822,7 @@ async def run_cloud_agent(
try:
credentials_info = decrypt_token(config.oauth_token_encrypted)
except ValueError as exc:
logger.error("agent_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
logger.error("scout_runner: failed to decrypt OAuth token for agent %s: %s", config.id, exc)
await _finalize_run(
run_log,
status="error",
@@ -868,7 +868,7 @@ async def run_cloud_agent(
raw_messages = []
except RuntimeError as exc:
logger.error(
"agent_runner: provider fetch failed for cloud agent %s: %s", config.id, exc
"scout_runner: provider fetch failed for cloud agent %s: %s", config.id, exc
)
await _finalize_run(
run_log,
@@ -881,7 +881,7 @@ async def run_cloud_agent(
return
logger.info(
"agent_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
"scout_runner: cloud agent %s fetched %d item(s) from %s for user=%s",
config.id,
len(raw_messages),
config.provider,
@@ -941,16 +941,16 @@ async def run_cloud_agent(
new_encrypted = encrypt_token(refreshed)
async with async_session() as db:
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config.id)
select(CloudScoutConfig).where(CloudScoutConfig.id == config.id)
)
cfg_row = cfg_result.scalar_one_or_none()
if cfg_row:
cfg_row.oauth_token_encrypted = new_encrypted
await db.commit()
logger.debug("agent_runner: refreshed OAuth token persisted for agent %s", config.id)
logger.debug("scout_runner: refreshed OAuth token persisted for agent %s", config.id)
except Exception as exc:
logger.warning(
"agent_runner: failed to persist refreshed token for agent %s: %s",
"scout_runner: failed to persist refreshed token for agent %s: %s",
config.id,
exc,
)
@@ -974,7 +974,7 @@ async def run_cloud_agent(
config_type="cloud",
)
logger.info(
"agent_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
"scout_runner: cloud run=%s done status=%s processed=%d created=%d errors=%d",
run_id,
final_status,
items_processed,
@@ -996,7 +996,7 @@ async def trigger_pending_runs(
Called as a background task from the device WS endpoint on ``device_hello``.
"""
logger.info(
"agent_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
"scout_runner: pending-run scan skipped for user=%s device=%s (client-owned agent config)",
user_id,
device_id,
)
@@ -1007,7 +1007,7 @@ async def trigger_pending_runs(
async def _finalize_run(
run_log: AgentRunLog,
run_log: ScoutRunLog,
*,
status: str,
items_processed: int = 0,
@@ -1031,14 +1031,14 @@ async def _finalize_run(
if update_config_last_run and config_id:
if config_type == "local":
cfg_result = await db.execute(
select(LocalAgentConfig).where(LocalAgentConfig.id == config_id)
select(LocalScoutConfig).where(LocalScoutConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
cfg.last_run_at = now
elif config_type == "cloud":
cfg_result = await db.execute(
select(CloudAgentConfig).where(CloudAgentConfig.id == config_id)
select(CloudScoutConfig).where(CloudScoutConfig.id == config_id)
)
cfg = cfg_result.scalar_one_or_none()
if cfg:
@@ -1047,5 +1047,5 @@ async def _finalize_run(
await db.commit()
except Exception as exc:
logger.error(
"agent_runner: failed to finalize run_log=%s: %s", run_log.id, exc
"scout_runner: failed to finalize run_log=%s: %s", run_log.id, exc
)

View File

@@ -54,6 +54,43 @@ class _SessionBuffer:
with self._lock:
self._store.pop((user_id, session_id), None)
def append_system_message(self, user_id: str, session_id: str, text: str) -> None:
"""Append a synthetic system message to the buffer for the given session.
Creates the session slot if it does not yet exist. Used by the
contextual_scope_update handler to inject navigation events without
making an LLM call.
"""
from langchain_core.messages import SystemMessage # noqa: PLC0415
key = (user_id, session_id)
with self._lock:
entry = self._store.get(key)
if entry is None:
msgs: list[BaseMessage] = [SystemMessage(content=text)]
else:
_, existing = entry
msgs = list(existing) + [SystemMessage(content=text)]
capped = msgs[-MAX_MESSAGES_PER_SESSION:]
self._store[key] = (time.monotonic(), capped)
class ContextualBufferProxy:
"""Thin wrapper around _SessionBuffer that closes over user_id + session_id.
Returned by get_session_buffer() so callers can call
``proxy.append_system_message(text)`` without threading user_id/session_id
through every call site.
"""
def __init__(self, buf: "_SessionBuffer", user_id: str, session_id: str) -> None:
self._buf = buf
self._user_id = user_id
self._session_id = session_id
def append_system_message(self, text: str) -> None:
self._buf.append_system_message(self._user_id, self._session_id, text)
# Module-level singleton — same pattern as _pending_states in api/app/api/routes/auth.py
session_buffer = _SessionBuffer()

View File

@@ -8,7 +8,7 @@ blocking the event loop.
Token refresh is handled transparently: when the stored access token has
expired, ``google.auth.transport.requests.Request`` will use the refresh
token to obtain a fresh one. The caller is responsible for persisting
any refreshed credentials back to ``CloudAgentConfig.oauth_token_encrypted``
any refreshed credentials back to ``CloudScoutConfig.oauth_token_encrypted``
(see ``agent_runner.run_cloud_agent``).
Credential dict shape (Google OAuth2):

View File

@@ -77,8 +77,98 @@ async def _memory_cron_tick() -> None:
_log.warning("memory cron tick: failed: %s", exc)
async def _scout_cron_tick() -> None:
"""Every-15-min cron: poll enabled cloud scouts (cron-fallback; push is primary).
Skips any scout whose ``last_run_at`` is within the last 5 minutes so
a push notification and the fallback cron don't double-fire within the
same window.
"""
import logging # noqa: PLC0415
import uuid # noqa: PLC0415
from datetime import datetime, timezone # noqa: PLC0415
_log = logging.getLogger(__name__)
_log.info("scout cron tick: starting")
try:
from app.db import async_session # noqa: PLC0415
from app.models import CloudScoutConfig # noqa: PLC0415
from app.scouts.engine import ScoutEngine # noqa: PLC0415
from sqlalchemy import select # noqa: PLC0415
async with async_session() as session:
scouts = (await session.execute(
select(CloudScoutConfig).where(CloudScoutConfig.enabled == True) # noqa: E712
)).scalars().all()
engine = ScoutEngine()
triggered = 0
for scout in scouts:
# Rate-limit guard: push is primary; skip if ran within 5 minutes.
if scout.last_run_at:
elapsed = (datetime.now(tz=timezone.utc) - scout.last_run_at).total_seconds()
if elapsed < 300:
continue
try:
await engine.trigger_scout(uuid.UUID(str(scout.id)))
triggered += 1
except Exception as exc:
_log.warning("scout cron tick: trigger failed scout=%s: %s", scout.id, exc)
_log.info("scout cron tick: done triggered=%d total=%d", triggered, len(scouts))
except Exception as exc:
_log.warning("scout cron tick: failed: %s", exc)
async def _scout_watch_renewal_tick() -> None:
"""Every-24-hour cron: re-issue Gmail users.watch for scouts expiring within 24h.
Handles missing or misconfigured connectors gracefully — logs and continues.
"""
import logging # noqa: PLC0415
from datetime import datetime, timedelta, timezone # noqa: PLC0415
_log = logging.getLogger(__name__)
_log.info("scout watch renewal tick: starting")
try:
from app.db import async_session # noqa: PLC0415
from app.models import CloudScoutConfig # noqa: PLC0415
from app.scouts.connectors.registry import get_connector # noqa: PLC0415
from sqlalchemy import select # noqa: PLC0415
threshold = datetime.now(tz=timezone.utc) + timedelta(hours=24)
renewed = 0
async with async_session() as session:
scouts = (await session.execute(
select(CloudScoutConfig).where(
CloudScoutConfig.enabled == True, # noqa: E712
CloudScoutConfig.provider == "gmail",
CloudScoutConfig.gmail_watch_expires_at <= threshold,
)
)).scalars().all()
for scout in scouts:
try:
connector = get_connector("gmail")
await connector.renew_watch(scout)
renewed += 1
except Exception:
_log.exception("scout watch renewal tick: renew failed scout=%s", scout.id)
await session.commit()
_log.info("scout watch renewal tick: done renewed=%d", renewed)
except Exception as exc:
_log.warning("scout watch renewal tick: failed: %s", exc)
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup: register source connectors.
from app.scouts.connectors.gmail import GmailConnector # noqa: PLC0415
from app.scouts.connectors.registry import register_connector # noqa: PLC0415
register_connector(GmailConnector())
# Startup: ensure agent tool modules are loaded.
import app.agents # noqa: F401
@@ -89,6 +179,14 @@ async def lifespan(app: FastAPI):
scheduler = AsyncIOScheduler()
scheduler.add_job(_memory_cron_tick, "interval", hours=1, id="memory_cron")
scheduler.add_job(_memory_audit_cron_tick, "interval", weeks=1, id="memory_audit_cron")
scheduler.add_job(
_scout_cron_tick, "interval", minutes=15,
id="scout_cron_tick", replace_existing=True,
)
scheduler.add_job(
_scout_watch_renewal_tick, "interval", hours=24,
id="scout_watch_renewal_tick", replace_existing=True,
)
scheduler.start()
logging.getLogger(__name__).info("memory cron scheduler started (interval=1h)")
@@ -124,14 +222,15 @@ def create_app() -> FastAPI:
app.add_middleware(SanitizerMiddleware)
app.add_middleware(TierRateLimitMiddleware)
from app.api.routes import agents, auth, billing, chat, device_ws, memory
from app.api.routes import scouts, auth, billing, chat, device_ws, memory, scout_webhooks
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(agents.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")
app.include_router(memory.router, prefix="/api/v1")
app.include_router(auth.router, prefix="/api/v1")
app.include_router(chat.router, prefix="/api/v1")
app.include_router(billing.router, prefix="/api/v1")
app.include_router(scouts.router, prefix="/api/v1")
app.include_router(scout_webhooks.router, prefix="/api/v1")
app.include_router(device_ws.router, prefix="/api/v1")
app.include_router(memory.router, prefix="/api/v1")
@app.get("/api/v1/health", tags=["health"])
async def health() -> dict:

View File

@@ -1,15 +1,15 @@
"""SQLAlchemy ORM models for all persistent tables.
Only auth, billing, agent config, and memory data live here.
Only auth, billing, scout config, and memory data live here.
User content (notes, tasks, etc.) lives exclusively on the client.
Table inventory:
users — account credentials + tier
refresh_tokens — hashed refresh token store
subscriptions — Stripe subscription records
local_agent_configs — per-device batch agent configs
cloud_agent_configs — OAuth-backed cloud agent configs
agent_run_logs — execution history for all agents
local_scout_configs — per-device batch scout configs
cloud_scout_configs — OAuth-backed cloud scout configs
scout_run_logs — execution history for all scouts
memory_core — per-user persistent key/value preferences (encrypted)
memory_associative — per-user semantic memory with embeddings (encrypted)
memory_episodic — per-user session summaries (encrypted)
@@ -34,8 +34,10 @@ from sqlalchemy import (
LargeBinary,
String,
Text,
UniqueConstraint,
Uuid,
func,
text,
)
from sqlalchemy.orm import Mapped, mapped_column, relationship
@@ -158,8 +160,8 @@ class Subscription(Base):
user: Mapped[User] = relationship(back_populates="subscription")
class LocalAgentConfig(Base):
__tablename__ = "local_agent_configs"
class LocalScoutConfig(Base):
__tablename__ = "local_scout_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
@@ -172,7 +174,7 @@ class LocalAgentConfig(Base):
directory_paths: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
data_types: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
prompt_template: Mapped[str] = mapped_column(Text, nullable=False, default="")
agent_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
scout_config: Mapped[dict | None] = mapped_column(JSON, nullable=True)
file_extensions: Mapped[list] = mapped_column(JSON, nullable=False, default=list)
schedule_cron: Mapped[str] = mapped_column(String(100), nullable=False, default="0 */6 * * *")
enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
@@ -184,17 +186,17 @@ class LocalAgentConfig(Base):
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="local_agent",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
back_populates="local_scout",
primaryjoin="and_(ScoutRunLog.scout_id == LocalScoutConfig.id, ScoutRunLog.scout_type == 'local')",
foreign_keys="ScoutRunLog.scout_id",
cascade="all, delete-orphan",
overlaps="run_logs,cloud_agent",
overlaps="run_logs,cloud_scout",
)
class CloudAgentConfig(Base):
__tablename__ = "cloud_agent_configs"
class CloudScoutConfig(Base):
__tablename__ = "cloud_scout_configs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
@@ -217,26 +219,50 @@ class CloudAgentConfig(Base):
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.now(), onupdate=func.now()
)
auto_trash_spam: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False, server_default=text("false"))
gmail_history_id: Mapped[str | None] = mapped_column(String(64), nullable=True)
gmail_watch_expires_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
device_inactivity_pause_days: Mapped[int] = mapped_column(Integer, nullable=False, default=14, server_default="14")
run_logs: Mapped[list[AgentRunLog]] = relationship(
back_populates="cloud_agent",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
run_logs: Mapped[list["ScoutRunLog"]] = relationship(
back_populates="cloud_scout",
primaryjoin="and_(ScoutRunLog.scout_id == CloudScoutConfig.id, ScoutRunLog.scout_type == 'cloud')",
foreign_keys="ScoutRunLog.scout_id",
cascade="all, delete-orphan",
overlaps="run_logs,local_agent",
overlaps="run_logs,local_scout",
)
class AgentRunLog(Base):
__tablename__ = "agent_run_logs"
class ScoutTriageQueue(Base):
__tablename__ = "scout_triage_queue"
__table_args__ = (
UniqueConstraint("scout_id", "source_msg_ref", name="uq_scout_triage_queue_scout_msg"),
)
id: Mapped[str] = mapped_column(Uuid(as_uuid=False), primary_key=True, default=_uuid)
user_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True)
scout_id: Mapped[str] = mapped_column(Uuid(as_uuid=False), ForeignKey("cloud_scout_configs.id", ondelete="CASCADE"), nullable=False)
source_type: Mapped[str] = mapped_column(String(50), nullable=False)
source_msg_ref: Mapped[str] = mapped_column(String(255), nullable=False)
triage_verdict: Mapped[str] = mapped_column(String(20), nullable=False)
triage_reason: Mapped[str | None] = mapped_column(Text, nullable=True)
status: Mapped[str] = mapped_column(String(20), nullable=False, default="queued", server_default="queued")
triaged_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False, server_default=func.now())
delivered_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
acked_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
expires_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
class ScoutRunLog(Base):
__tablename__ = "scout_run_logs"
id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), primary_key=True, default=_uuid
)
# Plain string — not a FK because it references either local_agent_configs or cloud_agent_configs
# depending on agent_type. Query by (agent_id, agent_type) to locate the source config.
agent_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
agent_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
# Plain string — not a FK because it references either local_scout_configs or cloud_scout_configs
# depending on scout_type. Query by (scout_id, scout_type) to locate the source config.
scout_id: Mapped[str] = mapped_column(String(255), nullable=False, index=True)
scout_type: Mapped[str] = mapped_column(AgentTypeEnum, nullable=False)
user_id: Mapped[str] = mapped_column(
Uuid(as_uuid=False), ForeignKey("users.id", ondelete="CASCADE"), nullable=False, index=True
)
@@ -250,17 +276,17 @@ class AgentRunLog(Base):
)
completed_at: Mapped[datetime | None] = mapped_column(DateTime(timezone=True), nullable=True)
local_agent: Mapped[LocalAgentConfig | None] = relationship(
local_scout: Mapped["LocalScoutConfig | None"] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == LocalAgentConfig.id, AgentRunLog.agent_type == 'local')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,cloud_agent",
primaryjoin="and_(ScoutRunLog.scout_id == LocalScoutConfig.id, ScoutRunLog.scout_type == 'local')",
foreign_keys="ScoutRunLog.scout_id",
overlaps="run_logs,cloud_scout",
)
cloud_agent: Mapped[CloudAgentConfig | None] = relationship(
cloud_scout: Mapped["CloudScoutConfig | None"] = relationship(
back_populates="run_logs",
primaryjoin="and_(AgentRunLog.agent_id == CloudAgentConfig.id, AgentRunLog.agent_type == 'cloud')",
foreign_keys="AgentRunLog.agent_id",
overlaps="run_logs,local_agent",
primaryjoin="and_(ScoutRunLog.scout_id == CloudScoutConfig.id, ScoutRunLog.scout_type == 'cloud')",
foreign_keys="ScoutRunLog.scout_id",
overlaps="run_logs,local_scout",
)

View File

@@ -73,11 +73,9 @@ class WsFrameType(str, Enum):
device_hello = "device_hello"
# ── v3 frame types ─────────────────────────────────────────────────
home_request = "home_request"
floating_request = "floating_request"
stream_start = "stream_start"
stream_text = "stream_text"
stream_end = "stream_end"
floating_domain = "floating_domain"
data_request = "data_request"
data_response = "data_response"
mutation = "mutation"
@@ -96,6 +94,13 @@ class WsFrameType(str, Enum):
index_file_result = "index_file_result"
index_session_progress = "index_session_progress"
index_session_done = "index_session_done"
# ── v8 contextual sidebar frame types ────────────────────────────
contextual_request = "contextual_request"
contextual_scope_update = "contextual_scope_update"
contextual_scope_ack = "contextual_scope_ack"
# ── v9 scout proposal frame types ────────────────────────────────
SCOUT_PROPOSAL = "scout_proposal"
SCOUT_PROPOSAL_ACK = "scout_proposal_ack"
class WsToolCall(BaseModel):
@@ -145,7 +150,7 @@ class WsDeviceHello(BaseModel):
type: Literal[WsFrameType.device_hello] = WsFrameType.device_hello
device_id: str
agent_ids: list[str] = Field(default_factory=list)
scout_ids: list[str] = Field(default_factory=list)
@@ -161,13 +166,6 @@ class FormatPrefsModel(BaseModel):
now_iso: str = ""
class WsFloatingScope(BaseModel):
"""Scope for a floating request — narrows the agent to a specific entity."""
type: Literal["task", "project", "note", "timeline"]
id: str | None = None
class WsHomeRequest(BaseModel):
"""Client → Server: Home chat message."""
@@ -177,15 +175,6 @@ class WsHomeRequest(BaseModel):
format_prefs: FormatPrefsModel | None = None
class WsFloatingRequest(BaseModel):
"""Client → Server: Floating chat message scoped to an entity."""
type: Literal[WsFrameType.floating_request] = WsFrameType.floating_request
message: str
scope: WsFloatingScope
format_prefs: FormatPrefsModel | None = None
class WsBriefRequest(BaseModel):
"""Client → Server: Request a plain-text brief (home or project)."""
@@ -221,26 +210,10 @@ class WsStreamEnd(BaseModel):
mutations: list[dict[str, Any]] | None = None
class WsDomain(BaseModel):
"""Structured floating domain payload for UI routing decisions."""
type: Literal["task", "timeline", "project", "node"]
id: str | None = None
section: Literal["task", "timeline", "note"] | None = None
# ── Scout Config V2 ───────────────────────────────────────────────────
class WsFloatingDomain(BaseModel):
"""Server → Client: domain determined for a floating request."""
type: Literal[WsFrameType.floating_domain] = WsFrameType.floating_domain
request_id: str
domain: WsDomain
# ── Agent Config V2 ───────────────────────────────────────────────────
class ContentTypeConfig(BaseModel):
class ScoutContentTypeConfig(BaseModel):
"""Per-type extraction config produced by the journey chatbot."""
id: str
@@ -250,34 +223,34 @@ class ContentTypeConfig(BaseModel):
extraction_prompt: str
class AgentConfig(BaseModel):
"""Structured agent configuration (replaces freeform prompt_template)."""
class ScoutConfig(BaseModel):
"""Structured scout configuration (replaces freeform prompt_template)."""
content_types: list[ContentTypeConfig] = []
content_types: list[ScoutContentTypeConfig] = []
global_rules: list[str] = []
data_types: list[str] = []
# ── Agent Catalog ─────────────────────────────────────────────────────
# ── Scout Catalog ─────────────────────────────────────────────────────
class AgentCatalogItem(BaseModel):
class ScoutCatalogItem(BaseModel):
type: str
name: str
description: str
class AgentCreationCheckRequest(BaseModel):
class ScoutCreationCheckRequest(BaseModel):
active_agents: int = Field(ge=0, default=0)
class AgentCreationCheckResponse(BaseModel):
class ScoutCreationCheckResponse(BaseModel):
allowed: bool
tier: BillingTier
active_agents: int
limit: int
class AgentTriggerRequest(BaseModel):
class ScoutTriggerRequest(BaseModel):
directory: str = Field(min_length=1)
device_id: str = Field(default="")
agent_id: str | None = None # FE stable agent ID (electron-store UUID)
@@ -289,9 +262,9 @@ class AgentTriggerRequest(BaseModel):
last_run_at: int | None = None # epoch ms from FE — enables incremental scanning
# ── Agent Run Log ─────────────────────────────────────────────────────
# ── Scout Run Log ─────────────────────────────────────────────────────
class AgentRunLogResponse(BaseModel):
class ScoutRunLogResponse(BaseModel):
id: str
agent_id: str
agent_type: Literal["local", "cloud"]
@@ -305,3 +278,25 @@ class AgentRunLogResponse(BaseModel):
# ── Chatbot Journey ───────────────────────────────────────────────────
# ── Scout Proposal Frame Models ───────────────────────────────────────
class ScoutProposalPayload(BaseModel):
id: str
scout_id: str
source_type: str
source_msg_ref: str
raw_subject: str | None = None
raw_snippet: str | None = None
category: Literal["unprocessed"] = "unprocessed"
payload: dict | None = None
class ScoutProposalFrame(BaseModel):
type: Literal[WsFrameType.SCOUT_PROPOSAL]
proposal: ScoutProposalPayload
class ScoutProposalAckFrame(BaseModel):
type: Literal[WsFrameType.SCOUT_PROPOSAL_ACK]
proposal_id: str

73
app/schemas/contextual.py Normal file
View File

@@ -0,0 +1,73 @@
"""Contextual sidebar scope schema and prompt block renderer.
ContextualScope mirrors the TypeScript ContextualScope type sent by the
Electron renderer when the user opens the side chat anchored to a specific
view. The renderer ships camelCase keys; Pydantic's alias_generator maps
them to snake_case Python attributes automatically.
"""
from __future__ import annotations
from typing import Literal, Optional
from pydantic import BaseModel, ConfigDict
from pydantic.alias_generators import to_camel
PageType = Literal[
"timeline",
"tasks",
"projects-list",
"project",
"note",
]
EntityType = Literal["project", "note", "task", "timeline_event"]
class ContextualScope(BaseModel):
"""Scope payload sent by the Electron renderer for contextual chat.
The renderer ships camelCase keys (entityType, entityId, ...). Pydantic's
alias generator maps them to snake_case Python attrs.
"""
model_config = ConfigDict(populate_by_name=True, alias_generator=to_camel)
page: PageType
entity_type: Optional[EntityType] = None
entity_id: Optional[str] = None
entity_name: Optional[str] = None
project_id: Optional[str] = None
char_count: Optional[int] = None
counts: Optional[dict[str, int]] = None
filters: Optional[dict] = None
def render_scope_block(scope: ContextualScope) -> str:
"""Produce a single-paragraph human-readable summary of the current view
for injection into the contextual agent system prompt.
Never emits internal ids — only names. The LLM is told to use names in
prose; ids travel through tool calls.
"""
if scope.entity_type == "project":
c = scope.counts or {}
return (
f"User is viewing the project {scope.entity_name!r}. "
f"{c.get('tasks', 0)} tasks, "
f"{c.get('notes', 0)} notes, "
f"{c.get('milestones', 0)} milestones."
)
if scope.entity_type == "note":
return (
f"User is viewing the note {scope.entity_name!r} "
f"({scope.char_count or 0} characters)."
)
if scope.page == "tasks":
return "User is viewing the global Tasks list (all projects)."
if scope.page == "timeline":
return "User is viewing the global Timeline view."
if scope.page == "projects-list":
return "User is viewing the Projects list."
return f"User is on page {scope.page}."

0
app/scouts/__init__.py Normal file
View File

View File

View File

@@ -0,0 +1,56 @@
"""Source connector Protocol and shared item types.
A SourceConnector adapts a third-party data source (Gmail, Slack, ...) to the
shared ScoutEngine interface. Each connector owns:
* how to enumerate new items since the last poll (``list_new``)
* how to fetch a single item's metadata cheaply (``fetch_metadata``)
* how to fetch a single item's full content for in-memory triage
(``fetch_content``) — this content MUST NOT be persisted by the engine
* how to archive/trash an item (``archive``) for spam handling
* optional push-notification setup (``setup_watch`` / ``renew_watch``)
"""
from __future__ import annotations
from datetime import datetime
from typing import Literal, Protocol
from pydantic import BaseModel, Field
class ItemRef(BaseModel):
source_msg_ref: str
received_at: datetime | None = None
class ItemMetadata(BaseModel):
subject: str | None = None
sender: str | None = None
snippet: str | None = None
received_at: datetime | None = None
class ItemContent(BaseModel):
metadata: ItemMetadata
body_text: str
raw_headers: dict[str, str] = Field(default_factory=dict)
class TriageVerdict(BaseModel):
verdict: Literal["relevant", "spam"]
reason: str
confidence: float = Field(ge=0.0, le=1.0)
class SourceConnector(Protocol):
"""Adapter for a third-party data source (Gmail, Slack, ...)."""
source_type: str # e.g. "gmail"
async def list_new(self, scout) -> list[ItemRef]: ...
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata: ...
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent: ...
async def archive(self, scout, ref: ItemRef) -> None: ...
async def setup_watch(self, scout) -> None: ...
async def renew_watch(self, scout) -> None: ...

View File

@@ -0,0 +1,213 @@
"""Gmail SourceConnector — wraps the existing GmailClient.
Responsibilities:
* list_new: incremental fetch since the scout's stored gmail_history_id
* fetch_metadata: subject + sender + snippet only (Gmail metadata format)
* fetch_content: full body text — transient, never persisted by engine
* archive: move a message to Gmail Trash (recoverable for 30 days)
* setup_watch / renew_watch: Gmail push notifications via Pub/Sub
"""
from __future__ import annotations
import asyncio
import logging
from datetime import datetime, timezone
from app.config.settings import settings
from app.integrations import decrypt_token
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef
logger = logging.getLogger(__name__)
def _extract_plain_text_body(payload: dict) -> str:
"""Recursively walk a Gmail message payload to find text/plain content."""
import base64
mime_type = payload.get("mimeType", "")
if mime_type == "text/plain":
data = payload.get("body", {}).get("data", "")
if data:
return base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return ""
if mime_type.startswith("multipart/"):
for part in payload.get("parts", []):
text = _extract_plain_text_body(part)
if text:
return text
# text/html fallback: strip tags rudimentarily if no text/plain part
if mime_type == "text/html":
data = payload.get("body", {}).get("data", "")
if data:
import re
html = base64.urlsafe_b64decode(data + "==").decode("utf-8", errors="replace")
return re.sub(r"<[^>]+>", " ", html)
return ""
def _get_gmail_service(scout):
"""Return a synchronous Google API client for low-level metadata/history calls."""
from googleapiclient.discovery import build
from google.oauth2.credentials import Credentials
creds_info = decrypt_token(scout.oauth_token_encrypted)
credentials = Credentials(
token=creds_info.get("token"),
refresh_token=creds_info.get("refresh_token"),
token_uri=creds_info.get("token_uri", "https://oauth2.googleapis.com/token"),
client_id=creds_info.get("client_id"),
client_secret=creds_info.get("client_secret"),
scopes=creds_info.get("scopes"),
)
return build("gmail", "v1", credentials=credentials, cache_discovery=False)
class GmailConnector:
source_type = "gmail"
# ── list_new ──────────────────────────────────────────────────────────
async def list_new(self, scout) -> list[ItemRef]:
"""Return new message refs since scout.gmail_history_id.
On first run (gmail_history_id is None/empty), records the current
historyId without backfilling — avoids flooding the user with old mail.
Updates scout.gmail_history_id in-place (caller must persist to DB).
"""
def _sync() -> tuple[list[ItemRef], str | None]:
service = _get_gmail_service(scout)
history_id = scout.gmail_history_id
refs: list[ItemRef] = []
new_history_id = history_id
if history_id:
resp = (
service.users()
.history()
.list(
userId="me",
startHistoryId=history_id,
historyTypes=["messageAdded"],
)
.execute()
)
for entry in resp.get("history", []):
for added in entry.get("messagesAdded", []):
refs.append(ItemRef(source_msg_ref=added["message"]["id"]))
new_history_id = resp.get("historyId", history_id)
else:
# First run: capture baseline history id without backfilling.
profile = service.users().getProfile(userId="me").execute()
new_history_id = profile["historyId"]
return refs, new_history_id
refs, new_history_id = await asyncio.to_thread(_sync)
if new_history_id and new_history_id != scout.gmail_history_id:
scout.gmail_history_id = new_history_id
return refs
# ── fetch_metadata ────────────────────────────────────────────────────
async def fetch_metadata(self, scout, ref: ItemRef) -> ItemMetadata:
"""Fetch subject, sender, snippet only — uses Gmail metadata format (no body)."""
def _sync() -> ItemMetadata:
service = _get_gmail_service(scout)
msg = (
service.users()
.messages()
.get(
userId="me",
id=ref.source_msg_ref,
format="metadata",
metadataHeaders=["Subject", "From", "Date"],
)
.execute()
)
headers = {
h["name"]: h["value"]
for h in msg.get("payload", {}).get("headers", [])
}
return ItemMetadata(
subject=headers.get("Subject"),
sender=headers.get("From"),
snippet=msg.get("snippet"),
received_at=None,
)
return await asyncio.to_thread(_sync)
# ── fetch_content ─────────────────────────────────────────────────────
async def fetch_content(self, scout, ref: ItemRef) -> ItemContent:
"""Fetch full body text for a single message — transient, must not be persisted."""
def _sync() -> ItemContent:
service = _get_gmail_service(scout)
msg = service.users().messages().get(
userId="me", id=ref.source_msg_ref, format="full",
).execute()
headers = {h["name"]: h["value"] for h in msg.get("payload", {}).get("headers", [])}
body_text = _extract_plain_text_body(msg.get("payload", {}))
return ItemContent(
metadata=ItemMetadata(
subject=headers.get("Subject"),
sender=headers.get("From"),
snippet=msg.get("snippet"),
received_at=None,
),
body_text=body_text,
raw_headers=headers,
)
return await asyncio.to_thread(_sync)
# ── archive ───────────────────────────────────────────────────────────
async def archive(self, scout, ref: ItemRef) -> None:
"""Move the message to Gmail Trash (recoverable for 30 days)."""
def _sync() -> None:
service = _get_gmail_service(scout)
service.users().messages().trash(
userId="me", id=ref.source_msg_ref
).execute()
await asyncio.to_thread(_sync)
# ── watch management ──────────────────────────────────────────────────
async def setup_watch(self, scout) -> None:
"""Register a Gmail Pub/Sub push watch for the INBOX label.
Requires ``settings.GMAIL_PUBSUB_TOPIC`` to be set to the full topic
resource name (e.g. ``projects/my-project/topics/gmail-push``).
Logs a warning and returns without error if the topic is not configured.
"""
topic = settings.GMAIL_PUBSUB_TOPIC
if not topic:
logger.warning(
"setup_watch: GMAIL_PUBSUB_TOPIC is not configured — skipping watch setup"
)
return
def _sync() -> None:
service = _get_gmail_service(scout)
request_body = {
"labelIds": ["INBOX"],
"topicName": topic,
}
resp = service.users().watch(userId="me", body=request_body).execute()
scout.gmail_history_id = resp.get("historyId")
expiration_ms = resp.get("expiration")
if expiration_ms:
scout.gmail_watch_expires_at = datetime.fromtimestamp(
int(expiration_ms) / 1000, tz=timezone.utc
)
await asyncio.to_thread(_sync)
async def renew_watch(self, scout) -> None:
"""Renew an existing Gmail Pub/Sub watch (same as setup_watch)."""
await self.setup_watch(scout)

View File

@@ -0,0 +1,32 @@
"""Connector registry — single source of truth for source_type -> connector."""
from __future__ import annotations
from typing import Any
_CONNECTORS: dict[str, Any] = {}
def register_connector(connector: Any) -> None:
"""Register a SourceConnector instance under its ``source_type``.
Calling twice with the same ``source_type`` replaces the prior entry —
useful for tests and hot-reload, but in production each connector
should be registered exactly once at startup.
"""
if not getattr(connector, "source_type", None):
raise ValueError("Connector must declare a non-empty source_type")
_CONNECTORS[connector.source_type] = connector
def get_connector(source_type: str) -> Any:
"""Return the registered connector for ``source_type`` or raise KeyError."""
try:
return _CONNECTORS[source_type]
except KeyError as exc:
raise KeyError(f"No connector registered for source_type {source_type!r}") from exc
def _reset_for_tests() -> None:
"""Clear the registry — for use in pytest fixtures only."""
_CONNECTORS.clear()

270
app/scouts/engine.py Normal file
View File

@@ -0,0 +1,270 @@
"""ScoutEngine — orchestrates triage, queueing, and delivery for cloud scouts.
Triage flow per scout:
1. Resolve scout config from the DB.
2. Skip if device hasn't connected within ``device_inactivity_pause_days``.
3. Ask the connector to ``list_new`` — fresh items since last poll.
4. For each item:
- skip if already in the queue (idempotent on (scout_id, source_msg_ref))
- fetch the full content via the connector (transient, never persisted)
- run the triage LLM call → relevant | spam
- spam + auto_trash_spam → connector.archive
- relevant → INSERT scout_triage_queue row
5. Update scout.last_run_at.
Delivery flow on Electron WS reconnect:
- drain ``status='queued'`` rows for the user
- fetch metadata-only for each (subject + snippet)
- send a ``scout_proposal`` frame
- flip status to ``delivered`` on ack
"""
from __future__ import annotations
import logging
import uuid
from datetime import datetime, timedelta, timezone
from sqlalchemy import select
from sqlalchemy.exc import IntegrityError
from app.core.langfuse_client import extract_usage, get_langfuse, get_prompt_or_fallback
from app.core.llm import get_llm
from app.db import async_session
from app.models import CloudScoutConfig, ScoutTriageQueue
from app.scouts.connectors.base import ItemContent, ItemRef, TriageVerdict
from app.scouts.connectors.registry import get_connector
logger = logging.getLogger(__name__)
QUEUE_TTL_DAYS = 30
class ScoutEngine:
def __init__(self, session_factory=None) -> None:
self._session_factory = session_factory or async_session
async def trigger_scout(self, scout_id: uuid.UUID) -> None:
async with self._session_factory() as session:
scout = await session.get(CloudScoutConfig, str(scout_id))
if scout is None:
logger.warning("trigger_scout: no such scout id=%s", scout_id)
return
if not scout.enabled:
return
# Device-inactivity pause check is a simple heuristic on last_run_at —
# the device-online signal lives in the DeviceConnectionManager and is
# consulted at delivery time. For triage, we only check that the
# configured pause threshold isn't suppressing the run.
connector = get_connector(scout.provider)
try:
refs = await connector.list_new(scout)
except Exception:
logger.exception("scout %s: list_new failed", scout.id)
return
for ref in refs:
await self._process_item(session, scout, connector, ref)
scout.last_run_at = datetime.now(tz=timezone.utc)
await session.commit()
async def _process_item(
self,
session,
scout: CloudScoutConfig,
connector,
ref: ItemRef,
) -> None:
# Idempotency check
existing = await session.execute(
select(ScoutTriageQueue.id).where(
ScoutTriageQueue.scout_id == scout.id,
ScoutTriageQueue.source_msg_ref == ref.source_msg_ref,
)
)
if existing.first() is not None:
return
try:
content = await connector.fetch_content(scout, ref)
except Exception:
logger.exception("scout %s: fetch_content failed for %s", scout.id, ref.source_msg_ref)
return
try:
verdict = await self._triage_llm(scout, content)
except Exception:
logger.exception("scout %s: triage_llm failed for %s", scout.id, ref.source_msg_ref)
return
if verdict.verdict == "spam":
if scout.auto_trash_spam:
try:
await connector.archive(scout, ref)
except Exception:
logger.exception("scout %s: archive failed for %s", scout.id, ref.source_msg_ref)
return
now = datetime.now(tz=timezone.utc)
row = ScoutTriageQueue(
id=str(uuid.uuid4()),
user_id=scout.user_id,
scout_id=scout.id,
source_type=connector.source_type,
source_msg_ref=ref.source_msg_ref,
triage_verdict=verdict.verdict,
triage_reason=verdict.reason,
status="queued",
triaged_at=now,
expires_at=now + timedelta(days=QUEUE_TTL_DAYS),
)
session.add(row)
try:
# Use a savepoint so an IntegrityError on race doesn't poison the
# outer session — works on both PostgreSQL (SAVEPOINT) and SQLite.
async with session.begin_nested():
await session.flush()
except IntegrityError:
# Race: another worker inserted between our SELECT and INSERT.
# The unique constraint did its job; safe to ignore.
logger.debug(
"scout %s: idempotent skip for %s (race on unique constraint)",
scout.id,
ref.source_msg_ref,
)
async def deliver_pending(self, user_id: uuid.UUID, ws) -> None:
"""Drain status='queued' rows for user, send scout_proposal WS frames, flip to 'delivered'."""
from app.scouts.connectors.base import ItemRef # noqa: PLC0415
async with self._session_factory() as session:
rows = (await session.execute(
select(ScoutTriageQueue).where(
ScoutTriageQueue.user_id == str(user_id),
ScoutTriageQueue.status == "queued",
)
)).scalars().all()
for row in rows:
try:
connector = get_connector(row.source_type)
except KeyError:
logger.warning("deliver_pending: no connector for %s", row.source_type)
continue
scout = await session.get(CloudScoutConfig, row.scout_id)
if scout is None:
continue
try:
meta = await connector.fetch_metadata(scout, ItemRef(source_msg_ref=row.source_msg_ref))
except Exception:
logger.exception("deliver_pending: fetch_metadata failed")
continue
payload = {
"type": "scout_proposal",
"proposal": {
"id": row.id,
"scout_id": row.scout_id,
"source_type": row.source_type,
"source_msg_ref": row.source_msg_ref,
"raw_subject": meta.subject,
"raw_snippet": meta.snippet,
"category": "unprocessed",
"payload": None,
},
}
await ws.send_json(payload)
row.status = "delivered"
row.delivered_at = datetime.now(tz=timezone.utc)
await session.commit()
async def ack_proposal(self, proposal_id: str) -> None:
"""Flip a delivered proposal to acked. Idempotent — no-op if already acked."""
async with self._session_factory() as session:
row = await session.get(ScoutTriageQueue, proposal_id)
if row is None:
return
row.status = "acked"
row.acked_at = datetime.now(tz=timezone.utc)
await session.commit()
async def _triage_llm(self, scout: CloudScoutConfig, content: ItemContent) -> TriageVerdict:
"""Call the scout-triage-system Langfuse prompt to classify an item as relevant or spam.
Uses gpt-4o-mini with JSON mode. Wraps the LLM call in a Langfuse generation
observation when Langfuse is configured.
"""
import json # noqa: PLC0415
from langchain_core.messages import HumanMessage, SystemMessage # noqa: PLC0415
_TRIAGE_FALLBACK = (
"You are a triage classifier for an executive-assistant scout that watches a "
"{source_type} feed.\n"
'The scout\'s purpose is: "{scout_purpose}".\n\n'
"Given one item, decide whether it is RELEVANT (worth surfacing to the user as a "
"potential task / event / note / project) or SPAM (advertising, mass marketing, "
"phishing, bulk notifications with no actionable content).\n\n"
"Item:\n"
" - Subject: {item_subject}\n"
" - From: {item_sender}\n"
" - Body (truncated): {item_body_truncated_2k}\n\n"
'Return JSON only, matching this schema:\n'
' {{"verdict": "relevant" | "spam", "reason": <short string>, "confidence": <0..1>}}\n\n'
"Be conservative on \"spam\" — if a message could plausibly be a personal/work "
"email, mark it relevant."
)
template, prompt_obj = get_prompt_or_fallback("scout-triage-system", _TRIAGE_FALLBACK)
body_trunc = (content.body_text or "")[:2000]
variables = dict(
source_type=scout.provider,
scout_purpose=scout.prompt_template or "",
item_subject=content.metadata.subject or "",
item_sender=content.metadata.sender or "",
item_body_truncated_2k=body_trunc,
)
if prompt_obj is not None:
try:
system_text = prompt_obj.compile(**variables)
if isinstance(system_text, list):
system_text = "\n".join(
m.get("content", "") for m in system_text if isinstance(m, dict)
)
except Exception as exc:
logger.warning("scout triage: compile failed: %s", exc)
system_text = template.replace("{{source_type}}", variables["source_type"]) \
.replace("{{scout_purpose}}", variables["scout_purpose"]) \
.replace("{{item_subject}}", variables["item_subject"]) \
.replace("{{item_sender}}", variables["item_sender"]) \
.replace("{{item_body_truncated_2k}}", variables["item_body_truncated_2k"])
else:
system_text = template.format(**variables)
llm = get_llm(model="gpt-4o-mini", temperature=0)
llm_json = llm.bind(response_format={"type": "json_object"}) # type: ignore[attr-defined]
messages = [
SystemMessage(content=system_text),
HumanMessage(content="Classify this item."),
]
lf = get_langfuse()
if lf:
with lf.start_as_current_observation(
as_type="generation",
name="scout-triage",
model="gpt-4o-mini",
prompt=prompt_obj,
input=messages,
) as gen:
response = await llm_json.ainvoke(messages)
gen.update(output=response.content, usage=extract_usage(response))
else:
response = await llm_json.ainvoke(messages)
data = json.loads(response.content)
return TriageVerdict(**data)

View File

@@ -35,7 +35,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
import yaml
from app.core.agent_runner import (
from app.core.scout_runner import (
_format_metadata,
_format_projects,
_get_extraction_rules,
@@ -44,7 +44,7 @@ from app.core.agent_runner import (
)
from app.core.device_manager import DeviceConnectionManager
from app.core.langfuse_client import get_langfuse
from app.models import AgentRunLog, LocalAgentConfig
from app.models import ScoutRunLog, LocalScoutConfig
from tests.conftest import TEST_USER_IDS
# ── Constants ─────────────────────────────────────────────────────────────
@@ -127,8 +127,8 @@ def _make_config(
agent_config: dict | None = None,
directory: str = "/emails",
device_id: str = "dev-001",
) -> LocalAgentConfig:
return LocalAgentConfig(
) -> LocalScoutConfig:
return LocalScoutConfig(
id=str(uuid.uuid4()),
user_id=_USER_ID,
device_id=device_id,
@@ -136,7 +136,7 @@ def _make_config(
directory_paths=[directory],
data_types=["tasks", "notes", "timelines"],
prompt_template="",
agent_config=agent_config or _AGENT_CONFIG,
scout_config=agent_config or _AGENT_CONFIG,
file_extensions=[".html", ".eml"],
schedule_cron="0 */6 * * *",
enabled=True,
@@ -144,11 +144,11 @@ def _make_config(
)
def _make_run_log(agent_id: str) -> AgentRunLog:
return AgentRunLog(
def _make_run_log(agent_id: str) -> ScoutRunLog:
return ScoutRunLog(
id=str(uuid.uuid4()),
agent_id=agent_id,
agent_type="local",
scout_id=agent_id,
scout_type="local",
user_id=_USER_ID,
status="running",
started_at=datetime.now(timezone.utc),
@@ -271,7 +271,7 @@ async def test_2_9_device_offline():
run_log = _make_run_log(config.id)
mgr = _make_manager(online=False)
with patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
with patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args
@@ -295,8 +295,8 @@ async def test_2_10_empty_file():
projects=[_PROJECTS["alpha"]],
)
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
with patch("app.core.scout_runner._make_agent_executor", return_value=executor), \
patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args
@@ -326,9 +326,9 @@ async def test_2_8_items_created_count():
_tool_calls_out.extend(["create_task", "create_note", "update_task"])
return "Done."
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
patch("app.core.agent_runner._run_agent_with_tools", side_effect=mock_run_agent), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
with patch("app.core.scout_runner._make_agent_executor", return_value=executor), \
patch("app.core.scout_runner._run_agent_with_tools", side_effect=mock_run_agent), \
patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args
@@ -377,8 +377,8 @@ async def test_eval_runner(runner_case, pytestconfig):
) if lf else nullcontext()
with obs_ctx as obs:
with patch("app.core.agent_runner._make_agent_executor", return_value=executor), \
patch("app.core.agent_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
with patch("app.core.scout_runner._make_agent_executor", return_value=executor), \
patch("app.core.scout_runner._finalize_run", new_callable=AsyncMock) as mock_fin:
await run_local_agent(_USER_ID, config, run_log, mgr)
_, kwargs = mock_fin.call_args

View File

@@ -0,0 +1,52 @@
import pytest
from app.schemas.contextual import ContextualScope, render_scope_block
def test_render_project_scope():
scope = ContextualScope(
page="project",
entity_type="project",
entity_id="p1",
entity_name="Acme Q3 launch",
counts={"tasks": 12, "notes": 4, "milestones": 3},
)
block = render_scope_block(scope)
assert "Acme Q3 launch" in block
assert "12 tasks" in block
assert "4 notes" in block
assert "3 milestones" in block
assert "p1" not in block
def test_render_list_scope_no_entity():
scope = ContextualScope(page="tasks", entity_type=None)
block = render_scope_block(scope)
assert "tasks" in block.lower()
assert "None" not in block
def test_render_note_scope_includes_char_count():
scope = ContextualScope(
page="note",
entity_type="note",
entity_id="n1",
entity_name="Meeting 14 May",
project_id="p1",
char_count=4280,
)
block = render_scope_block(scope)
assert "Meeting 14 May" in block
assert "4280" in block or "4,280" in block
def test_parses_camelcase_payload_from_renderer():
payload = {
"page": "project",
"entityType": "project",
"entityId": "p1",
"entityName": "Acme",
"counts": {"tasks": 5, "notes": 1, "milestones": 2},
}
scope = ContextualScope.model_validate(payload)
assert scope.entity_id == "p1"
assert scope.entity_name == "Acme"

View File

@@ -0,0 +1,44 @@
"""Tests for contextual WS frame handlers.
These tests only exercise the new handler functions in device_ws.py and do
not depend on litellm or the full deep_agent import chain. They monkeypatch
run_contextual_stream so no LLM call is made.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
@pytest.mark.asyncio
async def test_handle_contextual_scope_update_appends_system_message_no_llm(monkeypatch):
"""_handle_contextual_scope_update must:
- call append_system_message on the session buffer
- send a contextual_scope_ack back on the socket
- make no LLM call
"""
from app.api.routes import device_ws
ws = AsyncMock()
buffer = MagicMock()
buffer.append_system_message = MagicMock()
payload = {
"type": "contextual_scope_update",
"session_id": "s1",
"scope": {
"page": "project",
"entityType": "project",
"entityId": "p1",
"entityName": "Acme",
"counts": {"tasks": 1, "notes": 0, "milestones": 0},
},
}
monkeypatch.setattr(device_ws, "get_session_buffer", lambda *a, **kw: buffer)
await device_ws._handle_contextual_scope_update(ws, "user1", payload)
ws.send_text.assert_awaited_once()
import json
sent = json.loads(ws.send_text.await_args.args[0])
assert sent["type"] == "contextual_scope_ack"
assert sent["session_id"] == "s1"
buffer.append_system_message.assert_called_once()

View File

@@ -12,11 +12,8 @@ from langchain_core.messages import AIMessage, ToolMessage
from app.core.deep_agent import (
_build_system_prompt,
_datetime_context_injection,
_infer_floating_domain,
_normalize_tagged_list_lines,
_request_context_block,
run_floating,
run_floating_stream,
run_home,
)
@@ -75,57 +72,6 @@ async def test_run_home_uses_mocked_tool_result():
assert "Mock Task" in out
@pytest.mark.asyncio
async def test_run_floating_stream_emits_domain_then_tokens_with_mocked_tool_result():
fake_llm = _FakeLLM()
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"show me timeline updates",
{"scope": {"type": "timeline", "id": "tl-1"}},
):
events.append(event)
assert events[0] == (
"floating_domain",
{"type": "timeline", "id": "tl-1", "section": None},
)
# _run_single_agent_stream uses ainvoke (not astream); the final token is
# the second LLM response which echoes the tool result.
token_events = [e for e in events if e[0] == "token"]
assert token_events, "Expected at least one token event"
combined = "".join(str(e[1]) for e in token_events)
assert "Mock Task" in combined
@pytest.mark.asyncio
async def test_infer_floating_domain_prefers_message_intent_over_scope_type():
class _ClassifierOnlyLLM:
async def ainvoke(self, _messages):
return AIMessage(
content='{"type":"project","id":"213213-312321-312312-421321","section":"task"}'
)
with patch("app.core.deep_agent.get_agent_llm", return_value=_ClassifierOnlyLLM()):
domain = await _infer_floating_domain(
"Quali sono i miei task per il progetto X",
{
"scope": {"type": "timeline"},
"resolved_project_id": "213213-312321-312312-421321",
},
)
assert domain == {
"type": "project",
"id": "213213-312321-312312-421321",
"section": "task",
}
def test_normalize_tagged_list_lines_rewrites_mixed_task_lines_to_tag_only_lines():
raw = (
"Certo!\n\n"
@@ -162,139 +108,6 @@ def test_normalize_tagged_list_lines_filters_upcoming_timeline_query_to_current_
assert "<timeline>[tl-future]</timeline>" not in out
@pytest.mark.asyncio
async def test_run_floating_strips_xml_like_tags_from_final_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return (
"Hai 1 task:\\n"
"Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
)
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert "<task>" not in text
assert "</task>" not in text
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in text
@pytest.mark.asyncio
async def test_run_floating_stream_strips_xml_like_tags_from_streamed_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "Hai 1 task:\\n"
yield "token", "Mail barra in prod <task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
token_events = [str(data) for event_type, data in events if event_type == "token"]
combined = "".join(token_events)
assert "<task>" not in combined
assert "</task>" not in combined
assert "[180faff3-507d-4d88-aba8-66f204eb59ef]" not in combined
@pytest.mark.asyncio
async def test_run_floating_stream_falls_back_to_final_response_content_when_astream_is_empty():
class _NoChunkLLM:
def __init__(self) -> None:
self.calls = 0
def bind_tools(self, _tools):
return self
async def ainvoke(self, _messages):
self.calls += 1
if self.calls == 1:
return AIMessage(
content="",
tool_calls=[
{
"id": "call-1",
"name": "list_tasks",
"args": {},
}
],
)
return AIMessage(content="No notes found.")
async def astream(self, _messages):
if False:
yield None
with patch("app.core.deep_agent.get_agent_llm", return_value=_NoChunkLLM()), patch(
"app.core.deep_agent._all_tools", return_value=[_FakeTool()]
):
events = []
async for event in run_floating_stream(
"user-1",
"quali sono le note?",
{"scope": {"type": "note"}},
):
events.append(event)
assert events[0][0] == "floating_domain"
assert ("token", "No notes found.") in events
@pytest.mark.asyncio
async def test_run_floating_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_run_single_agent(**_kwargs):
return "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent", side_effect=_fake_run_single_agent
):
text, _domain = await run_floating(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
)
assert text == "No results found."
@pytest.mark.asyncio
async def test_run_floating_stream_returns_fallback_when_sanitization_would_empty_text():
fake_llm = _FakeLLM()
async def _fake_stream(**_kwargs):
yield "token", "<task>[180faff3-507d-4d88-aba8-66f204eb59ef]</task>"
with patch("app.core.deep_agent.get_agent_llm", return_value=fake_llm), patch(
"app.core.deep_agent._run_single_agent_stream", side_effect=_fake_stream
):
events = []
async for event in run_floating_stream(
"user-1",
"quali task ho?",
{"scope": {"type": "task"}},
):
events.append(event)
assert ("token", "No results found.") in events
# ── _datetime_context_injection ────────────────────────────────────────────────
def _fp(tz: str, now_iso: str) -> dict:

View File

@@ -22,7 +22,7 @@ import pytest
from app.core.device_manager import DeviceConnectionManager
from app.db import get_session
from app.main import app
from app.models import AgentRunLog
from app.models import ScoutRunLog
from tests.conftest import TEST_USER_IDS, make_jwt
# ---------------------------------------------------------------------------
@@ -33,9 +33,9 @@ _FREE_UID = TEST_USER_IDS["free"]
_PRO_UID = TEST_USER_IDS["pro"]
def _device_hello(device_id: str = "dev-001", agent_ids: list[str] | None = None) -> str:
def _device_hello(device_id: str = "dev-001", scout_ids: list[str] | None = None) -> str:
return json.dumps(
{"type": "device_hello", "device_id": device_id, "agent_ids": agent_ids or []}
{"type": "device_hello", "device_id": device_id, "scout_ids": scout_ids or []}
)
@@ -262,10 +262,10 @@ async def test_mark_runs_disconnected_updates_db(db_session):
user_id = TEST_USER_IDS["free"]
run_log = AgentRunLog(
run_log = ScoutRunLog(
id=str(uuid.uuid4()),
agent_id=str(uuid.uuid4()),
agent_type="local",
scout_id=str(uuid.uuid4()),
scout_type="local",
user_id=user_id,
status="running",
started_at=datetime.now(timezone.utc),
@@ -280,7 +280,7 @@ async def test_mark_runs_disconnected_updates_db(db_session):
# Verify through the same session factory.
async with _TestSessionLocal() as s:
result = await s.execute(
select(AgentRunLog).where(AgentRunLog.id == run_log.id)
select(ScoutRunLog).where(ScoutRunLog.id == run_log.id)
)
updated = result.scalar_one_or_none()

View File

@@ -1,6 +1,6 @@
"""Tests for Local Agent V2 journey setup (Step 4).
Covers the chatbot journey that produces a structured AgentConfig JSON
Covers the chatbot journey that produces a structured ScoutConfig JSON
instead of a freeform prompt_template string.
Unit tests (no LLM)
@@ -16,7 +16,7 @@ Eval test (real LLM + Langfuse scoring)
----------------------------------------
4.1 Journey start explores directory → first reply contains a question
Cases 4.24.5 (multi-turn conversations producing a full AgentConfig) are
Cases 4.24.5 (multi-turn conversations producing a full ScoutConfig) are
non-deterministic and tested manually — results tracked in Langfuse.
Run:
@@ -37,7 +37,7 @@ from unittest.mock import patch
import pytest
import yaml
from app.api.routes.agent_setup import (
from app.api.routes.scout_setup import (
_CONFIG_END,
_CONFIG_START,
_MAX_TURNS,
@@ -48,7 +48,7 @@ from app.api.routes.agent_setup import (
)
from app.core.langfuse_client import get_langfuse
from app.core.ws_context import clear_client_executor, set_client_executor
from app.schemas import AgentConfig
from app.schemas import ScoutConfig
from tests.conftest import TEST_USER_IDS
# ── Constants ─────────────────────────────────────────────────────────────
@@ -179,7 +179,7 @@ def _evaluate_case(case: dict, reply: dict) -> tuple[float, str]:
def test_4_6a_extract_valid_json():
"""_extract_agent_config: valid JSON between markers → returns serialised config."""
config = AgentConfig(
config = ScoutConfig(
content_types=[],
global_rules=["No project = no entity"],
data_types=["tasks"],
@@ -187,7 +187,7 @@ def test_4_6a_extract_valid_json():
text = f"Some preamble\n{_CONFIG_START}\n{config.model_dump_json()}\n{_CONFIG_END}\nTrailing"
result = _extract_agent_config(text)
assert result is not None
parsed = AgentConfig.model_validate_json(result)
parsed = ScoutConfig.model_validate_json(result)
assert parsed.global_rules == ["No project = no entity"]
@@ -230,7 +230,7 @@ async def test_4_6f_nudge_uses_new_markers():
# Return plain text — no markers — to trigger the nudge path.
return "I still need more information from you."
from app.api.routes.agent_setup import JourneySession
from app.api.routes.scout_setup import JourneySession
fake_session = JourneySession(
session_id=session_id,
@@ -248,7 +248,7 @@ async def test_4_6f_nudge_uses_new_markers():
_sessions[session_id] = fake_session
try:
with patch("app.api.routes.agent_setup._call_llm_with_tools", side_effect=_mock_llm):
with patch("app.api.routes.scout_setup._call_llm_with_tools", side_effect=_mock_llm):
await handle_journey_message(_USER_ID, {
"session_id": session_id,
"message": "one more message to trigger nudge",

View File

@@ -322,7 +322,7 @@ def test_home_request_calls_memory_middleware(client):
):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-mem", "agent_ids": []
"type": "device_hello", "device_id": "dev-mem", "scout_ids": []
}))
ws.send_text(json.dumps({
"type": "home_request",

View File

@@ -5,7 +5,7 @@ from __future__ import annotations
import pytest
from app.core.output_formatter import StreamFormatter
from app.schemas import WsFloatingDomain, WsStreamEnd, WsStreamStart, WsStreamText
from app.schemas import WsStreamEnd, WsStreamStart, WsStreamText
async def _stream(*events: tuple[str, object]):
@@ -36,29 +36,6 @@ async def test_stream_formatter_text_stream() -> None:
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_stream_formatter_floating_domain_first() -> None:
formatter = StreamFormatter(request_id="req-2")
frames = await _collect(
formatter,
_stream(
(
"floating_domain",
{"type": "node", "id": "n-1", "section": None},
),
("token", "Summary"),
),
)
assert isinstance(frames[0], WsFloatingDomain)
assert frames[0].domain.type == "node"
assert frames[0].domain.id == "n-1"
assert isinstance(frames[1], WsStreamStart)
assert isinstance(frames[2], WsStreamText)
assert frames[2].chunk == "Summary"
assert isinstance(frames[-1], WsStreamEnd)
@pytest.mark.asyncio
async def test_stream_formatter_ignores_unknown_events() -> None:
formatter = StreamFormatter(request_id="req-3")

View File

@@ -0,0 +1,85 @@
"""Tests for run_contextual_stream.
These tests monkeypatch _run_single_agent_stream (the actual internal runner)
rather than the plan's fictional _run_agent_loop, matching the real
deep_agent.py architecture.
"""
import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from app.schemas.contextual import ContextualScope
@pytest.mark.asyncio
async def test_run_contextual_stream_includes_scope_block(monkeypatch):
"""run_contextual_stream must inject the scope block into the system prompt
and include get_page_details in the tool list while excluding note-edit tools."""
import app.core.deep_agent as deep_agent
captured = {}
async def fake_stream(
*,
user_id,
system_prompt,
message,
context,
agent_name="agent",
tools=None,
conversation_history=None,
**kwargs,
):
captured["sys"] = system_prompt
captured["tool_names"] = [getattr(t, "name", str(t)) for t in (tools or [])]
captured["agent_name"] = agent_name
# Async generator that yields nothing — still satisfies the protocol.
if False:
yield # pragma: no cover
monkeypatch.setattr(deep_agent, "_run_single_agent_stream", fake_stream)
scope = ContextualScope(
page="project",
entity_type="project",
entity_id="p1",
entity_name="Acme",
counts={"tasks": 1, "notes": 0, "milestones": 0},
)
context = {
"conversation_history": [],
"_debug": {"session_id": "s1"},
}
results = []
async for item in deep_agent.run_contextual_stream(
user_id="user1",
message="hi",
context=context,
scope=scope,
):
results.append(item)
assert "Acme" in captured["sys"], "scope block must appear in system prompt"
assert "Current view" in captured["sys"], "section header must be present"
names = captured["tool_names"]
assert "get_page_details" in names, "get_page_details tool must be included"
# Entity-create tools: at least one of these must be present.
assert any(n in names for n in ("create_task", "create_note", "update_task")), (
"at least one entity-create tool must be present"
)
assert "create_timeline" in names, "create_timeline tool must be included"
# Note edit tools must NOT be exposed.
assert "propose_note_edit" not in names, "propose_note_edit must be excluded"
# Legacy read tools must be excluded — they return shallow snapshots and
# cause the agent to under-answer (see trace 0b46841484ba7d024ed9f8d5ac8b1df0).
assert "list_projects" not in names, "list_projects must be excluded (legacy read)"
assert "get_project" not in names, "get_project must be excluded (legacy read)"
assert "list_tasks" not in names, "list_tasks must be excluded (legacy read)"
assert "get_task" not in names, "get_task must be excluded (legacy read)"
assert "list_notes" not in names, "list_notes must be excluded (legacy read)"
assert "get_note" not in names, "get_note must be excluded (legacy read)"

View File

@@ -4,12 +4,8 @@ import pytest
from pydantic import ValidationError
from app.schemas import (
WsDomain,
WsFrameType,
WsHomeRequest,
WsFloatingDomain,
WsFloatingRequest,
WsFloatingScope,
WsStreamEnd,
WsStreamStart,
WsStreamText,
@@ -22,11 +18,9 @@ from app.schemas import (
def test_v3_frame_types_exist():
v3_types = [
"home_request",
"floating_request",
"stream_start",
"stream_text",
"stream_end",
"floating_domain",
"data_request",
"data_response",
"mutation",
@@ -86,51 +80,6 @@ def test_home_request_requires_message():
WsHomeRequest.model_validate({"type": "home_request"})
# ── WsFloatingRequest ────────────────────────────────────────────────────
def test_floating_request_basic():
frame = WsFloatingRequest(
message="Summarise",
scope=WsFloatingScope(type="task", id="task-123"),
)
assert frame.type == WsFrameType.floating_request
assert frame.scope.type == "task"
assert frame.scope.id == "task-123"
def test_floating_request_scope_without_id():
frame = WsFloatingRequest(
message="Show all",
scope=WsFloatingScope(type="project"),
)
assert frame.scope.id is None
def test_floating_request_serializes():
frame = WsFloatingRequest(
message="Test",
scope=WsFloatingScope(type="note", id="n-1"),
)
data = frame.model_dump()
assert data["type"] == "floating_request"
assert data["scope"]["type"] == "note"
assert data["scope"]["id"] == "n-1"
def test_floating_request_invalid_scope_type():
with pytest.raises(ValidationError):
WsFloatingRequest(
message="X",
scope=WsFloatingScope(type="unknown"), # type: ignore[arg-type]
)
def test_floating_request_requires_scope():
with pytest.raises(ValidationError):
WsFloatingRequest.model_validate({"type": "floating_request", "message": "X"})
# ── WsStreamStart ─────────────────────────────────────────────────────
@@ -189,51 +138,3 @@ def test_stream_end_deserializes():
assert frame.request_id == "r3"
# ── WsFloatingDomain ─────────────────────────────────────────────────────
def test_floating_domain_tasks():
frame = WsFloatingDomain(request_id="r1", domain=WsDomain(type="task"))
assert frame.type == WsFrameType.floating_domain
assert frame.domain.type == "task"
def test_floating_domain_valid_domains():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="213213-312321-312312-421321", section="task"),
)
assert frame.domain.type == "project"
assert frame.domain.id == "213213-312321-312312-421321"
assert frame.domain.section == "task"
def test_floating_domain_object_valid():
frame = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="project", id="p1", section="task"),
)
assert frame.domain.type == "project"
def test_floating_domain_serializes():
d = WsFloatingDomain(
request_id="r1",
domain=WsDomain(type="timeline"),
).model_dump()
assert d == {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "timeline", "id": None, "section": None},
}
def test_floating_domain_deserializes():
raw = {
"type": "floating_domain",
"request_id": "r1",
"domain": {"type": "node", "id": "n-1", "section": None},
}
frame = WsFloatingDomain.model_validate(raw)
assert frame.domain.type == "node"
assert frame.domain.id == "n-1"

View File

@@ -0,0 +1,48 @@
"""Tests for the connector registry."""
from __future__ import annotations
import pytest
from app.scouts.connectors.base import ItemRef
from app.scouts.connectors.registry import (
get_connector,
register_connector,
_reset_for_tests,
)
class _DummyConnector:
source_type = "dummy"
async def list_new(self, scout): return []
async def fetch_metadata(self, scout, ref): raise NotImplementedError
async def fetch_content(self, scout, ref): raise NotImplementedError
async def archive(self, scout, ref): raise NotImplementedError
async def setup_watch(self, scout): raise NotImplementedError
async def renew_watch(self, scout): raise NotImplementedError
@pytest.fixture(autouse=True)
def _clean_registry():
_reset_for_tests()
yield
_reset_for_tests()
def test_register_and_get():
c = _DummyConnector()
register_connector(c)
assert get_connector("dummy") is c
def test_unknown_source_raises():
with pytest.raises(KeyError):
get_connector("nope")
def test_double_register_replaces():
a = _DummyConnector()
b = _DummyConnector()
register_connector(a)
register_connector(b)
assert get_connector("dummy") is b

View File

@@ -0,0 +1,48 @@
"""Tests for the SourceConnector base protocol and shared types."""
from __future__ import annotations
from datetime import datetime, timezone
import pytest
from app.scouts.connectors.base import (
ItemContent,
ItemMetadata,
ItemRef,
TriageVerdict,
)
def test_item_ref_round_trips_through_pydantic():
ref = ItemRef(source_msg_ref="abc123", received_at=datetime.now(tz=timezone.utc))
parsed = ItemRef.model_validate(ref.model_dump())
assert parsed.source_msg_ref == "abc123"
assert parsed.received_at == ref.received_at
def test_item_metadata_allows_all_optional():
meta = ItemMetadata()
assert meta.subject is None
assert meta.sender is None
assert meta.snippet is None
assert meta.received_at is None
def test_item_content_requires_metadata_and_body():
content = ItemContent(
metadata=ItemMetadata(subject="hi"),
body_text="hello world",
raw_headers={"X-Foo": "bar"},
)
assert content.metadata.subject == "hi"
assert content.body_text == "hello world"
assert content.raw_headers["X-Foo"] == "bar"
def test_triage_verdict_constraints():
v = TriageVerdict(verdict="relevant", reason="contains task language", confidence=0.92)
assert v.verdict == "relevant"
with pytest.raises(ValueError):
TriageVerdict(verdict="meh", reason="x", confidence=0.5) # bad enum value

View File

@@ -0,0 +1,84 @@
"""Tests for GmailConnector."""
from __future__ import annotations
import uuid
from unittest.mock import MagicMock, patch
import pytest
from app.models import CloudScoutConfig
from app.scouts.connectors.base import ItemRef
from app.scouts.connectors.gmail import GmailConnector
def _make_scout():
return CloudScoutConfig(
id=str(uuid.uuid4()),
user_id="00000000-0000-0000-0000-000000000003",
provider="gmail",
name="Inbox",
data_types=[],
prompt_template="",
oauth_token_encrypted="encrypted-blob",
schedule_cron="0 * * * *",
enabled=True,
auto_trash_spam=False,
device_inactivity_pause_days=14,
gmail_history_id="100",
)
@pytest.mark.asyncio
async def test_fetch_metadata_returns_subject_and_snippet():
scout = _make_scout()
conn = GmailConnector()
fake_message = {
"id": "msg-1",
"snippet": "preview text",
"payload": {"headers": [
{"name": "Subject", "value": "Hello"},
{"name": "From", "value": "alice@example.com"},
{"name": "Date", "value": "Wed, 14 May 2026 10:00:00 +0000"},
]},
}
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
mock_svc.return_value.users().messages().get().execute.return_value = fake_message
meta = await conn.fetch_metadata(scout, ItemRef(source_msg_ref="msg-1"))
assert meta.subject == "Hello"
assert meta.sender == "alice@example.com"
assert meta.snippet == "preview text"
@pytest.mark.asyncio
async def test_fetch_content_returns_body_text():
import base64
scout = _make_scout()
conn = GmailConnector()
body_data = base64.urlsafe_b64encode(b"hello world").decode()
fake_message = {
"id": "msg-1",
"snippet": "hello world",
"payload": {
"mimeType": "text/plain",
"headers": [
{"name": "Subject", "value": "S"},
{"name": "From", "value": "a@b"},
],
"body": {"data": body_data},
},
}
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
mock_svc.return_value.users().messages().get().execute.return_value = fake_message
content = await conn.fetch_content(scout, ItemRef(source_msg_ref="msg-1"))
assert content.body_text == "hello world"
assert content.metadata.subject == "S"
@pytest.mark.asyncio
async def test_archive_calls_trash():
scout = _make_scout()
conn = GmailConnector()
with patch("app.scouts.connectors.gmail._get_gmail_service") as mock_svc:
await conn.archive(scout, ItemRef(source_msg_ref="msg-1"))
mock_svc.return_value.users().messages().trash.assert_called()

270
tests/test_scout_engine.py Normal file
View File

@@ -0,0 +1,270 @@
"""Unit tests for ScoutEngine.trigger_scout / _process_item."""
from __future__ import annotations
import uuid
from datetime import datetime, timedelta, timezone
from unittest.mock import AsyncMock
import pytest
from sqlalchemy import select
from app.models import CloudScoutConfig, ScoutTriageQueue, User, Subscription
from app.scouts.connectors.base import ItemContent, ItemMetadata, ItemRef, TriageVerdict
from app.scouts.connectors.registry import register_connector, _reset_for_tests
from app.scouts.engine import ScoutEngine
from tests.conftest import _TestSessionLocal
def _make_connector(items, content_for):
c = AsyncMock()
# source_type must match the scout.provider ("gmail") so get_connector()
# finds it when the engine calls get_connector(scout.provider).
c.source_type = "gmail"
c.list_new = AsyncMock(return_value=items)
c.fetch_content = AsyncMock(side_effect=lambda scout, ref: content_for[ref.source_msg_ref])
c.archive = AsyncMock()
return c
@pytest.fixture(autouse=True)
def _registry():
_reset_for_tests()
yield
_reset_for_tests()
@pytest.mark.asyncio
async def test_relevant_item_inserted_into_queue(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003" # power tier seeded in conftest
scout_id = str(uuid.uuid4())
async with _TestSessionLocal() as session:
scout = CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
)
session.add(scout)
await session.commit()
refs = [ItemRef(source_msg_ref="msg-1")]
content = {"msg-1": ItemContent(metadata=ItemMetadata(subject="Hi"), body_text="task tomorrow")}
connector = _make_connector(refs, content)
register_connector(connector)
engine = ScoutEngine(session_factory=_TestSessionLocal)
monkeypatch.setattr(
engine,
"_triage_llm",
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="task", confidence=0.9)),
)
await engine.trigger_scout(uuid.UUID(scout_id))
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert len(rows) == 1
assert rows[0].source_msg_ref == "msg-1"
assert rows[0].triage_verdict == "relevant"
assert rows[0].status == "queued"
connector.archive.assert_not_awaited()
@pytest.mark.asyncio
async def test_spam_with_auto_trash_archives_and_does_not_queue(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003"
scout_id = str(uuid.uuid4())
async with _TestSessionLocal() as session:
scout = CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=True, device_inactivity_pause_days=14,
)
session.add(scout)
await session.commit()
refs = [ItemRef(source_msg_ref="msg-spam")]
content = {"msg-spam": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
connector = _make_connector(refs, content)
register_connector(connector)
engine = ScoutEngine(session_factory=_TestSessionLocal)
monkeypatch.setattr(
engine,
"_triage_llm",
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
)
await engine.trigger_scout(uuid.UUID(scout_id))
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert rows == []
connector.archive.assert_awaited_once()
@pytest.mark.asyncio
async def test_spam_without_auto_trash_does_not_archive_and_does_not_queue(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003"
scout_id = str(uuid.uuid4())
async with _TestSessionLocal() as session:
scout = CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
)
session.add(scout)
await session.commit()
refs = [ItemRef(source_msg_ref="msg-2")]
content = {"msg-2": ItemContent(metadata=ItemMetadata(subject="$$$"), body_text="buy")}
connector = _make_connector(refs, content)
register_connector(connector)
engine = ScoutEngine(session_factory=_TestSessionLocal)
monkeypatch.setattr(
engine,
"_triage_llm",
AsyncMock(return_value=TriageVerdict(verdict="spam", reason="bait", confidence=0.99)),
)
await engine.trigger_scout(uuid.UUID(scout_id))
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert rows == []
connector.archive.assert_not_awaited()
@pytest.mark.asyncio
async def test_idempotent_replay(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003"
scout_id = str(uuid.uuid4())
async with _TestSessionLocal() as session:
session.add(CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
))
await session.commit()
refs = [ItemRef(source_msg_ref="msg-3")]
content = {"msg-3": ItemContent(metadata=ItemMetadata(subject="x"), body_text="y")}
connector = _make_connector(refs, content)
register_connector(connector)
engine = ScoutEngine(session_factory=_TestSessionLocal)
monkeypatch.setattr(
engine,
"_triage_llm",
AsyncMock(return_value=TriageVerdict(verdict="relevant", reason="x", confidence=0.5)),
)
await engine.trigger_scout(uuid.UUID(scout_id))
await engine.trigger_scout(uuid.UUID(scout_id))
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert len(rows) == 1, "Replay must not create duplicate queue rows"
@pytest.mark.asyncio
async def test_triage_llm_parses_json_response(monkeypatch):
"""Real _triage_llm path: mock the LLM ainvoke, verify TriageVerdict parsed correctly."""
from unittest.mock import MagicMock # noqa: PLC0415
from app.models import CloudScoutConfig # noqa: PLC0415
scout = CloudScoutConfig(
id=str(uuid.uuid4()),
user_id="00000000-0000-0000-0000-000000000003",
provider="gmail",
name="test-scout",
data_types=[],
prompt_template="watch invoices and project updates",
schedule_cron="0 * * * *",
enabled=True,
auto_trash_spam=False,
device_inactivity_pause_days=14,
)
content = ItemContent(
metadata=ItemMetadata(subject="Invoice 42", sender="billing@acme.com"),
body_text="Payment of €1 200 is due on 2026-06-01. Please confirm receipt.",
)
# Build a fake LangChain response whose .content is valid JSON.
fake_response = MagicMock()
fake_response.content = '{"verdict": "relevant", "reason": "invoice due", "confidence": 0.92}'
fake_response.usage_metadata = {"input_tokens": 10, "output_tokens": 5, "total_tokens": 15}
# Fake LLM: .bind() returns self (or another mock with ainvoke).
fake_llm = MagicMock()
fake_llm.bind.return_value = fake_llm
fake_llm.ainvoke = AsyncMock(return_value=fake_response)
# Patch get_llm inside app.scouts.engine so our fake is used.
monkeypatch.setattr("app.scouts.engine.get_llm", lambda **kwargs: fake_llm)
# Disable Langfuse for this test.
monkeypatch.setattr("app.scouts.engine.get_langfuse", lambda: None)
# Use fallback prompt (no Langfuse) — patch get_prompt_or_fallback to return fallback.
monkeypatch.setattr(
"app.scouts.engine.get_prompt_or_fallback",
lambda name, fallback: (fallback, None),
)
engine = ScoutEngine(session_factory=_TestSessionLocal)
verdict = await engine._triage_llm(scout, content)
assert verdict.verdict == "relevant"
assert verdict.reason == "invoice due"
assert abs(verdict.confidence - 0.92) < 1e-6
fake_llm.ainvoke.assert_awaited_once()
@pytest.mark.asyncio
async def test_deliver_pending_sends_one_frame_per_queued_row(monkeypatch):
user_id = "00000000-0000-0000-0000-000000000003"
scout_id = str(uuid.uuid4())
now = datetime.now(tz=timezone.utc)
async with _TestSessionLocal() as session:
session.add(CloudScoutConfig(
id=scout_id, user_id=user_id, provider="gmail", name="Test",
data_types=[], prompt_template="", schedule_cron="0 * * * *",
enabled=True, auto_trash_spam=False, device_inactivity_pause_days=14,
))
for i in range(3):
session.add(ScoutTriageQueue(
id=str(uuid.uuid4()), user_id=user_id, scout_id=scout_id,
source_type="gmail", source_msg_ref=f"msg-{i}",
triage_verdict="relevant", status="queued",
triaged_at=now, expires_at=now + timedelta(days=30),
))
await session.commit()
connector = AsyncMock()
connector.source_type = "gmail"
connector.fetch_metadata = AsyncMock(side_effect=lambda scout, ref: ItemMetadata(
subject=f"sub-{ref.source_msg_ref}", snippet=f"snip-{ref.source_msg_ref}",
))
register_connector(connector)
sent = []
ws = AsyncMock()
ws.send_json = AsyncMock(side_effect=lambda payload: sent.append(payload))
engine = ScoutEngine(session_factory=_TestSessionLocal)
await engine.deliver_pending(uuid.UUID(user_id), ws)
assert len(sent) == 3
assert all(s["type"] == "scout_proposal" for s in sent)
subjects = {s["proposal"]["raw_subject"] for s in sent}
assert subjects == {"sub-msg-0", "sub-msg-1", "sub-msg-2"}
async with _TestSessionLocal() as session:
rows = (await session.execute(select(ScoutTriageQueue))).scalars().all()
assert all(r.status == "delivered" for r in rows)
assert all(r.delivered_at is not None for r in rows)

106
tests/test_scout_webhook.py Normal file
View File

@@ -0,0 +1,106 @@
"""Tests for the Gmail Pub/Sub webhook route.
Covers:
- Happy path: valid JWT + known user + enabled scout → 204, engine triggered.
- Rejection: invalid JWT → 401.
"""
from __future__ import annotations
import base64
import json
import uuid
from unittest.mock import AsyncMock, patch
import pytest
from httpx import ASGITransport, AsyncClient
from app.main import app
from app.models import CloudScoutConfig, User
from tests.conftest import _TestSessionLocal
def _pubsub_payload(email: str, history_id: str) -> dict:
"""Build a minimal Pub/Sub push envelope."""
inner = json.dumps({"emailAddress": email, "historyId": history_id}).encode()
return {
"message": {"data": base64.b64encode(inner).decode(), "messageId": "m1"},
"subscription": "projects/x/subscriptions/gmail-watch-sub",
}
@pytest.mark.asyncio
async def test_webhook_triggers_scout_for_matching_user():
"""204 returned and ScoutEngine.trigger_scout awaited for the matching scout."""
user_id = "00000000-0000-0000-0000-000000000003" # seeded 'power' user
scout_id = str(uuid.uuid4())
# Mutate the seeded user email so the webhook can resolve it,
# and add a cloud scout config for gmail.
async with _TestSessionLocal() as session:
user = await session.get(User, user_id)
user.email = "alice@example.com"
session.add(
CloudScoutConfig(
id=scout_id,
user_id=user_id,
provider="gmail",
name="Inbox",
data_types=[],
prompt_template="",
schedule_cron="0 * * * *",
enabled=True,
auto_trash_spam=False,
device_inactivity_pause_days=14,
)
)
await session.commit()
payload = _pubsub_payload("alice@example.com", "200")
with (
patch(
"app.api.routes.scout_webhooks._verify_pubsub_jwt",
return_value=True,
),
patch(
"app.api.routes.scout_webhooks.async_session",
_TestSessionLocal,
),
patch(
"app.scouts.engine.ScoutEngine.trigger_scout",
new=AsyncMock(),
) as mock_trigger,
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/scouts/webhooks/gmail",
json=payload,
headers={"Authorization": "Bearer fake-google-jwt"},
)
assert resp.status_code == 204
mock_trigger.assert_awaited_once_with(uuid.UUID(scout_id))
@pytest.mark.asyncio
async def test_webhook_rejects_unverified_jwt():
"""401 returned when JWT verification fails."""
payload = _pubsub_payload("alice@example.com", "200")
with patch(
"app.api.routes.scout_webhooks._verify_pubsub_jwt",
return_value=False,
):
async with AsyncClient(
transport=ASGITransport(app=app), base_url="http://test"
) as client:
resp = await client.post(
"/api/v1/scouts/webhooks/gmail",
json=payload,
headers={"Authorization": "Bearer bogus"},
)
assert resp.status_code == 401

View File

@@ -1,6 +1,6 @@
"""Integration tests for the unified WebSocket handler (Step 5).
Tests the device WS endpoint with home_request and floating_request frames,
Tests the device WS endpoint with home_request frames,
verifying that the correct v3 frame sequence is returned.
LLM calls are mocked to avoid network dependency.
@@ -34,7 +34,7 @@ def _override_db(db_session):
def _recv_until_end(ws, max_frames: int = 20) -> list[dict]:
"""Receive frames until stream_end (or stream_end inside floating flow), or max_frames."""
"""Receive frames until stream_end or max_frames."""
frames = []
for _ in range(max_frames):
raw = ws.receive_text()
@@ -49,11 +49,6 @@ async def _mock_home_stream(user_id, message, context):
yield "token", "Hello"
async def _mock_floating_stream(user_id, message, context):
yield "floating_domain", {"type": "task", "id": None, "section": None}
yield "token", "Here is a summary"
# ── tests ─────────────────────────────────────────────────────────────────────
def test_home_request_produces_stream_frames(client):
@@ -63,7 +58,7 @@ def test_home_request_produces_stream_frames(client):
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_mock_home_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-1", "agent_ids": []
"type": "device_hello", "device_id": "dev-1", "scout_ids": []
}))
ws.send_text(json.dumps({
"type": "home_request",
@@ -79,33 +74,6 @@ def test_home_request_produces_stream_frames(client):
assert types.index(WsFrameType.stream_start) < types.index(WsFrameType.stream_end)
def test_floating_request_produces_domain_frame(client):
"""floating_request → floating_domain first, then stream_text*, stream_end."""
token = make_jwt("power", user_id=USER_ID)
with patch("app.api.routes.device_ws.run_floating_stream", side_effect=_mock_floating_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-2", "agent_ids": []
}))
ws.send_text(json.dumps({
"type": "floating_request",
"request_id": "p1",
"message": "Summarize this task",
"scope": {"type": "task", "id": "task-123"},
}))
frames = _recv_until_end(ws)
types = [f["type"] for f in frames]
assert WsFrameType.floating_domain in types
assert WsFrameType.stream_end in types
assert types.index(WsFrameType.floating_domain) < types.index(WsFrameType.stream_end)
domain_frame = next(f for f in frames if f["type"] == WsFrameType.floating_domain)
assert domain_frame["domain"]["type"] == "task"
assert domain_frame["request_id"] == "p1"
def test_home_request_request_id_propagated(client):
"""request_id in home_request is echoed in all response frames."""
token = make_jwt("power", user_id=USER_ID)
@@ -117,7 +85,7 @@ def test_home_request_request_id_propagated(client):
with patch("app.api.routes.device_ws.run_home_stream", side_effect=_stream):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-3", "agent_ids": []
"type": "device_hello", "device_id": "dev-3", "scout_ids": []
}))
ws.send_text(json.dumps({
"type": "home_request",
@@ -138,7 +106,7 @@ def test_tool_result_dispatch_silent_on_unknown_id(client):
with patch("app.api.routes.device_ws._HEARTBEAT_INTERVAL", 0.05):
with client.websocket_connect(f"/api/v1/ws/device?token={token}") as ws:
ws.send_text(json.dumps({
"type": "device_hello", "device_id": "dev-4", "agent_ids": []
"type": "device_hello", "device_id": "dev-4", "scout_ids": []
}))
ws.send_text(json.dumps({
"type": "tool_result", "id": "no-such-id", "ok": True